aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-09-11 20:19:50 +0800
committerGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-09-11 20:19:50 +0800
commitb2896c3cc3a0656b838f58975338d7dd309e3e62 (patch)
tree14f25741ab43c15e945e6044833c0ff44f11d83f /tensorflow/contrib
parent38f811077dd52820eaa3d5c684f41142de01c7eb (diff)
parente18f84a394bcbde62b344a3b32e8d8fd248fea58 (diff)
Merge remote-tracking branch 'upstream/master' into ENH/div_no_nan_treate_negative_as_zero
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/BUILD11
-rw-r--r--tensorflow/contrib/__init__.py9
-rw-r--r--tensorflow/contrib/android/asset_manager_filesystem.cc4
-rw-r--r--tensorflow/contrib/autograph/converters/builtin_functions.py41
-rw-r--r--tensorflow/contrib/autograph/converters/builtin_functions_test.py9
-rw-r--r--tensorflow/contrib/autograph/converters/control_flow.py23
-rw-r--r--tensorflow/contrib/autograph/converters/control_flow_test.py46
-rw-r--r--tensorflow/contrib/autograph/converters/logical_expressions.py4
-rw-r--r--tensorflow/contrib/autograph/converters/logical_expressions_test.py9
-rw-r--r--tensorflow/contrib/autograph/docs/pyfunc_dtypes.md2
-rw-r--r--tensorflow/contrib/autograph/examples/integration_tests/BUILD1
-rw-r--r--tensorflow/contrib/autograph/impl/api.py4
-rw-r--r--tensorflow/contrib/autograph/impl/api_test.py3
-rw-r--r--tensorflow/contrib/autograph/operators/BUILD11
-rw-r--r--tensorflow/contrib/autograph/operators/__init__.py5
-rw-r--r--tensorflow/contrib/autograph/operators/control_flow.py6
-rw-r--r--tensorflow/contrib/autograph/operators/py_builtins.py225
-rw-r--r--tensorflow/contrib/autograph/operators/py_builtins_test.py131
-rw-r--r--tensorflow/contrib/autograph/operators/slices.py9
-rw-r--r--tensorflow/contrib/autograph/operators/slices_test.py15
-rw-r--r--tensorflow/contrib/autograph/pyct/common_transformers/BUILD4
-rw-r--r--tensorflow/contrib/autograph/pyct/common_transformers/__init__.py0
-rw-r--r--tensorflow/contrib/autograph/pyct/common_transformers/anf.py10
-rw-r--r--tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py40
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/live_values.py7
-rw-r--r--tensorflow/contrib/autograph/pyct/templates.py6
-rw-r--r--tensorflow/contrib/autograph/pyct/templates_test.py36
-rw-r--r--tensorflow/contrib/autograph/pyct/testing/BUILD2
-rw-r--r--tensorflow/contrib/autograph/utils/BUILD23
-rw-r--r--tensorflow/contrib/autograph/utils/__init__.py3
-rw-r--r--tensorflow/contrib/autograph/utils/builtins.py143
-rw-r--r--tensorflow/contrib/autograph/utils/builtins_test.py145
-rw-r--r--tensorflow/contrib/autograph/utils/misc_test.py4
-rw-r--r--tensorflow/contrib/autograph/utils/py_func_test.py8
-rw-r--r--tensorflow/contrib/autograph/utils/tensor_list_test.py8
-rw-r--r--tensorflow/contrib/autograph/utils/tensors.py41
-rw-r--r--tensorflow/contrib/autograph/utils/tensors_test.py57
-rw-r--r--tensorflow/contrib/bigtable/README.md10
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc6
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_lib.h4
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc2
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc2
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc2
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc2
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc2
-rw-r--r--tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc2
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator.py43
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/model.py8
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc213
-rw-r--r--tensorflow/contrib/boosted_trees/kernels/training_ops.cc43
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py20
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py121
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py126
-rw-r--r--tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h8
-rw-r--r--tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc22
-rw-r--r--tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc3
-rw-r--r--tensorflow/contrib/boosted_trees/proto/split_info.proto4
-rw-r--r--tensorflow/contrib/boosted_trees/proto/tree_config.proto12
-rw-r--r--tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py4
-rw-r--r--tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py9
-rw-r--r--tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py551
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py18
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py157
-rw-r--r--tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc2
-rw-r--r--tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py4
-rw-r--r--tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py4
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py4
-rw-r--r--tensorflow/contrib/cmake/CMakeLists.txt2
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt4
-rw-r--r--tensorflow/contrib/coder/BUILD1
-rw-r--r--tensorflow/contrib/compiler/BUILD34
-rw-r--r--tensorflow/contrib/compiler/xla.py208
-rw-r--r--tensorflow/contrib/compiler/xla_test.py180
-rw-r--r--tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py4
-rw-r--r--tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py10
-rw-r--r--tensorflow/contrib/crf/python/kernel_tests/crf_test.py20
-rw-r--r--tensorflow/contrib/data/BUILD22
-rw-r--r--tensorflow/contrib/data/__init__.py10
-rw-r--r--tensorflow/contrib/data/kernels/BUILD38
-rw-r--r--tensorflow/contrib/data/kernels/assert_next_dataset_op.cc2
-rw-r--r--tensorflow/contrib/data/kernels/csv_dataset_op.cc4
-rw-r--r--tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc4
-rw-r--r--tensorflow/contrib/data/kernels/identity_indexed_dataset.cc155
-rw-r--r--tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc4
-rw-r--r--tensorflow/contrib/data/kernels/indexed_dataset.cc373
-rw-r--r--tensorflow/contrib/data/kernels/indexed_dataset.h119
-rw-r--r--tensorflow/contrib/data/kernels/lmdb_dataset_op.cc217
-rw-r--r--tensorflow/contrib/data/kernels/prefetching_kernels.cc19
-rw-r--r--tensorflow/contrib/data/kernels/threadpool_dataset_op.cc2
-rw-r--r--tensorflow/contrib/data/kernels/unique_dataset_op.cc4
-rw-r--r--tensorflow/contrib/data/ops/dataset_ops.cc9
-rw-r--r--tensorflow/contrib/data/ops/indexed_dataset_ops.cc80
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD84
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py224
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/bucketing_test.py32
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py8
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py76
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py78
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py36
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py66
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py8
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py120
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/BUILD88
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py64
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py58
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py224
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py219
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py108
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py404
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py850
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py28
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py4
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py67
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py58
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/resample_test.py6
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py6
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/BUILD14
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py42
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py45
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py50
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py70
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py64
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py65
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py10
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/test_utils.py29
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py15
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py2
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py166
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py6
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD41
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py27
-rw-r--r--tensorflow/contrib/data/python/ops/indexed_dataset_ops.py173
-rw-r--r--tensorflow/contrib/data/python/ops/interleave_ops.py74
-rw-r--r--tensorflow/contrib/data/python/ops/map_defun.py2
-rw-r--r--tensorflow/contrib/data/python/ops/parsing_ops.py150
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py90
-rw-r--r--tensorflow/contrib/data/python/ops/stats_ops.py23
-rw-r--r--tensorflow/contrib/distribute/BUILD2
-rw-r--r--tensorflow/contrib/distribute/README.md305
-rw-r--r--tensorflow/contrib/distribute/__init__.py4
-rw-r--r--tensorflow/contrib/distribute/python/BUILD136
-rw-r--r--tensorflow/contrib/distribute/python/checkpoint_utils_test.py2
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py227
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py116
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py31
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops.py120
-rw-r--r--tensorflow/contrib/distribute/python/cross_tower_ops_test.py99
-rw-r--r--tensorflow/contrib/distribute/python/estimator_training_test.py659
-rw-r--r--tensorflow/contrib/distribute/python/examples/BUILD15
-rw-r--r--tensorflow/contrib/distribute/python/examples/keras_mnist.py125
-rw-r--r--tensorflow/contrib/distribute/python/examples/keras_model_with_estimator.py (renamed from tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py)10
-rw-r--r--tensorflow/contrib/distribute/python/input_ops.py13
-rw-r--r--tensorflow/contrib/distribute/python/input_ops_test.py8
-rw-r--r--tensorflow/contrib/distribute/python/keras_test.py206
-rw-r--r--tensorflow/contrib/distribute/python/minimize_loss_test.py12
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py209
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py153
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_test.py29
-rw-r--r--tensorflow/contrib/distribute/python/monitor_test.py2
-rw-r--r--tensorflow/contrib/distribute/python/multi_worker_test_base.py107
-rw-r--r--tensorflow/contrib/distribute/python/one_device_strategy.py2
-rw-r--r--tensorflow/contrib/distribute/python/optimizer_v2_test.py2
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy.py251
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy_test.py179
-rw-r--r--tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py6
-rw-r--r--tensorflow/contrib/distribute/python/step_fn_test.py2
-rw-r--r--tensorflow/contrib/distribute/python/strategy_test_lib.py8
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py124
-rw-r--r--tensorflow/contrib/distribute/python/values.py37
-rw-r--r--tensorflow/contrib/distribute/python/values_test.py5
-rw-r--r--tensorflow/contrib/distributions/BUILD2
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py6
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py6
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py14
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py34
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py10
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py10
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py12
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py6
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py6
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/inline_test.py4
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py8
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py6
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py6
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py96
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py60
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py6
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py6
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py6
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py24
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py6
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py22
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py8
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py26
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py19
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py4
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/weibull_test.py6
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py3
-rw-r--r--tensorflow/contrib/eager/python/BUILD14
-rw-r--r--tensorflow/contrib/eager/python/evaluator_test.py4
-rw-r--r--tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb4
-rw-r--r--tensorflow/contrib/eager/python/examples/notebooks/README.md14
-rw-r--r--tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb298
-rw-r--r--tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb389
-rw-r--r--tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb467
-rw-r--r--tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb485
-rw-r--r--tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb7
-rw-r--r--tensorflow/contrib/eager/python/examples/resnet50/resnet50.py8
-rw-r--r--tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py4
-rw-r--r--tensorflow/contrib/eager/python/examples/scan/BUILD25
-rw-r--r--tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py54
-rw-r--r--tensorflow/contrib/eager/python/examples/scan/scan_test.py54
-rw-r--r--tensorflow/contrib/eager/python/metrics_impl.py22
-rw-r--r--tensorflow/contrib/eager/python/metrics_test.py71
-rw-r--r--tensorflow/contrib/eager/python/remote.py73
-rw-r--r--tensorflow/contrib/eager/python/remote_test.py13
-rw-r--r--tensorflow/contrib/eager/python/tfe.py3
-rw-r--r--tensorflow/contrib/estimator/BUILD1
-rw-r--r--tensorflow/contrib/estimator/python/estimator/baseline_test.py6
-rw-r--r--tensorflow/contrib/estimator/python/estimator/extenders.py29
-rw-r--r--tensorflow/contrib/estimator/python/estimator/extenders_test.py129
-rw-r--r--tensorflow/contrib/estimator/python/estimator/rnn.py14
-rw-r--r--tensorflow/contrib/estimator/python/estimator/rnn_test.py41
-rw-r--r--tensorflow/contrib/factorization/python/ops/factorization_ops_test.py16
-rw-r--r--tensorflow/contrib/factorization/python/ops/gmm_ops_test.py6
-rw-r--r--tensorflow/contrib/factorization/python/ops/kmeans_test.py2
-rw-r--r--tensorflow/contrib/factorization/python/ops/wals.py70
-rw-r--r--tensorflow/contrib/factorization/python/ops/wals_test.py18
-rw-r--r--tensorflow/contrib/ffmpeg/ffmpeg_ops.py4
-rw-r--r--tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py18
-rw-r--r--tensorflow/contrib/framework/python/framework/tensor_util.py33
-rw-r--r--tensorflow/contrib/framework/python/framework/tensor_util_test.py32
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc2
-rw-r--r--tensorflow/contrib/gan/BUILD52
-rw-r--r--tensorflow/contrib/gan/python/estimator/__init__.py5
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py10
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py2
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/stargan_estimator.py (renamed from tensorflow/contrib/kfac/python/ops/optimizer_lib.py)16
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py363
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py306
-rw-r--r--tensorflow/contrib/gan/python/losses/python/losses_impl_test.py52
-rw-r--r--tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py8
-rw-r--r--tensorflow/contrib/gdr/gdr_memory_manager.cc12
-rw-r--r--tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc4
-rw-r--r--tensorflow/contrib/image/kernels/image_ops.cc60
-rw-r--r--tensorflow/contrib/image/ops/image_ops.cc44
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/image_ops_test.py10
-rw-r--r--tensorflow/contrib/image/python/ops/image_ops.py7
-rw-r--r--tensorflow/contrib/integrate/python/ops/odes_test.py34
-rw-r--r--tensorflow/contrib/kfac/BUILD26
-rw-r--r--tensorflow/contrib/kfac/README.md93
-rw-r--r--tensorflow/contrib/kfac/__init__.py46
-rw-r--r--tensorflow/contrib/kfac/examples/BUILD80
-rw-r--r--tensorflow/contrib/kfac/examples/convnet.py667
-rw-r--r--tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py62
-rw-r--r--tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py48
-rw-r--r--tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py39
-rw-r--r--tensorflow/contrib/kfac/examples/mlp.py354
-rw-r--r--tensorflow/contrib/kfac/examples/mlp_mnist_main.py64
-rw-r--r--tensorflow/contrib/kfac/examples/mnist.py69
-rw-r--r--tensorflow/contrib/kfac/examples/tests/BUILD52
-rw-r--r--tensorflow/contrib/kfac/examples/tests/convnet_test.py166
-rw-r--r--tensorflow/contrib/kfac/examples/tests/mlp_test.py63
-rw-r--r--tensorflow/contrib/kfac/examples/tests/mnist_test.py72
-rw-r--r--tensorflow/contrib/kfac/g3doc/autoencoder.pngbin54204 -> 0 bytes
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/BUILD160
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py310
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py1018
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py955
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py597
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py190
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py50
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py219
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/utils_test.py410
-rw-r--r--tensorflow/contrib/kfac/python/ops/BUILD263
-rw-r--r--tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py183
-rw-r--r--tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py30
-rw-r--r--tensorflow/contrib/kfac/python/ops/estimator.py516
-rw-r--r--tensorflow/contrib/kfac/python/ops/estimator_lib.py31
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks.py1752
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py45
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors.py1830
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py38
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection.py1269
-rw-r--r--tensorflow/contrib/kfac/python/ops/layer_collection_lib.py46
-rw-r--r--tensorflow/contrib/kfac/python/ops/linear_operator.py95
-rw-r--r--tensorflow/contrib/kfac/python/ops/loss_functions.py754
-rw-r--r--tensorflow/contrib/kfac/python/ops/loss_functions_lib.py39
-rw-r--r--tensorflow/contrib/kfac/python/ops/op_queue.py69
-rw-r--r--tensorflow/contrib/kfac/python/ops/op_queue_lib.py30
-rw-r--r--tensorflow/contrib/kfac/python/ops/optimizer.py727
-rw-r--r--tensorflow/contrib/kfac/python/ops/placement.py114
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils.py709
-rw-r--r--tensorflow/contrib/kfac/python/ops/utils_lib.py50
-rw-r--r--tensorflow/contrib/layers/python/layers/embedding_ops_test.py54
-rw-r--r--tensorflow/contrib/layers/python/layers/encoders_test.py20
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column.py8
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column_ops_test.py206
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column_test.py26
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py4
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py318
-rw-r--r--tensorflow/contrib/layers/python/layers/normalization_test.py8
-rw-r--r--tensorflow/contrib/layers/python/layers/optimizers_test.py14
-rw-r--r--tensorflow/contrib/layers/python/layers/regularizers_test.py14
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib.py127
-rw-r--r--tensorflow/contrib/layers/python/layers/rev_block_lib_test.py20
-rw-r--r--tensorflow/contrib/layers/python/layers/summaries_test.py12
-rw-r--r--tensorflow/contrib/layers/python/layers/utils_test.py24
-rw-r--r--tensorflow/contrib/layers/python/ops/sparse_ops_test.py46
-rw-r--r--tensorflow/contrib/learn/BUILD1
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py4
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py26
-rw-r--r--tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py18
-rw-r--r--tensorflow/contrib/learn/python/learn/ops/ops_test.py6
-rw-r--r--tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py6
-rw-r--r--tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py4
-rw-r--r--tensorflow/contrib/linear_optimizer/kernels/g3doc/readme.md40
-rw-r--r--tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py51
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py14
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py6
-rw-r--r--tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py4
-rw-r--r--tensorflow/contrib/lite/BUILD46
-rw-r--r--tensorflow/contrib/lite/RELEASE.md8
-rw-r--r--tensorflow/contrib/lite/allocation.cc4
-rw-r--r--tensorflow/contrib/lite/allocation.h4
-rw-r--r--tensorflow/contrib/lite/arena_planner.h6
-rw-r--r--tensorflow/contrib/lite/build_def.bzl23
-rw-r--r--tensorflow/contrib/lite/builtin_op_data.h277
-rw-r--r--tensorflow/contrib/lite/builtin_ops.h2
-rw-r--r--tensorflow/contrib/lite/c/BUILD39
-rw-r--r--tensorflow/contrib/lite/c/builtin_op_data.h298
-rw-r--r--tensorflow/contrib/lite/c/builtin_op_data_test.cc83
-rw-r--r--tensorflow/contrib/lite/c/c_api_internal.c (renamed from tensorflow/contrib/lite/context.c)6
-rw-r--r--tensorflow/contrib/lite/c/c_api_internal.h491
-rw-r--r--tensorflow/contrib/lite/c/c_api_internal_test.cc (renamed from tensorflow/contrib/lite/context_test.cc)10
-rw-r--r--tensorflow/contrib/lite/context.h474
-rw-r--r--tensorflow/contrib/lite/context_util.h2
-rw-r--r--tensorflow/contrib/lite/core/api/BUILD57
-rw-r--r--tensorflow/contrib/lite/core/api/error_reporter.cc38
-rw-r--r--tensorflow/contrib/lite/core/api/error_reporter.h45
-rw-r--r--tensorflow/contrib/lite/core/api/error_reporter_test.cc49
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc622
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions.h48
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc104
-rw-r--r--tensorflow/contrib/lite/core/api/op_resolver.cc60
-rw-r--r--tensorflow/contrib/lite/core/api/op_resolver.h47
-rw-r--r--tensorflow/contrib/lite/core/api/op_resolver_test.cc197
-rw-r--r--tensorflow/contrib/lite/delegates/eager/BUILD7
-rw-r--r--tensorflow/contrib/lite/delegates/eager/buffer_map.h2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate.h2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/delegate_test.cc20
-rw-r--r--tensorflow/contrib/lite/delegates/eager/kernel.cc15
-rw-r--r--tensorflow/contrib/lite/delegates/eager/kernel.h2
-rw-r--r--tensorflow/contrib/lite/delegates/eager/test_util.cc45
-rw-r--r--tensorflow/contrib/lite/delegates/eager/test_util.h28
-rw-r--r--tensorflow/contrib/lite/delegates/eager/util.cc36
-rw-r--r--tensorflow/contrib/lite/delegates/eager/util.h15
-rw-r--r--tensorflow/contrib/lite/delegates/eager/util_test.cc38
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/BUILD2
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc102
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h2
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc83
-rw-r--r--tensorflow/contrib/lite/error_reporter.h38
-rw-r--r--tensorflow/contrib/lite/examples/android/app/build.gradle6
-rw-r--r--tensorflow/contrib/lite/examples/ios/camera/Podfile2
-rw-r--r--tensorflow/contrib/lite/examples/ios/simple/Podfile2
-rw-r--r--tensorflow/contrib/lite/experimental/c/BUILD1
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api.cc12
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api.h13
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_internal.h6
-rw-r--r--tensorflow/contrib/lite/experimental/c/c_api_test.cc4
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/BUILD3
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h18
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc4
-rw-r--r--tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc15
-rw-r--r--tensorflow/contrib/lite/experimental/writer/BUILD66
-rw-r--r--tensorflow/contrib/lite/experimental/writer/enum_mapping.h116
-rw-r--r--tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc370
-rw-r--r--tensorflow/contrib/lite/experimental/writer/writer.cc41
-rw-r--r--tensorflow/contrib/lite/experimental/writer/writer_lib.cc281
-rw-r--r--tensorflow/contrib/lite/experimental/writer/writer_lib.h126
-rw-r--r--tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc62
-rw-r--r--tensorflow/contrib/lite/g3doc/README.md4
-rw-r--r--tensorflow/contrib/lite/g3doc/api_docs/python/index.md10
-rw-r--r--tensorflow/contrib/lite/g3doc/apis.md45
-rw-r--r--tensorflow/contrib/lite/g3doc/custom_operators.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/demo_android.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/demo_ios.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/devguide.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/ios.md8
-rw-r--r--tensorflow/contrib/lite/g3doc/models.md100
-rw-r--r--tensorflow/contrib/lite/g3doc/ops_versioning.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/overview.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/performance.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/rpi.md36
-rw-r--r--tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md27
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/android_build.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/index.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md2
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md2
-rw-r--r--tensorflow/contrib/lite/graph_info.h2
-rw-r--r--tensorflow/contrib/lite/interpreter.cc8
-rw-r--r--tensorflow/contrib/lite/interpreter.h10
-rw-r--r--tensorflow/contrib/lite/interpreter_test.cc2
-rw-r--r--tensorflow/contrib/lite/java/demo/README.md6
-rw-r--r--tensorflow/contrib/lite/java/demo/app/build.gradle6
-rw-r--r--tensorflow/contrib/lite/java/ovic/BUILD3
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/build.gradle6
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h8
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/tensor_jni.h2
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD125
-rw-r--r--tensorflow/contrib/lite/kernels/activation_functor.h2
-rw-r--r--tensorflow/contrib/lite/kernels/activations.cc48
-rw-r--r--tensorflow/contrib/lite/kernels/activations_test.cc70
-rw-r--r--tensorflow/contrib/lite/kernels/add.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/arg_min_max.cc15
-rw-r--r--tensorflow/contrib/lite/kernels/audio_spectrogram.cc6
-rw-r--r--tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/basic_rnn.cc33
-rw-r--r--tensorflow/contrib/lite/kernels/basic_rnn_test.cc21
-rw-r--r--tensorflow/contrib/lite/kernels/batch_to_space_nd.cc16
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc996
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc430
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc467
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc40
-rw-r--r--tensorflow/contrib/lite/kernels/cast.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/concatenation.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc162
-rw-r--r--tensorflow/contrib/lite/kernels/conv_test.cc284
-rw-r--r--tensorflow/contrib/lite/kernels/depthwise_conv.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/dequantize.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/detection_postprocess.cc6
-rw-r--r--tensorflow/contrib/lite/kernels/detection_postprocess_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/div.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/eigen_support.h4
-rw-r--r--tensorflow/contrib/lite/kernels/elementwise.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/exp.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/expand_dims.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/expand_dims_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/fake_quant.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/floor.cc7
-rw-r--r--tensorflow/contrib/lite/kernels/floor_div.cc146
-rw-r--r--tensorflow/contrib/lite/kernels/floor_div_test.cc90
-rw-r--r--tensorflow/contrib/lite/kernels/fully_connected.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/gather.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/gather_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/gemm_support.h2
-rw-r--r--tensorflow/contrib/lite/kernels/hashtable_lookup.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/internal/BUILD48
-rw-r--r--tensorflow/contrib/lite/kernels/internal/common.h2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.cc683
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.h91
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h331
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc31
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h21
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h822
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h18
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util.cc210
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util.h38
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc133
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h426
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc48
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h36
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h1871
-rw-r--r--tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc30
-rw-r--r--tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h92
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor.h111
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h135
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils.h17
-rw-r--r--tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc206
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h57
-rw-r--r--tensorflow/contrib/lite/kernels/kernel_util.h10
-rw-r--r--tensorflow/contrib/lite/kernels/l2norm.cc24
-rw-r--r--tensorflow/contrib/lite/kernels/layer_norm_lstm.cc1316
-rw-r--r--tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc664
-rw-r--r--tensorflow/contrib/lite/kernels/local_response_norm.cc18
-rw-r--r--tensorflow/contrib/lite/kernels/logical.cc16
-rw-r--r--tensorflow/contrib/lite/kernels/lsh_projection.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc80
-rw-r--r--tensorflow/contrib/lite/kernels/lstm_test.cc49
-rw-r--r--tensorflow/contrib/lite/kernels/maximum_minimum.cc13
-rw-r--r--tensorflow/contrib/lite/kernels/mfcc.cc6
-rw-r--r--tensorflow/contrib/lite/kernels/mfcc_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/mul.cc79
-rw-r--r--tensorflow/contrib/lite/kernels/neg.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/one_hot.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/optional_tensor_test.cc24
-rw-r--r--tensorflow/contrib/lite/kernels/pack.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/pad.cc30
-rw-r--r--tensorflow/contrib/lite/kernels/pad_test.cc13
-rw-r--r--tensorflow/contrib/lite/kernels/padding.h2
-rw-r--r--tensorflow/contrib/lite/kernels/pooling.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/pow.cc16
-rw-r--r--tensorflow/contrib/lite/kernels/reduce.cc344
-rw-r--r--tensorflow/contrib/lite/kernels/reduce_test.cc120
-rw-r--r--tensorflow/contrib/lite/kernels/register.cc10
-rw-r--r--tensorflow/contrib/lite/kernels/register.h3
-rw-r--r--tensorflow/contrib/lite/kernels/relu1.cc59
-rw-r--r--tensorflow/contrib/lite/kernels/relu1_test.cc79
-rw-r--r--tensorflow/contrib/lite/kernels/reshape.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/resize_bilinear.cc16
-rw-r--r--tensorflow/contrib/lite/kernels/select.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/shape.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/skip_gram.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/slice.cc30
-rw-r--r--tensorflow/contrib/lite/kernels/space_to_batch_nd.cc18
-rw-r--r--tensorflow/contrib/lite/kernels/space_to_depth.cc14
-rw-r--r--tensorflow/contrib/lite/kernels/sparse_to_dense.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/split.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/squeeze.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/strided_slice.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/sub.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/svdf.cc61
-rw-r--r--tensorflow/contrib/lite/kernels/svdf_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/tile.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/tile_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/topk_v2.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/topk_v2_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/transpose.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/transpose_conv.cc4
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc112
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc73
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc35
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc20
-rw-r--r--tensorflow/contrib/lite/kernels/unpack.cc130
-rw-r--r--tensorflow/contrib/lite/kernels/unpack_test.cc225
-rwxr-xr-xtensorflow/contrib/lite/lib_package/create_ios_frameworks.sh7
-rw-r--r--tensorflow/contrib/lite/memory_planner.h2
-rw-r--r--tensorflow/contrib/lite/mmap_allocation.cc2
-rw-r--r--tensorflow/contrib/lite/model.cc630
-rw-r--r--tensorflow/contrib/lite/model.h5
-rw-r--r--tensorflow/contrib/lite/model_test.cc2
-rw-r--r--tensorflow/contrib/lite/models/speech_test.cc14
-rw-r--r--tensorflow/contrib/lite/mutable_op_resolver.cc (renamed from tensorflow/contrib/lite/op_resolver.cc)3
-rw-r--r--tensorflow/contrib/lite/mutable_op_resolver.h79
-rw-r--r--tensorflow/contrib/lite/mutable_op_resolver_test.cc (renamed from tensorflow/contrib/lite/op_resolver_test.cc)2
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc78
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.h4
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate_disabled.cc8
-rw-r--r--tensorflow/contrib/lite/op_resolver.h78
-rw-r--r--tensorflow/contrib/lite/python/BUILD3
-rw-r--r--tensorflow/contrib/lite/python/convert.py68
-rw-r--r--tensorflow/contrib/lite/python/convert_test.py101
-rw-r--r--tensorflow/contrib/lite/python/lite.py200
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py321
-rw-r--r--tensorflow/contrib/lite/python/tflite_convert.py52
-rw-r--r--tensorflow/contrib/lite/schema/BUILD16
-rw-r--r--tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc2
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs10
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h127
-rw-r--r--tensorflow/contrib/lite/simple_memory_arena.h2
-rw-r--r--tensorflow/contrib/lite/stderr_reporter.cc (renamed from tensorflow/contrib/lite/error_reporter.cc)22
-rw-r--r--tensorflow/contrib/lite/stderr_reporter.h34
-rw-r--r--tensorflow/contrib/lite/string_util.cc2
-rw-r--r--tensorflow/contrib/lite/string_util.h2
-rw-r--r--tensorflow/contrib/lite/string_util_test.cc2
-rw-r--r--tensorflow/contrib/lite/testing/BUILD3
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py93
-rw-r--r--tensorflow/contrib/lite/testing/generated_examples_zip_test.cc15
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.cc22
-rw-r--r--tensorflow/contrib/lite/testing/util.h2
-rw-r--r--tensorflow/contrib/lite/toco/BUILD2
-rw-r--r--tensorflow/contrib/lite/toco/args.h7
-rw-r--r--tensorflow/contrib/lite/toco/export_tensorflow.cc45
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md22
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md24
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/python_api.md1
-rw-r--r--tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc29
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc9
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc151
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize.cc3
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc106
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc19
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD13
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc167
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc51
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.h5
-rw-r--r--tensorflow/contrib/lite/toco/model.h20
-rw-r--r--tensorflow/contrib/lite/toco/python/toco_from_protos_test.py2
-rw-r--r--tensorflow/contrib/lite/toco/tflite/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.cc80
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.h54
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export_test.cc63
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc88
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.h8
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator_test.cc38
-rw-r--r--tensorflow/contrib/lite/toco/toco_cmdline_flags.cc43
-rw-r--r--tensorflow/contrib/lite/toco/toco_flags.proto21
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc28
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc1
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.h5
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/BUILD328
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/README.md38
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h49
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/android_required_build_flags.cc27
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/csv_writer.h79
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/eval_pipeline.cc39
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h87
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.cc100
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h99
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder_test.cc229
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/eval_pipeline_test.cc133
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/file_reader_stage.cc29
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h37
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/file_reader_stage_test.cc110
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD182
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md146
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/clsloc_validation_blacklist.txt1762
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/generate_validation_labels.py105
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc165
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc351
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h124
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.cc114
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h83
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval_test.cc151
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.cc80
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h75
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing_test.cc123
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/ilsvrc/testdata/grace_hopper.jpgbin0 -> 73746 bytes
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op.cc158
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op_test.cc200
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.cc45
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h53
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/stage.h56
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/utils.cc102
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/utils.h46
-rw-r--r--tensorflow/contrib/lite/tools/accuracy/utils_test.cc76
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/README.md22
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/ios/README.md4
-rw-r--r--tensorflow/contrib/lite/tools/make/Makefile108
-rw-r--r--tensorflow/contrib/lite/tools/optimize/BUILD25
-rw-r--r--tensorflow/contrib/lite/tools/optimize/g3doc/quantize_weights.md70
-rw-r--r--tensorflow/contrib/lite/tools/optimize/quantize_weights.cc432
-rw-r--r--tensorflow/contrib/lite/tools/optimize/quantize_weights.h57
-rw-r--r--tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc226
-rw-r--r--tensorflow/contrib/lite/tutorials/BUILD20
-rw-r--r--tensorflow/contrib/lite/tutorials/dataset.py122
-rw-r--r--tensorflow/contrib/lite/tutorials/mnist_tflite.py87
-rw-r--r--tensorflow/contrib/lite/util.h2
-rw-r--r--tensorflow/contrib/lite/util_test.cc2
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py206
-rw-r--r--tensorflow/contrib/losses/python/losses/loss_ops_test.py214
-rw-r--r--tensorflow/contrib/makefile/proto_text_cc_files.txt113
-rw-r--r--tensorflow/contrib/makefile/proto_text_pb_cc_files.txt74
-rw-r--r--tensorflow/contrib/makefile/proto_text_pb_h_files.txt73
-rw-r--r--tensorflow/contrib/makefile/tf_op_files.txt523
-rw-r--r--tensorflow/contrib/makefile/tf_pb_text_files.txt56
-rw-r--r--tensorflow/contrib/makefile/tf_proto_files.txt76
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py51
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py2
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops_test.py508
-rw-r--r--tensorflow/contrib/model_pruning/BUILD2
-rw-r--r--tensorflow/contrib/model_pruning/README.md2
-rw-r--r--tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py4
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning_utils.py9
-rw-r--r--tensorflow/contrib/nearest_neighbor/python/kernel_tests/hyperplane_lsh_probes_test.py2
-rw-r--r--tensorflow/contrib/opt/BUILD18
-rw-r--r--tensorflow/contrib/opt/__init__.py2
-rw-r--r--tensorflow/contrib/opt/python/training/adamax_test.py10
-rw-r--r--tensorflow/contrib/opt/python/training/elastic_average_optimizer.py14
-rw-r--r--tensorflow/contrib/opt/python/training/external_optimizer_test.py18
-rw-r--r--tensorflow/contrib/opt/python/training/ggt_test.py2
-rw-r--r--tensorflow/contrib/opt/python/training/lars_optimizer_test.py4
-rw-r--r--tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py34
-rw-r--r--tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py247
-rw-r--r--tensorflow/contrib/opt/python/training/matrix_functions.py155
-rw-r--r--tensorflow/contrib/opt/python/training/matrix_functions_test.py63
-rw-r--r--tensorflow/contrib/opt/python/training/model_average_optimizer.py8
-rw-r--r--tensorflow/contrib/opt/python/training/model_average_optimizer_test.py40
-rw-r--r--tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py4
-rw-r--r--tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper_test.py4
-rw-r--r--tensorflow/contrib/opt/python/training/nadam_optimizer_test.py4
-rw-r--r--tensorflow/contrib/opt/python/training/reg_adagrad_optimizer_test.py22
-rw-r--r--tensorflow/contrib/opt/python/training/shampoo.py98
-rw-r--r--tensorflow/contrib/opt/python/training/shampoo_test.py214
-rw-r--r--tensorflow/contrib/opt/python/training/sign_decay_test.py6
-rw-r--r--tensorflow/contrib/opt/python/training/variable_clipping_optimizer_test.py4
-rw-r--r--tensorflow/contrib/opt/python/training/weight_decay_optimizers.py77
-rw-r--r--tensorflow/contrib/optimizer_v2/adam.py9
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2.py6
-rw-r--r--tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py14
-rw-r--r--tensorflow/contrib/quantize/BUILD3
-rw-r--r--tensorflow/contrib/quantize/python/common.py26
-rw-r--r--tensorflow/contrib/quantize/python/common_test.py25
-rw-r--r--tensorflow/contrib/quantize/python/fold_batch_norms.py25
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py5
-rw-r--r--tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py4
-rw-r--r--tensorflow/contrib/recurrent/python/ops/functional_rnn.py8
-rw-r--r--tensorflow/contrib/rnn/BUILD8
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py73
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py2
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py4
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py56
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py20
-rw-r--r--tensorflow/contrib/saved_model/BUILD17
-rw-r--r--tensorflow/contrib/saved_model/__init__.py7
-rw-r--r--tensorflow/contrib/saved_model/cc/saved_model/BUILD2
-rw-r--r--tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc81
-rw-r--r--tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h3
-rw-r--r--tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc123
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py260
-rw-r--r--tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py303
-rw-r--r--tensorflow/contrib/session_bundle/session_bundle.cc11
-rw-r--r--tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py11
-rw-r--r--tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py316
-rw-r--r--tensorflow/contrib/specs/python/specs_test.py22
-rw-r--r--tensorflow/contrib/specs/python/summaries_test.py8
-rw-r--r--tensorflow/contrib/tensor_forest/BUILD8
-rw-r--r--tensorflow/contrib/tensor_forest/client/random_forest.py6
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc6
-rw-r--r--tensorflow/contrib/tensor_forest/kernels/v4/input_data.h3
-rw-r--r--tensorflow/contrib/tensorrt/BUILD1
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc1
-rw-r--r--tensorflow/contrib/timeseries/examples/BUILD1
-rw-r--r--tensorflow/contrib/timeseries/examples/known_anomaly.py8
-rw-r--r--tensorflow/contrib/timeseries/examples/known_anomaly_test.py4
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head_test.py2
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py6
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py23
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py2
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/state_management_test.py6
-rw-r--r--tensorflow/contrib/tpu/BUILD2
-rw-r--r--tensorflow/contrib/tpu/__init__.py1
-rw-r--r--tensorflow/contrib/tpu/ops/cross_replica_ops.cc103
-rw-r--r--tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc3
-rw-r--r--tensorflow/contrib/tpu/profiler/tf_op_stats.proto5
-rw-r--r--tensorflow/contrib/tpu/proto/optimization_parameters.proto8
-rw-r--r--tensorflow/contrib/tpu/python/ops/tpu_ops.py79
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py937
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py287
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py170
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py9
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py27
-rw-r--r--tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py4
-rw-r--r--tensorflow/contrib/training/python/training/bucket_ops_test.py10
-rw-r--r--tensorflow/contrib/training/python/training/evaluation.py9
-rw-r--r--tensorflow/contrib/training/python/training/evaluation_test.py33
-rw-r--r--tensorflow/contrib/training/python/training/resample_test.py8
-rw-r--r--tensorflow/contrib/training/python/training/sampling_ops_test.py18
-rw-r--r--tensorflow/contrib/training/python/training/sampling_ops_threading_test.py2
-rw-r--r--tensorflow/contrib/training/python/training/sequence_queueing_state_saver_test.py18
-rw-r--r--tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py8
-rw-r--r--tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py18
-rw-r--r--tensorflow/contrib/training/python/training/training_test.py14
-rw-r--r--tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc2
758 files changed, 37533 insertions, 26495 deletions
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 222e66cebe..798f499870 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -20,13 +20,7 @@ py_library(
),
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
- deps = if_not_windows([
- # TODO(aaroey): tensorrt dependency has to appear before tflite so the
- # build can resolve its flatbuffers symbols within the tensorrt library.
- # This is an issue with the tensorrt static library and will be fixed by
- # the next tensorrt release, so fix the order here after that.
- "//tensorflow/contrib/tensorrt:init_py", # doesn't compile on windows
- ]) + [
+ deps = [
"//tensorflow/contrib/all_reduce",
"//tensorflow/contrib/batching:batch_py",
"//tensorflow/contrib/bayesflow:bayesflow_py",
@@ -61,7 +55,6 @@ py_library(
"//tensorflow/contrib/integrate:integrate_py",
"//tensorflow/contrib/keras",
"//tensorflow/contrib/kernel_methods",
- "//tensorflow/contrib/kfac",
"//tensorflow/contrib/labeled_tensor",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/learn",
@@ -136,6 +129,7 @@ py_library(
]) + if_not_windows([
"//tensorflow/contrib/bigtable", # depends on bigtable
"//tensorflow/contrib/cloud:cloud_py", # doesn't compile on Windows
+ "//tensorflow/contrib/tensorrt:init_py", # doesn't compile on windows
"//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
]),
)
@@ -182,6 +176,7 @@ cc_library(
"//tensorflow/contrib/boosted_trees:boosted_trees_ops_op_lib",
"//tensorflow/contrib/coder:all_ops",
"//tensorflow/contrib/data:dataset_ops_op_lib",
+ "//tensorflow/contrib/data:indexed_dataset_ops_op_lib",
"//tensorflow/contrib/factorization:all_ops",
"//tensorflow/contrib/framework:all_ops",
"//tensorflow/contrib/hadoop:dataset_ops_op_lib",
diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py
index 45a7680160..9478e42b46 100644
--- a/tensorflow/contrib/__init__.py
+++ b/tensorflow/contrib/__init__.py
@@ -21,6 +21,14 @@ from __future__ import print_function
import os
+from tensorflow.python.tools import component_api_helper
+component_api_helper.package_hook(
+ parent_package_str=(
+ "tensorflow.contrib"),
+ child_package_str=(
+ "tensorflow_estimator.contrib.estimator"))
+del component_api_helper
+
# Add projects here, they will show up under tf.contrib.
from tensorflow.contrib import autograph
from tensorflow.contrib import batching
@@ -51,7 +59,6 @@ from tensorflow.contrib import input_pipeline
from tensorflow.contrib import integrate
from tensorflow.contrib import keras
from tensorflow.contrib import kernel_methods
-from tensorflow.contrib import kfac
from tensorflow.contrib import labeled_tensor
from tensorflow.contrib import layers
from tensorflow.contrib import learn
diff --git a/tensorflow/contrib/android/asset_manager_filesystem.cc b/tensorflow/contrib/android/asset_manager_filesystem.cc
index 513d519eab..d14b2126a0 100644
--- a/tensorflow/contrib/android/asset_manager_filesystem.cc
+++ b/tensorflow/contrib/android/asset_manager_filesystem.cc
@@ -28,7 +28,7 @@ string RemoveSuffix(const string& name, const string& suffix) {
string output(name);
StringPiece piece(output);
str_util::ConsumeSuffix(&piece, suffix);
- return piece.ToString();
+ return string(piece);
}
// Closes the given AAsset when variable is destructed.
@@ -231,7 +231,7 @@ string AssetManagerFileSystem::NormalizeDirectoryPath(const string& fname) {
string AssetManagerFileSystem::RemoveAssetPrefix(const string& name) {
StringPiece piece(name);
str_util::ConsumePrefix(&piece, prefix_);
- return piece.ToString();
+ return string(piece);
}
bool AssetManagerFileSystem::DirectoryExists(const std::string& fname) {
diff --git a/tensorflow/contrib/autograph/converters/builtin_functions.py b/tensorflow/contrib/autograph/converters/builtin_functions.py
index b26c52294c..29dce13999 100644
--- a/tensorflow/contrib/autograph/converters/builtin_functions.py
+++ b/tensorflow/contrib/autograph/converters/builtin_functions.py
@@ -21,6 +21,8 @@ from __future__ import print_function
import gast
from tensorflow.contrib.autograph.core import converter
+from tensorflow.contrib.autograph.operators import py_builtins
+from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import templates
@@ -31,41 +33,32 @@ class BuiltinFunctionTransformer(converter.Base):
TF equivalent, like `len`.
"""
- def _convert_builtin(self, node):
+ def _convert_builtin(self, f, args, as_expression):
template = """
- ag__.utils.dynamic_builtin(func, args)
+ ag__.func(args)
"""
- return templates.replace(template, func=node.func, args=node.args)[0].value
-
- def _convert_print(self, node):
- template = """
- ag__.utils.dynamic_print(args)
- """
- return templates.replace(template, args=node.args)[0].value
+ if as_expression:
+ return templates.replace_as_expression(
+ template, func=py_builtins.overload_of(f).__name__, args=args)
+ else:
+ return templates.replace(
+ template, func=py_builtins.overload_of(f).__name__, args=args)
def visit_Call(self, node):
- self.generic_visit(node)
- # TODO(mdan): This won't work if the function was hidden.
- # TODO(mdan): Rely on the live_val and use inspect_utils.is_builtin instead.
- if (isinstance(node.func, gast.Name) and
- node.func.id in ('len', 'range', 'xrange', 'float', 'int')):
- return self._convert_builtin(node)
- # Print needs to be handled separately because it can be read as statement.
- if isinstance(node.func, gast.Name) and node.func.id == 'print':
- return self._convert_print(node)
+ node = self.generic_visit(node)
+ if anno.hasanno(node.func, 'live_val'):
+ live_val = anno.getanno(node.func, 'live_val')
+ if live_val in py_builtins.SUPPORTED_BUILTINS:
+ node = self._convert_builtin(live_val, node.args, as_expression=True)
return node
def visit_Print(self, node):
- self.generic_visit(node)
+ node = self.generic_visit(node)
args = node.values
# Following is the case when calling print(a, b)
if len(args) == 1 and isinstance(args[0], gast.Tuple):
args = args[0].elts
- template = """
- fname(args)
- """
- function_call = templates.replace(template, fname='print', args=args)[0]
- return self.visit(function_call)
+ return self._convert_builtin(print, args, as_expression=False)
def transform(node, ctx):
diff --git a/tensorflow/contrib/autograph/converters/builtin_functions_test.py b/tensorflow/contrib/autograph/converters/builtin_functions_test.py
index d0a0cbbeb6..3e3a04f38b 100644
--- a/tensorflow/contrib/autograph/converters/builtin_functions_test.py
+++ b/tensorflow/contrib/autograph/converters/builtin_functions_test.py
@@ -23,6 +23,7 @@ import six
from tensorflow.contrib.autograph.converters import builtin_functions
from tensorflow.contrib.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -34,11 +35,11 @@ class BuiltinFunctionsTest(converter_testing.TestCase):
def test_fn(a):
return len(a)
- with self.converted(test_fn, builtin_functions, {'len': len},
- array_ops.shape) as result:
+ with self.converted(test_fn, builtin_functions, {'len': len}) as result:
with self.cached_session() as sess:
- ops = result.test_fn(constant_op.constant([0, 0, 0]))
- self.assertEqual(sess.run(ops), 3)
+ p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
+ ops = result.test_fn(p)
+ self.assertEqual(sess.run(ops, {p: [0, 0, 0]}), 3)
def test_print(self):
diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py
index 5a5a6ad63a..3530fbb2ec 100644
--- a/tensorflow/contrib/autograph/converters/control_flow.py
+++ b/tensorflow/contrib/autograph/converters/control_flow.py
@@ -95,6 +95,18 @@ class ControlFlowTransformer(converter.Base):
return 'no variables'
return ', '.join(map(str, symbol_set))
+ def _validate_no_live_vars_created(self, node):
+ body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
+ live_vars_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
+ live_vars_created_in_body = live_vars_out & body_scope.created
+ if live_vars_created_in_body:
+ raise ValueError(
+ 'The following variables are created inside the loop and used later:'
+ '\n%s\n'
+ 'Variables must be declared outside loops because loops may not'
+ ' necessarily execute.' % self._fmt_symbol_list(
+ live_vars_created_in_body))
+
def visit_If(self, node):
node = self.generic_visit(node)
@@ -197,13 +209,15 @@ class ControlFlowTransformer(converter.Base):
def visit_While(self, node):
self.generic_visit(node)
+ self._validate_no_live_vars_created(node)
+
body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
body_closure = body_scope.modified - body_scope.created
all_referenced = body_scope.referenced
cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE)
cond_closure = set()
- for s in cond_scope.referenced:
+ for s in cond_scope.used:
for root in s.support_set:
if root not in body_scope.created:
cond_closure.add(root)
@@ -236,6 +250,7 @@ class ControlFlowTransformer(converter.Base):
node_body = ast_util.rename_symbols(node.body, ssf_map)
test = ast_util.rename_symbols(node.test, ssf_map)
+ # TODO(b/113118541) investigate the need-for and correctness-of extra_deps
template = """
def test_name(state_ssf):
return test
@@ -262,6 +277,8 @@ class ControlFlowTransformer(converter.Base):
def visit_For(self, node):
self.generic_visit(node)
+ self._validate_no_live_vars_created(node)
+
body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
body_closure = body_scope.modified - body_scope.created
all_referenced = body_scope.referenced
@@ -294,7 +311,9 @@ class ControlFlowTransformer(converter.Base):
template = """
def extra_test_name(state_ssf):
return extra_test_expr
- def body_name(iterate, state_ssf):
+ def body_name(loop_vars, state_ssf):
+ # Workaround for PEP-3113
+ iterate = loop_vars
body
return state_ssf,
state_ast_tuple = ag__.for_stmt(
diff --git a/tensorflow/contrib/autograph/converters/control_flow_test.py b/tensorflow/contrib/autograph/converters/control_flow_test.py
index 6cb907f69a..1d04ba3ba6 100644
--- a/tensorflow/contrib/autograph/converters/control_flow_test.py
+++ b/tensorflow/contrib/autograph/converters/control_flow_test.py
@@ -48,6 +48,24 @@ class ControlFlowTest(converter_testing.TestCase):
self.assertTransformedResult(test_fn, constant_op.constant(5), (10, 5, 5))
+ def test_while_nested(self):
+
+ def test_fn(n):
+ i = 0
+ j = 0
+ s = 0
+ while i < n:
+ while j < i:
+ j += 3
+ u = i + j # 'u' is not defined within the inner loop
+ s += u
+ i += 1
+ j = 0
+ return s, i, j, n
+
+ self.assertTransformedResult(test_fn, constant_op.constant(5),
+ (25, 5, 0, 5))
+
def test_while_single_output(self):
def test_fn(n):
@@ -57,6 +75,17 @@ class ControlFlowTest(converter_testing.TestCase):
self.assertTransformedResult(test_fn, constant_op.constant(5), 0)
+ def test_while_variable_defined_in_body(self):
+ def bad_while_loop(n):
+ while n > 0:
+ n -= 1
+ s = n
+ return s
+
+ node, ctx = self.prepare(bad_while_loop, {})
+ with self.assertRaises(transformer.AutographParseError):
+ control_flow.transform(node, ctx)
+
def test_if_basic(self):
def test_fn(n):
@@ -196,6 +225,23 @@ class ControlFlowTest(converter_testing.TestCase):
self.assertEqual(result.test_fn(5), 10)
self.assertEqual(eval_count[0], 1)
+ def test_for_variable_defined_in_body(self):
+ def bad_for_loop(n):
+ for i in range(n):
+ s = i
+ return s
+
+ node, ctx = self.prepare(bad_for_loop, {})
+ with self.assertRaises(transformer.AutographParseError):
+ control_flow.transform(node, ctx)
+
+ def test_for_tuple_unpacking(self):
+ def test_fn(x_list):
+ z = tf.constant(0) # pylint:disable=undefined-variable
+ for i, x in enumerate(x_list):
+ z = z + x + i
+ return z
+ self.assertTransformedResult(test_fn, [3, 3], 7)
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/autograph/converters/logical_expressions.py b/tensorflow/contrib/autograph/converters/logical_expressions.py
index 16eb1f0e3f..41c3424fa3 100644
--- a/tensorflow/contrib/autograph/converters/logical_expressions.py
+++ b/tensorflow/contrib/autograph/converters/logical_expressions.py
@@ -57,8 +57,8 @@ class LogicalExpressionTransformer(converter.Base):
gast.NotEq: 'tf.not_equal',
gast.Or: 'tf.logical_or',
gast.USub: 'tf.negative',
- gast.Is: 'autograph_utils.dynamic_is',
- gast.IsNot: 'autograph_utils.dynamic_is_not'
+ gast.Is: 'ag__.utils.dynamic_is',
+ gast.IsNot: 'ag__.utils.dynamic_is_not'
}
def _expect_simple_symbol(self, operand):
diff --git a/tensorflow/contrib/autograph/converters/logical_expressions_test.py b/tensorflow/contrib/autograph/converters/logical_expressions_test.py
index 8f9eee7081..409a73afba 100644
--- a/tensorflow/contrib/autograph/converters/logical_expressions_test.py
+++ b/tensorflow/contrib/autograph/converters/logical_expressions_test.py
@@ -47,6 +47,15 @@ class GradientsFunctionTest(converter_testing.TestCase):
with self.cached_session() as sess:
self.assertTrue(sess.run(result.test_fn(True, False, True)))
+ def test_ag_utils_lookup(self):
+ def test_fn(a, b):
+ return a is b or a is not b
+
+ with self.converted(test_fn, logical_expressions, {}, math_ops.logical_or
+ ) as result:
+ with self.cached_session() as sess:
+ self.assertTrue(sess.run(result.test_fn(True, False)))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/autograph/docs/pyfunc_dtypes.md b/tensorflow/contrib/autograph/docs/pyfunc_dtypes.md
index bcbb920cc5..c2427f5f4f 100644
--- a/tensorflow/contrib/autograph/docs/pyfunc_dtypes.md
+++ b/tensorflow/contrib/autograph/docs/pyfunc_dtypes.md
@@ -4,7 +4,7 @@ The `py_func` op requires specifying a
[data type](https://www.tensorflow.org/guide/tensors#data_types).
When wrapping a function with `py_func`, for instance using
-`@autograph.do_not_convert(run_mode=autograph.RunMode.PY_FUNC)`, you have two
+`@autograph.do_not_convert(run_as=autograph.RunMode.PY_FUNC)`, you have two
options to specify the returned data type:
* explicitly, with a specified `tf.DType` value
diff --git a/tensorflow/contrib/autograph/examples/integration_tests/BUILD b/tensorflow/contrib/autograph/examples/integration_tests/BUILD
index 6c281485b4..3630b41fc8 100644
--- a/tensorflow/contrib/autograph/examples/integration_tests/BUILD
+++ b/tensorflow/contrib/autograph/examples/integration_tests/BUILD
@@ -23,7 +23,6 @@ py_test(
],
srcs_version = "PY2AND3",
tags = ["no_windows"],
- visibility = ["//visibility:public"],
deps = [
"//tensorflow:tensorflow_py",
],
diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py
index 276a387180..8b38d5d080 100644
--- a/tensorflow/contrib/autograph/impl/api.py
+++ b/tensorflow/contrib/autograph/impl/api.py
@@ -29,9 +29,9 @@ import six
from tensorflow.contrib.autograph.core import config
from tensorflow.contrib.autograph.core import converter
from tensorflow.contrib.autograph.impl import conversion
+from tensorflow.contrib.autograph.operators import py_builtins
from tensorflow.contrib.autograph.pyct import compiler
from tensorflow.contrib.autograph.pyct import inspect_utils
-from tensorflow.contrib.autograph.utils import builtins
from tensorflow.contrib.autograph.utils import py_func
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_decorator
@@ -150,7 +150,7 @@ def converted_call(f, recursive, verbose, force_conversion, arg_types, *args,
unknown_arg_value = object() # Sentinel for arguments of unknown value
if inspect_utils.isbuiltin(f):
- return builtins.dynamic_builtin(f, *args, **kwargs)
+ return py_builtins.overload_of(f)(*args, **kwargs)
if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
# Regular functions
diff --git a/tensorflow/contrib/autograph/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py
index 803fde9089..a4c6fed265 100644
--- a/tensorflow/contrib/autograph/impl/api_test.py
+++ b/tensorflow/contrib/autograph/impl/api_test.py
@@ -38,9 +38,6 @@ class ApiTest(test.TestCase):
def setUp(self):
config.COMPILED_IMPORT_STATEMENTS = (
'from __future__ import print_function',
- 'from tensorflow.contrib.autograph import utils'
- ' as autograph_utils',
- 'tf = autograph_utils.fake_tf()',
)
def test_decorator_recurses(self):
diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD
index 332d5dab19..29759bad79 100644
--- a/tensorflow/contrib/autograph/operators/BUILD
+++ b/tensorflow/contrib/autograph/operators/BUILD
@@ -22,6 +22,7 @@ py_library(
"__init__.py",
"control_flow.py",
"data_structures.py",
+ "py_builtins.py",
"slices.py",
],
srcs_version = "PY2AND3",
@@ -62,6 +63,16 @@ py_test(
)
py_test(
+ name = "py_builtins_test",
+ srcs = ["py_builtins_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":operators",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
name = "slices_test",
srcs = ["slices_test.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py
index 392cb60bcc..c4fbc260a2 100644
--- a/tensorflow/contrib/autograph/operators/__init__.py
+++ b/tensorflow/contrib/autograph/operators/__init__.py
@@ -45,6 +45,11 @@ from tensorflow.contrib.autograph.operators.data_structures import list_stack
from tensorflow.contrib.autograph.operators.data_structures import ListPopOpts
from tensorflow.contrib.autograph.operators.data_structures import ListStackOpts
from tensorflow.contrib.autograph.operators.data_structures import new_list
+from tensorflow.contrib.autograph.operators.py_builtins import float_
+from tensorflow.contrib.autograph.operators.py_builtins import int_
+from tensorflow.contrib.autograph.operators.py_builtins import len_
+from tensorflow.contrib.autograph.operators.py_builtins import print_
+from tensorflow.contrib.autograph.operators.py_builtins import range_
from tensorflow.contrib.autograph.operators.slices import get_item
from tensorflow.contrib.autograph.operators.slices import GetItemOpts
from tensorflow.contrib.autograph.operators.slices import set_item
diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py
index 9909e52164..9a66a6bb60 100644
--- a/tensorflow/contrib/autograph/operators/control_flow.py
+++ b/tensorflow/contrib/autograph/operators/control_flow.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.utils import builtins
+from tensorflow.contrib.autograph.operators import py_builtins
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
@@ -82,8 +82,8 @@ def _py_for_stmt(iter_, extra_test, body, init_state):
def _known_len_for_stmt(iter_, extra_test, body, init_state):
- """Overload of for_stmt that iterates over objects that define a length."""
- n = builtins.dynamic_len(iter_)
+ """Overload of for_stmt that iterates over objects that admit a length."""
+ n = py_builtins.len_(iter_)
def while_body(iterate_index, *state):
iterate = iter_[iterate_index]
diff --git a/tensorflow/contrib/autograph/operators/py_builtins.py b/tensorflow/contrib/autograph/operators/py_builtins.py
new file mode 100644
index 0000000000..c5730934e7
--- /dev/null
+++ b/tensorflow/contrib/autograph/operators/py_builtins.py
@@ -0,0 +1,225 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Operators corresponding to Python builtin functions.
+
+List of built-in functions: https://docs.python.org/3/library/functions.html
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six
+
+from tensorflow.contrib.autograph.utils import py_func
+from tensorflow.contrib.autograph.utils import tensors
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_parsing_ops
+from tensorflow.python.ops import gen_string_ops
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import math_ops
+
+
+UNDEFINED = object()
+
+
+def overload_of(f):
+ if f in SUPPORTED_BUILTINS:
+ return BUILTIN_FUINCTIONS_MAP[f.__name__]
+ return f
+
+
+def abs_(x):
+ if tensor_util.is_tensor(x):
+ return _tf_abs(x)
+ return _py_abs(x)
+
+
+def _tf_abs(x):
+ return math_ops.abs(x)
+
+
+def _py_abs(x):
+ return abs(x)
+
+
+def float_(x=0):
+ if tensor_util.is_tensor(x):
+ return _tf_float(x)
+ return _py_float(x)
+
+
+def _tf_float(x):
+ # TODO(mdan): We shouldn't assume float32.
+ if x.dtype == dtypes.string:
+ return gen_parsing_ops.string_to_number(x, out_type=dtypes.float32)
+ return math_ops.cast(x, dtype=dtypes.float32)
+
+
+def _py_float(x):
+ return float(x)
+
+
+def int_(x=0, base=UNDEFINED):
+ if tensor_util.is_tensor(x):
+ return _tf_int(x, base)
+ return _py_int(x, base)
+
+
+def _tf_int(x, base):
+ if base not in (10, UNDEFINED):
+ raise NotImplementedError('base {} not supported for int'.format(base))
+
+ # TODO(mdan): We shouldn't assume int32.
+ if x.dtype == dtypes.string:
+ return gen_parsing_ops.string_to_number(x, out_type=dtypes.int32)
+ return math_ops.cast(x, dtype=dtypes.int32)
+
+
+def _py_int(x, base):
+ if base is UNDEFINED:
+ return int(x)
+ return int(x, base)
+
+
+def len_(s):
+ if tensors.is_tensor_array(s):
+ return _tf_tensor_array_len(s)
+ elif tensors.is_tensor_list(s):
+ return _tf_tensor_list_len(s)
+ elif tensor_util.is_tensor(s):
+ return _tf_tensor_len(s)
+ return _py_len(s)
+
+
+def _tf_tensor_array_len(s):
+ return s.size()
+
+
+def _tf_tensor_list_len(s):
+ return list_ops.tensor_list_length(s)
+
+
+def _tf_tensor_len(s):
+ """Overload of len_ for Tensor arguments."""
+ # Statically shaped tensors: length is known ahead of time.
+ if s.shape.ndims and s.shape[0].value is not None:
+ return s.shape[0].value
+
+ # Static shape of unknown dimensions: use dynamic shape but statically
+ # chech that it's a scalar.
+ shape = array_ops.shape(s)
+
+ assert shape.shape, 'shape tensor of zero size? {}'.format(shape)
+
+ if shape.shape[0] == 0:
+ raise ValueError(
+ 'len requires a non-scalar tensor, got one of shape {}'.format(shape))
+
+ if shape.shape[0].value is not None:
+ return array_ops.shape(s)[0]
+
+ # Fully dynamic shape: use ops.
+ rank = array_ops.rank(s)
+
+ def raise_zero_rank_error():
+ msg = gen_string_ops.string_join(
+ ['len requires non-zero rank, got ',
+ gen_string_ops.as_string(rank)])
+ with ops.control_dependencies([control_flow_ops.Assert(False, [msg])]):
+ return constant_op.constant(0, dtype=dtypes.int32)
+
+ return control_flow_ops.cond(rank > 0, lambda: array_ops.shape(s)[0],
+ raise_zero_rank_error)
+
+
+def _py_len(s):
+ return len(s)
+
+
+def print_(*objects, **kwargs):
+ # Note: Python 2.6 doesn't support explicit keywords after starargs.
+ unknown_kwargs = tuple(
+ set(kwargs.keys()) - set(('sep', 'end', 'file', 'flush')))
+ if unknown_kwargs:
+ raise ValueError('invalid keyword arguments: {}'.format(unknown_kwargs))
+
+ # TODO(mdan): use logging_ops.Print when py_func is not supported.
+ return _tf_py_func_print(objects, kwargs)
+
+
+def _tf_py_func_print(objects, kwargs):
+ """Overload of print_ as a py_func implementation."""
+ override_kwargs = {k: v for k, v in kwargs.items() if v is not UNDEFINED}
+ if 'flush' not in override_kwargs:
+ # Defaulting to flushing the console in graph mode, which helps reduce
+ # garbled output in IPython.
+ override_kwargs['flush'] = True
+
+ def print_wrapper(*vals):
+ if six.PY3:
+ # TensorFlow doesn't seem to generate Unicode when passing strings to
+ # py_func. This causes the print to add a "b'" wrapper to the output,
+ # which is probably never what you want.
+ vals = tuple(
+ v.decode('utf-8') if isinstance(v, bytes) else v for v in vals)
+ six.print_(*vals, **override_kwargs)
+
+ return py_func.wrap_py_func(
+ print_wrapper, None, objects, use_dummy_return=True)
+
+
+def range_(start_or_stop, stop=UNDEFINED, step=UNDEFINED):
+ if any(tensor_util.is_tensor(s) for s in (start_or_stop, stop, step)):
+ return _tf_range(start_or_stop, stop, step)
+ return _py_range(start_or_stop, stop, step)
+
+
+def _tf_range(start_or_stop, stop, step):
+ # TODO(mdan): We should optimize this when a full tensor is not required.
+ if step is not UNDEFINED:
+ return math_ops.range(start_or_stop, stop, step)
+ if stop is not UNDEFINED:
+ return math_ops.range(start_or_stop, stop)
+ return math_ops.range(start_or_stop)
+
+
+def _py_range(start_or_stop, stop, step):
+ if step is not UNDEFINED:
+ return range(start_or_stop, stop, step)
+ if stop is not UNDEFINED:
+ return range(start_or_stop, stop)
+ return range(start_or_stop)
+
+
+SUPPORTED_BUILTINS = set((abs, float, int, len, print, range))
+
+if six.PY2:
+ SUPPORTED_BUILTINS.add(xrange)
+
+BUILTIN_FUINCTIONS_MAP = {
+ 'abs': abs_,
+ 'float': float_,
+ 'int': int_,
+ 'len': len_,
+ 'print': print_,
+ 'range': range_,
+ 'xrange': range_,
+}
diff --git a/tensorflow/contrib/autograph/operators/py_builtins_test.py b/tensorflow/contrib/autograph/operators/py_builtins_test.py
new file mode 100644
index 0000000000..4073c51785
--- /dev/null
+++ b/tensorflow/contrib/autograph/operators/py_builtins_test.py
@@ -0,0 +1,131 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for py_builtins module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+import six
+
+from tensorflow.contrib.autograph.operators import data_structures
+from tensorflow.contrib.autograph.operators import py_builtins
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.platform import test
+
+
+class PyBuiltinsTest(test.TestCase):
+
+ def test_abs(self):
+ self.assertEqual(py_builtins.abs_(-1), 1)
+ with self.test_session() as sess:
+ t = py_builtins.abs_(constant_op.constant(-1))
+ self.assertEqual(sess.run(t), 1)
+ t = py_builtins.abs_(constant_op.constant([-1, 2, -3]))
+ self.assertAllEqual(sess.run(t), [1, 2, 3])
+
+ def test_float(self):
+ self.assertEqual(py_builtins.float_(10), 10.0)
+ self.assertEqual(py_builtins.float_('10.0'), 10.0)
+ with self.test_session() as sess:
+ t = py_builtins.float_(constant_op.constant(1, dtype=dtypes.int64))
+ self.assertEqual(sess.run(t), 1.0)
+ st = py_builtins.float_(constant_op.constant('1.0'))
+ self.assertEqual(sess.run(st), 1.0)
+
+ def test_int(self):
+ self.assertEqual(py_builtins.int_(10.0), 10)
+ self.assertEqual(py_builtins.int_('11', 2), 3)
+ with self.test_session() as sess:
+ t = py_builtins.int_(constant_op.constant(1, dtype=dtypes.float64))
+ self.assertEqual(sess.run(t), 1)
+ st = py_builtins.int_(constant_op.constant('1'))
+ self.assertEqual(sess.run(st), 1)
+ st = py_builtins.int_(constant_op.constant('1'), 10)
+ self.assertEqual(sess.run(st), 1)
+
+ def test_int_unsupported_base(self):
+ t = constant_op.constant(1, dtype=dtypes.float64)
+ with self.assertRaises(NotImplementedError):
+ py_builtins.int_(t, 2)
+
+ def test_len(self):
+ self.assertEqual(py_builtins.len_([1, 2, 3]), 3)
+ with self.test_session() as sess:
+ t = py_builtins.len_(constant_op.constant([[1], [2], [3]]))
+ self.assertEqual(t, 3)
+ ta = py_builtins.len_(tensor_array_ops.TensorArray(dtypes.int32, size=5))
+ self.assertEqual(sess.run(ta), 5)
+ tl = py_builtins.len_(data_structures.tf_tensor_list_new([3, 4, 5]))
+ self.assertEqual(sess.run(tl), 3)
+
+ def test_len_scalar(self):
+ with self.assertRaises(ValueError):
+ py_builtins.len_(constant_op.constant(1))
+
+ def test_len_dynamic_shape(self):
+ with self.test_session() as sess:
+ p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
+ t = py_builtins.len_(p)
+ self.assertEqual(sess.run(t, {p: [1, 2, 3]}), 3)
+
+ with self.assertRaises(errors_impl.InvalidArgumentError):
+ t = py_builtins.len_(p)
+ sess.run(t, {p: 1})
+
+ def test_print_tensors(self):
+ try:
+ out_capturer = six.StringIO()
+ sys.stdout = out_capturer
+ with self.test_session() as sess:
+ sess.run(py_builtins.print_(constant_op.constant('test message'), 1))
+ self.assertEqual(out_capturer.getvalue(), 'test message 1\n')
+ finally:
+ sys.stdout = sys.__stdout__
+
+ def test_print_complex(self):
+ try:
+ out_capturer = six.StringIO()
+ sys.stdout = out_capturer
+ with self.test_session() as sess:
+ sess.run(
+ py_builtins.print_(constant_op.constant('test message'), [1, 2]))
+ self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n')
+ finally:
+ sys.stdout = sys.__stdout__
+
+ def test_range(self):
+ self.assertListEqual(list(py_builtins.range_(3)), [0, 1, 2])
+ self.assertListEqual(list(py_builtins.range_(1, 3)), [1, 2])
+ self.assertListEqual(list(py_builtins.range_(2, 0, -1)), [2, 1])
+
+ def test_range_tensor(self):
+ with self.test_session() as sess:
+ r = py_builtins.range_(constant_op.constant(3))
+ self.assertAllEqual(sess.run(r), [0, 1, 2])
+ r = py_builtins.range_(1, constant_op.constant(3))
+ self.assertAllEqual(sess.run(r), [1, 2])
+ r = py_builtins.range_(2, 0, constant_op.constant(-1))
+ self.assertAllEqual(sess.run(r), [2, 1])
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/autograph/operators/slices.py b/tensorflow/contrib/autograph/operators/slices.py
index 04fbeb2f6e..2b7f5ad922 100644
--- a/tensorflow/contrib/autograph/operators/slices.py
+++ b/tensorflow/contrib/autograph/operators/slices.py
@@ -22,6 +22,7 @@ import collections
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import gen_string_ops
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import tensor_array_ops
@@ -57,6 +58,8 @@ def get_item(target, i, opts):
elif tensor_util.is_tensor(target):
if target.dtype == dtypes.variant:
return _tf_tensor_list_get_item(target, i, opts)
+ elif target.dtype == dtypes.string and target.shape.ndims == 0:
+ return _tf_tensor_string_get_item(target, i)
else:
return _tf_tensor_get_item(target, i)
else:
@@ -82,6 +85,12 @@ def _tf_tensor_get_item(target, i):
return target[i]
+def _tf_tensor_string_get_item(target, i):
+ """Overload of get_item that stages a Tensor string read."""
+ x = gen_string_ops.substr(target, i, 1)
+ return x
+
+
def _py_get_item(target, i):
"""Overload of get_item that executes a Python list modification."""
return target[i]
diff --git a/tensorflow/contrib/autograph/operators/slices_test.py b/tensorflow/contrib/autograph/operators/slices_test.py
index 56aafe07c8..5255b7e2b6 100644
--- a/tensorflow/contrib/autograph/operators/slices_test.py
+++ b/tensorflow/contrib/autograph/operators/slices_test.py
@@ -46,6 +46,21 @@ class SlicesTest(test.TestCase):
with self.cached_session() as sess:
self.assertAllEqual(sess.run(t), [3, 4])
+ def test_get_item_tensor_string(self):
+ initial_str = constant_op.constant('abcd')
+ t = slices.get_item(initial_str, 1,
+ slices.GetItemOpts(element_dtype=initial_str.dtype))
+
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(t), b'b')
+
+ initial_list_str = constant_op.constant(['abcd', 'bcde'])
+ t = slices.get_item(initial_list_str, 1,
+ slices.GetItemOpts(element_dtype=initial_str.dtype))
+
+ with self.test_session() as sess:
+ self.assertEqual(sess.run(t), b'bcde')
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/BUILD b/tensorflow/contrib/autograph/pyct/common_transformers/BUILD
index a0938b3e5f..fe630ef852 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/BUILD
+++ b/tensorflow/contrib/autograph/pyct/common_transformers/BUILD
@@ -22,9 +22,11 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
- "//tensorflow/contrib/autograph/pyct",
"@gast_archive//:gast",
"@six_archive//:six",
+ # TODO(aqj) Revisit this dependency direction when pyct is more
+ # modularized
+ "//tensorflow/contrib/autograph/pyct",
],
)
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/__init__.py b/tensorflow/contrib/autograph/pyct/common_transformers/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/contrib/autograph/pyct/common_transformers/__init__.py
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py
index e42f679cfe..d77c15915b 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/anf.py
+++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf.py
@@ -394,10 +394,16 @@ class AnfTransformer(transformer.Base):
# just recur.
def visit_List(self, node):
- return self._visit_strict_expression(node)
+ node = self.generic_visit(node)
+ if not isinstance(node.ctx, gast.Store):
+ self._ensure_fields_trivial(node)
+ return node
def visit_Tuple(self, node):
- return self._visit_strict_expression(node)
+ node = self.generic_visit(node)
+ if not isinstance(node.ctx, gast.Store):
+ self._ensure_fields_trivial(node)
+ return node
def transform(node, entity_info, gensym_source=None):
diff --git a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py
index 951974820c..1ffd4bbe55 100644
--- a/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py
+++ b/tensorflow/contrib/autograph/pyct/common_transformers/anf_test.py
@@ -165,6 +165,46 @@ class AnfTransformerTest(test.TestCase):
self.assert_body_anfs_as_expected(expected_result, test_function)
+ def test_nested_multi_value_assign(self):
+
+ def test_function(a, b, c):
+ x, y = a, a + b
+ (z, y), x = (c, y + b), x + a
+ return z, (y, x)
+
+ def expected_result(a, b, c):
+ tmp_1001 = a + b
+ x, y = a, tmp_1001
+ tmp_1002 = y + b
+ tmp_1003 = (c, tmp_1002)
+ tmp_1004 = x + a
+ (z, y), x = tmp_1003, tmp_1004
+ tmp_1005 = y, x
+ tmp_1006 = z, tmp_1005
+ return tmp_1006
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
+ def test_deeply_nested_multi_value_assign(self):
+
+ def test_function(a):
+ [([(b, c), [d, e]], (f, g)), [(h, i, j), k]] = a
+ return [([(b, c), [d, e]], (f, g)), [(h, i, j), k]]
+
+ def expected_result(a):
+ [([(b, c), [d, e]], (f, g)), [(h, i, j), k]] = a
+ tmp_1001 = b, c
+ tmp_1002 = [d, e]
+ tmp_1003 = [tmp_1001, tmp_1002]
+ tmp_1004 = f, g
+ tmp_1005 = h, i, j
+ tmp_1006 = tmp_1003, tmp_1004
+ tmp_1007 = [tmp_1005, k]
+ tmp_1008 = [tmp_1006, tmp_1007]
+ return tmp_1008
+
+ self.assert_body_anfs_as_expected(expected_result, test_function)
+
def test_local_definition_and_binary_compare(self):
def test_function():
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py
index 2d8f922a45..e7baa244b2 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py
@@ -29,6 +29,11 @@ from tensorflow.contrib.autograph.pyct import anno
from tensorflow.contrib.autograph.pyct import transformer
from tensorflow.contrib.autograph.pyct.static_analysis.annos import NodeAnno
+# TODO(aqj): Do we need this? Do other builtins fail in similar ways
+# See b/114389775 for a related bug in pyct
+# These symbols are legal in Python, but don't appear in the namespace.
+_special_symbols = {'range': range}
+
class LiveValueResolver(transformer.Base):
"""Annotates nodes with live values."""
@@ -66,6 +71,8 @@ class LiveValueResolver(transformer.Base):
# If the symbol value is for example a primitive, then it will not
# have a name.
pass
+ elif node.id in _special_symbols:
+ anno.setanno(node, 'live_val', _special_symbols[node.id])
else:
pass
# TODO(mdan): Should we raise an error here?
diff --git a/tensorflow/contrib/autograph/pyct/templates.py b/tensorflow/contrib/autograph/pyct/templates.py
index 5831d57ceb..d81c50f524 100644
--- a/tensorflow/contrib/autograph/pyct/templates.py
+++ b/tensorflow/contrib/autograph/pyct/templates.py
@@ -113,7 +113,7 @@ class ReplaceTransformer(gast.NodeTransformer):
if isinstance(node, gast.Attribute):
self._check_inner_children_have_context(node.value)
self._check_has_context(node)
- elif isinstance(node, gast.Tuple):
+ elif isinstance(node, (gast.Tuple, gast.List)):
for e in node.elts:
self._check_inner_children_have_context(e)
self._check_has_context(node)
@@ -142,7 +142,7 @@ class ReplaceTransformer(gast.NodeTransformer):
if isinstance(node, gast.Attribute):
self._set_inner_child_context(node.value, gast.Load())
node.ctx = ctx
- elif isinstance(node, gast.Tuple):
+ elif isinstance(node, (gast.Tuple, gast.List)):
for e in node.elts:
self._set_inner_child_context(e, ctx)
node.ctx = ctx
@@ -191,7 +191,7 @@ class ReplaceTransformer(gast.NodeTransformer):
# Preserve the target context.
for n in new_nodes:
- if isinstance(n, gast.Tuple):
+ if isinstance(n, (gast.Tuple, gast.List)):
for e in n.elts:
self._set_inner_child_context(e, node.ctx)
if isinstance(n, gast.Attribute):
diff --git a/tensorflow/contrib/autograph/pyct/templates_test.py b/tensorflow/contrib/autograph/pyct/templates_test.py
index 77e8ff62fd..074105ea50 100644
--- a/tensorflow/contrib/autograph/pyct/templates_test.py
+++ b/tensorflow/contrib/autograph/pyct/templates_test.py
@@ -110,6 +110,42 @@ class TemplatesTest(test.TestCase):
self.assertIsInstance(node.body[0].targets[0].value.ctx, gast.Load)
self.assertIsInstance(node.body[0].targets[0].value.value.ctx, gast.Load)
+ def test_replace_list_context(self):
+ template = """
+ def test_fn(foo):
+ foo = 0
+ """
+
+ node = templates.replace(template, foo=parser.parse_expression('[a, b]'))[0]
+ self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
+ self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store)
+ self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store)
+
+ def test_replace_tuple_context(self):
+ template = """
+ def test_fn(foo):
+ foo = 0
+ """
+
+ node = templates.replace(template, foo=parser.parse_expression('(a, b)'))[0]
+ self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
+ self.assertIsInstance(node.body[0].targets[0].elts[0].ctx, gast.Store)
+ self.assertIsInstance(node.body[0].targets[0].elts[1].ctx, gast.Store)
+
+ def test_replace_complex_context(self):
+ template = """
+ def test_fn(foo):
+ foo = 0
+ """
+
+ node = templates.replace(
+ template, foo=parser.parse_expression('bar(([a, b],)).baz'))[0]
+ self.assertIsInstance(node.body[0].targets[0].ctx, gast.Store)
+ function_call_arg = node.body[0].targets[0].value.args[0]
+ self.assertIsInstance(function_call_arg.elts[0].ctx, gast.Load)
+ self.assertIsInstance(function_call_arg.elts[0].elts[0].ctx, gast.Load)
+ self.assertIsInstance(function_call_arg.elts[0].elts[1].ctx, gast.Load)
+
def test_replace_call_keyword(self):
template = """
def test_fn():
diff --git a/tensorflow/contrib/autograph/pyct/testing/BUILD b/tensorflow/contrib/autograph/pyct/testing/BUILD
index 9ef1ac9663..29a92444bb 100644
--- a/tensorflow/contrib/autograph/pyct/testing/BUILD
+++ b/tensorflow/contrib/autograph/pyct/testing/BUILD
@@ -34,8 +34,10 @@ py_test(
srcs = ["codegen_test.py"],
srcs_version = "PY2AND3",
tags = [
+ "manual",
"no_windows",
"nomsan",
+ "notap",
],
deps = [
":testing",
diff --git a/tensorflow/contrib/autograph/utils/BUILD b/tensorflow/contrib/autograph/utils/BUILD
index d2b399f19b..4504a5c7a3 100644
--- a/tensorflow/contrib/autograph/utils/BUILD
+++ b/tensorflow/contrib/autograph/utils/BUILD
@@ -20,12 +20,12 @@ py_library(
name = "utils",
srcs = [
"__init__.py",
- "builtins.py",
"context_managers.py",
"misc.py",
"multiple_dispatch.py",
"py_func.py",
"tensor_list.py",
+ "tensors.py",
"testing.py",
"type_check.py",
],
@@ -42,17 +42,6 @@ py_library(
)
py_test(
- name = "builtins_test",
- srcs = ["builtins_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_windows"],
- deps = [
- ":utils",
- "//tensorflow/python:client_testlib",
- ],
-)
-
-py_test(
name = "context_managers_test",
srcs = ["context_managers_test.py"],
srcs_version = "PY2AND3",
@@ -113,3 +102,13 @@ py_test(
"//tensorflow/python:list_ops",
],
)
+
+py_test(
+ name = "tensors_test",
+ srcs = ["tensors_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":utils",
+ "//tensorflow/python:client_testlib",
+ ],
+)
diff --git a/tensorflow/contrib/autograph/utils/__init__.py b/tensorflow/contrib/autograph/utils/__init__.py
index 57b5f74741..38e0a0a8f0 100644
--- a/tensorflow/contrib/autograph/utils/__init__.py
+++ b/tensorflow/contrib/autograph/utils/__init__.py
@@ -18,9 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.autograph.utils.builtins import dynamic_builtin
-from tensorflow.contrib.autograph.utils.builtins import dynamic_print
-from tensorflow.contrib.autograph.utils.builtins import dynamic_range
from tensorflow.contrib.autograph.utils.context_managers import control_dependency_on_returns
from tensorflow.contrib.autograph.utils.misc import alias_tensors
from tensorflow.contrib.autograph.utils.multiple_dispatch import dynamic_is
diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py
deleted file mode 100644
index 4dd440ef19..0000000000
--- a/tensorflow/contrib/autograph/utils/builtins.py
+++ /dev/null
@@ -1,143 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Builtin conversion utilities."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import sys
-
-import six
-
-from tensorflow.contrib.autograph.utils import py_func
-from tensorflow.contrib.autograph.utils import type_check
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import tensor_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import list_ops
-from tensorflow.python.ops import logging_ops
-from tensorflow.python.ops import math_ops
-
-
-def dynamic_builtin(f, *args, **kwargs):
- """Converts a builtin function call inline."""
- if f is len:
- return dynamic_len(*args, **kwargs)
- if six.PY2 and f is xrange:
- return dynamic_range(*args, **kwargs)
- if f is range:
- return dynamic_range(*args, **kwargs)
- if f is int:
- return dynamic_int(*args, **kwargs)
- if f is float:
- return dynamic_float(*args, **kwargs)
- if f is abs:
- return dynamic_abs(*args, **kwargs)
-
- raise NotImplementedError(
- 'The "%s" builtin is not yet supported.' % f.__name__)
-
-
-def dynamic_len(list_or_tensor):
- """Implementation of len using dynamic dispatch."""
- if _is_tensor_list(list_or_tensor):
- return list_ops.tensor_list_length(list_or_tensor)
- elif tensor_util.is_tensor(list_or_tensor):
- shape = list_or_tensor.shape
- if not shape.ndims:
- raise ValueError(
- 'len requires non-zero rank for tensor "%s"' % list_or_tensor)
- return array_ops.shape(list_or_tensor)[0]
- return len(list_or_tensor)
-
-
-def _is_tensor_list(list_or_tensor):
- return (tensor_util.is_tensor(list_or_tensor)
- and list_or_tensor.dtype == dtypes.variant)
-
-
-def dynamic_int(num_or_tensor, **kwargs):
- """Implementation of int() using dynamic dispatch."""
- if tensor_util.is_tensor(num_or_tensor):
- return math_ops.cast(num_or_tensor, dtype=dtypes.int32, **kwargs)
- return int(num_or_tensor)
-
-
-def dynamic_float(num_or_tensor, **kwargs):
- """Implementation of float() using dynamic dispatch."""
- if tensor_util.is_tensor(num_or_tensor):
- return math_ops.cast(num_or_tensor, dtype=dtypes.float32, **kwargs)
- return float(num_or_tensor)
-
-
-def dynamic_abs(num_or_tensor, **kwargs):
- if tensor_util.is_tensor(num_or_tensor):
- return math_ops.abs(num_or_tensor, **kwargs)
- else:
- return abs(num_or_tensor, **kwargs)
-
-
-def dynamic_range(start_or_stop, stop=None, step=None):
- """Implementation of range using dynamic dispatch."""
- if type_check.is_tensor(start_or_stop, stop, step):
- if step is not None:
- return math_ops.range(start_or_stop, stop, step)
- if stop is not None:
- return math_ops.range(start_or_stop, stop)
- return math_ops.range(start_or_stop)
-
- if step is not None:
- return range(start_or_stop, stop, step)
- elif stop is not None:
- return range(start_or_stop, stop)
- return range(start_or_stop)
-
-
-def is_tf_print_compatible(value):
- # TODO(mdan): Enable once we can reliably test this.
- # This is currently disabled because we can't capture the output of
- # op kernels from Python.
- del value
- return False
-
-
-def dynamic_print(*values):
- """Implementation of print using dynamic dispatch.
-
- The function attempts to use tf.Print if all the values are compatible.
- Otherwise, it will fall back to py_func.
-
- Args:
- *values: values to print
- Returns:
- A dummy value indicating the print completed. If tf.
- """
-
- if all(map(is_tf_print_compatible, values)):
- return logging_ops.Print(1, values)
-
- def print_wrapper(*vals):
- if six.PY3:
- # TensorFlow doesn't seem to generate Unicode when passing strings to
- # py_func. This causes the print to add a "b'" wrapper to the output,
- # which is probably never what you want.
- vals = tuple(v.decode() if isinstance(v, bytes) else v for v in vals)
- print(*vals)
- # The flush helps avoid garbled output in IPython.
- sys.stdout.flush()
-
- return py_func.wrap_py_func(
- print_wrapper, None, values, use_dummy_return=True)
diff --git a/tensorflow/contrib/autograph/utils/builtins_test.py b/tensorflow/contrib/autograph/utils/builtins_test.py
deleted file mode 100644
index b1cd5253bc..0000000000
--- a/tensorflow/contrib/autograph/utils/builtins_test.py
+++ /dev/null
@@ -1,145 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for builtins module."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import sys
-
-import six
-
-from tensorflow.contrib.autograph.utils import builtins
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.platform import test
-
-
-class BuiltinsTest(test.TestCase):
-
- def test_dynamic_len_tf_scalar(self):
- a = constant_op.constant(1)
-
- with self.assertRaisesRegexp(ValueError,
- 'len requires non-zero rank for tensor.*'):
- with self.test_session() as sess:
- sess.run(builtins.dynamic_builtin(len, a))
-
- def test_dynamic_len_tf_array(self):
- a = constant_op.constant([1, 2, 3])
-
- with self.test_session() as sess:
- self.assertEqual(3, sess.run(builtins.dynamic_builtin(len, a)))
-
- def test_dynamic_abs_tf_scalar(self):
- a = constant_op.constant(-1)
-
- with self.test_session() as sess:
- self.assertEqual(1, sess.run(builtins.dynamic_builtin(abs, a)))
-
- def test_dynamic_abs_tf_array(self):
- a = constant_op.constant([-1, 2, -3])
-
- with self.test_session() as sess:
- self.assertListEqual([1, 2, 3],
- list(sess.run(builtins.dynamic_builtin(abs, a))))
-
- def test_dynamic_abs_py_scalar(self):
- a = -1
- self.assertEqual(1, builtins.dynamic_builtin(abs, a))
-
- def test_dynamic_len_tf_matrix(self):
- a = constant_op.constant([[1, 2], [3, 4]])
-
- with self.test_session() as sess:
- self.assertEqual(2, sess.run(builtins.dynamic_builtin(len, a)))
-
- def test_dynamic_len_py_list(self):
- a = [3] * 5
-
- self.assertEqual(5, builtins.dynamic_builtin(len, a))
-
- def test_dynamic_range_all_python(self):
- self.assertListEqual(list(builtins.dynamic_builtin(range, 3)), [0, 1, 2])
- self.assertListEqual(list(builtins.dynamic_builtin(range, 1, 3)), [1, 2])
- self.assertListEqual(
- list(builtins.dynamic_builtin(range, 2, 0, -1)), [2, 1])
-
- def test_dynamic_range_tf(self):
- with self.test_session() as sess:
- self.assertAllEqual(
- sess.run(builtins.dynamic_builtin(range, constant_op.constant(3))),
- [0, 1, 2])
- self.assertAllEqual(
- sess.run(builtins.dynamic_builtin(range, 1, constant_op.constant(3))),
- [1, 2])
- self.assertAllEqual(
- sess.run(
- builtins.dynamic_builtin(range, 2, 0, constant_op.constant(-1))),
- [2, 1])
-
- def test_dynamic_range_detection(self):
- def range(x): # pylint:disable=redefined-builtin
- return x
-
- # Functions that just have the names of builtins are rejected.
- with self.assertRaises(NotImplementedError):
- self.assertEqual(builtins.dynamic_builtin(range, 1), 1)
- if six.PY2:
- self.assertListEqual(
- list(builtins.dynamic_builtin(xrange, 3)), [0, 1, 2])
- self.assertListEqual(
- list(builtins.dynamic_builtin(six.moves.range, 3)), [0, 1, 2])
- self.assertListEqual(
- list(builtins.dynamic_builtin(six.moves.xrange, 3)), [0, 1, 2])
-
- def test_casts(self):
- i = constant_op.constant(2, dtype=dtypes.int32)
- f = constant_op.constant(1.0, dtype=dtypes.float32)
-
- self.assertEqual(builtins.dynamic_builtin(int, i).dtype, dtypes.int32)
- self.assertEqual(builtins.dynamic_builtin(int, f).dtype, dtypes.int32)
- self.assertEqual(builtins.dynamic_builtin(float, i).dtype, dtypes.float32)
- self.assertEqual(builtins.dynamic_builtin(float, f).dtype, dtypes.float32)
-
- self.assertEqual(builtins.dynamic_builtin(int, True), 1)
- self.assertEqual(builtins.dynamic_builtin(int, False), 0)
- self.assertEqual(builtins.dynamic_builtin(float, True), 1.0)
- self.assertEqual(builtins.dynamic_builtin(float, False), 0.0)
-
- def test_dynamic_print_tf(self):
- try:
- out_capturer = six.StringIO()
- sys.stdout = out_capturer
- with self.test_session() as sess:
- sess.run(builtins.dynamic_print('test message', 1))
- self.assertEqual(out_capturer.getvalue(), 'test message 1\n')
- finally:
- sys.stdout = sys.__stdout__
-
- def test_dynamic_print_complex(self):
- try:
- out_capturer = six.StringIO()
- sys.stdout = out_capturer
- with self.test_session() as sess:
- sess.run(builtins.dynamic_print('test message', [1, 2]))
- self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n')
- finally:
- sys.stdout = sys.__stdout__
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/autograph/utils/misc_test.py b/tensorflow/contrib/autograph/utils/misc_test.py
index 71e358c33e..968ea03df6 100644
--- a/tensorflow/contrib/autograph/utils/misc_test.py
+++ b/tensorflow/contrib/autograph/utils/misc_test.py
@@ -31,7 +31,7 @@ class MiscTest(test.TestCase):
new_a = alias_tensors(a)
self.assertFalse(new_a is a)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(1, sess.run(new_a))
def test_alias_tensors(self):
@@ -46,7 +46,7 @@ class MiscTest(test.TestCase):
self.assertTrue(new_v is v)
self.assertTrue(new_s is s)
self.assertTrue(new_l is l)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(1, sess.run(new_a))
diff --git a/tensorflow/contrib/autograph/utils/py_func_test.py b/tensorflow/contrib/autograph/utils/py_func_test.py
index 2468263142..f60b57bcce 100644
--- a/tensorflow/contrib/autograph/utils/py_func_test.py
+++ b/tensorflow/contrib/autograph/utils/py_func_test.py
@@ -31,7 +31,7 @@ class PyFuncTest(test.TestCase):
def test_fn(a, b, c):
return a + b + c
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = py_func.wrap_py_func(test_fn, dtypes.int64,
(1, constant_op.constant(1), 1))
self.assertEqual(3, sess.run(result))
@@ -52,7 +52,7 @@ class PyFuncTest(test.TestCase):
def test_fn(a, b):
return a * b.foo
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass()))
self.assertEqual(35, sess.run(result))
result = py_func.wrap_py_func(test_fn, dtypes.int64,
@@ -69,7 +69,7 @@ class PyFuncTest(test.TestCase):
def test_fn(a, b, c, d):
return a * b.foo + c * d.foo
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = py_func.wrap_py_func(test_fn, dtypes.int64, (7, TestClass(5)), {
'c': 11,
'd': TestClass(13)
@@ -89,7 +89,7 @@ class PyFuncTest(test.TestCase):
def test_fn(_):
side_counter[0] += 1
- with self.test_session() as sess:
+ with self.cached_session() as sess:
result = py_func.wrap_py_func(test_fn, None, (5,), use_dummy_return=True)
self.assertEqual(1, sess.run(result))
self.assertEqual([1], side_counter)
diff --git a/tensorflow/contrib/autograph/utils/tensor_list_test.py b/tensorflow/contrib/autograph/utils/tensor_list_test.py
index d58489eb68..faaf7b7877 100644
--- a/tensorflow/contrib/autograph/utils/tensor_list_test.py
+++ b/tensorflow/contrib/autograph/utils/tensor_list_test.py
@@ -42,18 +42,18 @@ class TensorListTest(test.TestCase):
l = list_ops.empty_tensor_list(self._shape(()), dtypes.int32)
l = tl.dynamic_list_append(l, 1)
s = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(sess.run(s), [1])
l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
l = tl.dynamic_list_append(l, 1)
s = l.stack()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(sess.run(s), [1])
l = tl.TensorList(self._shape(()), dtypes.int32)
l = tl.dynamic_list_append(l, 1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual(sess.run(l[0]), 1)
def test_list_append_python(self):
@@ -107,7 +107,7 @@ class TensorListTest(test.TestCase):
l0 = l[0]
l[0] = b
l1 = l[0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
l0, l1, a, b = sess.run([l0, l1, a, b])
self.assertEqual(l0, a)
self.assertEqual(l1, b)
diff --git a/tensorflow/contrib/autograph/utils/tensors.py b/tensorflow/contrib/autograph/utils/tensors.py
new file mode 100644
index 0000000000..fa5db81a71
--- /dev/null
+++ b/tensorflow/contrib/autograph/utils/tensors.py
@@ -0,0 +1,41 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""This module defines tensor utilities not found in TensorFlow.
+
+The reason these utilities are not defined in TensorFlow is because they may
+not be not fully robust, although they work in the vast majority of cases. So
+we define them here in order for their behavior to be consistently verified.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import tensor_array_ops
+
+
+def is_tensor_array(t):
+ return isinstance(t, tensor_array_ops.TensorArray)
+
+
+def is_tensor_list(t):
+ # TODO(mdan): This is just a heuristic.
+ # With TF lacking support for templated types, this is unfortunately the
+ # closest we can get right now. A dedicated op ought to be possible to
+ # construct.
+ return (tensor_util.is_tensor(t) and t.dtype == dtypes.variant and
+ not t.shape.ndims)
diff --git a/tensorflow/contrib/autograph/utils/tensors_test.py b/tensorflow/contrib/autograph/utils/tensors_test.py
new file mode 100644
index 0000000000..e855e0b6cb
--- /dev/null
+++ b/tensorflow/contrib/autograph/utils/tensors_test.py
@@ -0,0 +1,57 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensors module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph.utils import tensors
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.platform import test
+
+
+class TensorsTest(test.TestCase):
+
+ def _simple_tensor_array(self):
+ return tensor_array_ops.TensorArray(dtypes.int32, size=3)
+
+ def _simple_tensor_list(self):
+ return list_ops.empty_tensor_list(
+ element_shape=constant_op.constant([1]), element_dtype=dtypes.int32)
+
+ def _simple_list_of_tensors(self):
+ return [constant_op.constant(1), constant_op.constant(2)]
+
+ def test_is_tensor_array(self):
+ self.assertTrue(tensors.is_tensor_array(self._simple_tensor_array()))
+ self.assertFalse(tensors.is_tensor_array(self._simple_tensor_list()))
+ self.assertFalse(tensors.is_tensor_array(constant_op.constant(1)))
+ self.assertFalse(tensors.is_tensor_array(self._simple_list_of_tensors()))
+ self.assertFalse(tensors.is_tensor_array(None))
+
+ def test_is_tensor_list(self):
+ self.assertFalse(tensors.is_tensor_list(self._simple_tensor_array()))
+ self.assertTrue(tensors.is_tensor_list(self._simple_tensor_list()))
+ self.assertFalse(tensors.is_tensor_list(constant_op.constant(1)))
+ self.assertFalse(tensors.is_tensor_list(self._simple_list_of_tensors()))
+ self.assertFalse(tensors.is_tensor_list(None))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/bigtable/README.md b/tensorflow/contrib/bigtable/README.md
index b9abfa8295..f33eaf7e3d 100644
--- a/tensorflow/contrib/bigtable/README.md
+++ b/tensorflow/contrib/bigtable/README.md
@@ -324,8 +324,14 @@ If you encounter a log line that includes the following:
"filename":"/usr/share/grpc/roots.pem"
```
-you likely need to copy the [gRPC `roots.pem` file][grpcPem] to
-`/usr/share/grpc/roots.pem` on your local machine.
+you can solve it via either of the following approaches:
+
+* copy the [gRPC `roots.pem` file][grpcPem] to
+ `/usr/share/grpc/roots.pem` on your local machine, which is the default
+ location where gRPC will look for this file
+* export the environment variable `GRPC_DEFAULT_SSL_ROOTS_FILE_PATH` to point to
+ the full path of the gRPC `roots.pem` file on your file system if it's in a
+ different location
[grpcPem]: https://github.com/grpc/grpc/blob/master/etc/roots.pem
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc
index a25a641cdb..6138d79126 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_kernels.cc
@@ -172,6 +172,11 @@ class BigtableTableOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("BigtableTable").Device(DEVICE_CPU),
BigtableTableOp);
+} // namespace
+
+namespace data {
+namespace {
+
class ToBigtableOp : public AsyncOpKernel {
public:
explicit ToBigtableOp(OpKernelConstruction* ctx)
@@ -354,5 +359,6 @@ REGISTER_KERNEL_BUILDER(Name("DatasetToBigtable").Device(DEVICE_CPU),
ToBigtableOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h
index a2a5df1037..4652021fec 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_lib.h
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_lib.h
@@ -79,6 +79,8 @@ class BigtableTableResource : public ResourceBase {
::google::cloud::bigtable::noex::Table table_;
};
+namespace data {
+
// BigtableReaderDatasetIterator is an abstract class for iterators from
// datasets that are "readers" (source datasets, not transformation datasets)
// that read from Bigtable.
@@ -138,6 +140,8 @@ class BigtableReaderDatasetIterator : public DatasetIterator<Dataset> {
::google::cloud::bigtable::RowReader::iterator iterator_ GUARDED_BY(mu_);
};
+} // namespace data
+
} // namespace tensorflow
#endif // TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
index bd32672aa9..11f530e82a 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_lookup_dataset_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
+namespace data {
namespace {
class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
@@ -226,4 +227,5 @@ REGISTER_KERNEL_BUILDER(Name("BigtableLookupDataset").Device(DEVICE_CPU),
BigtableLookupDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
index a803fdcb49..5cab729d9c 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_prefix_key_dataset_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
+namespace data {
namespace {
class BigtablePrefixKeyDatasetOp : public DatasetOpKernel {
@@ -111,4 +112,5 @@ REGISTER_KERNEL_BUILDER(Name("BigtablePrefixKeyDataset").Device(DEVICE_CPU),
BigtablePrefixKeyDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
index 5cd0371c79..4dc4647bd2 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_range_key_dataset_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
+namespace data {
namespace {
class BigtableRangeKeyDatasetOp : public DatasetOpKernel {
@@ -117,4 +118,5 @@ class BigtableRangeKeyDatasetOp : public DatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("BigtableRangeKeyDataset").Device(DEVICE_CPU),
BigtableRangeKeyDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
index 6928d9423c..736775bdac 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_key_pairs_dataset_op.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
+namespace data {
namespace {
class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
@@ -205,4 +206,5 @@ REGISTER_KERNEL_BUILDER(
BigtableSampleKeyPairsDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
index a759fb5063..208b7b3e08 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_sample_keys_dataset_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
+namespace data {
namespace {
class BigtableSampleKeysDatasetOp : public DatasetOpKernel {
@@ -118,4 +119,5 @@ REGISTER_KERNEL_BUILDER(Name("BigtableSampleKeysDataset").Device(DEVICE_CPU),
BigtableSampleKeysDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
index 78a920b077..9407855fe8 100644
--- a/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
+++ b/tensorflow/contrib/bigtable/kernels/bigtable_scan_dataset_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
+namespace data {
namespace {
class BigtableScanDatasetOp : public DatasetOpKernel {
@@ -224,4 +225,5 @@ REGISTER_KERNEL_BUILDER(Name("BigtableScanDataset").Device(DEVICE_CPU),
BigtableScanDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
index 870ce2442b..4c7a538b38 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py
@@ -52,7 +52,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
center_bias=True,
use_core_libs=False,
output_leaf_index=False,
- override_global_step_value=None):
+ override_global_step_value=None,
+ num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeClassifier estimator instance.
Args:
@@ -94,6 +95,7 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
trees were trained), this parameter can be used to set the global step
to a large value, making it look like that number of training steps ran.
If None, no override of global step will happen.
+ num_quantiles: Number of quantiles to build for numeric feature values.
Raises:
ValueError: If learner_config is not valid.
@@ -134,7 +136,8 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator):
'logits_modifier_function': logits_modifier_function,
'use_core_libs': use_core_libs,
'output_leaf_index': output_leaf_index,
- 'override_global_step_value': override_global_step_value
+ 'override_global_step_value': override_global_step_value,
+ 'num_quantiles': num_quantiles,
},
model_dir=model_dir,
config=config,
@@ -159,7 +162,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
center_bias=True,
use_core_libs=False,
output_leaf_index=False,
- override_global_step_value=None):
+ override_global_step_value=None,
+ num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeRegressor estimator instance.
Args:
@@ -201,6 +205,7 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
trees were trained), this parameter can be used to set the global step
to a large value, making it look like that number of training steps ran.
If None, no override of global step will happen.
+ num_quantiles: Number of quantiles to build for numeric feature values.
"""
head = head_lib.regression_head(
label_name=label_name,
@@ -224,7 +229,8 @@ class GradientBoostedDecisionTreeRegressor(estimator.Estimator):
'center_bias': center_bias,
'use_core_libs': use_core_libs,
'output_leaf_index': False,
- 'override_global_step_value': override_global_step_value
+ 'override_global_step_value': override_global_step_value,
+ 'num_quantiles': num_quantiles,
},
model_dir=model_dir,
config=config,
@@ -251,7 +257,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
center_bias=True,
use_core_libs=False,
output_leaf_index=False,
- override_global_step_value=None):
+ override_global_step_value=None,
+ num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeEstimator estimator instance.
Args:
@@ -289,6 +296,7 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
trees were trained), this parameter can be used to set the global step
to a large value, making it look like that number of training steps ran.
If None, no override of global step will happen.
+ num_quantiles: Number of quantiles to build for numeric feature values.
"""
super(GradientBoostedDecisionTreeEstimator, self).__init__(
model_fn=model.model_builder,
@@ -303,7 +311,8 @@ class GradientBoostedDecisionTreeEstimator(estimator.Estimator):
'center_bias': center_bias,
'use_core_libs': use_core_libs,
'output_leaf_index': False,
- 'override_global_step_value': override_global_step_value
+ 'override_global_step_value': override_global_step_value,
+ 'num_quantiles': num_quantiles,
},
model_dir=model_dir,
config=config,
@@ -329,7 +338,8 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator):
center_bias=False,
use_core_libs=False,
output_leaf_index=False,
- override_global_step_value=None):
+ override_global_step_value=None,
+ num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeRanker instance.
This is an estimator that can be trained off the pairwise data and can be
@@ -377,6 +387,8 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator):
trees were trained), this parameter can be used to set the global step
to a large value, making it look like that number of training steps ran.
If None, no override of global step will happen.
+ num_quantiles: Number of quantiles to build for numeric feature values.
+
Raises:
ValueError: If learner_config is not valid.
"""
@@ -395,7 +407,8 @@ class GradientBoostedDecisionTreeRanker(estimator.Estimator):
'use_core_libs': use_core_libs,
'output_leaf_index': output_leaf_index,
'ranking_model_pair_keys': ranking_model_pair_keys,
- 'override_global_step_value': override_global_step_value
+ 'override_global_step_value': override_global_step_value,
+ 'num_quantiles': num_quantiles,
},
model_dir=model_dir,
config=config,
@@ -444,7 +457,8 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
feature_engineering_fn=None,
logits_modifier_function=None,
center_bias=True,
- output_leaf_index=False):
+ output_leaf_index=False,
+ num_quantiles=100):
"""Initializes a core version of GradientBoostedDecisionTreeEstimator.
Args:
@@ -474,6 +488,7 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
for example_prediction_result in result_dict:
# access leaf index list by example_prediction_result["leaf_index"]
# which contains one leaf index per tree
+ num_quantiles: Number of quantiles to build for numeric feature values.
"""
def _model_fn(features, labels, mode, config):
@@ -493,7 +508,8 @@ class CoreGradientBoostedDecisionTreeEstimator(core_estimator.Estimator):
'logits_modifier_function': logits_modifier_function,
'use_core_libs': True,
'output_leaf_index': output_leaf_index,
- 'override_global_step_value': None
+ 'override_global_step_value': None,
+ 'num_quantiles': num_quantiles,
},
output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
@@ -517,7 +533,8 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator):
label_keys=None,
logits_modifier_function=None,
center_bias=False,
- output_leaf_index=False):
+ output_leaf_index=False,
+ num_quantiles=100):
"""Initializes a GradientBoostedDecisionTreeRanker instance.
This is an estimator that can be trained off the pairwise data and can be
@@ -552,6 +569,7 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator):
for result_dict in result_iter:
# access leaf index list by result_dict["leaf_index"]
# which contains one leaf index per tree
+ num_quantiles: Number of quantiles to build for numeric feature values.
Raises:
ValueError: If learner_config is not valid.
@@ -576,7 +594,8 @@ class CoreGradientBoostedDecisionTreeRanker(core_estimator.Estimator):
'use_core_libs': True,
'output_leaf_index': output_leaf_index,
'ranking_model_pair_keys': ranking_model_pair_keys,
- 'override_global_step_value': None
+ 'override_global_step_value': None,
+ 'num_quantiles': num_quantiles,
},
output_type=model.ModelBuilderOutputType.ESTIMATOR_SPEC)
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
index 04b46c3483..a6e422847d 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py
@@ -81,6 +81,7 @@ def model_builder(features,
logits_modifier_function = params["logits_modifier_function"]
output_leaf_index = params["output_leaf_index"]
override_global_step_value = params.get("override_global_step_value", None)
+ num_quantiles = params["num_quantiles"]
if features is None:
raise ValueError("At least one feature must be specified.")
@@ -116,7 +117,8 @@ def model_builder(features,
logits_dimension=head.logits_dimension,
features=training_features,
use_core_columns=use_core_libs,
- output_leaf_index=output_leaf_index)
+ output_leaf_index=output_leaf_index,
+ num_quantiles=num_quantiles)
with ops.name_scope("gbdt", "gbdt_optimizer"):
predictions_dict = gbdt_model.predict(mode)
logits = predictions_dict["predictions"]
@@ -237,6 +239,7 @@ def ranking_model_builder(features,
output_leaf_index = params["output_leaf_index"]
ranking_model_pair_keys = params["ranking_model_pair_keys"]
override_global_step_value = params.get("override_global_step_value", None)
+ num_quantiles = params["num_quantiles"]
if features is None:
raise ValueError("At least one feature must be specified.")
@@ -299,7 +302,8 @@ def ranking_model_builder(features,
logits_dimension=head.logits_dimension,
features=main_features,
use_core_columns=use_core_libs,
- output_leaf_index=output_leaf_index)
+ output_leaf_index=output_leaf_index,
+ num_quantiles=num_quantiles)
with ops.name_scope("gbdt", "gbdt_optimizer"):
# Logits for inference.
diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
index 64349cfca3..3b28ed77f3 100644
--- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
+#include <limits>
#include <memory>
#include <string>
#include <vector>
@@ -325,13 +326,21 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
}
float best_gain = std::numeric_limits<float>::lowest();
- int64 best_bucket_idx = 0;
+ int64 best_bucket_id = 0;
std::vector<NodeStats> best_right_node_stats(num_elements, NodeStats(0));
std::vector<NodeStats> best_left_node_stats(num_elements, NodeStats(0));
std::vector<NodeStats> current_left_node_stats(num_elements, NodeStats(0));
std::vector<NodeStats> current_right_node_stats(num_elements, NodeStats(0));
- int64 current_bucket_id = 0;
+ int64 current_bucket_id = std::numeric_limits<int64>::max();
int64 last_bucket_id = -1;
+ // Find the lowest bucket id, this is going to be the first bucket id to
+ // try.
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ const int start_index = partition_boundaries[root_idx];
+ if (bucket_ids(start_index, 0) < current_bucket_id) {
+ current_bucket_id = bucket_ids(start_index, 0);
+ }
+ }
// Indexes offsets for each of the partitions that can be used to access
// gradients of a partition for a current bucket we consider.
std::vector<int> current_layer_offsets(num_elements, 0);
@@ -373,6 +382,7 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
best_gain = gain_of_split;
best_left_node_stats = current_left_node_stats;
best_right_node_stats = current_right_node_stats;
+ best_bucket_id = current_bucket_id;
}
current_bucket_id = next_bucket_id;
}
@@ -387,8 +397,7 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
oblivious_split_info.mutable_split_node()
->mutable_oblivious_dense_float_binary_split();
oblivious_dense_split->set_feature_column(state->feature_column_group_id());
- oblivious_dense_split->set_threshold(
- bucket_boundaries(bucket_ids(best_bucket_idx, 0)));
+ oblivious_dense_split->set_threshold(bucket_boundaries(best_bucket_id));
(*gains)(0) = best_gain;
for (int root_idx = 0; root_idx < num_elements; root_idx++) {
@@ -400,6 +409,7 @@ class BuildDenseInequalitySplitsOp : public OpKernel {
const int start_index = partition_boundaries[root_idx];
(*output_partition_ids)(root_idx) = partition_ids(start_index);
+ oblivious_split_info.add_children_parent_id(partition_ids(start_index));
}
oblivious_split_info.SerializeToString(&(*output_splits)(0));
}
@@ -729,6 +739,11 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
context->input("bias_feature_id", &bias_feature_id_t));
int64 bias_feature_id = bias_feature_id_t->scalar<int64>()();
+ const Tensor* weak_learner_type_t;
+ OP_REQUIRES_OK(context,
+ context->input("weak_learner_type", &weak_learner_type_t));
+ const int32 weak_learner_type = weak_learner_type_t->scalar<int32>()();
+
// Find the number of unique partitions before we allocate the output.
std::vector<int32> partition_boundaries;
std::vector<int32> non_empty_partitions;
@@ -757,20 +772,63 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
tensorflow::TTypes<int32>::Vec output_partition_ids =
output_partition_ids_t->vec<int32>();
+ // For a normal tree, we output a split per partition. For an oblivious
+ // tree, we output one split for all partitions of the layer.
+ int size_output = num_elements;
+ if (weak_learner_type == LearnerConfig::OBLIVIOUS_DECISION_TREE &&
+ num_elements > 0) {
+ size_output = 1;
+ }
+
Tensor* gains_t = nullptr;
- OP_REQUIRES_OK(
- context, context->allocate_output("gains", TensorShape({num_elements}),
- &gains_t));
+ OP_REQUIRES_OK(context, context->allocate_output(
+ "gains", TensorShape({size_output}), &gains_t));
tensorflow::TTypes<float>::Vec gains = gains_t->vec<float>();
Tensor* output_splits_t = nullptr;
- OP_REQUIRES_OK(context, context->allocate_output(
- "split_infos", TensorShape({num_elements}),
- &output_splits_t));
+ OP_REQUIRES_OK(context, context->allocate_output("split_infos",
+ TensorShape({size_output}),
+ &output_splits_t));
tensorflow::TTypes<string>::Vec output_splits =
output_splits_t->vec<string>();
+ if (num_elements == 0) {
+ return;
+ }
SplitBuilderState state(context);
+ switch (weak_learner_type) {
+ case LearnerConfig::NORMAL_DECISION_TREE: {
+ ComputeNormalDecisionTree(
+ context, &state, normalizer_ratio, num_elements,
+ partition_boundaries, non_empty_partitions, bias_feature_id,
+ partition_ids, feature_ids, gradients_t, hessians_t,
+ &output_partition_ids, &gains, &output_splits);
+ break;
+ }
+ case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
+ ComputeObliviousDecisionTree(
+ context, &state, normalizer_ratio, num_elements,
+ partition_boundaries, non_empty_partitions, bias_feature_id,
+ partition_ids, feature_ids, gradients_t, hessians_t,
+ &output_partition_ids, &gains, &output_splits);
+ break;
+ }
+ }
+ }
+
+ private:
+ void ComputeNormalDecisionTree(
+ OpKernelContext* const context, SplitBuilderState* state,
+ const float normalizer_ratio, const int num_elements,
+ const std::vector<int32>& partition_boundaries,
+ const std::vector<int32>& non_empty_partitions,
+ const int64 bias_feature_id,
+ const tensorflow::TTypes<int32>::ConstVec& partition_ids,
+ const tensorflow::TTypes<int64>::ConstMatrix& feature_ids,
+ const Tensor* gradients_t, const Tensor* hessians_t,
+ tensorflow::TTypes<int32>::Vec* output_partition_ids,
+ tensorflow::TTypes<float>::Vec* gains,
+ tensorflow::TTypes<string>::Vec* output_splits) {
for (int root_idx = 0; root_idx < num_elements; ++root_idx) {
float best_gain = std::numeric_limits<float>::lowest();
int start_index = partition_boundaries[non_empty_partitions[root_idx]];
@@ -780,7 +838,7 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
errors::InvalidArgument("Bias feature ID missing."));
GradientStats root_gradient_stats(*gradients_t, *hessians_t, start_index);
root_gradient_stats *= normalizer_ratio;
- NodeStats root_stats = state.ComputeNodeStats(root_gradient_stats);
+ NodeStats root_stats = state->ComputeNodeStats(root_gradient_stats);
int32 best_feature_idx = 0;
NodeStats best_right_node_stats(0);
NodeStats best_left_node_stats(0);
@@ -791,8 +849,8 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
left_gradient_stats *= normalizer_ratio;
GradientStats right_gradient_stats =
root_gradient_stats - left_gradient_stats;
- NodeStats left_stats = state.ComputeNodeStats(left_gradient_stats);
- NodeStats right_stats = state.ComputeNodeStats(right_gradient_stats);
+ NodeStats left_stats = state->ComputeNodeStats(left_gradient_stats);
+ NodeStats right_stats = state->ComputeNodeStats(right_gradient_stats);
if (left_stats.gain + right_stats.gain > best_gain) {
best_gain = left_stats.gain + right_stats.gain;
best_left_node_stats = left_stats;
@@ -803,18 +861,133 @@ class BuildCategoricalEqualitySplitsOp : public OpKernel {
SplitInfo split_info;
auto* equality_split = split_info.mutable_split_node()
->mutable_categorical_id_binary_split();
- equality_split->set_feature_column(state.feature_column_group_id());
+ equality_split->set_feature_column(state->feature_column_group_id());
equality_split->set_feature_id(feature_ids(best_feature_idx, 0));
auto* left_child = split_info.mutable_left_child();
auto* right_child = split_info.mutable_right_child();
- state.FillLeaf(best_left_node_stats, left_child);
- state.FillLeaf(best_right_node_stats, right_child);
- split_info.SerializeToString(&output_splits(root_idx));
- gains(root_idx) =
- best_gain - root_stats.gain - state.tree_complexity_regularization();
- output_partition_ids(root_idx) = partition_ids(start_index);
+ state->FillLeaf(best_left_node_stats, left_child);
+ state->FillLeaf(best_right_node_stats, right_child);
+ split_info.SerializeToString(&(*output_splits)(root_idx));
+ (*gains)(root_idx) =
+ best_gain - root_stats.gain - state->tree_complexity_regularization();
+ (*output_partition_ids)(root_idx) = partition_ids(start_index);
}
}
+
+ void ComputeObliviousDecisionTree(
+ OpKernelContext* const context, SplitBuilderState* state,
+ const float normalizer_ratio, const int num_elements,
+ const std::vector<int32>& partition_boundaries,
+ const std::vector<int32>& non_empty_partitions,
+ const int64 bias_feature_id,
+ const tensorflow::TTypes<int32>::ConstVec& partition_ids,
+ const tensorflow::TTypes<int64>::ConstMatrix& feature_ids,
+ const Tensor* gradients_t, const Tensor* hessians_t,
+ tensorflow::TTypes<int32>::Vec* output_partition_ids,
+ tensorflow::TTypes<float>::Vec* gains,
+ tensorflow::TTypes<string>::Vec* output_splits) {
+ // Holds the root stats per each node to be split.
+ std::vector<GradientStats> current_layer_stats;
+ current_layer_stats.reserve(num_elements);
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ const int start_index = partition_boundaries[root_idx];
+ // First feature ID in each partition should be the bias feature.
+ OP_REQUIRES(context, feature_ids(start_index, 0) == bias_feature_id,
+ errors::InvalidArgument("Bias feature ID missing."));
+ GradientStats root_gradient_stats(*gradients_t, *hessians_t, start_index);
+ root_gradient_stats *= normalizer_ratio;
+ current_layer_stats.push_back(root_gradient_stats);
+ }
+ float best_gain = std::numeric_limits<float>::lowest();
+ int64 best_feature_id = 0;
+ std::vector<NodeStats> best_right_node_stats(num_elements, NodeStats(0));
+ std::vector<NodeStats> best_left_node_stats(num_elements, NodeStats(0));
+ std::vector<NodeStats> current_left_node_stats(num_elements, NodeStats(0));
+ std::vector<NodeStats> current_right_node_stats(num_elements, NodeStats(0));
+ int64 current_feature_id = std::numeric_limits<int64>::max();
+ int64 last_feature_id = -1;
+ // Find the lowest feature id, this is going to be the first feature id to
+ // try.
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ const int start_index = partition_boundaries[root_idx];
+ if (feature_ids(start_index + 1, 0) < current_feature_id) {
+ current_feature_id = feature_ids(start_index + 1, 0);
+ }
+ }
+ // Indexes offsets for each of the partitions that can be used to access
+ // gradients of a partition for a current feature we consider. Start at one
+ // beacuse the zero index is for the bias.
+ std::vector<int> current_layer_offsets(num_elements, 1);
+ // The idea is to try every feature id in increasing order. In each
+ // iteration we calculate the gain of the layer using the current feature id
+ // as split value, and we also obtain the following feature id to try.
+ while (current_feature_id > last_feature_id) {
+ last_feature_id = current_feature_id;
+ int64 next_feature_id = -1;
+ // Left gradient stats per node.
+ std::vector<GradientStats> left_gradient_stats(num_elements);
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ int idx =
+ current_layer_offsets[root_idx] + partition_boundaries[root_idx];
+ const int end_index = partition_boundaries[root_idx + 1];
+ if (idx < end_index && feature_ids(idx, 0) == current_feature_id) {
+ GradientStats g(*gradients_t, *hessians_t, idx);
+ g *= normalizer_ratio;
+ left_gradient_stats[root_idx] = g;
+ current_layer_offsets[root_idx]++;
+ idx++;
+ }
+ if (idx < end_index &&
+ (feature_ids(idx, 0) < next_feature_id || next_feature_id == -1)) {
+ next_feature_id = feature_ids(idx, 0);
+ }
+ }
+ float gain_of_split = 0.0;
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ GradientStats right_gradient_stats =
+ current_layer_stats[root_idx] - left_gradient_stats[root_idx];
+ NodeStats left_stat =
+ state->ComputeNodeStats(left_gradient_stats[root_idx]);
+ NodeStats right_stat = state->ComputeNodeStats(right_gradient_stats);
+ gain_of_split += left_stat.gain + right_stat.gain;
+ current_left_node_stats[root_idx] = left_stat;
+ current_right_node_stats[root_idx] = right_stat;
+ }
+ if (gain_of_split > best_gain) {
+ best_gain = gain_of_split;
+ best_left_node_stats = current_left_node_stats;
+ best_right_node_stats = current_right_node_stats;
+ best_feature_id = current_feature_id;
+ }
+ current_feature_id = next_feature_id;
+ }
+
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ best_gain -= state->ComputeNodeStats(current_layer_stats[root_idx]).gain;
+ }
+ best_gain -= num_elements * state->tree_complexity_regularization();
+
+ ObliviousSplitInfo oblivious_split_info;
+ auto* equality_split =
+ oblivious_split_info.mutable_split_node()
+ ->mutable_oblivious_categorical_id_binary_split();
+ equality_split->set_feature_column(state->feature_column_group_id());
+ equality_split->set_feature_id(best_feature_id);
+ (*gains)(0) = best_gain;
+
+ for (int root_idx = 0; root_idx < num_elements; root_idx++) {
+ auto* left_child = oblivious_split_info.add_children();
+ auto* right_child = oblivious_split_info.add_children();
+
+ state->FillLeaf(best_left_node_stats[root_idx], left_child);
+ state->FillLeaf(best_right_node_stats[root_idx], right_child);
+
+ const int start_index = partition_boundaries[root_idx];
+ (*output_partition_ids)(root_idx) = partition_ids(start_index);
+ oblivious_split_info.add_children_parent_id(partition_ids(start_index));
+ }
+ oblivious_split_info.SerializeToString(&(*output_splits)(0));
+ }
};
REGISTER_KERNEL_BUILDER(
diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc
index bb5ae78d9b..ab2853352a 100644
--- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc
+++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
+#include <vector>
+
#include "tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h"
#include "tensorflow/contrib/boosted_trees/proto/learner.pb.h"
#include "tensorflow/contrib/boosted_trees/proto/split_info.pb.h"
@@ -772,20 +774,32 @@ class GrowTreeEnsembleOp : public OpKernel {
// The number of new children.
int num_children = 1 << (depth + 1);
auto split_info = split->oblivious_split_info;
- CHECK(num_children == split_info.children_size())
- << "Wrong number of new children: " << num_children
- << " != " << split_info.children_size();
- for (int idx = 0; idx < num_children; idx += 2) {
- // Old leaf is at position depth + idx / 2.
+ CHECK(num_children >= split_info.children_size())
+ << "Too many new children, expected <= " << num_children << " and got "
+ << split_info.children_size();
+ std::vector<trees::Leaf> new_leaves;
+ new_leaves.reserve(num_children);
+ int next_id = 0;
+ for (int idx = 0; idx < num_children / 2; idx++) {
trees::Leaf old_leaf =
- *tree_config->mutable_nodes(depth + idx / 2)->mutable_leaf();
- // Update left leaf.
- *split_info.mutable_children(idx) =
- *MergeLeafWeights(old_leaf, split_info.mutable_children(idx));
- // Update right leaf.
- *split_info.mutable_children(idx + 1) =
- *MergeLeafWeights(old_leaf, split_info.mutable_children(idx + 1));
+ *tree_config->mutable_nodes(depth + idx)->mutable_leaf();
+ // Check if a split was made for this leaf.
+ if (next_id < split_info.children_parent_id_size() &&
+ depth + idx == split_info.children_parent_id(next_id)) {
+ // Add left leaf.
+ new_leaves.push_back(*MergeLeafWeights(
+ old_leaf, split_info.mutable_children(2 * next_id)));
+ // Add right leaf.
+ new_leaves.push_back(*MergeLeafWeights(
+ old_leaf, split_info.mutable_children(2 * next_id + 1)));
+ next_id++;
+ } else {
+ // If there is no split for this leaf, just duplicate it.
+ new_leaves.push_back(old_leaf);
+ new_leaves.push_back(old_leaf);
+ }
}
+ CHECK(next_id == split_info.children_parent_id_size());
TreeNodeMetadata* split_metadata =
split_info.mutable_split_node()->mutable_node_metadata();
split_metadata->set_gain(split->gain);
@@ -804,11 +818,10 @@ class GrowTreeEnsembleOp : public OpKernel {
if (idx + depth + 1 < nodes_size) {
// Update leaves that were already there.
*tree_config->mutable_nodes(idx + depth + 1)->mutable_leaf() =
- *split_info.mutable_children(idx);
+ new_leaves[idx];
} else {
// Add new leaves.
- *tree_config->add_nodes()->mutable_leaf() =
- *split_info.mutable_children(idx);
+ *tree_config->add_nodes()->mutable_leaf() = new_leaves[idx];
}
}
}
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
index efe29216c2..35d727482b 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.boosted_trees.lib.learner.batch import base_split_handler
+from tensorflow.contrib.boosted_trees.proto import learner_pb2
from tensorflow.contrib.boosted_trees.python.ops import split_handler_ops
from tensorflow.contrib.boosted_trees.python.ops import stats_accumulator_ops
from tensorflow.python.framework import constant_op
@@ -46,6 +47,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
multiclass_strategy,
init_stamp_token=0,
loss_uses_sum_reduction=False,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE,
name=None):
"""Initialize the internal state for this split handler.
@@ -66,6 +68,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
stamped objects.
loss_uses_sum_reduction: A scalar boolean tensor that specifies whether
SUM or MEAN reduction was used for the loss.
+ weak_learner_type: Specifies the type of weak learner to use.
name: An optional handler name.
"""
super(EqualitySplitHandler, self).__init__(
@@ -85,6 +88,7 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
hessian_shape,
name="StatsAccumulator/{}".format(self._name))
self._sparse_int_column = sparse_int_column
+ self._weak_learner_type = weak_learner_type
def update_stats(self, stamp_token, example_partition_ids, gradients,
hessians, empty_gradients, empty_hessians, weights,
@@ -137,11 +141,18 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
# The bias is computed on gradients and hessians (and not
# filtered_gradients) which have exactly one value per example, so we
# don't double count a gradient in multivalent columns.
+ # Since unsorted_segment_sum can be numerically unstable, use 64bit
+ # operation.
+ gradients64 = math_ops.cast(gradients, dtypes.float64)
+ hessians64 = math_ops.cast(hessians, dtypes.float64)
per_partition_gradients = math_ops.unsorted_segment_sum(
- gradients, mapped_partitions, array_ops.size(unique_partitions))
+ gradients64, mapped_partitions, array_ops.size(unique_partitions))
per_partition_hessians = math_ops.unsorted_segment_sum(
- hessians, mapped_partitions, array_ops.size(unique_partitions))
-
+ hessians64, mapped_partitions, array_ops.size(unique_partitions))
+ per_partition_gradients = math_ops.cast(per_partition_gradients,
+ dtypes.float32)
+ per_partition_hessians = math_ops.cast(per_partition_hessians,
+ dtypes.float32)
# Prepend a bias feature per partition that accumulates the stats for all
# examples in that partition.
# Bias is added to the stats even if there are no examples with values in
@@ -197,7 +208,8 @@ class EqualitySplitHandler(base_split_handler.BaseSplitHandler):
tree_complexity_regularization=self._tree_complexity_regularization,
min_node_weight=self._min_node_weight,
bias_feature_id=_BIAS_FEATURE_ID,
- multiclass_strategy=self._multiclass_strategy))
+ multiclass_strategy=self._multiclass_strategy,
+ weak_learner_type=self._weak_learner_type))
# There are no warm-up rounds needed in the equality column handler. So we
# always return ready.
are_splits_ready = constant_op.constant(True)
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
index ef253e7cec..94ea7bc2eb 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/categorical_split_handler_test.py
@@ -47,7 +47,7 @@ def get_empty_tensors(gradient_shape, hessian_shape):
class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
def testGenerateFeatureSplitCandidates(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Feature ID |
# i0 | (0.2, 0.12) | 0 | 1,2 |
@@ -169,10 +169,121 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(1, split_node.feature_id)
- def testGenerateFeatureSplitCandidatesSumReduction(self):
+ def testObliviousFeatureSplitGeneration(self):
with self.test_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Feature ID |
+ # i0 | (0.2, 0.12) | 1 | 1 |
+ # i1 | (-0.5, 0.07) | 1 | 2 |
+ # i2 | (1.2, 0.2) | 1 | 1 |
+ # i3 | (4.0, 0.13) | 2 | 2 |
+ gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
+ hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
+ partition_ids = [1, 1, 1, 2]
+ indices = [[0, 0], [1, 0], [2, 0], [3, 0]]
+ values = array_ops.constant([1, 2, 1, 2], dtype=dtypes.int64)
+
+ gradient_shape = tensor_shape.scalar()
+ hessian_shape = tensor_shape.scalar()
+ class_id = -1
+
+ split_handler = categorical_split_handler.EqualitySplitHandler(
+ l1_regularization=0.1,
+ l2_regularization=1,
+ tree_complexity_regularization=0,
+ min_node_weight=0,
+ sparse_int_column=sparse_tensor.SparseTensor(indices, values, [4, 1]),
+ feature_column_group_id=0,
+ gradient_shape=gradient_shape,
+ hessian_shape=hessian_shape,
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ init_stamp_token=0,
+ weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ empty_gradients, empty_hessians = get_empty_tensors(
+ gradient_shape, hessian_shape)
+ example_weights = array_ops.ones([4, 1], dtypes.float32)
+
+ update_1 = split_handler.update_stats_sync(
+ 0,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+ update_2 = split_handler.update_stats_sync(
+ 0,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
+
+ with ops.control_dependencies([update_1, update_2]):
+ are_splits_ready, partitions, gains, splits = (
+ split_handler.make_splits(0, 1, class_id))
+ are_splits_ready, partitions, gains, splits = (
+ sess.run([are_splits_ready, partitions, gains, splits]))
+ self.assertTrue(are_splits_ready)
+ self.assertAllEqual([1, 2], partitions)
+
+ # For partition 1.
+ # -(0.2 + 1.2 - 0.1) / (0.12 + 0.2 + 1)
+ expected_left_weight1 = -0.9848484848484846
+ # (0.2 + 1.2 - 0.1) ** 2 / (0.12 + 0.2 + 1)
+ expected_left_gain1 = 1.2803030303030298
+
+ # -(-0.5 + 0.1) / (0.07 + 1)
+ expected_right_weight1 = 0.37383177570093457
+
+ # (-0.5 + 0.1) ** 2 / (0.07 + 1)
+ expected_right_gain1 = 0.14953271028037385
+
+ # (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1)
+ expected_bias_gain1 = 0.46043165467625885
+
+ split_info = split_info_pb2.ObliviousSplitInfo()
+ split_info.ParseFromString(splits[0])
+ # Children of partition 1.
+ left_child = split_info.children[0].vector
+ right_child = split_info.children[1].vector
+ split_node = split_info.split_node.oblivious_categorical_id_binary_split
+
+ self.assertEqual(0, split_node.feature_column)
+ self.assertEqual(1, split_node.feature_id)
+ self.assertAllClose([expected_left_weight1], left_child.value, 0.00001)
+ self.assertAllClose([expected_right_weight1], right_child.value, 0.00001)
+
+ # For partition2.
+ expected_left_weight2 = 0
+ expected_left_gain2 = 0
+ # -(4 - 0.1) / (0.13 + 1)
+ expected_right_weight2 = -3.4513274336283186
+ # (4 - 0.1) ** 2 / (0.13 + 1)
+ expected_right_gain2 = 13.460176991150442
+ # (4 - 0.1) ** 2 / (0.13 + 1)
+ expected_bias_gain2 = 13.460176991150442
+
+ # Children of partition 2.
+ left_child = split_info.children[2].vector
+ right_child = split_info.children[3].vector
+ self.assertAllClose([expected_left_weight2], left_child.value, 0.00001)
+ self.assertAllClose([expected_right_weight2], right_child.value, 0.00001)
+
+ self.assertAllClose(
+ expected_left_gain1 + expected_right_gain1 - expected_bias_gain1 +
+ expected_left_gain2 + expected_right_gain2 - expected_bias_gain2,
+ gains[0], 0.00001)
+
+ def testGenerateFeatureSplitCandidatesSumReduction(self):
+ with self.cached_session() as sess:
+ # The data looks like the following:
+ # Example | Gradients | Partition | Feature ID |
# i0 | (0.2, 0.12) | 0 | 1,2 |
# i1 | (-0.5, 0.07) | 0 | |
# i2 | (1.2, 0.2) | 0 | 2 |
@@ -293,7 +404,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(1, split_node.feature_id)
def testGenerateFeatureSplitCandidatesMulticlass(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Batch size is 4, 2 gradients per each instance.
gradients = array_ops.constant(
[[0.2, 0.1], [-0.5, 0.2], [1.2, 3.4], [4.0, -3.5]], shape=[4, 2])
@@ -371,7 +482,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(1, split_node.feature_id)
def testEmpty(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
partition_ids = [0, 0, 0, 1]
@@ -419,7 +530,7 @@ class EqualitySplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(len(splits), 0)
def testInactive(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
partition_ids = [0, 0, 0, 1]
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
index d9caebb645..74b0ea6989 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py
@@ -50,7 +50,7 @@ def get_empty_tensors(gradient_shape, hessian_shape):
class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
def testGenerateFeatureSplitCandidates(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Dense Quantile |
# i0 | (0.2, 0.12) | 0 | 1 |
@@ -183,17 +183,18 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.threshold, 0.00001)
def testObliviousFeatureSplitGeneration(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Dense Quantile |
- # i0 | (0.2, 0.12) | 0 | 2 |
- # i1 | (-0.5, 0.07) | 0 | 2 |
- # i2 | (1.2, 0.2) | 0 | 0 |
- # i3 | (4.0, 0.13) | 1 | 1 |
- dense_column = array_ops.constant([0.62, 0.62, 0.3, 0.52])
+ # i0 | (0.2, 0.12) | 1 | 3 |
+ # i1 | (-0.5, 0.07) | 1 | 3 |
+ # i2 | (1.2, 0.2) | 1 | 1 |
+ # i3 | (4.0, 0.13) | 2 | 2 |
+ dense_column = array_ops.placeholder(
+ dtypes.float32, shape=(4, 1), name="dense_column")
gradients = array_ops.constant([0.2, -0.5, 1.2, 4.0])
hessians = array_ops.constant([0.12, 0.07, 0.2, 0.13])
- partition_ids = array_ops.constant([0, 0, 0, 1], dtype=dtypes.int32)
+ partition_ids = array_ops.constant([1, 1, 1, 2], dtype=dtypes.int32)
class_id = -1
gradient_shape = tensor_shape.scalar()
@@ -230,31 +231,35 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
with ops.control_dependencies([update_1]):
are_splits_ready = split_handler.make_splits(
np.int64(0), np.int64(1), class_id)[0]
+ # Forcing the creation of four buckets.
+ are_splits_ready = sess.run(
+ [are_splits_ready],
+ feed_dict={dense_column: [[0.2], [0.62], [0.3], [0.52]]})[0]
- with ops.control_dependencies([are_splits_ready]):
- update_2 = split_handler.update_stats_sync(
- 1,
- partition_ids,
- gradients,
- hessians,
- empty_gradients,
- empty_hessians,
- example_weights,
- is_active=array_ops.constant([True, True]))
+ update_2 = split_handler.update_stats_sync(
+ 1,
+ partition_ids,
+ gradients,
+ hessians,
+ empty_gradients,
+ empty_hessians,
+ example_weights,
+ is_active=array_ops.constant([True, True]))
with ops.control_dependencies([update_2]):
are_splits_ready2, partitions, gains, splits = (
split_handler.make_splits(np.int64(1), np.int64(2), class_id))
- are_splits_ready, are_splits_ready2, partitions, gains, splits = (
- sess.run([
- are_splits_ready, are_splits_ready2, partitions, gains, splits
- ]))
+ # Only using the last three buckets.
+ are_splits_ready2, partitions, gains, splits = (
+ sess.run(
+ [are_splits_ready2, partitions, gains, splits],
+ feed_dict={dense_column: [[0.62], [0.62], [0.3], [0.52]]}))
# During the first iteration, inequality split handlers are not going to
# have any splits. Make sure that we return not_ready in that case.
self.assertFalse(are_splits_ready)
self.assertTrue(are_splits_ready2)
- self.assertAllEqual([0, 1], partitions)
+ self.assertAllEqual([1, 2], partitions)
oblivious_split_info = split_info_pb2.ObliviousSplitInfo()
oblivious_split_info.ParseFromString(splits[0])
@@ -263,54 +268,59 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.3, split_node.threshold, 0.00001)
self.assertEqual(0, split_node.feature_column)
- # Check the split on partition 0.
+ # Check the split on partition 1.
# -(1.2 - 0.1) / (0.2 + 1)
- expected_left_weight_0 = -0.9166666666666666
+ expected_left_weight_1 = -0.9166666666666666
- # expected_left_weight_0 * -(1.2 - 0.1)
- expected_left_gain_0 = 1.008333333333333
+ # expected_left_weight_1 * -(1.2 - 0.1)
+ expected_left_gain_1 = 1.008333333333333
# (-0.5 + 0.2 + 0.1) / (0.19 + 1)
- expected_right_weight_0 = 0.1680672
+ expected_right_weight_1 = 0.1680672
- # expected_right_weight_0 * -(-0.5 + 0.2 + 0.1))
- expected_right_gain_0 = 0.033613445378151252
+ # expected_right_weight_1 * -(-0.5 + 0.2 + 0.1))
+ expected_right_gain_1 = 0.033613445378151252
# (0.2 + -0.5 + 1.2 - 0.1) ** 2 / (0.12 + 0.07 + 0.2 + 1)
- expected_bias_gain_0 = 0.46043165467625896
+ expected_bias_gain_1 = 0.46043165467625896
left_child = oblivious_split_info.children[0].vector
right_child = oblivious_split_info.children[1].vector
- self.assertAllClose([expected_left_weight_0], left_child.value, 0.00001)
+ self.assertAllClose([expected_left_weight_1], left_child.value, 0.00001)
- self.assertAllClose([expected_right_weight_0], right_child.value, 0.00001)
+ self.assertAllClose([expected_right_weight_1], right_child.value, 0.00001)
- # Check the split on partition 1.
- expected_left_weight_1 = 0
- expected_left_gain_1 = 0
+ # Check the split on partition 2.
+ expected_left_weight_2 = 0
+ expected_left_gain_2 = 0
# -(4 - 0.1) / (0.13 + 1)
- expected_right_weight_1 = -3.4513274336283186
- # expected_right_weight_1 * -(4 - 0.1)
- expected_right_gain_1 = 13.460176991150442
+ expected_right_weight_2 = -3.4513274336283186
+ # expected_right_weight_2 * -(4 - 0.1)
+ expected_right_gain_2 = 13.460176991150442
# (-4 + 0.1) ** 2 / (0.13 + 1)
- expected_bias_gain_1 = 13.460176991150442
+ expected_bias_gain_2 = 13.460176991150442
left_child = oblivious_split_info.children[2].vector
right_child = oblivious_split_info.children[3].vector
- self.assertAllClose([expected_left_weight_1], left_child.value, 0.00001)
+ self.assertAllClose([expected_left_weight_2], left_child.value, 0.00001)
- self.assertAllClose([expected_right_weight_1], right_child.value, 0.00001)
+ self.assertAllClose([expected_right_weight_2], right_child.value, 0.00001)
# The layer gain is the sum of the gains of each partition
layer_gain = (
- expected_left_gain_0 + expected_right_gain_0 - expected_bias_gain_0) + (
- expected_left_gain_1 + expected_right_gain_1 - expected_bias_gain_1)
+ expected_left_gain_1 + expected_right_gain_1 - expected_bias_gain_1) + (
+ expected_left_gain_2 + expected_right_gain_2 - expected_bias_gain_2)
self.assertAllClose(layer_gain, gains[0], 0.00001)
+ # We have examples in both partitions, then we get both ids.
+ self.assertEqual(2, len(oblivious_split_info.children_parent_id))
+ self.assertEqual(1, oblivious_split_info.children_parent_id[0])
+ self.assertEqual(2, oblivious_split_info.children_parent_id[1])
+
def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Dense Quantile |
# i0 | (0.2, 0.12) | 0 | 1 |
@@ -448,7 +458,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.threshold, 0.00001)
def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dense_column = array_ops.constant([0.52, 0.52, 0.3, 0.52])
# Batch size is 4, 2 gradients per each instance.
gradients = array_ops.constant(
@@ -536,7 +546,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.3, split_node.threshold, 1e-6)
def testGenerateFeatureSplitCandidatesMulticlassDiagonalHessian(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
dense_column = array_ops.constant([0.52, 0.52, 0.3, 0.52])
# Batch size is 4, 2 gradients per each instance.
gradients = array_ops.constant(
@@ -623,7 +633,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.3, split_node.threshold, 1e-6)
def testGenerateFeatureSplitCandidatesInactive(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Dense Quantile |
# i0 | (0.2, 0.12) | 0 | 1 |
@@ -698,7 +708,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(len(splits), 0)
def testGenerateFeatureSplitCandidatesWithTreeComplexity(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Dense Quantile |
# i0 | (0.2, 0.12) | 0 | 1 |
@@ -832,7 +842,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.threshold, 0.00001)
def testGenerateFeatureSplitCandidatesWithMinNodeWeight(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Dense Quantile |
# i0 | (0.2, 0.12) | 0 | 1 |
@@ -941,7 +951,7 @@ class DenseSplitHandlerTest(test_util.TensorFlowTestCase):
class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
def testGenerateFeatureSplitCandidates(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Sparse Quantile |
# i0 | (0.2, 0.12) | 0 | 1 |
@@ -1064,7 +1074,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.split.threshold)
def testGenerateFeatureSplitCandidatesLossUsesSumReduction(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Sparse Quantile |
# i0 | (0.2, 0.12) | 0 | 1 |
@@ -1197,7 +1207,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.split.threshold)
def testGenerateFeatureSplitCandidatesMulticlassFullHessian(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Batch is 4, 2 classes
gradients = array_ops.constant([[0.2, 1.4], [-0.5, 0.1], [1.2, 3],
[4.0, -3]])
@@ -1292,7 +1302,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.split.threshold)
def testGenerateFeatureSplitCandidatesMulticlassDiagonalHessian(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Batch is 4, 2 classes
gradients = array_ops.constant([[0.2, 1.4], [-0.5, 0.1], [1.2, 3],
[4.0, -3]])
@@ -1387,7 +1397,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertAllClose(0.52, split_node.split.threshold)
def testGenerateFeatureSplitCandidatesInactive(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# The data looks like the following:
# Example | Gradients | Partition | Sparse Quantile |
# i0 | (0.2, 0.12) | 0 | 1 |
@@ -1465,7 +1475,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(len(splits), 0)
def testEmpty(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
indices = array_ops.constant([], dtype=dtypes.int64, shape=[0, 2])
# No values in this feature column in this mini-batch.
values = array_ops.constant([], dtype=dtypes.float32)
@@ -1535,7 +1545,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
def testEmptyBuckets(self):
"""Test that reproduces the case when quantile buckets were empty."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sparse_column = array_ops.sparse_placeholder(dtypes.float32)
# We have two batches - at first, a sparse feature is empty.
@@ -1628,7 +1638,7 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase):
self.assertEqual(len(splits), 0)
def testDegenerativeCase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# One data example only, one leaf and thus one quantile bucket.The same
# situation is when all examples have the same values. This case was
# causing before a failure.
diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h
index 69bb8fd4ad..8d71a6cdbc 100644
--- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h
+++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h
@@ -36,12 +36,6 @@ class WeightedQuantilesSummary {
struct SummaryEntry {
SummaryEntry(const ValueType& v, const WeightType& w, const WeightType& min,
const WeightType& max) {
- // Explicitly initialize all of memory (including padding from memory
- // alignment) to allow the struct to be msan-resistant "plain old data".
- //
- // POD = http://en.cppreference.com/w/cpp/concept/PODType
- memset(this, 0, sizeof(*this));
-
value = v;
weight = w;
min_rank = min;
@@ -49,8 +43,6 @@ class WeightedQuantilesSummary {
}
SummaryEntry() {
- memset(this, 0, sizeof(*this));
-
value = ValueType();
weight = 0;
min_rank = 0;
diff --git a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc
index 3ed6c5c04d..64921faf81 100644
--- a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc
+++ b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc
@@ -111,6 +111,18 @@ int DecisionTree::Traverse(const DecisionTreeConfig& config,
node_id++;
break;
}
+ case TreeNode::kObliviousCategoricalIdBinarySplit: {
+ const auto& split =
+ current_node.oblivious_categorical_id_binary_split();
+ oblivious_leaf_idx <<= 1;
+ const auto& features =
+ example.sparse_int_features[split.feature_column()];
+ if (features.find(split.feature_id()) == features.end()) {
+ oblivious_leaf_idx++;
+ }
+ node_id++;
+ break;
+ }
case TreeNode::NODE_NOT_SET: {
LOG(QFATAL) << "Invalid node in tree: " << current_node.DebugString();
break;
@@ -181,6 +193,11 @@ void DecisionTree::LinkChildren(const std::vector<int32>& children,
<< "Not implemented for the ObliviousDenseFloatBinarySplit case.";
break;
}
+ case TreeNode::kObliviousCategoricalIdBinarySplit: {
+ LOG(QFATAL)
+ << "Not implemented for the ObliviousCategoricalIdBinarySplit case.";
+ break;
+ }
case TreeNode::NODE_NOT_SET: {
LOG(QFATAL) << "A non-set node cannot have children.";
break;
@@ -220,6 +237,11 @@ std::vector<int32> DecisionTree::GetChildren(const TreeNode& node) {
<< "Not implemented for the ObliviousDenseFloatBinarySplit case.";
return {};
}
+ case TreeNode::kObliviousCategoricalIdBinarySplit: {
+ LOG(QFATAL)
+ << "Not implemented for the ObliviousCategoricalIdBinarySplit case.";
+ break;
+ }
case TreeNode::NODE_NOT_SET: {
return {};
}
diff --git a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
index 9b68a9de96..f1e12a028a 100644
--- a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
+++ b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
@@ -179,6 +179,7 @@ REGISTER_OP("BuildCategoricalEqualitySplits")
.Input("tree_complexity_regularization: float")
.Input("min_node_weight: float")
.Input("multiclass_strategy: int32")
+ .Input("weak_learner_type: int32")
.Output("output_partition_ids: int32")
.Output("gains: float32")
.Output("split_infos: string")
@@ -224,6 +225,8 @@ min_node_weight: A scalar, minimum sum of example hessian needed in a child.
be considered.
multiclass_strategy: A scalar, specifying the multiclass handling strategy.
See LearnerConfig.MultiClassStrategy for valid values.
+weak_learner_type: A scalar, specifying the weak learner type to use.
+ See LearnerConfig.WeakLearnerType for valid values.
output_partition_ids: A rank 1 tensor, the partition IDs that we created splits
for.
gains: A rank 1 tensor, for the computed gain for the created splits.
diff --git a/tensorflow/contrib/boosted_trees/proto/split_info.proto b/tensorflow/contrib/boosted_trees/proto/split_info.proto
index 65448996bf..784977af39 100644
--- a/tensorflow/contrib/boosted_trees/proto/split_info.proto
+++ b/tensorflow/contrib/boosted_trees/proto/split_info.proto
@@ -21,4 +21,8 @@ message SplitInfo {
message ObliviousSplitInfo {
tensorflow.boosted_trees.trees.TreeNode split_node = 1;
repeated tensorflow.boosted_trees.trees.Leaf children = 2;
+ // For each child, children_parent_id stores the node_id of its parent when it
+ // was a leaf. For the idx-th child it corresponds the idx/2-th
+ // children_parent_id.
+ repeated int32 children_parent_id = 3;
}
diff --git a/tensorflow/contrib/boosted_trees/proto/tree_config.proto b/tensorflow/contrib/boosted_trees/proto/tree_config.proto
index 500909bf2a..520b4f8b11 100644
--- a/tensorflow/contrib/boosted_trees/proto/tree_config.proto
+++ b/tensorflow/contrib/boosted_trees/proto/tree_config.proto
@@ -16,6 +16,7 @@ message TreeNode {
CategoricalIdSetMembershipBinarySplit
categorical_id_set_membership_binary_split = 6;
ObliviousDenseFloatBinarySplit oblivious_dense_float_binary_split = 7;
+ ObliviousCategoricalIdBinarySplit oblivious_categorical_id_binary_split = 8;
}
TreeNodeMetadata node_metadata = 777;
}
@@ -116,6 +117,17 @@ message ObliviousDenseFloatBinarySplit {
// leaves.
}
+// Split rule for categorical features with a single feature Id in the oblivious
+// case.
+message ObliviousCategoricalIdBinarySplit {
+ // Categorical feature column and Id describing the rule feature == Id.
+ int32 feature_column = 1;
+ int64 feature_id = 2;
+ // We don't store children ids, because either the next node represents the
+ // whole next layer of the tree or starting with the next node we only have
+ // leaves.
+}
+
// DecisionTreeConfig describes a list of connected nodes.
// Node 0 must be the root and can carry any payload including a leaf
// in the case of representing the bias.
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py
index 4278a30ba9..46dfbdefeb 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py
@@ -331,7 +331,7 @@ class PredictionOpsTest(test_util.TensorFlowTestCase):
self.assertAllEqual([[], []], dropout_info.eval())
def testObliviousEnsemble(self):
- with self.test_session():
+ with self.cached_session():
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
# Bias tree.
tree1 = tree_ensemble_config.trees.add()
@@ -1399,7 +1399,7 @@ class PartitionExamplesOpsTest(test_util.TensorFlowTestCase):
self.assertAllEqual([0, 0], result.eval())
def testObliviousTreeNonFinalized(self):
- with self.test_session():
+ with self.cached_session():
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
# Depth 3 tree.
tree1 = tree_ensemble_config.trees.add()
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
index 5e62bad672..74917f7cde 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/split_handler_ops_test.py
@@ -541,7 +541,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
feature_column_group_id=0,
bias_feature_id=-1,
class_id=-1,
- multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS))
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
partitions, gains, splits = sess.run([partitions, gains, splits])
self.assertAllEqual([0, 1], partitions)
@@ -637,7 +638,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
feature_column_group_id=0,
bias_feature_id=-1,
class_id=-1,
- multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN))
+ multiclass_strategy=learner_pb2.LearnerConfig.FULL_HESSIAN,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
partitions, gains, splits = sess.run([partitions, gains, splits])
self.assertAllEqual([0, 1], partitions)
@@ -674,7 +676,8 @@ class SplitHandlerOpsTest(test_util.TensorFlowTestCase):
feature_column_group_id=0,
bias_feature_id=-1,
class_id=-1,
- multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS))
+ multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS,
+ weak_learner_type=learner_pb2.LearnerConfig.NORMAL_DECISION_TREE))
partitions, gains, splits = (sess.run([partitions, gains, splits]))
self.assertEqual(0, len(partitions))
self.assertEqual(0, len(gains))
diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py
index 278dc1f756..86fd5770a0 100644
--- a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py
+++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py
@@ -91,7 +91,8 @@ def _gen_dense_split_info(fc, threshold, left_weight, right_weight):
return split.SerializeToString()
-def _gen_dense_oblivious_split_info(fc, threshold, leave_weights):
+def _gen_dense_oblivious_split_info(fc, threshold, leave_weights,
+ children_parent_id):
split_str = """
split_node {
oblivious_dense_float_binary_split {
@@ -107,6 +108,9 @@ def _gen_dense_oblivious_split_info(fc, threshold, leave_weights):
}
}""" % (
weight)
+ for x in children_parent_id:
+ split_str += """
+ children_parent_id: %d""" % (x)
split = split_info_pb2.ObliviousSplitInfo()
text_format.Merge(split_str, split)
return split.SerializeToString()
@@ -407,7 +411,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testGrowEmptyEnsembleObliviousCase(self):
"""Test growing an empty ensemble in the oblivious case."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create empty ensemble.
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
tree_ensemble_handle = model_ops.tree_ensemble_variable(
@@ -432,14 +436,18 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
handler1_partitions = np.array([0], dtype=np.int32)
handler1_gains = np.array([7.62], dtype=np.float32)
handler1_split = [
- _gen_dense_oblivious_split_info(0, 0.52, [-4.375, 7.143])
+ _gen_dense_oblivious_split_info(0, 0.52, [-4.375, 7.143], [0])
]
handler2_partitions = np.array([0], dtype=np.int32)
handler2_gains = np.array([0.63], dtype=np.float32)
- handler2_split = [_gen_dense_oblivious_split_info(0, 0.23, [-0.6, 0.24])]
+ handler2_split = [
+ _gen_dense_oblivious_split_info(0, 0.23, [-0.6, 0.24], [0])
+ ]
handler3_partitions = np.array([0], dtype=np.int32)
handler3_gains = np.array([7.62], dtype=np.float32)
- handler3_split = [_gen_dense_oblivious_split_info(0, 7, [-4.375, 7.143])]
+ handler3_split = [
+ _gen_dense_oblivious_split_info(0, 7, [-4.375, 7.143], [0])
+ ]
# Grow tree ensemble.
grow_op = training_ops.grow_tree_ensemble(
@@ -1612,7 +1620,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
def testGrowEnsembleTreeLayerByLayerObliviousCase(self):
"""Test growing an existing ensemble with the last tree not finalized."""
- with self.test_session() as session:
+ with self.cached_session() as session:
# Create existing ensemble with one root split
tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
text_format.Merge(
@@ -1675,17 +1683,20 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
handler1_partitions = np.array([0], dtype=np.int32)
handler1_gains = np.array([1.4], dtype=np.float32)
handler1_split = [
- _gen_dense_oblivious_split_info(0, 0.21, [-6.0, 1.65, 1.0, -0.5])
+ _gen_dense_oblivious_split_info(0, 0.21, [-6.0, 1.65, 1.0, -0.5],
+ [1, 2])
]
handler2_partitions = np.array([0], dtype=np.int32)
handler2_gains = np.array([2.7], dtype=np.float32)
handler2_split = [
- _gen_dense_oblivious_split_info(0, 0.23, [-0.6, 0.24, 0.3, 0.4]),
+ _gen_dense_oblivious_split_info(0, 0.23, [-0.6, 0.24, 0.3, 0.4],
+ [1, 2])
]
handler3_partitions = np.array([0], dtype=np.int32)
handler3_gains = np.array([1.7], dtype=np.float32)
handler3_split = [
- _gen_dense_oblivious_split_info(0, 3, [-0.75, 1.93, 0.2, -0.1])
+ _gen_dense_oblivious_split_info(0, 3, [-0.75, 1.93, 0.2, -0.1],
+ [1, 2])
]
# Grow tree ensemble layer by layer.
@@ -1797,6 +1808,528 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase):
self.assertEqual(stats.attempted_layers, 2)
self.assertProtoEquals(expected_result, tree_ensemble_config)
+ def testGrowEnsembleWithEmptyNodesMiddleCase(self):
+ """Test case: The middle existing leaves don't have examples."""
+ with self.cached_session() as session:
+ tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
+ text_format.Merge(
+ """
+ trees {
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 4
+ threshold: 7
+ }
+ node_metadata {
+ gain: 7.62
+ original_oblivious_leaves {
+ }
+ }
+ }
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 1
+ threshold: 0.23
+ }
+ node_metadata {
+ gain: 2.7
+ original_oblivious_leaves {
+ vector {
+ value: 7.143
+ }
+ }
+ original_oblivious_leaves {
+ vector {
+ value: -4.375
+ }
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 6.543
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 7.5
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -4.075
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -3.975
+ }
+ }
+ }
+ }
+ tree_weights: 0.1
+ tree_metadata {
+ num_tree_weight_updates: 1
+ num_layers_grown: 2
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 2
+ }
+ """, tree_ensemble_config)
+ tree_ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0,
+ tree_ensemble_config=tree_ensemble_config.SerializeToString(),
+ name="tree_ensemble")
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # Prepare learner config.
+ learner_config = _gen_learner_config(
+ num_classes=2,
+ l1_reg=0,
+ l2_reg=0,
+ tree_complexity=0,
+ max_depth=6,
+ min_node_weight=0,
+ pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
+ growing_mode=learner_pb2.LearnerConfig.LAYER_BY_LAYER)
+
+ # Prepare handler inputs.
+ handler1_partitions = np.array([0], dtype=np.int32)
+ handler1_gains = np.array([1.8], dtype=np.float32)
+ handler1_split = [
+ _gen_dense_oblivious_split_info(0, 0.9, [1.0, 2.0, 3.0, 4.0], [2, 5])
+ ]
+ # The tree currently has depth 2, so the ids for the four leaves are in
+ # the range [2, 6). In this test case we are assuming that our examples
+ # only fall in leaves 2 and 5.
+
+ # Grow tree ensemble layer by layer.
+ grow_op = training_ops.grow_tree_ensemble(
+ tree_ensemble_handle,
+ stamp_token=0,
+ next_stamp_token=1,
+ learning_rate=0.1,
+ partition_ids=[handler1_partitions],
+ gains=[handler1_gains],
+ splits=[handler1_split],
+ learner_config=learner_config.SerializeToString(),
+ dropout_seed=123,
+ center_bias=True,
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
+ session.run(grow_op)
+
+ new_stamp, serialized = session.run(
+ model_ops.tree_ensemble_serialize(tree_ensemble_handle))
+ stats = session.run(
+ training_ops.tree_ensemble_stats(tree_ensemble_handle, stamp_token=1))
+ tree_ensemble_config.ParseFromString(serialized)
+ expected_result = """
+ trees {
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 4
+ threshold: 7
+ }
+ node_metadata {
+ gain: 7.62
+ original_oblivious_leaves {
+ }
+ }
+ }
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 1
+ threshold: 0.23
+ }
+ node_metadata {
+ gain: 2.7
+ original_oblivious_leaves {
+ vector {
+ value: 7.143
+ }
+ }
+ original_oblivious_leaves {
+ vector {
+ value: -4.375
+ }
+ }
+ }
+ }
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 0
+ threshold: 0.9
+ }
+ node_metadata {
+ gain: 1.8
+ original_oblivious_leaves {
+ vector {
+ value: 6.543
+ }
+ }
+ original_oblivious_leaves {
+ vector {
+ value: 7.5
+ }
+ }
+ original_oblivious_leaves {
+ vector {
+ value: -4.075
+ }
+ }
+ original_oblivious_leaves {
+ vector {
+ value: -3.975
+ }
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 7.543
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 8.543
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 7.5
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 7.5
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -4.075
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -4.075
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -0.975
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 0.025
+ }
+ }
+ }
+ }
+ tree_weights: 0.1
+ tree_metadata {
+ num_tree_weight_updates: 1
+ num_layers_grown: 3
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 3
+ }
+ """
+ self.assertEqual(new_stamp, 1)
+ self.assertEqual(stats.num_trees, 0)
+ self.assertEqual(stats.num_layers, 3)
+ self.assertEqual(stats.active_tree, 1)
+ self.assertEqual(stats.active_layer, 3)
+ self.assertEqual(stats.attempted_trees, 1)
+ self.assertEqual(stats.attempted_layers, 3)
+ self.assertProtoEquals(expected_result, tree_ensemble_config)
+
+ def testGrowEnsembleWithEmptyNodesBorderCase(self):
+ """Test case: The first and last existing leaves don't have examples."""
+ with self.cached_session() as session:
+ tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig()
+ text_format.Merge(
+ """
+ trees {
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 4
+ threshold: 7
+ }
+ node_metadata {
+ gain: 7.62
+ original_oblivious_leaves {
+ }
+ }
+ }
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 1
+ threshold: 0.23
+ }
+ node_metadata {
+ gain: 2.7
+ original_oblivious_leaves {
+ vector {
+ value: 7.143
+ }
+ }
+ original_oblivious_leaves {
+ vector {
+ value: -4.375
+ }
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 6.543
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 7.5
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -4.075
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -3.975
+ }
+ }
+ }
+ }
+ tree_weights: 0.1
+ tree_metadata {
+ num_tree_weight_updates: 1
+ num_layers_grown: 2
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 2
+ }
+ """, tree_ensemble_config)
+ tree_ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0,
+ tree_ensemble_config=tree_ensemble_config.SerializeToString(),
+ name="tree_ensemble")
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # Prepare learner config.
+ learner_config = _gen_learner_config(
+ num_classes=2,
+ l1_reg=0,
+ l2_reg=0,
+ tree_complexity=0,
+ max_depth=6,
+ min_node_weight=0,
+ pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE,
+ growing_mode=learner_pb2.LearnerConfig.LAYER_BY_LAYER)
+
+ # Prepare handler inputs.
+ handler1_partitions = np.array([0], dtype=np.int32)
+ handler1_gains = np.array([1.8], dtype=np.float32)
+ handler1_split = [
+ _gen_dense_oblivious_split_info(0, 0.9, [1.0, 2.0, 3.0, 4.0], [3, 4])
+ ]
+ # The tree currently has depth 2, so the ids for the four leaves are in
+ # the range [2, 6). In this test case we are assuming that our examples
+ # only fall in leaves 3 and 4.
+
+ # Grow tree ensemble layer by layer.
+ grow_op = training_ops.grow_tree_ensemble(
+ tree_ensemble_handle,
+ stamp_token=0,
+ next_stamp_token=1,
+ learning_rate=0.1,
+ partition_ids=[handler1_partitions],
+ gains=[handler1_gains],
+ splits=[handler1_split],
+ learner_config=learner_config.SerializeToString(),
+ dropout_seed=123,
+ center_bias=True,
+ max_tree_depth=learner_config.constraints.max_tree_depth,
+ weak_learner_type=learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
+ session.run(grow_op)
+
+ new_stamp, serialized = session.run(
+ model_ops.tree_ensemble_serialize(tree_ensemble_handle))
+ stats = session.run(
+ training_ops.tree_ensemble_stats(tree_ensemble_handle, stamp_token=1))
+ tree_ensemble_config.ParseFromString(serialized)
+ expected_result = """
+ trees {
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 4
+ threshold: 7
+ }
+ node_metadata {
+ gain: 7.62
+ original_oblivious_leaves {
+ }
+ }
+ }
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 1
+ threshold: 0.23
+ }
+ node_metadata {
+ gain: 2.7
+ original_oblivious_leaves {
+ vector {
+ value: 7.143
+ }
+ }
+ original_oblivious_leaves {
+ vector {
+ value: -4.375
+ }
+ }
+ }
+ }
+ nodes {
+ oblivious_dense_float_binary_split {
+ feature_column: 0
+ threshold: 0.9
+ }
+ node_metadata {
+ gain: 1.8
+ original_oblivious_leaves {
+ vector {
+ value: 6.543
+ }
+ }
+ original_oblivious_leaves {
+ vector {
+ value: 7.5
+ }
+ }
+ original_oblivious_leaves {
+ vector {
+ value: -4.075
+ }
+ }
+ original_oblivious_leaves {
+ vector {
+ value: -3.975
+ }
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 6.543
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 6.543
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 8.5
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 9.5
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -1.075
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -0.075
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -3.975
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -3.975
+ }
+ }
+ }
+ }
+ tree_weights: 0.1
+ tree_metadata {
+ num_tree_weight_updates: 1
+ num_layers_grown: 3
+ }
+ growing_metadata {
+ num_trees_attempted: 1
+ num_layers_attempted: 3
+ }
+ """
+ self.assertEqual(new_stamp, 1)
+ self.assertEqual(stats.num_trees, 0)
+ self.assertEqual(stats.num_layers, 3)
+ self.assertEqual(stats.active_tree, 1)
+ self.assertEqual(stats.active_layer, 3)
+ self.assertEqual(stats.attempted_trees, 1)
+ self.assertEqual(stats.attempted_layers, 3)
+ self.assertProtoEquals(expected_result, tree_ensemble_config)
+
def testGrowExistingEnsembleTreeFinalizedWithDropout(self):
"""Test growing an existing ensemble with the last tree finalized."""
with self.cached_session() as session:
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index 97743ba255..c7eb2493a8 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -304,7 +304,8 @@ class GradientBoostedDecisionTreeModel(object):
feature_columns=None,
use_core_columns=False,
output_leaf_index=False,
- output_leaf_index_modes=None):
+ output_leaf_index_modes=None,
+ num_quantiles=100):
"""Construct a new GradientBoostedDecisionTreeModel function.
Args:
@@ -327,6 +328,7 @@ class GradientBoostedDecisionTreeModel(object):
output_leaf_index_modes: A list of modes from (TRAIN, EVAL, INFER) which
dictates when leaf indices will be outputted. By default, leaf indices
are only outputted in INFER mode.
+ num_quantiles: Number of quantiles to build for numeric feature values.
Raises:
ValueError: if inputs are not valid.
@@ -399,6 +401,7 @@ class GradientBoostedDecisionTreeModel(object):
self._learner_config = learner_config
self._feature_columns = feature_columns
self._learner_config_serialized = learner_config.SerializeToString()
+ self._num_quantiles = num_quantiles
self._max_tree_depth = variables.Variable(
initial_value=self._learner_config.constraints.max_tree_depth)
self._attempted_trees = variables.Variable(
@@ -689,8 +692,8 @@ class GradientBoostedDecisionTreeModel(object):
loss_uses_sum_reduction = constant_op.constant(loss_uses_sum_reduction)
weak_learner_type = constant_op.constant(
self._learner_config.weak_learner_type)
- epsilon = 0.01
- num_quantiles = 100
+ num_quantiles = self._num_quantiles
+ epsilon = 1.0 / num_quantiles
strategy_tensor = constant_op.constant(strategy)
with ops.device(self._get_replica_device_setter(worker_device)):
# Create handlers for dense float columns
@@ -762,7 +765,8 @@ class GradientBoostedDecisionTreeModel(object):
hessian_shape=self._hessian_shape,
multiclass_strategy=strategy_tensor,
init_stamp_token=init_stamp_token,
- loss_uses_sum_reduction=loss_uses_sum_reduction))
+ loss_uses_sum_reduction=loss_uses_sum_reduction,
+ weak_learner_type=weak_learner_type))
fc_name_idx += 1
# Create ensemble stats variables.
@@ -1063,6 +1067,12 @@ class GradientBoostedDecisionTreeModel(object):
# Grow the ensemble given the current candidates.
sizes = array_ops.unstack(split_sizes)
partition_ids_list = list(array_ops.split(partition_ids, sizes, axis=0))
+ # When using the oblivious decision tree as weak learner, it produces
+ # one gain and one split per handler and not number of partitions.
+ if self._learner_config.weak_learner_type == (
+ learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE):
+ sizes = len(training_state.handlers)
+
gains_list = list(array_ops.split(gains, sizes, axis=0))
split_info_list = list(array_ops.split(split_infos, sizes, axis=0))
return training_ops.grow_tree_ensemble(
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
index f7867d882d..73e41bc457 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from google.protobuf import text_format
from tensorflow.contrib import layers
+from tensorflow.contrib import learn
from tensorflow.contrib.boosted_trees.proto import learner_pb2
from tensorflow.contrib.boosted_trees.proto import tree_config_pb2
from tensorflow.contrib.boosted_trees.python.ops import model_ops
@@ -314,6 +315,162 @@ class GbdtTest(test_util.TensorFlowTestCase):
}"""
self.assertProtoEquals(expected_tree, output.trees[0])
+ def testObliviousDecisionTreeAsWeakLearner(self):
+ with self.test_session():
+ ensemble_handle = model_ops.tree_ensemble_variable(
+ stamp_token=0, tree_ensemble_config="", name="tree_ensemble")
+ learner_config = learner_pb2.LearnerConfig()
+ learner_config.num_classes = 2
+ learner_config.learning_rate_tuner.fixed.learning_rate = 1
+ learner_config.regularization.l1 = 0
+ learner_config.regularization.l2 = 0
+ learner_config.constraints.max_tree_depth = 2
+ learner_config.constraints.min_node_weight = 0
+ learner_config.weak_learner_type = (
+ learner_pb2.LearnerConfig.OBLIVIOUS_DECISION_TREE)
+ learner_config.pruning_mode = learner_pb2.LearnerConfig.PRE_PRUNE
+ learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER
+ features = {}
+ features["dense_float"] = array_ops.constant([[-2], [-1], [1], [2]],
+ dtypes.float32)
+
+ gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel(
+ is_chief=True,
+ num_ps_replicas=0,
+ center_bias=False,
+ ensemble_handle=ensemble_handle,
+ examples_per_layer=1,
+ learner_config=learner_config,
+ logits_dimension=1,
+ features=features)
+
+ predictions_dict = gbdt_model.predict(learn.ModeKeys.TRAIN)
+ predictions = predictions_dict["predictions"]
+ labels = array_ops.constant([[-2], [-1], [1], [2]], dtypes.float32)
+ weights = array_ops.ones([4, 1], dtypes.float32)
+
+ train_op = gbdt_model.train(
+ loss=math_ops.reduce_mean(
+ _squared_loss(labels, weights, predictions)),
+ predictions_dict=predictions_dict,
+ labels=labels)
+ variables.global_variables_initializer().run()
+ resources.initialize_resources(resources.shared_resources()).run()
+
+ # On first run, expect no splits to be chosen because the quantile
+ # buckets will not be ready.
+ train_op.run()
+ stamp_token, serialized = model_ops.tree_ensemble_serialize(
+ ensemble_handle)
+ output = tree_config_pb2.DecisionTreeEnsembleConfig()
+ output.ParseFromString(serialized.eval())
+ self.assertEquals(len(output.trees), 0)
+ self.assertEquals(len(output.tree_weights), 0)
+ self.assertEquals(stamp_token.eval(), 1)
+
+ # Second run.
+ train_op.run()
+ stamp_token, serialized = model_ops.tree_ensemble_serialize(
+ ensemble_handle)
+ output = tree_config_pb2.DecisionTreeEnsembleConfig()
+ output.ParseFromString(serialized.eval())
+ self.assertEquals(len(output.trees), 1)
+ self.assertAllClose(output.tree_weights, [1])
+ self.assertEquals(stamp_token.eval(), 2)
+ expected_tree = """
+ nodes {
+ oblivious_dense_float_binary_split {
+ threshold: -1.0
+ }
+ node_metadata {
+ gain: 4.5
+ original_oblivious_leaves {
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -1.5
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 1.5
+ }
+ }
+ }"""
+ self.assertProtoEquals(expected_tree, output.trees[0])
+ # Third run.
+ train_op.run()
+ stamp_token, serialized = model_ops.tree_ensemble_serialize(
+ ensemble_handle)
+ output = tree_config_pb2.DecisionTreeEnsembleConfig()
+ output.ParseFromString(serialized.eval())
+ self.assertEquals(len(output.trees), 1)
+ self.assertAllClose(output.tree_weights, [1])
+ self.assertEquals(stamp_token.eval(), 3)
+ expected_tree = """
+ nodes {
+ oblivious_dense_float_binary_split {
+ threshold: -1.0
+ }
+ node_metadata {
+ gain: 4.5
+ original_oblivious_leaves {
+ }
+ }
+ }
+ nodes {
+ oblivious_dense_float_binary_split {
+ threshold: -2.0
+ }
+ node_metadata {
+ gain: 0.25
+ original_oblivious_leaves {
+ vector {
+ value: -1.5
+ }
+ }
+ original_oblivious_leaves {
+ vector {
+ value: 1.5
+ }
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -2.0
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: -1.0
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 1.5
+ }
+ }
+ }
+ nodes {
+ leaf {
+ vector {
+ value: 1.5
+ }
+ }
+ }"""
+ self.assertProtoEquals(expected_tree, output.trees[0])
+
def testTrainFnChiefSparseAndDense(self):
"""Tests the train function with sparse and dense features."""
with self.test_session() as sess:
diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc
index 58fadffce3..e57a66b99f 100644
--- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc
+++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc
@@ -33,7 +33,7 @@ bool IsPartitionEmpty(const BigQueryTablePartition& partition) {
Status ParseJson(StringPiece json, Json::Value* result) {
Json::Reader reader;
- if (!reader.parse(json.ToString(), *result)) {
+ if (!reader.parse(string(json), *result)) {
return errors::Internal("Couldn't parse JSON response from BigQuery.");
}
return Status::OK();
diff --git a/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py
index 493b3c6f1b..11e177cd0c 100644
--- a/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py
+++ b/tensorflow/contrib/cloud/python/ops/bigquery_reader_ops_test.py
@@ -197,7 +197,7 @@ class BigQueryReaderOpsTest(test.TestCase):
def _ReadAndCheckRowsUsingFeatures(self, num_rows):
self.server.handler.num_rows = num_rows
- with self.test_session() as sess:
+ with self.cached_session() as sess:
feature_configs = {
"int64_col":
parsing_ops.FixedLenFeature(
@@ -254,7 +254,7 @@ class BigQueryReaderOpsTest(test.TestCase):
num_rows = 10
self.server.handler.num_rows = num_rows
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = cloud.BigQueryReader(
project_id=_PROJECT,
dataset_id=_DATASET,
diff --git a/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py
index 9b6c056d6c..4f2ecbcb17 100644
--- a/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py
+++ b/tensorflow/contrib/cloud/python/ops/gcs_config_ops_test.py
@@ -26,7 +26,7 @@ class GcsConfigOpsTest(test.TestCase):
def testSetBlockCache(self):
cfg = gcs_config_ops.BlockCacheParams(max_bytes=1024*1024*1024)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
gcs_config_ops.configure_gcs(sess, block_cache=cfg)
def testConfigureGcsHook(self):
@@ -36,7 +36,7 @@ class GcsConfigOpsTest(test.TestCase):
'type': 'authorized_user'}
hook = gcs_config_ops.ConfigureGcsHook(credentials=creds)
hook.begin()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run = lambda _, feed_dict=None, options=None, run_metadata=None: None
hook.after_create_session(sess, None)
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
index 1ab150d74a..1056894f18 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
@@ -229,6 +229,10 @@ class TPUClusterResolver(ClusterResolver):
def get_master(self):
return self.master()
+ def get_job_name(self):
+ if self._shouldResolve():
+ return self._job_name
+
def cluster_spec(self):
"""Returns a ClusterSpec object based on the latest TPU information.
diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt
index f6c928e2be..ebcabb4223 100644
--- a/tensorflow/contrib/cmake/CMakeLists.txt
+++ b/tensorflow/contrib/cmake/CMakeLists.txt
@@ -364,7 +364,7 @@ if (tensorflow_ENABLE_MKL_SUPPORT)
list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkldnn_copy_shared_to_destination)
include_directories(${mkldnn_INCLUDE_DIRS})
else (tensorflow_ENABLE_MKLDNN_SUPPORT)
- add_definitions(-DINTEL_MKL_ML)
+ add_definitions(-DINTEL_MKL_ML_ONLY)
endif()
endif (tensorflow_ENABLE_MKL_SUPPORT)
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index 07934ef324..fb871acae9 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -247,10 +247,6 @@ tensorflow/contrib/kernel_methods/python
tensorflow/contrib/kernel_methods/python/mappers
tensorflow/contrib/kinesis/python
tensorflow/contrib/kinesis/python/ops
-tensorflow/contrib/kfac
-tensorflow/contrib/kfac/examples
-tensorflow/contrib/kfac/python
-tensorflow/contrib/kfac/python/ops
tensorflow/contrib/labeled_tensor
tensorflow/contrib/labeled_tensor/python
tensorflow/contrib/labeled_tensor/python/ops
diff --git a/tensorflow/contrib/coder/BUILD b/tensorflow/contrib/coder/BUILD
index 855c824ead..4bfd753bb1 100644
--- a/tensorflow/contrib/coder/BUILD
+++ b/tensorflow/contrib/coder/BUILD
@@ -3,6 +3,7 @@
package(default_visibility = [
"//learning/brain:__subpackages__",
+ "//research/vision/piedpiper:__subpackages__",
"//tensorflow:__subpackages__",
])
diff --git a/tensorflow/contrib/compiler/BUILD b/tensorflow/contrib/compiler/BUILD
index bcee0b04c8..d7583be6d8 100644
--- a/tensorflow/contrib/compiler/BUILD
+++ b/tensorflow/contrib/compiler/BUILD
@@ -8,6 +8,7 @@ package_group(
packages = ["//tensorflow/..."],
)
+load("//tensorflow:tensorflow.bzl", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
py_library(
@@ -46,3 +47,36 @@ cuda_py_test(
],
xla_enabled = True,
)
+
+py_library(
+ name = "xla",
+ srcs = ["xla.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:util",
+ "//tensorflow/python/estimator:model_fn",
+ ],
+)
+
+tf_py_test(
+ name = "xla_test",
+ srcs = ["xla_test.py"],
+ additional_deps = [
+ ":xla",
+ "@six_archive//:six",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:control_flow_util",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ ],
+ tags = ["no_pip"],
+)
diff --git a/tensorflow/contrib/compiler/xla.py b/tensorflow/contrib/compiler/xla.py
new file mode 100644
index 0000000000..60f5af1662
--- /dev/null
+++ b/tensorflow/contrib/compiler/xla.py
@@ -0,0 +1,208 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""xla provides experimental xla support API."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import compat
+
+_XLA_COMPILE_ATTR = '_xla_compile_id'
+_MAX_WARNING_LINES = 5
+
+# Operations that indicate some error in the users graph. For example, XLA
+# computation should not have any Placeholder op.
+_BLACKLISTED_OPS = set([
+ 'Placeholder',
+])
+
+# XLA doesn't currently support reading of intermediate tensors, thus some ops
+# are not supported.
+_UNSUPPORTED_OPS = set([
+ 'AudioSummary',
+ 'AudioSummaryV2',
+ 'HistogramSummary',
+ 'ImageSummary',
+ 'MergeSummary',
+ 'Print',
+ 'ScalarSummary',
+ 'TensorSummary',
+ 'TensorSummaryV2',
+])
+
+
+class XLACompileContext(control_flow_ops.XLAControlFlowContext):
+ """A `ControlFlowContext` for nodes inside an XLA computation cluster.
+
+ THIS IS ONLY FOR TENSORFLOW INTERNAL IMPLEMENTATION, DO NO USE DIRECTLY.
+
+ The primary role of `XLACompileContext` is to mark operators inside a
+ xla.compile() computation with attribute "_xla_compile_id=XYZ", where XYZ is
+ a unique name.
+
+ `ControlFlowContext` is used to perform the annotation since it integrates
+ with Tensorflow constructs like ResourceVariables. For example, if a
+ `ResourceVariable` is constructed inside a xla.compile() block, the
+ `ResourceVariable` implementation can use
+ `with ops.control_dependencies(None)` to build the variable's definition
+ outside the compiled computation.
+ """
+
+ def __init__(self, name, pivot):
+ """Builds a new XLACompileContext.
+
+ Args:
+ name: a unique name for the context, used to populate the
+ `_xla_compile_id` attribute.
+ pivot: a pivot node. Nodes in the XLACompileContext that do not have any
+ inputs will have a control dependency on the pivot node. This ensures
+ that nodes are correctly included in any enclosing control flow
+ contexts.
+ """
+ super(XLACompileContext, self).__init__()
+ self._name = name
+ self._name_as_bytes = compat.as_bytes(name)
+ self._unsupported_ops = []
+ self._pivot = pivot
+
+ def report_unsupported_operations(self):
+ if self._unsupported_ops:
+ op_str = '\n'.join([
+ ' %s (%s)' % (op.type, op.name)
+ for op in self._unsupported_ops[:_MAX_WARNING_LINES]
+ ])
+ logging.warning('%d unsupported operations found: \n%s',
+ len(self._unsupported_ops), op_str)
+ if len(self._unsupported_ops) > _MAX_WARNING_LINES:
+ logging.warning('... and %d more',
+ len(self._unsupported_ops) - _MAX_WARNING_LINES)
+
+ def AddOp(self, op):
+ """Create op in XLACompileContext and notifies outer context recursively."""
+ # pylint: disable=protected-access
+ if op.type in _BLACKLISTED_OPS:
+ logging.error(
+ 'Operation of type %s (%s) is not supported in XLA. Execution will '
+ 'fail if this op is used in the graph. ', op.type, op.name)
+
+ # TODO(ycao): Automatically disable summaries instead of reporting them.
+ if op.type in _UNSUPPORTED_OPS:
+ self._unsupported_ops.append(op)
+
+ if any(x.dtype._is_ref_dtype for x in op.inputs):
+ raise NotImplementedError(
+ 'Non-resource Variables are not supported inside XLA computations '
+ '(operator name: %s)' % op.name)
+
+ if _XLA_COMPILE_ATTR in op.node_def.attr:
+ raise ValueError('XLA compiled computations cannot be nested, (operator '
+ 'name: %s)' % op.name)
+
+ op._set_attr(
+ _XLA_COMPILE_ATTR, attr_value_pb2.AttrValue(s=self._name_as_bytes))
+
+ op.graph.prevent_feeding(op)
+ op.graph.prevent_fetching(op)
+
+ # Remove any control edges from outer control flow contexts. These may cause
+ # mismatched frame errors. An example is when one of op's inputs is
+ # generated in a different While control flow context.
+ (internal_control_inputs,
+ external_control_inputs) = self._RemoveExternalControlEdges(op)
+
+ if not op.inputs:
+ # Add a control edge from the control pivot to this op.
+ if not internal_control_inputs:
+ # pylint: disable=protected-access
+ op._add_control_input(self._pivot)
+ # pylint: enable=protected-access
+ else:
+ for index in xrange(len(op.inputs)):
+ x = op.inputs[index]
+ real_x = self.AddValue(x)
+ if real_x != x:
+ op._update_input(index, real_x) # pylint: disable=protected-access
+
+ if external_control_inputs:
+ # Use an identity to pull control inputs as data inputs. Note that we
+ # ignore ops which don't have outputs. TODO(phawkins): fix that.
+ with ops.control_dependencies(None):
+ self.Enter()
+ external_control_inputs = [
+ array_ops.identity(x.outputs[0]).op
+ for x in external_control_inputs
+ if x.outputs
+ ]
+ self.Exit()
+ # pylint: disable=protected-access
+ op._add_control_inputs(external_control_inputs)
+ # pylint: enable=protected-access
+
+ # Mark op's outputs as seen by this context and any outer contexts.
+ output_names = [x.name for x in op.outputs]
+ context = self
+ while context is not None:
+ # pylint: disable=protected-access
+ context._values.update(output_names)
+ context = context._outer_context
+ # pylint: enable=protected-access
+
+ if self._outer_context:
+ self._outer_context.AddInnerOp(op)
+
+ def AddValue(self, val):
+ """Add `val` to the current context and its outer context recursively."""
+ if val.name in self._values:
+ # Use the real value if it comes from outer context.
+ result = self._external_values.get(val.name)
+ return val if result is None else result
+
+ result = val
+ self._values.add(val.name)
+ if self._outer_context:
+ result = self._outer_context.AddValue(val)
+ self._values.add(result.name)
+
+ self._external_values[val.name] = result
+
+ return result
+
+ def AddInnerOp(self, op):
+ self.AddOp(op)
+ if self._outer_context:
+ self._outer_context.AddInnerOp(op)
+
+ @property
+ def grad_state(self):
+ # Define the gradient loop state associated with the XLACompileContext to
+ # be None as the XLACompileContext does not get nested nor does the
+ # grad_state outside the XLACompileContext affect the graph inside so the
+ # grad_state should be as if this is the top-level gradient state.
+ return None
+
+ @property
+ def back_prop(self):
+ """Forwards to the enclosing while context, if any."""
+ if self.GetWhileContext():
+ return self.GetWhileContext().back_prop
+ return False
diff --git a/tensorflow/contrib/compiler/xla_test.py b/tensorflow/contrib/compiler/xla_test.py
new file mode 100644
index 0000000000..a306b56f63
--- /dev/null
+++ b/tensorflow/contrib/compiler/xla_test.py
@@ -0,0 +1,180 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Tests for contrib.compiler.xla."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.compiler import xla
+from tensorflow.python import summary
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
+from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import summary_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import test
+
+
+class XLACompileContextTest(test.TestCase):
+
+ def create_test_xla_compile_context(self):
+ computation_name = ops.get_default_graph().unique_name('computation')
+ pivot = control_flow_ops.no_op(name=computation_name + '/pivot')
+ return xla.XLACompileContext(name=computation_name, pivot=pivot)
+
+ def test_report_unsupported_operations(self):
+ """Tests that unsupported operations are detected."""
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ dummy_tensor = constant_op.constant(1.1)
+ audio_summary = summary.audio('audio_summary', dummy_tensor, 0.5)
+ histogram_summary = summary.histogram('histogram_summary', dummy_tensor)
+ image_summary = summary.image('image_summary', dummy_tensor)
+ scalar_summary = summary.scalar('scalar_summary', dummy_tensor)
+ tensor_summary = summary_ops.tensor_summary('tensor_summary', dummy_tensor)
+ summary.merge(
+ [
+ audio_summary, histogram_summary, image_summary, scalar_summary,
+ tensor_summary
+ ],
+ name='merge_summary')
+ logging_ops.Print(dummy_tensor, [dummy_tensor], name='print_op')
+ context.Exit()
+
+ unsupported_ops_names = [op.name for op in context._unsupported_ops]
+ self.assertEqual(unsupported_ops_names, [
+ u'audio_summary', u'histogram_summary', u'image_summary',
+ u'scalar_summary', u'tensor_summary', u'merge_summary/merge_summary',
+ u'print_op'
+ ])
+
+ def test_resource_variable(self):
+ """Tests that resource variable usage is allowed."""
+ a = variable_scope.get_variable(
+ name='variable_a', shape=(1), use_resource=True)
+
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ state_ops.assign(a, a + 1)
+ context.Exit()
+
+ def test_non_resource_variable_error(self):
+ """Tests that non-resource variable usage is disallowed."""
+ a = variable_scope.get_variable(
+ name='variable_a', shape=(1), use_resource=False)
+
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ with self.assertRaisesRegexp(
+ NotImplementedError, 'Non-resource Variables are not supported inside '
+ r'XLA computations \(operator name: Assign\)'):
+ state_ops.assign(a, a + 1)
+ context.Exit()
+
+ def test_nested_xla_compile_error(self):
+ """Tests that nested XLA computation leads to fatal error."""
+ context1 = self.create_test_xla_compile_context()
+ context1.Enter()
+
+ context2 = self.create_test_xla_compile_context()
+ context2.Enter()
+ with self.assertRaisesRegexp(ValueError,
+ 'XLA compiled computations cannot be nested'):
+ constant_op.constant(1)
+ context2.Exit()
+ context1.Exit()
+
+ def test_xla_compile_attr(self):
+ """Tests that ops are tagged with XLA compile ID attribute."""
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ op = constant_op.constant(1)
+ context.Exit()
+ self.assertIn('_xla_compile_id', op.op.node_def.attr)
+
+ def test_op_without_input(self):
+ """Tests that ops without inputs depend on pivot correctly."""
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ op = constant_op.constant(1)
+ context.Exit()
+
+ self.assertIn(context._pivot, op.op.control_inputs)
+
+ def test_external_control_edges(self):
+ """Tests that external control edges are handled correctly."""
+ i = constant_op.constant(1)
+ op1 = constant_op.constant(1)
+
+ with ops.control_dependencies([op1]):
+ op2 = constant_op.constant(1)
+ self.assertIn(op1.op, op2.op.control_inputs)
+
+ def while_body(i):
+ del i # unused
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ with ops.control_dependencies([op1]):
+ op3 = constant_op.constant(1)
+ context.Exit()
+ self.assertNotIn(op1.op, op3.op.control_inputs)
+ return op3
+
+ control_flow_ops.while_loop(
+ cond=lambda i: math_ops.less(i, 10), body=while_body, loop_vars=[i])
+
+ def test_op_output_marked_as_seen(self):
+ """Tests that any op output is marked as seen in context."""
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ op = constant_op.constant(1)
+ context.Exit()
+
+ self.assertIn(op.name, context._values)
+
+ def testOpIsInContext(self):
+ """Tests that XLACompileContext is recognized as an XLA context."""
+ op1 = constant_op.constant(1)
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ op2 = constant_op.constant(2)
+ context.Exit()
+ self.assertFalse(control_flow_util.IsInXLAContext(op1.op))
+ self.assertTrue(control_flow_util.IsInXLAContext(op2.op))
+
+ def testOpPreventFeeding(self):
+ """Tests that ops created inside XLACompileContext can not be fed."""
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ op = constant_op.constant(1)
+ context.Exit()
+ self.assertFalse(op.graph.is_feedable(op.op))
+
+ def testOpPreventFetching(self):
+ """Tests that ops created inside XLACompileContext can not be fetched."""
+ context = self.create_test_xla_compile_context()
+ context.Enter()
+ op = constant_op.constant(1)
+ context.Exit()
+ self.assertFalse(op.graph.is_fetchable(op.op))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py
index 9b4bf62710..3e25079e02 100644
--- a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py
+++ b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py
@@ -75,7 +75,7 @@ class ExternalRegretOptimizerTest(test.TestCase):
multipliers3 = standard_ops.constant([0.4, 0.7, -0.2, 0.5, 0.1])
expected_projected_multipliers3 = np.array([0.2, 0.5, 0.0, 0.3, 0.0])
- with self.test_session() as session:
+ with self.cached_session() as session:
projected_multipliers1 = session.run(
external_regret_optimizer._project_multipliers_wrt_euclidean_norm(
multipliers1, 1.0))
@@ -122,7 +122,7 @@ class ExternalRegretOptimizerTest(test.TestCase):
]
multipliers = []
- with self.test_session() as session:
+ with self.cached_session() as session:
session.run(standard_ops.global_variables_initializer())
while len(multipliers) < len(expected_multipliers):
multipliers.append(session.run(optimizer.lagrange_multipliers))
diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py
index 34c4543dca..df0eced631 100644
--- a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py
+++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py
@@ -97,7 +97,7 @@ class SwapRegretOptimizerTest(test.TestCase):
matrix1 = np.matrix([[0.6, 0.1, 0.1], [0.0, 0.6, 0.9], [0.4, 0.3, 0.0]])
matrix2 = np.matrix([[0.4, 0.4, 0.2], [0.2, 0.1, 0.5], [0.4, 0.5, 0.3]])
- with self.test_session() as session:
+ with self.cached_session() as session:
eigenvector1 = session.run(
swap_regret_optimizer._maximal_eigenvector_power_method(
standard_ops.constant(matrix1)))
@@ -119,7 +119,7 @@ class SwapRegretOptimizerTest(test.TestCase):
expected_projected_matrix = np.array([[0.6, 0.1, 0.1], [0.0, 0.6, 0.9],
[0.4, 0.3, 0.0]])
- with self.test_session() as session:
+ with self.cached_session() as session:
projected_matrix = session.run(
swap_regret_optimizer._project_stochastic_matrix_wrt_euclidean_norm(
matrix))
@@ -134,7 +134,7 @@ class SwapRegretOptimizerTest(test.TestCase):
expected_projected_matrix = np.array([[0.4, 0.4, 0.2], [0.2, 0.1, 0.5],
[0.4, 0.5, 0.3]])
- with self.test_session() as session:
+ with self.cached_session() as session:
projected_matrix = session.run(
standard_ops.exp(
swap_regret_optimizer.
@@ -165,7 +165,7 @@ class SwapRegretOptimizerTest(test.TestCase):
]
matrices = []
- with self.test_session() as session:
+ with self.cached_session() as session:
session.run(standard_ops.global_variables_initializer())
while len(matrices) < len(expected_matrices):
matrices.append(session.run(optimizer.stochastic_matrix))
@@ -198,7 +198,7 @@ class SwapRegretOptimizerTest(test.TestCase):
]
matrices = []
- with self.test_session() as session:
+ with self.cached_session() as session:
session.run(standard_ops.global_variables_initializer())
while len(matrices) < len(expected_matrices):
matrices.append(session.run(optimizer.stochastic_matrix))
diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
index 8cfe142059..556d731840 100644
--- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
+++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py
@@ -61,7 +61,7 @@ class CrfTest(test.TestCase):
for sequence_lengths, inputs, tag_indices in zip(sequence_lengths_list,
inputs_list,
tag_indices_list):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sequence_score = crf.crf_sequence_score(
inputs=array_ops.expand_dims(inputs, 0),
tag_indices=array_ops.expand_dims(tag_indices, 0),
@@ -96,7 +96,7 @@ class CrfTest(test.TestCase):
]
for sequence_lengths, inputs, tag_bitmap in zip(
sequence_lengths_list, inputs_list, tag_bitmap_list):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sequence_score = crf.crf_multitag_sequence_score(
inputs=array_ops.expand_dims(inputs, 0),
tag_bitmap=array_ops.expand_dims(tag_bitmap, 0),
@@ -124,7 +124,7 @@ class CrfTest(test.TestCase):
for dtype in (np.int32, np.int64):
tag_indices = np.array([1, 2, 1, 0], dtype=dtype)
sequence_lengths = np.array(3, dtype=np.int32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
unary_score = crf.crf_unary_score(
tag_indices=array_ops.expand_dims(tag_indices, 0),
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
@@ -140,7 +140,7 @@ class CrfTest(test.TestCase):
transition_params = np.array(
[[-3, 5, -2], [3, 4, 1], [1, 2, 1]], dtype=np.float32)
sequence_lengths = np.array(3, dtype=np.int32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
binary_score = crf.crf_binary_score(
tag_indices=array_ops.expand_dims(tag_indices, 0),
sequence_lengths=array_ops.expand_dims(sequence_lengths, 0),
@@ -176,7 +176,7 @@ class CrfTest(test.TestCase):
tag_indices_list):
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
all_sequence_scores = []
# Compare the dynamic program with brute force computation.
@@ -206,7 +206,7 @@ class CrfTest(test.TestCase):
"""
Test `crf_log_norm` when `sequence_lengths` contains one or more zeros.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = constant_op.constant(np.ones([2, 10, 5],
dtype=np.float32))
transition_params = constant_op.constant(np.ones([5, 5],
@@ -226,7 +226,7 @@ class CrfTest(test.TestCase):
sequence_lengths = np.array(3, dtype=np.int32)
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
all_sequence_log_likelihoods = []
# Make sure all probabilities sum to 1.
@@ -254,7 +254,7 @@ class CrfTest(test.TestCase):
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
all_sequence_scores = []
all_sequences = []
@@ -310,7 +310,7 @@ class CrfTest(test.TestCase):
num_words = inputs.shape[0]
num_tags = inputs.shape[1]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
all_sequence_scores = []
all_sequences = []
@@ -351,7 +351,7 @@ class CrfTest(test.TestCase):
"""
Test that crf_decode works when sequence_length contains one or more zeros.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = constant_op.constant(np.ones([2, 10, 5],
dtype=np.float32))
transition_params = constant_op.constant(np.ones([5, 5],
diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD
index 8bdbba83ef..9f710613dd 100644
--- a/tensorflow/contrib/data/BUILD
+++ b/tensorflow/contrib/data/BUILD
@@ -33,14 +33,22 @@ cc_library(
tf_custom_op_library(
name = "_dataset_ops.so",
- srcs = ["ops/dataset_ops.cc"],
- deps = ["//tensorflow/contrib/data/kernels:dataset_kernels"] +
- if_static(
- extra_deps = [":lib_proto_parsing_for_dataset_ops"],
- otherwise = [],
- ),
+ srcs = [
+ "ops/dataset_ops.cc",
+ "ops/indexed_dataset_ops.cc",
+ ],
+ deps = [
+ "//tensorflow/contrib/data/kernels:dataset_kernels",
+ "//tensorflow/contrib/data/kernels:indexed_dataset",
+ ] + if_static(
+ extra_deps = [":lib_proto_parsing_for_dataset_ops"],
+ otherwise = [],
+ ),
)
tf_gen_op_libs(
- op_lib_names = ["dataset_ops"],
+ op_lib_names = [
+ "dataset_ops",
+ "indexed_dataset_ops",
+ ],
)
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index 5821d51bca..baec238c62 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -25,6 +25,8 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@Counter
@@CheckpointInputPipelineHook
@@CsvDataset
+@@LMDBDataset
+@@Optional
@@RandomDataset
@@Reducer
@@SqlDataset
@@ -37,7 +39,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@copy_to_device
@@dense_to_sparse_batch
@@enumerate_dataset
-
+@@get_next_as_optional
@@get_single_element
@@group_by_reducer
@@group_by_window
@@ -45,10 +47,10 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@make_batched_features_dataset
@@make_csv_dataset
@@make_saveable_from_iterator
-
@@map_and_batch
@@padded_batch_and_drop_remainder
@@parallel_interleave
+@@parse_example_dataset
@@prefetch_to_device
@@read_batch_features
@@rejection_resample
@@ -89,10 +91,12 @@ from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datase
from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
+from tensorflow.contrib.data.python.ops.parsing_ops import parse_example_dataset
from tensorflow.contrib.data.python.ops.prefetching_ops import copy_to_device
from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device
from tensorflow.contrib.data.python.ops.random_ops import RandomDataset
from tensorflow.contrib.data.python.ops.readers import CsvDataset
+from tensorflow.contrib.data.python.ops.readers import LMDBDataset
from tensorflow.contrib.data.python.ops.readers import make_batched_features_dataset
from tensorflow.contrib.data.python.ops.readers import make_csv_dataset
from tensorflow.contrib.data.python.ops.readers import read_batch_features
@@ -103,6 +107,8 @@ from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat
from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch
from tensorflow.contrib.data.python.ops.unique import unique
from tensorflow.contrib.data.python.ops.writers import TFRecordWriter
+from tensorflow.python.data.ops.iterator_ops import get_next_as_optional
+from tensorflow.python.data.ops.optional_ops import Optional
# pylint: enable=unused-import
from tensorflow.python.util.all_util import remove_undocumented
diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD
index 2e249f5c14..ec6cb37193 100644
--- a/tensorflow/contrib/data/kernels/BUILD
+++ b/tensorflow/contrib/data/kernels/BUILD
@@ -7,6 +7,31 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
cc_library(
+ name = "indexed_dataset_headers",
+ hdrs = ["indexed_dataset.h"],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ "@protobuf_archive//:protobuf_headers",
+ ],
+)
+
+cc_library(
+ name = "indexed_dataset",
+ srcs = [
+ "identity_indexed_dataset.cc",
+ "indexed_dataset.cc",
+ ],
+ deps = [
+ ":indexed_dataset_headers",
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ "@protobuf_archive//:protobuf_headers",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
name = "prefetching_kernels",
srcs = ["prefetching_kernels.cc"],
deps = [
@@ -52,6 +77,17 @@ cc_library(
)
cc_library(
+ name = "lmdb_dataset_op",
+ srcs = ["lmdb_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ "@lmdb",
+ "@protobuf_archive//:protobuf_headers",
+ ],
+)
+
+cc_library(
name = "threadpool_dataset_op",
srcs = ["threadpool_dataset_op.cc"],
deps = [
@@ -91,6 +127,8 @@ cc_library(
":csv_dataset_op",
":directed_interleave_dataset_op",
":ignore_errors_dataset_op",
+ ":indexed_dataset",
+ ":lmdb_dataset_op",
":prefetching_kernels",
":threadpool_dataset_op",
":unique_dataset_op",
diff --git a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc b/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc
index e36c9c0634..c19a609780 100644
--- a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
namespace tensorflow {
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -150,4 +151,5 @@ REGISTER_KERNEL_BUILDER(Name("AssertNextDataset").Device(DEVICE_CPU),
AssertNextDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
index d242cfdf49..74107d5242 100644
--- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/csv_dataset_op.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/zlib_inputstream.h"
namespace tensorflow {
+namespace data {
namespace {
class CSVDatasetOp : public DatasetOpKernel {
@@ -713,7 +714,7 @@ class CSVDatasetOp : public DatasetOpKernel {
component.scalar<string>()() =
dataset()->record_defaults_[output_idx].flat<string>()(0);
} else {
- component.scalar<string>()() = field.ToString();
+ component.scalar<string>()() = string(field);
}
break;
}
@@ -851,4 +852,5 @@ class CSVDatasetOp : public DatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("CSVDataset").Device(DEVICE_CPU), CSVDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc
index ccf7ec1f84..a5321620bf 100644
--- a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/core/lib/hash/hash.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -276,5 +276,5 @@ REGISTER_KERNEL_BUILDER(Name("DirectedInterleaveDataset").Device(DEVICE_CPU),
DirectedInterleaveDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc b/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc
new file mode 100644
index 0000000000..c3cb45dbf7
--- /dev/null
+++ b/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc
@@ -0,0 +1,155 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/data/kernels/indexed_dataset.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+class IdentityIndexedDatasetOp : public IndexedDatasetOpKernel {
+ public:
+ using IndexedDatasetOpKernel::IndexedDatasetOpKernel;
+
+ void MakeIndexedDataset(OpKernelContext* ctx,
+ IndexedDataset** output) override {
+ uint64 size = -1;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<uint64>(ctx, "size", &size));
+ OP_REQUIRES(ctx, size > 0, errors::InvalidArgument("`size` must be > 0"));
+ *output = new Dataset(ctx, size);
+ }
+
+ class Dataset : public IndexedDataset {
+ public:
+ Dataset(OpKernelContext* ctx, uint64 size)
+ : IndexedDataset(DatasetContext(ctx)), size_(size) {}
+
+ Status MaterializeDataset(
+ std::shared_ptr<MaterializedIndexedDataset>* materialized) override {
+ materialized->reset(new Materialized(this));
+ return Status::OK();
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* dtypes = new DataTypeVector({DT_UINT64});
+ return *dtypes;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}});
+ return *shapes;
+ }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::IdentityIndexedDataset")}));
+ }
+
+ string DebugString() const override {
+ return "IdentityIndexedDataset::Dataset";
+ }
+
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** node) const override {
+ return errors::Unimplemented(
+ "identity_indexed_dataset.AsGraphDefInternal");
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ if (cur_ < dataset()->size_) {
+ Tensor result_tensor(ctx->allocator({}), DT_UINT64, {});
+ result_tensor.scalar<uint64>()() = cur_++;
+ out_tensors->emplace_back(std::move(result_tensor));
+ *end_of_sequence = false;
+ return Status::OK();
+ }
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ uint64 cur_ GUARDED_BY(mu_);
+ };
+
+ class Materialized : public MaterializedIndexedDataset {
+ public:
+ explicit Materialized(Dataset* dataset) : dataset_(dataset) {
+ dataset->Ref();
+ }
+
+ ~Materialized() override {
+ // TODO(saeta): Pull this into MaterializedIndexedDataset
+ dataset_->Unref();
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return dataset_->output_dtypes();
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return dataset_->output_shapes();
+ }
+
+ Status Get(IteratorContext&& ctx, uint64 index,
+ std::vector<Tensor>* out_tensors) const override {
+ LOG(INFO) << "Materialized(" << dataset_->size_ << ")::Get(" << index
+ << ")";
+ if (index >= dataset_->size_) {
+ // Note: use InvalidArgument instead of OutOfRange error because many
+ // things consider OutOfRange to be a "clean termination" error.
+ return errors::InvalidArgument(
+ "Index ", index,
+ " is out of range for this dataset. (Size is: ", dataset_->size_,
+ ".)");
+ }
+ Tensor result_tensor(ctx.allocator({}), DT_UINT64, {});
+ result_tensor.scalar<uint64>()() = index;
+ out_tensors->emplace_back(std::move(result_tensor));
+ return Status::OK();
+ }
+
+ Status Size(uint64* size) const override {
+ *size = dataset_->size_;
+ return Status::OK();
+ }
+
+ private:
+ const Dataset* const dataset_; // Not owned.
+ };
+
+ const uint64 size_;
+ std::shared_ptr<Materialized> materialized_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("IdentityIndexedDataset").Device(DEVICE_CPU),
+ IdentityIndexedDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc
index db24e60846..beec344534 100644
--- a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -137,5 +137,5 @@ REGISTER_KERNEL_BUILDER(Name("IgnoreErrorsDataset").Device(DEVICE_CPU),
IgnoreErrorsDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/indexed_dataset.cc b/tensorflow/contrib/data/kernels/indexed_dataset.cc
new file mode 100644
index 0000000000..ced8ab0d60
--- /dev/null
+++ b/tensorflow/contrib/data/kernels/indexed_dataset.cc
@@ -0,0 +1,373 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/data/kernels/indexed_dataset.h"
+
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+Status VerifyTypesMatch(const DataTypeVector& expected,
+ const DataTypeVector& received) {
+ if (expected.size() != received.size()) {
+ return errors::InvalidArgument(
+ "Number of components does not match: expected ", expected.size(),
+ " types but got ", received.size(), ".");
+ }
+ for (size_t i = 0; i < expected.size(); ++i) {
+ if (expected[i] != received[i]) {
+ return errors::InvalidArgument("Data type mismatch at component ", i,
+ ": expected ", DataTypeString(expected[i]),
+ " but got ", DataTypeString(received[i]),
+ ".");
+ }
+ }
+ return Status::OK();
+}
+
+Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
+ const std::vector<PartialTensorShape>& received) {
+ if (expected.size() != received.size()) {
+ return errors::InvalidArgument(
+ "Number of components does not match: expected ", expected.size(),
+ " shapes but got ", received.size(), ".");
+ }
+ for (size_t i = 0; i < expected.size(); ++i) {
+ if (!expected[i].IsCompatibleWith(received[i])) {
+ return errors::InvalidArgument("Incompatible shapes at component ", i,
+ ": expected ", expected[i].DebugString(),
+ " but got ", received[i].DebugString(),
+ ".");
+ }
+ }
+
+ return Status::OK();
+}
+
+class MaterializedDatasetResource : public ResourceBase {
+ public:
+ MaterializedDatasetResource(
+ const DataTypeVector& output_dtypes,
+ const std::vector<PartialTensorShape>& output_shapes)
+ : output_dtypes_(output_dtypes), output_shapes_(output_shapes) {}
+
+ string DebugString() override {
+ return "Materialized IndexedDataset resource";
+ }
+
+ Status Get(IteratorContext&& ctx, uint64 index,
+ std::vector<Tensor>* out_tensors) {
+ std::shared_ptr<MaterializedIndexedDataset> captured(materialized_);
+ if (captured) {
+ return captured->Get(std::move(ctx), index, out_tensors);
+ } else {
+ return errors::FailedPrecondition(
+ "Get() failed because the MaterializedIndexedDataset has not been "
+ "initialized. Ensure that you have run the materialization operation "
+ "for this MaterializedIndexedDataset before retrieving elements.");
+ }
+ }
+
+ // TODO(saeta): Implement Save and Restore
+
+ const DataTypeVector& output_dtypes() const { return output_dtypes_; }
+ const std::vector<PartialTensorShape>& output_shapes() const {
+ return output_shapes_;
+ }
+
+ Status set_materialized_dataset(
+ const std::shared_ptr<MaterializedIndexedDataset>& dataset) {
+ if (dataset) {
+ TF_RETURN_IF_ERROR(
+ VerifyTypesMatch(output_dtypes_, dataset->output_dtypes()));
+ TF_RETURN_IF_ERROR(
+ VerifyShapesCompatible(output_shapes_, dataset->output_shapes()));
+ }
+ materialized_ = dataset;
+ return Status::OK();
+ }
+
+ private:
+ std::shared_ptr<MaterializedIndexedDataset> materialized_;
+ const DataTypeVector output_dtypes_;
+ const std::vector<PartialTensorShape> output_shapes_;
+};
+
+// A wrapper class for storing an `IndexedDataset` instance in a DT_VARIANT
+// tensor. Objects of the wrapper class own a reference on an instance of an
+// `IndexedTensor` and the wrapper's copy constructor and desctructor take care
+// of managing the reference count.
+//
+// NOTE: This is not a feature-complete implementation of the DT_VARIANT
+// specification. In particular, we cannot currently serialize an arbitrary
+// `IndexedDataset` object, so the `Encode()` and `Decode()` methods are not
+// implemented.
+//
+// NOTE(saeta): When `IndexedDataset`s get merged into core, we can instead just
+// use `tensorflow::DatasetVariantWrapper`.
+class IndexedDatasetVariantWrapper {
+ public:
+ IndexedDatasetVariantWrapper() : dataset_(nullptr) {}
+
+ // Transfers ownership of `dataset` to `*this`.
+ explicit IndexedDatasetVariantWrapper(IndexedDataset* dataset)
+ : dataset_(dataset) {}
+
+ IndexedDatasetVariantWrapper(const IndexedDatasetVariantWrapper& other)
+ : dataset_(other.dataset_) {
+ if (dataset_) dataset_->Ref();
+ }
+
+ ~IndexedDatasetVariantWrapper() {
+ if (dataset_) dataset_->Unref();
+ }
+
+ IndexedDataset* get() const { return dataset_; }
+
+ string TypeName() const { return "tensorflow::IndexedDatasetVariantWrapper"; }
+ string DebugString() const {
+ if (dataset_) {
+ return dataset_->DebugString();
+ } else {
+ return "<Uninitialized IndexedDatasetVariantWrapper>";
+ }
+ }
+
+ void Encode(VariantTensorData* data) const {
+ LOG(ERROR) << "The Encode() method is not implemented for "
+ "IndexedDatasetVariantWrapper objects.";
+ }
+
+ bool Decode(const VariantTensorData& data) {
+ LOG(ERROR) << "The Decode() method is not implemented for "
+ "IndexedDatasetVariantWrapper objects.";
+ return false;
+ }
+
+ private:
+ IndexedDataset* const dataset_; // Owns one reference.
+};
+
+} // namespace
+
+Status GetIndexedDatasetFromVariantTensor(const Tensor& tensor,
+ IndexedDataset** out_dataset) {
+ if (!(tensor.dtype() == DT_VARIANT ||
+ TensorShapeUtils::IsScalar(tensor.shape()))) {
+ return errors::InvalidArgument(
+ "IndexedDataset tensor must be a scalar of dtype DT_VARIANT.");
+ }
+ const Variant& variant = tensor.scalar<Variant>()();
+ const IndexedDatasetVariantWrapper* wrapper =
+ variant.get<IndexedDatasetVariantWrapper>();
+ if (wrapper == nullptr) {
+ return errors::InvalidArgument("Tensor must be an IndexedDataset object.");
+ }
+ *out_dataset = wrapper->get();
+ if (*out_dataset == nullptr) {
+ return errors::Internal("Read uninitialized IndexedDataset variant.");
+ }
+ return Status::OK();
+}
+
+Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset,
+ Tensor* tensor) {
+ if (!(tensor->dtype() == DT_VARIANT ||
+ TensorShapeUtils::IsScalar(tensor->shape()))) {
+ return errors::InvalidArgument(
+ "Dataset tensor must be a scalar of dtype DT_VARIANT.");
+ }
+ tensor->scalar<Variant>()() = IndexedDatasetVariantWrapper(dataset);
+ return Status::OK();
+}
+
+void IndexedDatasetOpKernel::Compute(OpKernelContext* ctx) {
+ IndexedDataset* dataset = nullptr;
+ MakeIndexedDataset(ctx, &dataset);
+
+ if (ctx->status().ok()) {
+ OP_REQUIRES(ctx, dataset != nullptr,
+ errors::Internal("MakeIndexedDataset did not correctly "
+ "construct the IndexedDataset"));
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
+ OP_REQUIRES_OK(ctx, StoreIndexedDatasetInVariantTensor(dataset, output));
+ }
+}
+
+namespace {
+
+class MaterializedHandleOp : public OpKernel {
+ public:
+ explicit MaterializedHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ ~MaterializedHandleOp() override {
+ if (resource_ != nullptr) {
+ resource_->Unref();
+ if (cinfo_.resource_is_private_to_kernel()) {
+ if (!cinfo_.resource_manager()
+ ->template Delete<MaterializedDatasetResource>(
+ cinfo_.container(), cinfo_.name())
+ .ok()) {
+ // Do nothing; the resource can have been deleted by session resets.
+ // Note: cargo-culted from $tf/core/framework/resource_op_kernel.h
+ }
+ }
+ }
+ }
+
+ void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) {
+ {
+ mutex_lock l(mu_);
+ if (resource_ == nullptr) {
+ ResourceMgr* mgr = context->resource_manager();
+ OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
+
+ MaterializedDatasetResource* resource;
+ OP_REQUIRES_OK(context,
+ mgr->LookupOrCreate<MaterializedDatasetResource>(
+ cinfo_.container(), cinfo_.name(), &resource,
+ [this](MaterializedDatasetResource** ret)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ *ret = new MaterializedDatasetResource(
+ output_dtypes_, output_shapes_);
+ return Status::OK();
+ }));
+ Status s = VerifyResource(resource);
+ if (TF_PREDICT_FALSE(!s.ok())) {
+ resource->Unref();
+ context->SetStatus(s);
+ return;
+ }
+
+ resource_ = resource;
+ }
+ }
+ OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
+ context, 0, cinfo_.container(), cinfo_.name(),
+ MakeTypeIndex<MaterializedDatasetResource>()));
+ }
+
+ private:
+ // During the first Compute(), resource is either created or looked up using
+ // shared_name. In the latter case, the resource found should be verified if
+ // it is compatible with this op's configuration. The verification may fail in
+ // cases such as two graphs asking queues of the same shared name to have
+ // inconsistent capacities.
+ Status VerifyResource(MaterializedDatasetResource* resource) {
+ TF_RETURN_IF_ERROR(
+ VerifyTypesMatch(output_dtypes_, resource->output_dtypes()));
+ TF_RETURN_IF_ERROR(
+ VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
+ return Status::OK();
+ }
+
+ mutex mu_;
+ ContainerInfo cinfo_; // Written once under mu_ then constant afterwards.
+ MaterializedDatasetResource* resource_ GUARDED_BY(mu_) = nullptr;
+ DataTypeVector output_dtypes_;
+ std::vector<PartialTensorShape> output_shapes_;
+};
+
+// TODO(saeta): Make async.
+class MaterializeDatasetOp : public OpKernel {
+ public:
+ explicit MaterializeDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ IndexedDataset* dataset;
+ OP_REQUIRES_OK(ctx,
+ GetIndexedDatasetFromVariantTensor(ctx->input(0), &dataset));
+
+ MaterializedDatasetResource* materialized_resource;
+ OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
+ &materialized_resource));
+ core::ScopedUnref unref(materialized_resource);
+ std::shared_ptr<MaterializedIndexedDataset> materialized;
+ OP_REQUIRES_OK(ctx, dataset->MaterializeDataset(&materialized));
+ OP_REQUIRES_OK(
+ ctx, materialized_resource->set_materialized_dataset(materialized));
+ }
+};
+
+// TODO(saeta): Make async
+class IndexedDatasetGet : public OpKernel {
+ public:
+ explicit IndexedDatasetGet(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ MaterializedDatasetResource* materialized_resource;
+ OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0),
+ &materialized_resource));
+ auto cleanup = gtl::MakeCleanup([materialized_resource] {
+ materialized_resource->Unref(); // Note: can't use core::ScopedUnref.
+ });
+
+ const Tensor* index_t;
+ OP_REQUIRES_OK(ctx, ctx->input("index", &index_t));
+ // TODO(saeta): Support batch reads (indexes should be non-scalar!)
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(index_t->shape()),
+ errors::InvalidArgument("index must be a scalar"));
+ const uint64 index = index_t->scalar<uint64>()();
+
+ std::vector<Tensor> out_tensors;
+ Status s =
+ materialized_resource->Get(IteratorContext(ctx), index, &out_tensors);
+
+ // Note: Unref materialized_resource to avoid destruction races. (Important
+ // in a [future] async op implementation.)
+ cleanup.release()();
+
+ if (!s.ok()) {
+ ctx->SetStatus(s);
+ } else {
+ auto expected_shapes = materialized_resource->output_shapes();
+ auto expected_types = materialized_resource->output_dtypes();
+ for (size_t i = 0; i < out_tensors.size(); ++i) {
+ OP_REQUIRES(
+ ctx, expected_shapes[i].IsCompatibleWith(out_tensors[i].shape()),
+ errors::Internal(
+ "Materialized dataset output at index ", i,
+ " is incompatible with the expected shape. (Expected: ",
+ expected_shapes[i], ", got: ", out_tensors[i].shape(), ")"));
+ OP_REQUIRES(ctx, out_tensors[i].dtype() == expected_types[i],
+ errors::Internal("Materialized dataset output at index ", i,
+ " was not the expected dtype. (Expected: ",
+ expected_types[i],
+ ", got: ", out_tensors[i].dtype(), ")"));
+ ctx->set_output(i, out_tensors[i]);
+ }
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("MaterializedIndexDatasetHandle").Device(DEVICE_CPU),
+ MaterializedHandleOp);
+REGISTER_KERNEL_BUILDER(Name("IndexedDatasetMaterialize").Device(DEVICE_CPU),
+ MaterializeDatasetOp);
+REGISTER_KERNEL_BUILDER(Name("IndexedDatasetGet").Device(DEVICE_CPU),
+ IndexedDatasetGet);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/indexed_dataset.h b/tensorflow/contrib/data/kernels/indexed_dataset.h
new file mode 100644
index 0000000000..7aa2d3fdbc
--- /dev/null
+++ b/tensorflow/contrib/data/kernels/indexed_dataset.h
@@ -0,0 +1,119 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_
+#define TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_
+
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace data {
+
+// TODO(saeta): Urgh, this is ugly.
+class MaterializedIndexedDataset {
+ public:
+ virtual ~MaterializedIndexedDataset() = default;
+
+ // Retrieve the element at a given index. The output tensors are stored in
+ // out_tensors.
+ //
+ // If `index` is greater than `Size()`, tensorflow::errors::OutOfRangeError is
+ // returned.
+ //
+ // Get is thread-safe.
+ virtual Status Get(IteratorContext&& ctx, uint64 index,
+ std::vector<Tensor>* out_tensors) const = 0;
+
+ // Size determines the number of elements in this IndexedDataset.
+ //
+ // Size is thread-safe.
+ virtual Status Size(uint64* size) const = 0;
+
+ // Returns a vector of DataType values, representing the respective
+ // element types of each tuple component in the outputs of this dataset.
+ virtual const DataTypeVector& output_dtypes() const = 0;
+
+ // Returns a vector of tensor shapes, representing the respective
+ // (and possibly partially defined) shapes of each tuple component
+ // in the outputs of this dataset.
+ virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
+};
+
+// IndexedDataset represents a dataset that supports random access in addition
+// to iterator-based sequential access.
+//
+// Note: IndexedDatasets are HIGHLY experimental at this time. Expect
+// significant (backwards incompatible) changes!
+class IndexedDataset : public DatasetBase {
+ public:
+ IndexedDataset(DatasetContext&& ctx) : DatasetBase(std::move(ctx)) {}
+
+ // Materialize (if necessary) the dataset, and return a pointer.
+ // TODO(saeta): Add in `IteratorContext* ctx` when materializing.
+ virtual Status MaterializeDataset(
+ std::shared_ptr<MaterializedIndexedDataset>* materialized) = 0;
+};
+
+// IndexedDatasetOpKernel abstracts away interfacing IndexedDatasets with the
+// rest of the TensorFlow runtime.
+//
+// Most IndexedDataset's will be private members of classes inheriting from this
+// class.
+class IndexedDatasetOpKernel : public OpKernel {
+ public:
+ IndexedDatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ void Compute(OpKernelContext* ctx) final;
+
+ protected:
+ // Subclasses should implement this method. It will be called during Compute
+ // execution.
+ virtual void MakeIndexedDataset(OpKernelContext* ctx,
+ IndexedDataset** output) = 0;
+
+ template <typename T>
+ Status ParseScalarArgument(OpKernelContext* ctx,
+ const StringPiece& argument_name, T* output) {
+ const Tensor* argument_t;
+ TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
+ if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
+ return errors::InvalidArgument(argument_name, " must be a scalar");
+ }
+ *output = argument_t->scalar<T>()();
+ return Status::OK();
+ }
+};
+
+// Validates and extracts an `IndexedDataset` object from `tensor`.
+//
+// `tensor` must have been written by a call to
+// `StoreIndexedDatasetInVariantTensor`
+//
+// The retrieved pointer isa borrowed reference to the dataset, which is owned
+// by the tensor. The consumer must either acquire its own reference to the
+// dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not
+// destroyed or mutated while the retrieved pointer is in use.
+Status GetIndexedDatasetFromVariantTensor(const Tensor& tensor,
+ IndexedDataset** out_dataset);
+
+// Stores an `IndexedDataset` object in `tensor.`
+//
+// The ownership of `dataset` is transferred to `tensor`.
+Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset,
+ Tensor* tensor);
+
+} // namespace data
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_
diff --git a/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc b/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc
new file mode 100644
index 0000000000..d233c1f8ec
--- /dev/null
+++ b/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc
@@ -0,0 +1,217 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <sys/stat.h>
+
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/lib/io/buffered_inputstream.h"
+#include "tensorflow/core/platform/file_system.h"
+
+#include "lmdb.h" // NOLINT(build/include)
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+class LMDBDatasetOp : public DatasetOpKernel {
+ public:
+ using DatasetOpKernel::DatasetOpKernel;
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ const Tensor* filenames_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor));
+ OP_REQUIRES(
+ ctx, filenames_tensor->dims() <= 1,
+ errors::InvalidArgument("`filenames` must be a scalar or a vector."));
+
+ std::vector<string> filenames;
+ filenames.reserve(filenames_tensor->NumElements());
+ for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
+ filenames.push_back(filenames_tensor->flat<string>()(i));
+ }
+
+ *output = new Dataset(ctx, filenames);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const std::vector<string>& filenames)
+ : DatasetBase(DatasetContext(ctx)), filenames_(filenames) {}
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::LMDB")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* dtypes =
+ new DataTypeVector({DT_STRING, DT_STRING});
+ return *dtypes;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}, {}});
+ return *shapes;
+ }
+
+ string DebugString() const override { return "LMDBDatasetOp::Dataset"; }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* filenames = nullptr;
+ TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {filenames}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ do {
+ if (mdb_cursor_) {
+ Tensor key_tensor(ctx->allocator({}), DT_STRING, {});
+ key_tensor.scalar<string>()() = string(
+ static_cast<const char*>(mdb_key_.mv_data), mdb_key_.mv_size);
+ out_tensors->emplace_back(std::move(key_tensor));
+
+ Tensor value_tensor(ctx->allocator({}), DT_STRING, {});
+ value_tensor.scalar<string>()() =
+ string(static_cast<const char*>(mdb_value_.mv_data),
+ mdb_value_.mv_size);
+ out_tensors->emplace_back(std::move(value_tensor));
+
+ int val;
+ val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_NEXT);
+ if (val != MDB_SUCCESS && val != MDB_NOTFOUND) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ if (val == MDB_NOTFOUND) {
+ ResetStreamsLocked();
+ ++current_file_index_;
+ }
+ *end_of_sequence = false;
+ return Status::OK();
+ }
+ if (current_file_index_ == dataset()->filenames_.size()) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+
+ TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
+ } while (true);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ return errors::Unimplemented(
+ "Checkpointing is currently not supported for LMDBDataset.");
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ return errors::Unimplemented(
+ "Checkpointing is currently not supported for LMDBDataset.");
+ }
+
+ private:
+ Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (current_file_index_ >= dataset()->filenames_.size()) {
+ return errors::InvalidArgument(
+ "current_file_index_:", current_file_index_,
+ " >= filenames_.size():", dataset()->filenames_.size());
+ }
+ const string& filename = dataset()->filenames_[current_file_index_];
+
+ int val = mdb_env_create(&mdb_env_);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ int flags = MDB_RDONLY | MDB_NOTLS | MDB_NOLOCK;
+
+ struct stat source_stat;
+ if (stat(filename.c_str(), &source_stat) == 0 &&
+ (source_stat.st_mode & S_IFREG)) {
+ flags |= MDB_NOSUBDIR;
+ }
+ val = mdb_env_open(mdb_env_, filename.c_str(), flags, 0664);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ val = mdb_txn_begin(mdb_env_, nullptr, MDB_RDONLY, &mdb_txn_);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ val = mdb_dbi_open(mdb_txn_, nullptr, 0, &mdb_dbi_);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ val = mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_FIRST);
+ if (val != MDB_SUCCESS && val != MDB_NOTFOUND) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ if (val == MDB_NOTFOUND) {
+ ResetStreamsLocked();
+ }
+ return Status::OK();
+ }
+ void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (mdb_env_ != nullptr) {
+ if (mdb_cursor_) {
+ mdb_cursor_close(mdb_cursor_);
+ mdb_cursor_ = nullptr;
+ }
+ mdb_dbi_close(mdb_env_, mdb_dbi_);
+ mdb_txn_abort(mdb_txn_);
+ mdb_env_close(mdb_env_);
+ mdb_txn_ = nullptr;
+ mdb_dbi_ = 0;
+ mdb_env_ = nullptr;
+ }
+ }
+ mutex mu_;
+ size_t current_file_index_ GUARDED_BY(mu_) = 0;
+ MDB_env* mdb_env_ GUARDED_BY(mu_) = nullptr;
+ MDB_txn* mdb_txn_ GUARDED_BY(mu_) = nullptr;
+ MDB_dbi mdb_dbi_ GUARDED_BY(mu_) = 0;
+ MDB_cursor* mdb_cursor_ GUARDED_BY(mu_) = nullptr;
+
+ MDB_val mdb_key_ GUARDED_BY(mu_);
+ MDB_val mdb_value_ GUARDED_BY(mu_);
+ };
+
+ const std::vector<string> filenames_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("LMDBDataset").Device(DEVICE_CPU), LMDBDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
index 74df1e42a8..078de717e0 100644
--- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc
+++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
+namespace data {
namespace {
struct BufferElement {
@@ -548,7 +549,9 @@ class MultiDeviceIterator : public ResourceBase {
devices_(devices),
flib_def_(std::move(flib_def)),
pflr_(std::move(pflr)),
- lib_(lib) {}
+ lib_(lib) {
+ CHECK_NOTNULL(lib_);
+ }
string DebugString() override {
return strings::StrCat("MultiDeviceIterator for ", devices_.size(),
@@ -600,6 +603,11 @@ class MultiDeviceIterator : public ResourceBase {
return lib_def_;
}
+ FunctionLibraryRuntime* const lib() {
+ tf_shared_lock l(mu_);
+ return lib_;
+ }
+
private:
// A private class that uses a background thread to keep a per device buffer
// full.
@@ -930,8 +938,10 @@ class MultiDeviceIteratorInitOp : public OpKernel {
core::ScopedUnref unref(resource);
std::unique_ptr<IteratorBase> iterator;
- OP_REQUIRES_OK(ctx, dataset->MakeIterator(IteratorContext(ctx), "Iterator",
- &iterator));
+ IteratorContext iter_ctx(ctx);
+ iter_ctx.set_lib(resource->lib());
+ OP_REQUIRES_OK(
+ ctx, dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator));
int64 incarnation_id;
OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), max_buffer_size,
&incarnation_id));
@@ -1105,5 +1115,6 @@ REGISTER_KERNEL_BUILDER(
Name("MultiDeviceIteratorFromStringHandle").Device(DEVICE_CPU),
MultiDeviceIteratorFromStringHandleOp);
-} // anonymous namespace
+} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
index ab584504a0..30fa97a636 100644
--- a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
+namespace data {
namespace {
class ThreadPoolResource : public ResourceBase {
@@ -214,4 +215,5 @@ REGISTER_KERNEL_BUILDER(Name("ThreadPoolDataset").Device(DEVICE_CPU),
ThreadPoolDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/data/kernels/unique_dataset_op.cc b/tensorflow/contrib/data/kernels/unique_dataset_op.cc
index 6fbf5d2ebb..57fc5697a4 100644
--- a/tensorflow/contrib/data/kernels/unique_dataset_op.cc
+++ b/tensorflow/contrib/data/kernels/unique_dataset_op.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/core/lib/hash/hash.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -219,5 +219,5 @@ REGISTER_KERNEL_BUILDER(Name("UniqueDataset").Device(DEVICE_CPU),
UniqueDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc
index cc5e250ea1..ae104d55bd 100644
--- a/tensorflow/contrib/data/ops/dataset_ops.cc
+++ b/tensorflow/contrib/data/ops/dataset_ops.cc
@@ -266,4 +266,13 @@ REGISTER_OP("AssertNextDataset")
return shape_inference::ScalarShape(c);
});
+REGISTER_OP("LMDBDataset")
+ .Input("filenames: string")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn(shape_inference::ScalarShape);
+
} // namespace tensorflow
diff --git a/tensorflow/contrib/data/ops/indexed_dataset_ops.cc b/tensorflow/contrib/data/ops/indexed_dataset_ops.cc
new file mode 100644
index 0000000000..cd9b7c68a0
--- /dev/null
+++ b/tensorflow/contrib/data/ops/indexed_dataset_ops.cc
@@ -0,0 +1,80 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("IdentityIndexedDataset")
+ .Input("size: uint64")
+ .Output("handle: variant")
+ .SetIsStateful()
+ .SetShapeFn(
+ shape_inference::ScalarShape); // TODO(saeta): check input shapes.
+
+///////////////////////////////////////////////////////////////////////////////
+// IndexedDataset Internals
+///////////////////////////////////////////////////////////////////////////////
+
+// Creates the handle.
+REGISTER_OP("MaterializedIndexDatasetHandle")
+ .Output("handle: resource")
+ .Attr("container: string")
+ .Attr("shared_name: string")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+// Actually materialize the materialize handle.
+REGISTER_OP("IndexedDatasetMaterialize")
+ .Input("dataset: variant")
+ .Input("materialized: resource")
+ .SetShapeFn(shape_inference::NoOutputs);
+
+namespace {
+
+Status GetShapeFn(shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ std::vector<PartialTensorShape> output_shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ if (output_shapes.size() != c->num_outputs()) {
+ return errors::InvalidArgument(
+ "`output_shapes` must be the same length as `output_types` (",
+ output_shapes.size(), " vs. ", c->num_outputs());
+ }
+ for (size_t i = 0; i < output_shapes.size(); ++i) {
+ shape_inference::ShapeHandle output_shape_handle;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ output_shapes[i], &output_shape_handle));
+ c->set_output(static_cast<int>(i), output_shape_handle);
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+REGISTER_OP("IndexedDatasetGet")
+ .Input("materialized: resource")
+ .Input("index: uint64")
+ .Output("components: output_types")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(GetShapeFn)
+ .Doc(R"doc(
+Gets the element at `index` from `materialized` IndexedDataset.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index cd46e382eb..6f0111a2bd 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -4,7 +4,8 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-load("//tensorflow:tensorflow.bzl", "cuda_py_test", "py_test")
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow:tensorflow.bzl", "py_test")
py_test(
name = "batch_dataset_op_test",
@@ -134,6 +135,21 @@ py_test(
)
py_test(
+ name = "indexed_dataset_ops_test",
+ srcs = ["indexed_dataset_ops_test.py"],
+ deps = [
+ "//tensorflow/contrib/data/python/ops:contrib_op_loader",
+ "//tensorflow/contrib/data/python/ops:gen_dataset_ops",
+ "//tensorflow/contrib/data/python/ops:indexed_dataset_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
name = "interleave_dataset_op_test",
size = "medium",
srcs = ["interleave_dataset_op_test.py"],
@@ -179,6 +195,31 @@ py_test(
)
py_test(
+ name = "lmdb_dataset_op_test",
+ size = "medium",
+ srcs = ["lmdb_dataset_op_test.py"],
+ data = ["//tensorflow/core:lmdb_testdata"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ "//tensorflow/contrib/data/python/ops:readers",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:session",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
name = "map_dataset_op_test",
size = "medium",
srcs = ["map_dataset_op_test.py"],
@@ -205,6 +246,25 @@ py_test(
)
py_test(
+ name = "filter_dataset_op_test",
+ size = "medium",
+ srcs = ["filter_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
name = "map_defun_op_test",
size = "small",
srcs = ["map_defun_op_test.py"],
@@ -219,29 +279,30 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:function",
+ "//tensorflow/python:functional_ops",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:session",
],
)
py_test(
- name = "optimize_dataset_op_test",
+ name = "parsing_ops_test",
size = "small",
- srcs = ["optimize_dataset_op_test.py"],
+ srcs = ["parsing_ops_test.py"],
srcs_version = "PY2AND3",
deps = [
- ":stats_dataset_test_base",
- ":test_utils",
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/contrib/data/python/ops:stats_ops",
- "//tensorflow/python:check_ops",
+ "//tensorflow/contrib/data/python/ops:parsing_ops",
+ "//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/ops:dataset_ops",
- "@absl_py//absl/testing:parameterized",
+ "//tensorflow/python/data/util:nest",
+ "//third_party/py/numpy",
],
)
@@ -331,6 +392,7 @@ py_test(
"//tensorflow/python:parsing_ops",
"//tensorflow/python:string_ops",
"//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/data/util:nest",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
index 42adfd17f0..8e368bf2bc 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
@@ -57,7 +57,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for start in range(0, len(components), 4):
@@ -85,7 +85,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for start in range(0, len(components), 4):
@@ -123,7 +123,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Initialize with an input tensor of incompatible rank.
sess.run(init_op, feed_dict={input_tensor: [[1]]})
with self.assertRaisesRegexp(errors.InvalidArgumentError,
@@ -148,7 +148,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual((i,) * 3, sess.run(op))
@@ -168,7 +168,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
@@ -187,7 +187,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
st_row = sess.run(next_element)
self.assertEqual([i], st_row.indices)
@@ -208,7 +208,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
dense_elem, st_row = sess.run(next_element)
self.assertEqual(i, dense_elem)
@@ -230,7 +230,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual(((i,),) * 3, sess.run(op))
@@ -250,7 +250,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
op = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")),
sess.run(op))
@@ -266,7 +266,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.OutOfRangeError):
sess.run(next_element)
@@ -284,7 +284,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = data.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Mismatch in the 0th dimension.
sess.run(
iterator.initializer,
@@ -319,7 +319,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for test_batch_size in [1, 3, 7, 10]:
sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size})
num_batches = 7 // test_batch_size
@@ -343,7 +343,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(2):
actual = sess.run(get_next)
@@ -374,7 +374,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for test_batch_size in [1, 3, 7, 10]:
sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size})
num_batches = 7 // test_batch_size
@@ -428,10 +428,10 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list())
@parameterized.named_parameters(
- ("default", None, None),
- ("sequential_calls", 1, None),
- ("parallel_calls", 2, None),
- ("parallel_batches", None, 10),
+ ("Default", None, None),
+ ("SequentialCalls", 1, None),
+ ("ParallelCalls", 2, None),
+ ("ParallelBatches", None, 10),
)
def testMapAndBatch(self, num_parallel_calls, num_parallel_batches):
"""Test a dataset that maps a TF function across its input elements."""
@@ -461,7 +461,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
self.assertEqual([[None] + list(c.shape[1:]) for c in components],
[t.shape.as_list() for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Batch of a finite input, where the batch_size divides the
# total number of elements.
sess.run(init_op, feed_dict={count: 28, batch_size: 14})
@@ -505,8 +505,8 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
sess.run(init_op, feed_dict={count: 14, batch_size: 0})
@parameterized.named_parameters(
- ("even", False),
- ("uneven", True),
+ ("Even", False),
+ ("Uneven", True),
)
def testMapAndBatchPartialBatch(self, drop_remainder):
iterator = (
@@ -520,7 +520,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
else:
self.assertEqual([None, 1], iterator.output_shapes.as_list())
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
if not drop_remainder:
@@ -535,7 +535,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
.make_one_shot_iterator())
self.assertEqual([None, 1], iterator.output_shapes.as_list())
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
self.assertAllEqual([[64], [81]], sess.run(next_element))
@@ -549,7 +549,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
elements = []
for _ in range(100):
elements.append(iterator.get_next())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(5):
got = sess.run(elements)
got.sort(key=lambda x: x[0])
@@ -569,7 +569,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
elements = []
for _ in range(100):
elements.append(iterator.get_next())
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(4):
got = sess.run(elements)
got.sort(key=lambda x: x[0])
@@ -591,7 +591,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(2):
actual = sess.run(get_next)
@@ -614,7 +614,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
.make_initializable_iterator())
init_op = iterator.initializer
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
sess.run(init_op, feed_dict={batch_size: 14})
@@ -635,7 +635,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"number of elements does not match"):
@@ -659,11 +659,18 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(3):
sess.run(get_next)
- @parameterized.parameters(0, 5, 10, 90, 95, 99)
+ @parameterized.named_parameters(
+ ("1", 0),
+ ("2", 5),
+ ("3", 10),
+ ("4", 90),
+ ("5", 95),
+ ("6", 99),
+ )
def testMapAndBatchOutOfRangeError(self, threshold):
def raising_py_fn(i):
@@ -679,7 +686,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
batch_size=10)).make_one_shot_iterator())
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(threshold // 10):
self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
if threshold % 10 != 0:
@@ -689,18 +696,18 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- @parameterized.parameters(
- (False, dtypes.bool),
- (-42, dtypes.int8),
- (-42, dtypes.int16),
- (-42, dtypes.int32),
- (-42, dtypes.int64),
- (42, dtypes.uint8),
- (42, dtypes.uint16),
- (42.0, dtypes.float16),
- (42.0, dtypes.float32),
- (42.0, dtypes.float64),
- (b"hello", dtypes.string),
+ @parameterized.named_parameters(
+ ("1", False, dtypes.bool),
+ ("2", -42, dtypes.int8),
+ ("3", -42, dtypes.int16),
+ ("4", -42, dtypes.int32),
+ ("5", -42, dtypes.int64),
+ ("6", 42, dtypes.uint8),
+ ("7", 42, dtypes.uint16),
+ ("8", 42.0, dtypes.float16),
+ ("9", 42.0, dtypes.float32),
+ ("10", 42.0, dtypes.float64),
+ ("11", b"hello", dtypes.string),
)
def testMapAndBatchTypes(self, element, dtype):
def gen():
@@ -711,7 +718,7 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(10):
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
@@ -720,6 +727,42 @@ class RestructuredDatasetTest(test.TestCase):
def test_assert_element_shape(self):
+ def create_dataset(_):
+ return (array_ops.ones(2, dtype=dtypes.float32),
+ array_ops.zeros((3, 4), dtype=dtypes.int32))
+
+ dataset = dataset_ops.Dataset.range(5).map(create_dataset)
+ expected_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 4)))
+ self.assertEqual(expected_shapes, dataset.output_shapes)
+
+ result = dataset.apply(batching.assert_element_shape(expected_shapes))
+ self.assertEqual(expected_shapes, result.output_shapes)
+
+ iterator = result.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.test_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run(get_next)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def test_assert_wrong_element_shape(self):
+
+ def create_dataset(_):
+ return (array_ops.ones(2, dtype=dtypes.float32),
+ array_ops.zeros((3, 4), dtype=dtypes.int32))
+
+ dataset = dataset_ops.Dataset.range(3).map(create_dataset)
+ wrong_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 10)))
+ with self.assertRaises(ValueError):
+ dataset.apply(batching.assert_element_shape(wrong_shapes))
+
+ def test_assert_element_shape_on_unknown_shape_dataset(self):
+
def create_unknown_shape_dataset(x):
return script_ops.py_func(
lambda _: ( # pylint: disable=g-long-lambda
@@ -741,6 +784,59 @@ class RestructuredDatasetTest(test.TestCase):
iterator = result.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run(get_next)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def test_assert_wrong_element_shape_on_unknown_shape_dataset(self):
+
+ def create_unknown_shape_dataset(x):
+ return script_ops.py_func(
+ lambda _: ( # pylint: disable=g-long-lambda
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
+ [x],
+ [dtypes.float32, dtypes.int32])
+
+ dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
+ unknown_shapes = (tensor_shape.TensorShape(None),
+ tensor_shape.TensorShape(None))
+ self.assertEqual(unknown_shapes, dataset.output_shapes)
+
+ wrong_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 10)))
+ iterator = (
+ dataset.apply(batching.assert_element_shape(wrong_shapes))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.test_session() as sess:
+ sess.run(init_op)
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(get_next)
+
+ def test_assert_partial_element_shape(self):
+
+ def create_dataset(_):
+ return (array_ops.ones(2, dtype=dtypes.float32),
+ array_ops.zeros((3, 4), dtype=dtypes.int32))
+
+ dataset = dataset_ops.Dataset.range(5).map(create_dataset)
+ partial_expected_shape = (tensor_shape.TensorShape(None), # Unknown shape
+ tensor_shape.TensorShape((None, 4))) # Partial shape
+ result = dataset.apply(
+ batching.assert_element_shape(partial_expected_shape))
+ # Partial shapes are merged with actual shapes:
+ actual_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 4)))
+ self.assertEqual(actual_shapes, result.output_shapes)
+
+ iterator = result.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
with self.test_session() as sess:
sess.run(init_op)
for _ in range(5):
@@ -748,7 +844,7 @@ class RestructuredDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def test_assert_wrong_element_shape(self):
+ def test_assert_wrong_partial_element_shape(self):
def create_dataset(_):
return (array_ops.ones(2, dtype=dtypes.float32),
@@ -756,11 +852,41 @@ class RestructuredDatasetTest(test.TestCase):
dataset = dataset_ops.Dataset.range(3).map(create_dataset)
wrong_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((3, 10)))
+ tensor_shape.TensorShape((None, 10)))
with self.assertRaises(ValueError):
dataset.apply(batching.assert_element_shape(wrong_shapes))
- def test_assert_wrong_element_shape_on_unknown_shape_dataset(self):
+ def test_assert_partial_element_shape_on_unknown_shape_dataset(self):
+
+ def create_unknown_shape_dataset(x):
+ return script_ops.py_func(
+ lambda _: ( # pylint: disable=g-long-lambda
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
+ [x],
+ [dtypes.float32, dtypes.int32])
+
+ dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset)
+ unknown_shapes = (tensor_shape.TensorShape(None),
+ tensor_shape.TensorShape(None))
+ self.assertEqual(unknown_shapes, dataset.output_shapes)
+
+ expected_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((None, 4)))
+ result = dataset.apply(batching.assert_element_shape(expected_shapes))
+ self.assertEqual(expected_shapes, result.output_shapes)
+
+ iterator = result.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.test_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run(get_next)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def test_assert_wrong_partial_element_shape_on_unknown_shape_dataset(self):
def create_unknown_shape_dataset(x):
return script_ops.py_func(
@@ -776,13 +902,13 @@ class RestructuredDatasetTest(test.TestCase):
self.assertEqual(unknown_shapes, dataset.output_shapes)
wrong_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((3, 10)))
+ tensor_shape.TensorShape((None, 10)))
iterator = (
dataset.apply(batching.assert_element_shape(wrong_shapes))
.make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
index 2022c1f2bd..293be2bd06 100644
--- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
@@ -40,7 +40,7 @@ class GroupByReducerTest(test.TestCase):
def checkResults(self, dataset, shapes, values):
self.assertEqual(shapes, dataset.output_shapes)
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for expected in values:
got = sess.run(get_next)
self.assertEqual(got, expected)
@@ -129,7 +129,7 @@ class GroupByReducerTest(test.TestCase):
self.assertIs(None, dataset.output_shapes[1].ndims)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x, y = sess.run(get_next)
self.assertAllEqual([0] * (2**i), x)
self.assertAllEqual(np.array(1, ndmin=i), y)
@@ -192,7 +192,7 @@ class GroupByReducerTest(test.TestCase):
(dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply(
grouping.group_by_reducer(lambda x, y: np.int64(0), reducer))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x, y = sess.run(get_next)
self.assertAllEqual(x, np.asarray([x for x in range(10)]))
self.assertEqual(y, 45)
@@ -210,7 +210,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
counts = []
with self.assertRaises(errors.OutOfRangeError):
@@ -237,7 +237,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
# The input is infinite, so this test demonstrates that:
# 1. We produce output without having to consume the entire input,
@@ -258,7 +258,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
@@ -275,7 +275,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
@@ -301,7 +301,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
@@ -329,7 +329,7 @@ class GroupByWindowTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
counts = []
with self.assertRaises(errors.OutOfRangeError):
@@ -376,7 +376,7 @@ class BucketTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
which_bucket, bucketed_values = sess.run(get_next)
@@ -411,7 +411,7 @@ class BucketTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
# Get two minibatches (one containing even values, one containing odds)
@@ -482,7 +482,7 @@ class BucketTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
# Get two minibatches ([0, 2, ...] and [64, 66, ...])
@@ -515,7 +515,7 @@ class BucketTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.OutOfRangeError):
batches = 0
@@ -556,7 +556,7 @@ class BucketBySequenceLength(test.TestCase):
element_len, boundaries, batch_sizes))
batch, = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batches = []
for _ in range(4):
batches.append(sess.run(batch))
@@ -600,7 +600,7 @@ class BucketBySequenceLength(test.TestCase):
pad_to_bucket_boundary=True))
batch, = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batches = []
for _ in range(3):
batches.append(sess.run(batch))
@@ -637,7 +637,7 @@ class BucketBySequenceLength(test.TestCase):
pad_to_bucket_boundary=True))
batch, = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batches = []
for _ in range(5):
batches.append(sess.run(batch))
diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
index 9b1857de1a..eb110324d1 100644
--- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
@@ -38,7 +38,7 @@ class DirectedInterleaveDatasetTest(test.TestCase):
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for _ in range(100):
for i in range(10):
@@ -67,7 +67,7 @@ class DirectedInterleaveDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
freqs = np.zeros([num_datasets])
for _ in range(num_samples):
freqs[sess.run(next_element)] += 1
@@ -84,7 +84,7 @@ class DirectedInterleaveDatasetTest(test.TestCase):
# Use chi-squared test to assert that the observed distribution matches the
# expected distribution. Based on the implementation in
# "tensorflow/python/kernel_tests/multinomial_op_test.py".
- for probs in [[.85, .05, .1], rand_probs]:
+ for probs in [[.85, .05, .1], rand_probs, [1.]]:
probs = np.asarray(probs)
classes = len(probs)
freqs = self._testSampleFromDatasetsHelper(probs, classes, num_samples)
@@ -104,7 +104,7 @@ class DirectedInterleaveDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in choice_array:
self.assertEqual(words[i], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
diff --git a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py
new file mode 100644
index 0000000000..6d01bf585c
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py
@@ -0,0 +1,76 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Benchmarks FilterDataset input pipeline op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy as np
+
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.client import session
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class FilterBenchmark(test.Benchmark):
+
+ # This benchmark compares the performance of pipeline with multiple chained
+ # filter with and without filter fusion.
+ def benchmarkFilters(self):
+ chain_lengths = [0, 1, 2, 5, 10, 20, 50]
+ for chain_length in chain_lengths:
+ self._benchmarkFilters(chain_length, False)
+ self._benchmarkFilters(chain_length, True)
+
+ def _benchmarkFilters(self, chain_length, optimize_dataset):
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors(5).repeat(None)
+ for _ in range(chain_length):
+ dataset = dataset.filter(lambda x: math_ops.greater_equal(x - 5, 0))
+ if optimize_dataset:
+ dataset = dataset.apply(optimization.optimize(["filter_fusion"]))
+
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for _ in range(10):
+ sess.run(next_element.op)
+ deltas = []
+ for _ in range(100):
+ start = time.time()
+ for _ in range(100):
+ sess.run(next_element.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ median_wall_time = np.median(deltas) / 100
+ opt_mark = "opt" if optimize_dataset else "no-opt"
+ print("Filter dataset {} chain length: {} Median wall time: {}".format(
+ opt_mark, chain_length, median_wall_time))
+ self.report_benchmark(
+ iters=1000,
+ wall_time=median_wall_time,
+ name="benchmark_filter_dataset_chain_latency_{}_{}".format(
+ opt_mark, chain_length))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
index e6883d53e0..f3968cdc15 100644
--- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
@@ -53,7 +53,7 @@ class GetSingleElementTest(test.TestCase, parameterized.TestCase):
lambda x: (x * x, make_sparse(x))).take(take_t)
element = get_single_element.get_single_element(dataset)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
if error is None:
dense_val, sparse_val = sess.run(
element, feed_dict={
@@ -90,7 +90,7 @@ class GetSingleElementTest(test.TestCase, parameterized.TestCase):
dataset = dataset_ops.Dataset.range(stop_t)
element = get_single_element.reduce_dataset(dataset, sum_reducer)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
value = sess.run(element, feed_dict={stop_t: stop})
self.assertEqual(stop * (stop - 1) / 2, value)
diff --git a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
new file mode 100644
index 0000000000..9c508d686d
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
@@ -0,0 +1,78 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for experimental indexed dataset ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import unittest
+
+from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
+from tensorflow.contrib.data.python.ops import gen_dataset_ops
+from tensorflow.contrib.data.python.ops import indexed_dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class IndexedDatasetOpsTest(test.TestCase):
+
+ def testLowLevelIndexedDatasetOps(self):
+ identity = gen_dataset_ops.identity_indexed_dataset(
+ ops.convert_to_tensor(16, dtype=dtypes.uint64))
+ handle = gen_dataset_ops.materialized_index_dataset_handle(
+ container="",
+ shared_name="",
+ output_types=[dtypes.uint64],
+ output_shapes=[[]])
+ materialize = gen_dataset_ops.indexed_dataset_materialize(identity, handle)
+ index = array_ops.placeholder(dtypes.uint64)
+ get_op = gen_dataset_ops.indexed_dataset_get(
+ handle, index, output_types=[dtypes.uint64], output_shapes=[[]])
+
+ with self.cached_session() as sess:
+ sess.run(materialize)
+ self.assertEqual([3], sess.run(get_op, feed_dict={index: 3}))
+
+ def testIdentityIndexedDataset(self):
+ ds = indexed_dataset_ops.IdentityIndexedDataset(16)
+ materialized = ds.materialize()
+ with self.cached_session() as sess:
+ sess.run(materialized.initializer)
+ placeholder = array_ops.placeholder(dtypes.uint64, shape=[])
+ for i in range(16):
+ output = sess.run(
+ materialized.get(placeholder), feed_dict={placeholder: i})
+ self.assertEqual([i], output)
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(materialized.get(placeholder), feed_dict={placeholder: 16})
+
+ @unittest.skip("Requisite functionality currently unimplemented.")
+ def testIdentityIndexedDatasetIterator(self):
+ ds = indexed_dataset_ops.IdentityIndexedDataset(16)
+ itr = ds.make_initializable_iterator()
+ n = itr.get_next()
+ with self.cached_session() as sess:
+ sess.run(itr.initializer)
+ for i in range(16):
+ output = sess.run(n)
+ self.assertEqual(i, output)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(n)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
index 7a3215f6cc..b9e74dfddb 100644
--- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
@@ -177,7 +177,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def _testSingleThreaded(self, sloppy=False, prefetch_input_elements=0):
# cycle_length=1,block_length=1 acts like `Dataset.interleave()` and
# `Dataset.flat_map()` and is single-threaded. No synchronization required.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
sess.run(
self.init_op,
@@ -212,7 +212,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def testSingleThreadedRagged(self):
# Tests a sequence with wildly different elements per iterator.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
sess.run(
self.init_op,
@@ -242,7 +242,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def _testTwoThreadsNoContention(self, sloppy=False):
# num_threads > 1.
# Explicit coordination should result in `Dataset.interleave()` behavior
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
@@ -286,7 +286,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
Args:
sloppy: Whether to be sloppy or not.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
@@ -328,7 +328,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def _testTwoThreadsNoContentionBlockLength(self, sloppy=False):
# num_threads > 1.
# Explicit coordination should result in `Dataset.interleave()` behavior
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
@@ -373,7 +373,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
Args:
sloppy: Whether to be sloppy or not.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
@@ -413,7 +413,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True)
def _testEmptyInput(self, sloppy=False):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Empty input.
self._clear_coordination_events()
sess.run(
@@ -437,7 +437,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def _testNonEmptyInputIntoEmptyOutputs(self, sloppy=False):
# Non-empty input leading to empty output.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
sess.run(
self.init_op,
@@ -461,7 +461,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def _testPartiallyEmptyOutputs(self, sloppy=False, prefetch_input_elements=1):
race_indices = {2, 8, 14} # Sequence points when sloppy mode has race conds
# Mixture of non-empty and empty interleaved datasets.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
@@ -500,7 +500,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def testDelayedOutputSloppy(self):
# Explicitly control the sequence of events to ensure we correctly avoid
# head-of-line blocking.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
sess.run(
self.init_op,
@@ -525,7 +525,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
sess.run(self.next_element)
def testBlockLengthWithContentionSloppy(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
done_first_event = False
sess.run(
@@ -560,7 +560,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
def _testEarlyExit(self, sloppy=False):
# Exiting without consuming all input should not block
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
sess.run(
self.init_op,
@@ -604,7 +604,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
interleave_fn, cycle_length=16, block_length=2, sloppy=sloppy))
iterator = dataset.make_one_shot_iterator()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output_values = []
for _ in range(30):
output_values.append(sess.run(iterator.get_next()))
@@ -635,7 +635,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
for j in range(2):
@@ -645,7 +645,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
sess.run(get_next)
def testErrorsInOutputFn(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self._clear_coordination_events()
sess.run(
self.init_op,
@@ -704,7 +704,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
self.init_op = self.iterator.initializer
self.next_element = self.iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.init_op,
feed_dict={
@@ -753,7 +753,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
self.init_op = self.iterator.initializer
self.next_element = self.iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.init_op,
feed_dict={
@@ -792,7 +792,7 @@ class ParallelInterleaveDatasetTest(test.TestCase):
next_element = iterator.get_next()
results = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(2):
elements = []
sess.run(iterator.initializer)
diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
new file mode 100644
index 0000000000..1cc5ddc9a2
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
@@ -0,0 +1,66 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for LMDBDatasetOp."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import shutil
+
+from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+prefix_path = "tensorflow/core/lib"
+
+
+class LMDBDatasetTest(test.TestCase):
+
+ def setUp(self):
+ super(LMDBDatasetTest, self).setUp()
+ # Copy database out because we need the path to be writable to use locks.
+ path = os.path.join(prefix_path, "lmdb", "testdata", "data.mdb")
+ self.db_path = os.path.join(self.get_temp_dir(), "data.mdb")
+ shutil.copy(path, self.db_path)
+
+ def testReadFromFile(self):
+ filename = self.db_path
+
+ filenames = constant_op.constant([filename], dtypes.string)
+ num_repeats = 2
+
+ dataset = readers.LMDBDataset(filenames).repeat(num_repeats)
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for _ in range(num_repeats): # Dataset is repeated.
+ for i in range(10): # 10 records.
+ k = compat.as_bytes(str(i))
+ v = compat.as_bytes(str(chr(ord("a") + i)))
+ self.assertEqual((k, v), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
index dc9d56dd53..e8519381d6 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
@@ -54,7 +54,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for x in [1., 2., 3., 5.]:
self.assertEqual(x, sess.run(get_next))
@@ -72,7 +72,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for x in [1., 2., 3., 5.]:
self.assertEqual(x, sess.run(get_next))
@@ -99,7 +99,7 @@ class MapDatasetTest(test.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# All of the files are present.
sess.run(init_op)
for filename in filenames:
@@ -209,7 +209,7 @@ class MapDatasetBenchmark(test.Benchmark):
end = time.time()
chained_deltas.append(end - start)
- fused_dataset = dataset = dataset.apply(
+ fused_dataset = dataset.apply(
batching.map_and_batch(
math_ops.matmul,
num_parallel_calls=num_calls,
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
index 73cde40305..61567bc8d7 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
@@ -17,7 +17,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import time
+
from tensorflow.contrib.data.python.ops import map_defun
+from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -25,10 +28,10 @@ from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-
class MapDefunTest(test.TestCase):
def testMapDefunSimple(self):
@@ -130,6 +133,121 @@ class MapDefunTest(test.TestCase):
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(result)
+ def testMapDefunCancelledCorrectly(self):
+
+ @function.Defun(dtypes.int64)
+ def defun(x):
+ # x has leading dimension 5, this will raise an error
+ return array_ops.gather(x, 10)
+
+ c = array_ops.tile(
+ array_ops.expand_dims(
+ constant_op.constant([1, 2, 3, 4, 5], dtype=dtypes.int64), 0),
+ [100, 1])
+ map_defun_op = map_defun.map_defun(defun, [c], [dtypes.int64], [()])[0]
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ r"indices = 10 is not in \[0, 5\)"):
+ self.evaluate(map_defun_op)
+
+ def testMapDefunWithUnspecifiedOutputShape(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ res = x * 2 + 3
+ return (res, res + 1, res + 2)
+
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems],
+ [dtypes.int32, dtypes.int32, dtypes.int32],
+ [None, (None,), (2,)])
+ expected = elems * 2 + 3
+ self.assertAllEqual(self.evaluate(r[0]), self.evaluate(expected))
+ self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1))
+ self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2))
+
+ def testMapDefunWithDifferentOutputShapeEachRun(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ return x * 2 + 3
+
+ elems = array_ops.placeholder(dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [None])[0]
+ with session.Session() as sess:
+ self.assertAllEqual(sess.run(r, feed_dict={elems: [0]}), [3])
+ self.assertAllEqual(
+ sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]])
+
+ def testMapDefunWithWrongOutputShape(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ return x * 2 + 3
+
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(1,)])[0]
+ with self.assertRaises(errors.InvalidArgumentError):
+ self.evaluate(r)
+
+ def testMapDefunWithInvalidInput(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ return x * 2
+
+ c = constant_op.constant(2)
+ with self.assertRaises(ValueError):
+ # Fails at graph construction time for inputs with known shapes.
+ r = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [None])[0]
+ p = array_ops.placeholder(dtypes.int32)
+ r = map_defun.map_defun(simple_fn, [p], [dtypes.int32], [None])[0]
+ with session.Session() as sess:
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(r, feed_dict={p: 0})
+
+
+class MapDefunBenchmark(test.Benchmark):
+
+ def _run(self, op, name=None, num_iters=3000):
+ with session.Session() as sess:
+ # Warm up the session
+ for _ in range(5):
+ sess.run(op)
+ start = time.time()
+ for _ in range(num_iters):
+ sess.run(op)
+ end = time.time()
+ mean_us = (end - start) * 1e6 / num_iters
+ self.report_benchmark(
+ name=name,
+ iters=num_iters,
+ wall_time=mean_us,
+ extras={"examples_per_sec": num_iters / (end - start)})
+
+ def benchmarkDefunVsMapFn(self):
+ """Benchmarks to compare the performance of MapDefun vs tf.map_fn."""
+
+ @function.Defun(dtypes.int32)
+ def defun(x):
+ return array_ops.identity(x)
+
+ def map_fn(x):
+ return array_ops.identity(x)
+
+ base = math_ops.range(100)
+ for input_size in [10, 100, 1000, 10000]:
+ num_iters = 100000 // input_size
+ map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()])
+ map_fn_op = functional_ops.map_fn(map_fn, base)
+
+ self._run(
+ map_defun_op,
+ "benchmarkMapDefun_size_%d" % input_size,
+ num_iters=num_iters)
+ self._run(
+ map_fn_op, "benchmarkMapFn_size_%d" % input_size, num_iters=num_iters)
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
new file mode 100644
index 0000000000..459bdf66f3
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
@@ -0,0 +1,88 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_test(
+ name = "assert_next_dataset_op_test",
+ size = "medium",
+ srcs = ["assert_next_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "latency_all_edges_test",
+ size = "small",
+ srcs = ["latency_all_edges_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base",
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/contrib/data/python/ops:stats_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "map_vectorization_test",
+ size = "small",
+ srcs = ["map_vectorization_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/data/python/kernel_tests:test_utils",
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/python:check_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "map_and_filter_fusion_test",
+ size = "medium",
+ srcs = ["map_and_filter_fusion_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "optimize_dataset_op_test",
+ size = "small",
+ srcs = ["optimize_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/data/python/ops:optimization",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
new file mode 100644
index 0000000000..bd7b50b902
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
@@ -0,0 +1,64 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+
+class AssertNextDatasetTest(test.TestCase):
+
+ def testAssertNext(self):
+ dataset = dataset_ops.Dataset.from_tensors(0).apply(
+ optimization.assert_next(["Map"])).map(lambda x: x)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ self.assertEqual(0, sess.run(get_next))
+
+ def testAssertNextInvalid(self):
+ dataset = dataset_ops.Dataset.from_tensors(0).apply(
+ optimization.assert_next(["Whoops"])).map(lambda x: x)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Asserted Whoops transformation at offset 0 but encountered "
+ "Map transformation instead."):
+ sess.run(get_next)
+
+ def testAssertNextShort(self):
+ dataset = dataset_ops.Dataset.from_tensors(0).apply(
+ optimization.assert_next(["Map", "Whoops"])).map(lambda x: x)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Asserted next 2 transformations but encountered only 1."):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py
new file mode 100644
index 0000000000..db380c02a9
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py
@@ -0,0 +1,58 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the LatencyAllEdges optimization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.contrib.data.python.ops import stats_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+
+class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
+
+ def testLatencyStatsOptimization(self):
+
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.from_tensors(1).apply(
+ optimization.assert_next(
+ ["LatencyStats", "Map", "LatencyStats", "Prefetch",
+ "LatencyStats"])).map(lambda x: x * x).prefetch(1).apply(
+ optimization.optimize(["latency_all_edges"])).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ get_next = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ self.assertEqual(1 * 1, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasCount(summary_str,
+ "record_latency_TensorDataset/_1", 1)
+ self._assertSummaryHasCount(summary_str, "record_latency_MapDataset/_4",
+ 1)
+ self._assertSummaryHasCount(summary_str,
+ "record_latency_PrefetchDataset/_6", 1)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
new file mode 100644
index 0000000000..dde115925e
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
@@ -0,0 +1,224 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the MapAndFilterFusion optimization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase):
+
+ @staticmethod
+ def map_functions():
+ identity = lambda x: x
+ increment = lambda x: x + 1
+
+ def increment_and_square(x):
+ y = x + 1
+ return y * y
+
+ functions = [identity, increment, increment_and_square]
+ tests = []
+ for i, fun1 in enumerate(functions):
+ for j, fun2 in enumerate(functions):
+ tests.append((
+ "Test{}{}".format(i, j),
+ [fun1, fun2],
+ ))
+ for k, fun3 in enumerate(functions):
+ tests.append((
+ "Test{}{}{}".format(i, j, k),
+ [fun1, fun2, fun3],
+ ))
+
+ swap = lambda x, n: (n, x)
+ tests.append((
+ "Swap1",
+ [lambda x: (x, 42), swap],
+ ))
+ tests.append((
+ "Swap2",
+ [lambda x: (x, 42), swap, swap],
+ ))
+ return tuple(tests)
+
+ @parameterized.named_parameters(*map_functions.__func__())
+ def testMapFusion(self, functions):
+ dataset = dataset_ops.Dataset.range(5).apply(
+ optimization.assert_next(["Map", "Prefetch"]))
+ for function in functions:
+ dataset = dataset.map(function)
+
+ dataset = dataset.prefetch(0).apply(optimization.optimize(["map_fusion"]))
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ for x in range(5):
+ result = sess.run(get_next)
+ r = x
+ for function in functions:
+ if isinstance(r, tuple):
+ r = function(*r) # Pass tuple as multiple arguments.
+ else:
+ r = function(r)
+ self.assertAllEqual(r, result)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ @staticmethod
+ def map_and_filter_functions():
+ identity = lambda x: x
+ increment = lambda x: x + 1
+ minus_five = lambda x: x - 5
+
+ def increment_and_square(x):
+ y = x + 1
+ return y * y
+
+ take_all = lambda x: constant_op.constant(True)
+ is_zero = lambda x: math_ops.equal(x, 0)
+ is_odd = lambda x: math_ops.equal(x % 2, 0)
+ greater = lambda x: math_ops.greater(x + 5, 0)
+
+ functions = [identity, increment, minus_five, increment_and_square]
+ filters = [take_all, is_zero, is_odd, greater]
+ tests = []
+
+ for x, fun in enumerate(functions):
+ for y, predicate in enumerate(filters):
+ tests.append(("Mixed{}{}".format(x, y), fun, predicate))
+
+ # Multi output
+ tests.append(("Multi1", lambda x: (x, x),
+ lambda x, y: constant_op.constant(True)))
+ tests.append(
+ ("Multi2", lambda x: (x, 2),
+ lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)))
+ return tuple(tests)
+
+ @parameterized.named_parameters(*map_and_filter_functions.__func__())
+ def testMapFilterFusion(self, function, predicate):
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next(
+ ["Map",
+ "FilterByLastComponent"])).map(function).filter(predicate).apply(
+ optimization.optimize(["map_and_filter_fusion"]))
+ self._testMapAndFilter(dataset, function, predicate)
+
+ def _testMapAndFilter(self, dataset, function, predicate):
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ for x in range(10):
+ r = function(x)
+ if isinstance(r, tuple):
+ b = predicate(*r) # Pass tuple as multiple arguments.
+ else:
+ b = predicate(r)
+ if sess.run(b):
+ result = sess.run(get_next)
+ self.assertAllEqual(r, result)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testAdditionalInputs(self):
+ a = constant_op.constant(3, dtype=dtypes.int64)
+ b = constant_op.constant(4, dtype=dtypes.int64)
+ some_tensor = math_ops.mul(a, b)
+ function = lambda x: x * x
+
+ def predicate(y):
+ return math_ops.less(math_ops.cast(y, dtypes.int64), some_tensor)
+
+ # We are currently not supporting functions with additional inputs.
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next(
+ ["Map", "Filter"])).map(function).filter(predicate).apply(
+ optimization.optimize(["map_and_filter_fusion"]))
+
+ self._testMapAndFilter(dataset, function, predicate)
+
+ @staticmethod
+ def filter_functions():
+ take_all = lambda x: constant_op.constant(True)
+ is_zero = lambda x: math_ops.equal(x, 0)
+ greater = lambda x: math_ops.greater(x + 5, 0)
+
+ tests = []
+ filters = [take_all, is_zero, greater]
+ identity = lambda x: x
+ for x, predicate_1 in enumerate(filters):
+ for y, predicate_2 in enumerate(filters):
+ tests.append(("Mixed{}{}".format(x, y), identity,
+ [predicate_1, predicate_2]))
+ for z, predicate_3 in enumerate(filters):
+ tests.append(("Mixed{}{}{}".format(x, y, z), identity,
+ [predicate_1, predicate_2, predicate_3]))
+
+ take_all_multiple = lambda x, y: constant_op.constant(True)
+ # Multi output
+ tests.append(("Multi1", lambda x: (x, x),
+ [take_all_multiple, take_all_multiple]))
+ tests.append(("Multi2", lambda x: (x, 2), [
+ take_all_multiple,
+ lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
+ ]))
+ return tuple(tests)
+
+ @parameterized.named_parameters(*filter_functions.__func__())
+ def testFilterFusion(self, map_function, predicates):
+ dataset = dataset_ops.Dataset.range(5).apply(
+ optimization.assert_next(["Map", "Filter",
+ "Prefetch"])).map(map_function)
+ for predicate in predicates:
+ dataset = dataset.filter(predicate)
+
+ dataset = dataset.prefetch(0).apply(
+ optimization.optimize(["filter_fusion"]))
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+ with self.test_session() as sess:
+ for x in range(5):
+ r = map_function(x)
+ filtered = False
+ for predicate in predicates:
+ if isinstance(r, tuple):
+ b = predicate(*r) # Pass tuple as multiple arguments.
+ else:
+ b = predicate(r)
+ if not sess.run(b):
+ filtered = True
+ break
+
+ if not filtered:
+ result = sess.run(get_next)
+ self.assertAllEqual(r, result)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
new file mode 100644
index 0000000000..e2c9bc82df
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
@@ -0,0 +1,219 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the MapVectorization optimization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.contrib.data.python.kernel_tests import test_utils
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.client import session
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
+
+ def _get_test_datasets(self,
+ base_dataset,
+ map_fn,
+ num_parallel_calls=None,
+ expect_optimized=True):
+ """Given base dataset and map fn, creates test datasets.
+
+ Returns a tuple of (unoptimized, dataset, optimized dataset). The
+ unoptimized dataset has the assertion that Batch follows Map. The optimized
+ dataset has the assertion that Map follows Batch, and has the
+ "map_vectorization" optimization applied.
+
+ Args:
+ base_dataset: Input dataset to map->batch
+ map_fn: Map function to use
+ num_parallel_calls: (Optional.) num_parallel_calls argument for map
+ expect_optimized: (Optional.) Whether we expect the optimization to take
+ place, in which case we will assert that Batch is followed by Map,
+ otherwise Map followed by Batch. Defaults to True.
+
+ Returns:
+ Tuple of (unoptimized dataset, optimized dataset).
+ """
+ map_node_name = "Map" if num_parallel_calls is None else "ParallelMap"
+ batch_size = 100
+
+ def _make_dataset(node_names):
+ return base_dataset.apply(optimization.assert_next(node_names)).map(
+ map_fn, num_parallel_calls=num_parallel_calls).batch(batch_size)
+
+ unoptimized = _make_dataset([map_node_name, "Batch"])
+ optimized = _make_dataset(["Batch", map_node_name] if expect_optimized else
+ [map_node_name, "Batch"]).apply(
+ optimization.optimize(["map_vectorization"]))
+
+ return unoptimized, optimized
+
+ @parameterized.named_parameters(
+ ("Basic", lambda x: (x, x + 1), None),
+ ("Parallel", lambda x: (x, x + 1), 12),
+ ("Gather", lambda x: array_ops.gather(x, 0), 12),
+ )
+ def testOptimization(self, map_fn, num_parallel_calls):
+ base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
+ [3, 4]]).repeat(5)
+ unoptimized, optimized = self._get_test_datasets(base_dataset, map_fn,
+ num_parallel_calls)
+ self._assert_datasets_equal(unoptimized, optimized)
+
+ def testOptimizationBadMapFn(self):
+ # Test map functions that give an error
+ def map_fn(x):
+ # x has leading dimension 5, this will raise an error
+ return array_ops.gather(x, 10)
+
+ base_dataset = dataset_ops.Dataset.range(5).repeat(5).batch(
+ 5, drop_remainder=True)
+ _, optimized = self._get_test_datasets(base_dataset, map_fn)
+ nxt = optimized.make_one_shot_iterator().get_next()
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ r"indices = 10 is not in \[0, 5\)"):
+ self.evaluate(nxt)
+
+ def testOptimizationWithCapturedInputs(self):
+ # Tests that vectorization works with captured inputs
+ def map_fn(x):
+ return x + y
+
+ y = constant_op.constant(1, shape=(2,))
+ base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
+ [3, 4]]).repeat(5)
+ # TODO(rachelim): when this optimization works, turn on expect_optimized
+ unoptimized, optimized = self._get_test_datasets(
+ base_dataset, map_fn, expect_optimized=False)
+ self._assert_datasets_equal(optimized, unoptimized)
+
+ def testOptimizationIgnoreStateful(self):
+
+ def map_fn(x):
+ with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
+ return array_ops.identity(x)
+
+ base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
+ [3, 4]]).repeat(5)
+ unoptimized, optimized = self._get_test_datasets(
+ base_dataset, map_fn, expect_optimized=False)
+ self._assert_datasets_raise_same_error(
+ unoptimized, optimized, errors.InvalidArgumentError,
+ [("OneShotIterator", "OneShotIterator_1", 1),
+ ("IteratorGetNext", "IteratorGetNext_1", 1)])
+
+ def testOptimizationIgnoreRagged(self):
+ # Make sure we ignore inputs that might not be uniformly sized
+ def map_fn(x):
+ return array_ops.gather(x, 0)
+
+ # output_shape = (?,)
+ base_dataset = dataset_ops.Dataset.range(20).batch(3, drop_remainder=False)
+ unoptimized, optimized = self._get_test_datasets(
+ base_dataset, map_fn, expect_optimized=False)
+ self._assert_datasets_equal(unoptimized, optimized)
+
+ def testOptimizationIgnoreRaggedMap(self):
+ # Don't optimize when the output of the map fn shapes are unknown.
+ def map_fn(x):
+ return array_ops.tile(x, x)
+
+ base_dataset = dataset_ops.Dataset.range(20).batch(1, drop_remainder=True)
+ unoptimized, optimized = self._get_test_datasets(
+ base_dataset, map_fn, expect_optimized=False)
+ self._assert_datasets_raise_same_error(
+ unoptimized, optimized, errors.InvalidArgumentError,
+ [("OneShotIterator", "OneShotIterator_1", 1),
+ ("IteratorGetNext", "IteratorGetNext_1", 1)])
+
+
+class MapVectorizationBenchmark(test.Benchmark):
+ # TODO(rachelim): Add a benchmark for more expensive transformations, such as
+ # vgg_preprocessing.
+
+ def _run(self, x, num_iters=100, name=None):
+ deltas = []
+ with session.Session() as sess:
+ for _ in range(5):
+ # Warm up session...
+ sess.run(x)
+ for _ in range(num_iters):
+ start = time.time()
+ sess.run(x)
+ end = time.time()
+ deltas.append(end - start)
+ median_time = np.median(deltas)
+ self.report_benchmark(iters=num_iters, wall_time=median_time, name=name)
+ return median_time
+
+ def benchmark_CheapFns(self):
+
+ input_sizes = [(10, 10, 3), (10, 100, 300)]
+ batch_size = 1000
+ for input_size in input_sizes:
+ input_dataset = dataset_ops.Dataset.from_tensor_slices(
+ (np.random.rand(*input_size), np.random.rand(*input_size))).repeat()
+ for map_fn, str_id in self._get_known_cheap_fns():
+ self._compare(input_dataset, map_fn, batch_size, input_size, str_id)
+
+ def _compare(self, input_dataset, map_fn, batch_size, input_size, str_id):
+ num_elems = np.prod(input_size)
+ name_template = "{}__batch_size_{}_input_size_{}_{}"
+ unoptimized = input_dataset.map(map_fn).batch(batch_size)
+ unoptimized_op = unoptimized.make_one_shot_iterator().get_next()
+
+ optimized = unoptimized.apply(optimization.optimize(["map_vectorization"]))
+ optimized_op = optimized.make_one_shot_iterator().get_next()
+
+ unoptimized_time = self._run(
+ unoptimized_op,
+ name=name_template.format(str_id, batch_size, num_elems, "unoptimized"))
+ optimized_time = self._run(
+ optimized_op,
+ name=name_template.format(str_id, batch_size, num_elems, "optimized"))
+
+ print("Batch size: {}\n"
+ "Input size: {}\n"
+ "Transformation: {}\n"
+ "Speedup: {}\n".format(batch_size, input_size, str_id,
+ (unoptimized_time / optimized_time)))
+
+ def _get_known_cheap_fns(self):
+ return [
+ (lambda *args: [array_ops.identity(x) for x in args], "identity"),
+ (lambda *args: [x + 1 for x in args], "add_const"),
+ (lambda *args: args[0], "select"),
+ (lambda *args: [math_ops.cast(x, dtypes.float64) for x in args],
+ "cast"),
+ ]
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
new file mode 100644
index 0000000000..909da5aee0
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
@@ -0,0 +1,108 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test
+
+
+class OptimizeDatasetTest(test.TestCase):
+
+ def testOptimizationDefault(self):
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next(
+ ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply(
+ optimization.optimize())
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testOptimizationEmpty(self):
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next(
+ ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply(
+ optimization.optimize([]))
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testOptimizationFusion(self):
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next(
+ ["MapAndBatch"])).map(lambda x: x * x).batch(10).apply(
+ optimization.optimize(["map_and_batch_fusion"]))
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testOptimizationStatefulFunction(self):
+ dataset = dataset_ops.Dataset.range(10).map(
+ lambda _: random_ops.random_uniform([])).batch(10).apply(
+ optimization.optimize(["map_and_batch_fusion"]))
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(get_next)
+
+ def testOptimizationLargeInputFromTensor(self):
+ input_t = array_ops.placeholder(dtypes.int32, (None, None, None))
+ dataset = dataset_ops.Dataset.from_tensors(input_t).apply(
+ optimization.optimize())
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)})
+ sess.run(get_next)
+
+ def testOptimizationLargeInputFromTensorSlices(self):
+ input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None))
+ dataset = dataset_ops.Dataset.from_tensor_slices(input_t).apply(
+ optimization.optimize())
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)})
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
deleted file mode 100644
index 76aa1c3cfd..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/optimize_dataset_op_test.py
+++ /dev/null
@@ -1,404 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from absl.testing import parameterized
-
-from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base
-from tensorflow.contrib.data.python.kernel_tests import test_utils
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.contrib.data.python.ops import stats_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class OptimizeDatasetTest(test.TestCase, parameterized.TestCase):
-
- def testAssertSuffix(self):
- dataset = dataset_ops.Dataset.from_tensors(0).apply(
- optimization.assert_next(["Map"])).map(lambda x: x)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- self.assertEqual(0, sess.run(get_next))
-
- def testAssertSuffixInvalid(self):
- dataset = dataset_ops.Dataset.from_tensors(0).apply(
- optimization.assert_next(["Whoops"])).map(lambda x: x)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- "Asserted Whoops transformation at offset 0 but encountered "
- "Map transformation instead."):
- sess.run(get_next)
-
- def testAssertSuffixShort(self):
- dataset = dataset_ops.Dataset.from_tensors(0).apply(
- optimization.assert_next(["Map", "Whoops"])).map(lambda x: x)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- "Asserted next 2 transformations but encountered only 1."):
- sess.run(get_next)
-
- def testDefaultOptimizations(self):
- dataset = dataset_ops.Dataset.range(10).apply(
- optimization.assert_next(
- ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply(
- optimization.optimize())
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testEmptyOptimizations(self):
- dataset = dataset_ops.Dataset.range(10).apply(
- optimization.assert_next(
- ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply(
- optimization.optimize([]))
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testOptimization(self):
- dataset = dataset_ops.Dataset.range(10).apply(
- optimization.assert_next(
- ["MapAndBatch"])).map(lambda x: x * x).batch(10).apply(
- optimization.optimize(["map_and_batch_fusion"]))
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testFunctionLibraryDefinitionModification(self):
- dataset = dataset_ops.Dataset.from_tensors(0).map(lambda x: x).apply(
- optimization.optimize(["_test_only_function_rename"]))
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- with self.assertRaisesRegexp(errors.NotFoundError,
- "Function .* is not defined."):
- sess.run(get_next)
-
- @staticmethod
- def map_functions():
- identity = lambda x: x
- increment = lambda x: x + 1
-
- def increment_and_square(x):
- y = x + 1
- return y * y
-
- functions = [identity, increment, increment_and_square]
- tests = []
- for i, fun1 in enumerate(functions):
- for j, fun2 in enumerate(functions):
- tests.append((
- "test_{}_{}".format(i, j),
- [fun1, fun2],
- ))
- for k, fun3 in enumerate(functions):
- tests.append((
- "test_{}_{}_{}".format(i, j, k),
- [fun1, fun2, fun3],
- ))
-
- swap = lambda x, n: (n, x)
- tests.append((
- "swap1",
- [lambda x: (x, 42), swap],
- ))
- tests.append((
- "swap2",
- [lambda x: (x, 42), swap, swap],
- ))
- return tuple(tests)
-
- @parameterized.named_parameters(*map_functions.__func__())
- def testMapFusion(self, functions):
- dataset = dataset_ops.Dataset.range(5).apply(
- optimization.assert_next(["Map", "Prefetch"]))
- for function in functions:
- dataset = dataset.map(function)
-
- dataset = dataset.prefetch(0).apply(optimization.optimize(["map_fusion"]))
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
- with self.test_session() as sess:
- for x in range(5):
- result = sess.run(get_next)
- r = x
- for function in functions:
- if isinstance(r, tuple):
- r = function(*r) # Pass tuple as multiple arguments.
- else:
- r = function(r)
- self.assertAllEqual(r, result)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- @staticmethod
- def map_and_filter_functions():
- identity = lambda x: x
- increment = lambda x: x + 1
- minus_five = lambda x: x - 5
-
- def increment_and_square(x):
- y = x + 1
- return y * y
-
- take_all = lambda x: constant_op.constant(True)
- is_zero = lambda x: math_ops.equal(x, 0)
- is_odd = lambda x: math_ops.equal(x % 2, 0)
- greater = lambda x: math_ops.greater(x + 5, 0)
-
- functions = [identity, increment, minus_five, increment_and_square]
- filters = [take_all, is_zero, is_odd, greater]
- tests = []
-
- for x, fun in enumerate(functions):
- for y, predicate in enumerate(filters):
- tests.append(("mixed_{}_{}".format(x, y), fun, predicate))
-
- # Multi output
- tests.append(("multiOne", lambda x: (x, x),
- lambda x, y: constant_op.constant(True)))
- tests.append(
- ("multiTwo", lambda x: (x, 2),
- lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)))
- return tuple(tests)
-
- @parameterized.named_parameters(*map_and_filter_functions.__func__())
- def testMapFilterFusion(self, function, predicate):
- dataset = dataset_ops.Dataset.range(10).apply(
- optimization.assert_next(
- ["Map",
- "FilterByLastComponent"])).map(function).filter(predicate).apply(
- optimization.optimize(["map_and_filter_fusion"]))
- self._testMapAndFilter(dataset, function, predicate)
-
- def _testMapAndFilter(self, dataset, function, predicate):
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
- with self.test_session() as sess:
- for x in range(10):
- r = function(x)
- if isinstance(r, tuple):
- b = predicate(*r) # Pass tuple as multiple arguments.
- else:
- b = predicate(r)
- if sess.run(b):
- result = sess.run(get_next)
- self.assertAllEqual(r, result)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testAdditionalInputs(self):
- a = constant_op.constant(3, dtype=dtypes.int64)
- b = constant_op.constant(4, dtype=dtypes.int64)
- some_tensor = math_ops.mul(a, b)
- function = lambda x: x * x
-
- def predicate(y):
- return math_ops.less(math_ops.cast(y, dtypes.int64), some_tensor)
-
- # We are currently not supporting functions with additional inputs.
- dataset = dataset_ops.Dataset.range(10).apply(
- optimization.assert_next(
- ["Map", "Filter"])).map(function).filter(predicate).apply(
- optimization.optimize(["map_and_filter_fusion"]))
-
- self._testMapAndFilter(dataset, function, predicate)
-
-
-class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
-
- def testLatencyStatsOptimization(self):
-
- stats_aggregator = stats_ops.StatsAggregator()
- dataset = dataset_ops.Dataset.from_tensors(1).apply(
- optimization.assert_next(
- ["LatencyStats", "Map", "LatencyStats", "Prefetch",
- "LatencyStats"])).map(lambda x: x * x).prefetch(1).apply(
- optimization.optimize(["latency_all_edges"])).apply(
- stats_ops.set_stats_aggregator(stats_aggregator))
- iterator = dataset.make_initializable_iterator()
- get_next = iterator.get_next()
- summary_t = stats_aggregator.get_summary()
-
- with self.test_session() as sess:
- sess.run(iterator.initializer)
- self.assertEqual(1 * 1, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
- summary_str = sess.run(summary_t)
- self._assertSummaryHasCount(summary_str,
- "record_latency_TensorDataset/_1", 1)
- self._assertSummaryHasCount(summary_str, "record_latency_MapDataset/_4",
- 1)
- self._assertSummaryHasCount(summary_str,
- "record_latency_PrefetchDataset/_6", 1)
-
-
-class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
-
- def _get_test_datasets(self,
- base_dataset,
- map_fn,
- num_parallel_calls=None,
- expect_optimized=True):
- """Given base dataset and map fn, creates test datasets.
-
- Returns a tuple of (unoptimized, dataset, optimized dataset). The
- unoptimized dataset has the assertion that Batch follows Map. The optimized
- dataset has the assertion that Map follows Batch, and has the
- "map_vectorization" optimization applied.
-
- Args:
- base_dataset: Input dataset to map->batch
- map_fn: Map function to use
- num_parallel_calls: (Optional.) num_parallel_calls argument for map
- expect_optimized: (Optional.) Whether we expect the optimization to take
- place, in which case we will assert that Batch is followed by Map,
- otherwise Map followed by Batch. Defaults to True.
-
- Returns:
- Tuple of (unoptimized dataset, optimized dataset).
- """
- map_node_name = "Map" if num_parallel_calls is None else "ParallelMap"
- batch_size = 100
-
- def _make_dataset(node_names):
- return base_dataset.apply(optimization.assert_next(node_names)).map(
- map_fn, num_parallel_calls=num_parallel_calls).batch(batch_size)
-
- unoptimized = _make_dataset([map_node_name, "Batch"])
- optimized = _make_dataset(["Batch", map_node_name] if expect_optimized else
- [map_node_name, "Batch"]).apply(
- optimization.optimize(["map_vectorization"]))
-
- return unoptimized, optimized
-
- @parameterized.named_parameters(
- ("Basic", lambda x: (x, x + 1), None),
- ("Parallel", lambda x: (x, x + 1), 12),
- ("Gather", lambda x: array_ops.gather(x, 0), 12),
- )
- def testOptimization(self, map_fn, num_parallel_calls):
- base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
- [3, 4]]).repeat(5)
- unoptimized, optimized = self._get_test_datasets(base_dataset, map_fn,
- num_parallel_calls)
- self._assert_datasets_equal(unoptimized, optimized)
-
- def testOptimizationBadMapFn(self):
- # Test map functions that give an error
- def map_fn(x):
- # x has leading dimension 5, this will raise an error
- return array_ops.gather(x, 10)
-
- base_dataset = dataset_ops.Dataset.range(5).repeat(5).batch(
- 5, drop_remainder=True)
- _, optimized = self._get_test_datasets(base_dataset, map_fn)
- nxt = optimized.make_one_shot_iterator().get_next()
- with self.assertRaisesRegexp(errors.InvalidArgumentError,
- r"indices = 10 is not in \[0, 5\)"):
- self.evaluate(nxt)
-
- def testOptimizationWithCapturedInputs(self):
- # Tests that vectorization works with captured inputs
- def map_fn(x):
- return x + y
-
- y = constant_op.constant(1, shape=(2,))
- base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
- [3, 4]]).repeat(5)
- # TODO(rachelim): when this optimization works, turn on expect_optimized
- unoptimized, optimized = self._get_test_datasets(
- base_dataset, map_fn, expect_optimized=False)
- self._assert_datasets_equal(optimized, unoptimized)
-
- def testOptimizationIgnoreStateful(self):
-
- def map_fn(x):
- with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
- return array_ops.identity(x)
-
- base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
- [3, 4]]).repeat(5)
- _, optimized = self._get_test_datasets(
- base_dataset, map_fn, expect_optimized=False)
- nxt = optimized.make_one_shot_iterator().get_next()
-
- # NOTE: Right now, it raises an error because we can't save datasets that
- # are stateful, and we rely on this saving mechanism to optimize datasets,
- # so stateful functions can't be optimized.
- with self.assertRaisesRegexp(errors.InvalidArgumentError, "[Ss]tateful"):
- self.evaluate(nxt)
-
- def testOptimizationIgnoreRagged(self):
- # Make sure we ignore inputs that might not be uniformly sized
- def map_fn(x):
- return array_ops.gather(x, 0)
-
- # output_shape = (?,)
- base_dataset = dataset_ops.Dataset.range(20).batch(3, drop_remainder=False)
- unoptimized, optimized = self._get_test_datasets(
- base_dataset, map_fn, expect_optimized=False)
- self._assert_datasets_equal(unoptimized, optimized)
-
- def testOptimizationIgnoreRaggedMap(self):
- # Don't optimize when the output of the map fn shapes are unknown.
- def map_fn(x):
- return array_ops.tile(x, x)
-
- base_dataset = dataset_ops.Dataset.range(20).batch(1, drop_remainder=True)
- unoptimized, optimized = self._get_test_datasets(
- base_dataset, map_fn, expect_optimized=False)
- self._assert_datasets_raise_same_error(unoptimized, optimized,
- errors.InvalidArgumentError)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
new file mode 100644
index 0000000000..c4623bca73
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
@@ -0,0 +1,850 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.parsing_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+
+import numpy as np
+
+from tensorflow.contrib.data.python.ops import parsing_ops as contrib_parsing_ops
+from tensorflow.core.example import example_pb2
+from tensorflow.core.example import feature_pb2
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+# Helpers for creating Example objects
+example = example_pb2.Example
+feature = feature_pb2.Feature
+features = lambda d: feature_pb2.Features(feature=d)
+bytes_feature = lambda v: feature(bytes_list=feature_pb2.BytesList(value=v))
+int64_feature = lambda v: feature(int64_list=feature_pb2.Int64List(value=v))
+float_feature = lambda v: feature(float_list=feature_pb2.FloatList(value=v))
+# Helpers for creating SequenceExample objects
+feature_list = lambda l: feature_pb2.FeatureList(feature=l)
+feature_lists = lambda d: feature_pb2.FeatureLists(feature_list=d)
+sequence_example = example_pb2.SequenceExample
+
+
+def _compare_output_to_expected(tester, dict_tensors, expected_tensors,
+ flat_output):
+ tester.assertEqual(set(dict_tensors.keys()), set(expected_tensors.keys()))
+
+ i = 0 # Index into the flattened output of session.run()
+ for k, v in sorted(dict_tensors.items()):
+ # TODO(shivaniagrawal): flat_output is same as v.
+ expected_v = expected_tensors[k]
+ tf_logging.info("Comparing key: %s", k)
+ print("i", i, "flat_output", flat_output[i], "expected_v", expected_v)
+ if sparse_tensor.is_sparse(v):
+ # Three outputs for SparseTensor : indices, values, shape.
+ tester.assertEqual([k, len(expected_v)], [k, 3])
+ print("i", i, "flat_output", flat_output[i].indices, "expected_v",
+ expected_v[0])
+ tester.assertAllEqual(expected_v[0], flat_output[i].indices)
+ tester.assertAllEqual(expected_v[1], flat_output[i].values)
+ tester.assertAllEqual(expected_v[2], flat_output[i].dense_shape)
+ else:
+ # One output for standard Tensor.
+ tester.assertAllEqual(expected_v, flat_output[i])
+ i += 1
+
+
+class ParseExampleTest(test.TestCase):
+
+ def _test(self,
+ input_tensor,
+ feature_val,
+ expected_values=None,
+ expected_err=None):
+
+ with self.cached_session() as sess:
+ if expected_err:
+ with self.assertRaisesWithPredicateMatch(expected_err[0],
+ expected_err[1]):
+ dataset = dataset_ops.Dataset.from_tensors(input_tensor).apply(
+ contrib_parsing_ops.parse_example_dataset(feature_val))
+ get_next = dataset.make_one_shot_iterator().get_next()
+ sess.run(get_next)
+ return
+ else:
+ # Returns dict w/ Tensors and SparseTensors.
+ # Check values.
+ dataset = dataset_ops.Dataset.from_tensors(input_tensor).apply(
+ contrib_parsing_ops.parse_example_dataset(feature_val))
+ get_next = dataset.make_one_shot_iterator().get_next()
+ result = sess.run(get_next)
+ flattened = nest.flatten(result)
+ print("result", result, "expected_values", expected_values)
+ _compare_output_to_expected(self, result, expected_values, flattened)
+
+ # Check shapes; if serialized is a Tensor we need its size to
+ # properly check.
+ batch_size = (
+ input_tensor.eval().size if isinstance(input_tensor, ops.Tensor) else
+ np.asarray(input_tensor).size)
+ for k, f in feature_val.items():
+ print("output_shapes as list ",
+ tuple(dataset.output_shapes[k].as_list()))
+ if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None:
+ self.assertEqual(dataset.output_shapes[k].as_list()[0], batch_size)
+ elif isinstance(f, parsing_ops.VarLenFeature):
+ self.assertEqual(dataset.output_shapes[k].as_list()[1], None)
+
+ def testEmptySerializedWithAllDefaults(self):
+ sparse_name = "st_a"
+ a_name = "a"
+ b_name = "b"
+ c_name = "c:has_a_tricky_name"
+ a_default = [0, 42, 0]
+ b_default = np.random.rand(3, 3).astype(bytes)
+ c_default = np.random.rand(2).astype(np.float32)
+
+ expected_st_a = ( # indices, values, shape
+ np.empty(
+ (0, 2), dtype=np.int64), # indices
+ np.empty(
+ (0,), dtype=np.int64), # sp_a is DT_INT64
+ np.array(
+ [2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
+
+ expected_output = {
+ sparse_name: expected_st_a,
+ a_name: np.array(2 * [[a_default]]),
+ b_name: np.array(2 * [b_default]),
+ c_name: np.array(2 * [c_default]),
+ }
+
+ self._test(
+ ops.convert_to_tensor(["", ""]), {
+ sparse_name:
+ parsing_ops.VarLenFeature(dtypes.int64),
+ a_name:
+ parsing_ops.FixedLenFeature(
+ (1, 3), dtypes.int64, default_value=a_default),
+ b_name:
+ parsing_ops.FixedLenFeature(
+ (3, 3), dtypes.string, default_value=b_default),
+ c_name:
+ parsing_ops.FixedLenFeature(
+ (2,), dtypes.float32, default_value=c_default),
+ },
+ expected_values=expected_output)
+
+ def testEmptySerializedWithoutDefaultsShouldFail(self):
+ input_features = {
+ "st_a":
+ parsing_ops.VarLenFeature(dtypes.int64),
+ "a":
+ parsing_ops.FixedLenFeature(
+ (1, 3), dtypes.int64, default_value=[0, 42, 0]),
+ "b":
+ parsing_ops.FixedLenFeature(
+ (3, 3),
+ dtypes.string,
+ default_value=np.random.rand(3, 3).astype(bytes)),
+ # Feature "c" is missing a default, this gap will cause failure.
+ "c":
+ parsing_ops.FixedLenFeature(
+ (2,), dtype=dtypes.float32),
+ }
+
+ # Edge case where the key is there but the feature value is empty
+ original = example(features=features({"c": feature()}))
+ self._test(
+ [original.SerializeToString()],
+ input_features,
+ expected_err=(errors_impl.InvalidArgumentError,
+ "Feature: c \\(data type: float\\) is required"))
+
+ # Standard case of missing key and value.
+ self._test(
+ ["", ""],
+ input_features,
+ expected_err=(errors_impl.InvalidArgumentError,
+ "Feature: c \\(data type: float\\) is required"))
+
+ def testDenseNotMatchingShapeShouldFail(self):
+ original = [
+ example(features=features({
+ "a": float_feature([1, 1, 3]),
+ })), example(features=features({
+ "a": float_feature([-1, -1]),
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ self._test(
+ ops.convert_to_tensor(serialized),
+ {"a": parsing_ops.FixedLenFeature((1, 3), dtypes.float32)},
+ expected_err=(errors_impl.InvalidArgumentError,
+ "Key: a, Index: 1. Number of float values"))
+
+ def testDenseDefaultNoShapeShouldFail(self):
+ original = [example(features=features({"a": float_feature([1, 1, 3]),})),]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ self._test(
+ ops.convert_to_tensor(serialized),
+ {"a": parsing_ops.FixedLenFeature(None, dtypes.float32)},
+ expected_err=(ValueError, "Missing shape for feature a"))
+
+ def testSerializedContainingSparse(self):
+ original = [
+ example(features=features({
+ "st_c": float_feature([3, 4])
+ })),
+ example(features=features({
+ "st_c": float_feature([]), # empty float list
+ })),
+ example(features=features({
+ "st_d": feature(), # feature with nothing in it
+ })),
+ example(features=features({
+ "st_c": float_feature([1, 2, -1]),
+ "st_d": bytes_feature([b"hi"])
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_st_c = ( # indices, values, shape
+ np.array(
+ [[0, 0], [0, 1], [3, 0], [3, 1], [3, 2]], dtype=np.int64), np.array(
+ [3.0, 4.0, 1.0, 2.0, -1.0], dtype=np.float32), np.array(
+ [4, 3], dtype=np.int64)) # batch == 2, max_elems = 3
+
+ expected_st_d = ( # indices, values, shape
+ np.array(
+ [[3, 0]], dtype=np.int64), np.array(
+ ["hi"], dtype=bytes), np.array(
+ [4, 1], dtype=np.int64)) # batch == 2, max_elems = 1
+
+ expected_output = {
+ "st_c": expected_st_c,
+ "st_d": expected_st_d,
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "st_c": parsing_ops.VarLenFeature(dtypes.float32),
+ "st_d": parsing_ops.VarLenFeature(dtypes.string)
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingSparseFeature(self):
+ original = [
+ example(features=features({
+ "val": float_feature([3, 4]),
+ "idx": int64_feature([5, 10])
+ })),
+ example(features=features({
+ "val": float_feature([]), # empty float list
+ "idx": int64_feature([])
+ })),
+ example(features=features({
+ "val": feature(), # feature with nothing in it
+ # missing idx feature
+ })),
+ example(features=features({
+ "val": float_feature([1, 2, -1]),
+ "idx":
+ int64_feature([0, 9, 3]) # unsorted
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_sp = ( # indices, values, shape
+ np.array(
+ [[0, 5], [0, 10], [3, 0], [3, 3], [3, 9]], dtype=np.int64),
+ np.array(
+ [3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32), np.array(
+ [4, 13], dtype=np.int64)) # batch == 4, max_elems = 13
+
+ expected_output = {"sp": expected_sp,}
+
+ self._test(
+ ops.convert_to_tensor(serialized),
+ {"sp": parsing_ops.SparseFeature(["idx"], "val", dtypes.float32, [13])},
+ expected_values=expected_output)
+
+ def testSerializedContainingSparseFeatureReuse(self):
+ original = [
+ example(features=features({
+ "val1": float_feature([3, 4]),
+ "val2": float_feature([5, 6]),
+ "idx": int64_feature([5, 10])
+ })),
+ example(features=features({
+ "val1": float_feature([]), # empty float list
+ "idx": int64_feature([])
+ })),
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_sp1 = ( # indices, values, shape
+ np.array(
+ [[0, 5], [0, 10]], dtype=np.int64), np.array(
+ [3.0, 4.0], dtype=np.float32), np.array(
+ [2, 13], dtype=np.int64)) # batch == 2, max_elems = 13
+
+ expected_sp2 = ( # indices, values, shape
+ np.array(
+ [[0, 5], [0, 10]], dtype=np.int64), np.array(
+ [5.0, 6.0], dtype=np.float32), np.array(
+ [2, 7], dtype=np.int64)) # batch == 2, max_elems = 13
+
+ expected_output = {
+ "sp1": expected_sp1,
+ "sp2": expected_sp2,
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "sp1":
+ parsing_ops.SparseFeature("idx", "val1", dtypes.float32, 13),
+ "sp2":
+ parsing_ops.SparseFeature(
+ "idx", "val2", dtypes.float32, size=7, already_sorted=True)
+ },
+ expected_values=expected_output)
+
+ def testSerializedContaining3DSparseFeature(self):
+ original = [
+ example(features=features({
+ "val": float_feature([3, 4]),
+ "idx0": int64_feature([5, 10]),
+ "idx1": int64_feature([0, 2]),
+ })),
+ example(features=features({
+ "val": float_feature([]), # empty float list
+ "idx0": int64_feature([]),
+ "idx1": int64_feature([]),
+ })),
+ example(features=features({
+ "val": feature(), # feature with nothing in it
+ # missing idx feature
+ })),
+ example(features=features({
+ "val": float_feature([1, 2, -1]),
+ "idx0": int64_feature([0, 9, 3]), # unsorted
+ "idx1": int64_feature([1, 0, 2]),
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_sp = (
+ # indices
+ np.array(
+ [[0, 5, 0], [0, 10, 2], [3, 0, 1], [3, 3, 2], [3, 9, 0]],
+ dtype=np.int64),
+ # values
+ np.array([3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32),
+ # shape batch == 4, max_elems = 13
+ np.array([4, 13, 3], dtype=np.int64))
+
+ expected_output = {"sp": expected_sp,}
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "sp":
+ parsing_ops.SparseFeature(["idx0", "idx1"], "val",
+ dtypes.float32, [13, 3])
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingDense(self):
+ aname = "a"
+ bname = "b*has+a:tricky_name"
+ original = [
+ example(features=features({
+ aname: float_feature([1, 1]),
+ bname: bytes_feature([b"b0_str"]),
+ })), example(features=features({
+ aname: float_feature([-1, -1]),
+ bname: bytes_feature([b""]),
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ aname:
+ np.array(
+ [[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1),
+ bname:
+ np.array(
+ ["b0_str", ""], dtype=bytes).reshape(2, 1, 1, 1, 1),
+ }
+
+ # No defaults, values required
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenFeature((1, 2, 1), dtype=dtypes.float32),
+ bname:
+ parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string),
+ },
+ expected_values=expected_output)
+
+ # This test is identical as the previous one except
+ # for the creation of 'serialized'.
+ def testSerializedContainingDenseWithConcat(self):
+ aname = "a"
+ bname = "b*has+a:tricky_name"
+ # TODO(lew): Feature appearing twice should be an error in future.
+ original = [
+ (example(features=features({
+ aname: float_feature([10, 10]),
+ })), example(features=features({
+ aname: float_feature([1, 1]),
+ bname: bytes_feature([b"b0_str"]),
+ }))),
+ (
+ example(features=features({
+ bname: bytes_feature([b"b100"]),
+ })),
+ example(features=features({
+ aname: float_feature([-1, -1]),
+ bname: bytes_feature([b"b1"]),
+ })),),
+ ]
+
+ serialized = [
+ m.SerializeToString() + n.SerializeToString() for (m, n) in original
+ ]
+
+ expected_output = {
+ aname:
+ np.array(
+ [[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1),
+ bname:
+ np.array(
+ ["b0_str", "b1"], dtype=bytes).reshape(2, 1, 1, 1, 1),
+ }
+
+ # No defaults, values required
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenFeature((1, 2, 1), dtype=dtypes.float32),
+ bname:
+ parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string),
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingDenseScalar(self):
+ original = [
+ example(features=features({
+ "a": float_feature([1]),
+ })), example(features=features({}))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ "a":
+ np.array(
+ [[1], [-1]], dtype=np.float32) # 2x1 (column vector)
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "a":
+ parsing_ops.FixedLenFeature(
+ (1,), dtype=dtypes.float32, default_value=-1),
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingDenseWithDefaults(self):
+ original = [
+ example(features=features({
+ "a": float_feature([1, 1]),
+ })),
+ example(features=features({
+ "b": bytes_feature([b"b1"]),
+ })),
+ example(features=features({
+ "b": feature()
+ })),
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ "a":
+ np.array(
+ [[1, 1], [3, -3], [3, -3]], dtype=np.float32).reshape(3, 1, 2,
+ 1),
+ "b":
+ np.array(
+ ["tmp_str", "b1", "tmp_str"], dtype=bytes).reshape(3, 1, 1, 1,
+ 1),
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "a":
+ parsing_ops.FixedLenFeature(
+ (1, 2, 1), dtype=dtypes.float32, default_value=[3.0, -3.0]),
+ "b":
+ parsing_ops.FixedLenFeature(
+ (1, 1, 1, 1), dtype=dtypes.string, default_value="tmp_str"),
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingSparseAndSparseFeatureAndDenseWithNoDefault(self):
+ expected_st_a = ( # indices, values, shape
+ np.empty(
+ (0, 2), dtype=np.int64), # indices
+ np.empty(
+ (0,), dtype=np.int64), # sp_a is DT_INT64
+ np.array(
+ [2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
+ expected_sp = ( # indices, values, shape
+ np.array(
+ [[0, 0], [0, 3], [1, 7]], dtype=np.int64), np.array(
+ ["a", "b", "c"], dtype="|S"), np.array(
+ [2, 13], dtype=np.int64)) # batch == 4, max_elems = 13
+
+ original = [
+ example(features=features({
+ "c": float_feature([3, 4]),
+ "val": bytes_feature([b"a", b"b"]),
+ "idx": int64_feature([0, 3])
+ })), example(features=features({
+ "c": float_feature([1, 2]),
+ "val": bytes_feature([b"c"]),
+ "idx": int64_feature([7])
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ a_default = [1, 2, 3]
+ b_default = np.random.rand(3, 3).astype(bytes)
+ expected_output = {
+ "st_a": expected_st_a,
+ "sp": expected_sp,
+ "a": np.array(2 * [[a_default]]),
+ "b": np.array(2 * [b_default]),
+ "c": np.array(
+ [[3, 4], [1, 2]], dtype=np.float32),
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized),
+ {
+ "st_a":
+ parsing_ops.VarLenFeature(dtypes.int64),
+ "sp":
+ parsing_ops.SparseFeature("idx", "val", dtypes.string, 13),
+ "a":
+ parsing_ops.FixedLenFeature(
+ (1, 3), dtypes.int64, default_value=a_default),
+ "b":
+ parsing_ops.FixedLenFeature(
+ (3, 3), dtypes.string, default_value=b_default),
+ # Feature "c" must be provided, since it has no default_value.
+ "c":
+ parsing_ops.FixedLenFeature((2,), dtypes.float32),
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingSparseAndSparseFeatureWithReuse(self):
+ expected_idx = ( # indices, values, shape
+ np.array(
+ [[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.int64),
+ np.array([0, 3, 7, 1]), np.array(
+ [2, 2], dtype=np.int64)) # batch == 4, max_elems = 2
+
+ expected_sp = ( # indices, values, shape
+ np.array(
+ [[0, 0], [0, 3], [1, 1], [1, 7]], dtype=np.int64), np.array(
+ ["a", "b", "d", "c"], dtype="|S"), np.array(
+ [2, 13], dtype=np.int64)) # batch == 4, max_elems = 13
+
+ original = [
+ example(features=features({
+ "val": bytes_feature([b"a", b"b"]),
+ "idx": int64_feature([0, 3])
+ })), example(features=features({
+ "val": bytes_feature([b"c", b"d"]),
+ "idx": int64_feature([7, 1])
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ "idx": expected_idx,
+ "sp": expected_sp,
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "idx":
+ parsing_ops.VarLenFeature(dtypes.int64),
+ "sp":
+ parsing_ops.SparseFeature(["idx"], "val", dtypes.string, [13]),
+ },
+ expected_values=expected_output)
+
+ def _testSerializedContainingVarLenDenseLargerBatch(self, batch_size):
+ # During parsing, data read from the serialized proto is stored in buffers.
+ # For small batch sizes, a buffer will contain one minibatch entry.
+ # For larger batch sizes, a buffer may contain several minibatch
+ # entries. This test identified a bug where the code that copied
+ # data out of the buffers and into the output tensors assumed each
+ # buffer only contained one minibatch entry. The bug has since been fixed.
+ truth_int = [i for i in range(batch_size)]
+ truth_str = [[("foo%d" % i).encode(), ("bar%d" % i).encode()]
+ for i in range(batch_size)]
+
+ expected_str = copy.deepcopy(truth_str)
+
+ # Delete some intermediate entries
+ for i in range(batch_size):
+ col = 1
+ if np.random.rand() < 0.25:
+ # w.p. 25%, drop out the second entry
+ expected_str[i][col] = b"default"
+ col -= 1
+ truth_str[i].pop()
+ if np.random.rand() < 0.25:
+ # w.p. 25%, drop out the second entry (possibly again)
+ expected_str[i][col] = b"default"
+ truth_str[i].pop()
+
+ expected_output = {
+ # Batch size batch_size, 1 time step.
+ "a": np.array(truth_int, dtype=np.int64).reshape(batch_size, 1),
+ # Batch size batch_size, 2 time steps.
+ "b": np.array(expected_str, dtype="|S").reshape(batch_size, 2),
+ }
+
+ original = [
+ example(features=features(
+ {"a": int64_feature([truth_int[i]]),
+ "b": bytes_feature(truth_str[i])}))
+ for i in range(batch_size)
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ self._test(
+ ops.convert_to_tensor(serialized, dtype=dtypes.string), {
+ "a":
+ parsing_ops.FixedLenSequenceFeature(
+ shape=(),
+ dtype=dtypes.int64,
+ allow_missing=True,
+ default_value=-1),
+ "b":
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[],
+ dtype=dtypes.string,
+ allow_missing=True,
+ default_value="default"),
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingVarLenDenseLargerBatch(self):
+ np.random.seed(3456)
+ for batch_size in (1, 10, 20, 100, 256):
+ self._testSerializedContainingVarLenDenseLargerBatch(batch_size)
+
+ def testSerializedContainingVarLenDense(self):
+ aname = "a"
+ bname = "b"
+ cname = "c"
+ dname = "d"
+ original = [
+ example(features=features({
+ cname: int64_feature([2]),
+ })),
+ example(features=features({
+ aname: float_feature([1, 1]),
+ bname: bytes_feature([b"b0_str", b"b1_str"]),
+ })),
+ example(features=features({
+ aname: float_feature([-1, -1, 2, 2]),
+ bname: bytes_feature([b"b1"]),
+ })),
+ example(features=features({
+ aname: float_feature([]),
+ cname: int64_feature([3]),
+ })),
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ aname:
+ np.array(
+ [
+ [0, 0, 0, 0],
+ [1, 1, 0, 0],
+ [-1, -1, 2, 2],
+ [0, 0, 0, 0],
+ ],
+ dtype=np.float32).reshape(4, 2, 2, 1),
+ bname:
+ np.array(
+ [["", ""], ["b0_str", "b1_str"], ["b1", ""], ["", ""]],
+ dtype=bytes).reshape(4, 2, 1, 1, 1),
+ cname:
+ np.array([2, 0, 0, 3], dtype=np.int64).reshape(4, 1),
+ dname:
+ np.empty(shape=(4, 0), dtype=bytes),
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1), dtype=dtypes.float32, allow_missing=True),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (1, 1, 1), dtype=dtypes.string, allow_missing=True),
+ cname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.int64, allow_missing=True),
+ dname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.string, allow_missing=True),
+ },
+ expected_values=expected_output)
+
+ # Test with padding values.
+ expected_output_custom_padding = dict(expected_output)
+ expected_output_custom_padding[aname] = np.array(
+ [
+ [-2, -2, -2, -2],
+ [1, 1, -2, -2],
+ [-1, -1, 2, 2],
+ [-2, -2, -2, -2],
+ ],
+ dtype=np.float32).reshape(4, 2, 2, 1)
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1),
+ dtype=dtypes.float32,
+ allow_missing=True,
+ default_value=-2.0),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (1, 1, 1), dtype=dtypes.string, allow_missing=True),
+ cname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.int64, allow_missing=True),
+ dname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.string, allow_missing=True),
+ }, expected_output_custom_padding)
+
+ # Change number of required values so the inputs are not a
+ # multiple of this size.
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1), dtype=dtypes.float32, allow_missing=True),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1, 1), dtype=dtypes.string, allow_missing=True),
+ },
+ expected_err=(
+ errors_impl.OpError, "Key: b, Index: 2. "
+ "Number of bytes values is not a multiple of stride length."))
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1),
+ dtype=dtypes.float32,
+ allow_missing=True,
+ default_value=[]),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1, 1), dtype=dtypes.string, allow_missing=True),
+ },
+ expected_err=(ValueError,
+ "Cannot reshape a tensor with 0 elements to shape"))
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenFeature((None, 2, 1), dtype=dtypes.float32),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1, 1), dtype=dtypes.string, allow_missing=True),
+ },
+ expected_err=(ValueError,
+ "First dimension of shape for feature a unknown. "
+ "Consider using FixedLenSequenceFeature."))
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ cname:
+ parsing_ops.FixedLenFeature(
+ (1, None), dtype=dtypes.int64, default_value=[[1]]),
+ },
+ expected_err=(ValueError,
+ "All dimensions of shape for feature c need to be known "
+ r"but received \(1, None\)."))
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1), dtype=dtypes.float32, allow_missing=True),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (1, 1, 1), dtype=dtypes.string, allow_missing=True),
+ cname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.int64, allow_missing=False),
+ dname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.string, allow_missing=True),
+ },
+ expected_err=(ValueError,
+ "Unsupported: FixedLenSequenceFeature requires "
+ "allow_missing to be True."))
+
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
index 361fe0dd39..0166ba0d44 100644
--- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
@@ -235,7 +235,7 @@ class PrefetchingKernelsOpsTest(test.TestCase):
destroy_op = resource_variable_ops.destroy_resource_op(
buffer_resource_handle, ignore_lookup_error=True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual([b"a"], sess.run(prefetch_op))
self.assertEqual([b"b"], sess.run(prefetch_op))
self.assertEqual([b"c"], sess.run(prefetch_op))
@@ -301,7 +301,7 @@ class PrefetchToDeviceTest(test.TestCase):
self.assertEqual(dtypes.int64, next_element.dtype)
self.assertEqual([], next_element.shape)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
@@ -384,7 +384,7 @@ class PrefetchToDeviceTest(test.TestCase):
iterator = device_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
@@ -435,7 +435,7 @@ class PrefetchToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(5):
self.assertEqual(i, sess.run(next_element))
@@ -683,7 +683,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(10):
self.assertEqual(i, sess.run(next_element))
@@ -702,7 +702,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(10):
self.assertEqual(i, sess.run(next_element))
@@ -721,7 +721,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
@@ -739,7 +739,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
@@ -757,7 +757,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
@@ -775,7 +775,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
@@ -796,7 +796,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = back_to_cpu_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(10):
self.assertEqual(i, sess.run(next_element))
@@ -875,7 +875,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(5):
self.assertEqual(i, sess.run(next_element))
@@ -897,7 +897,7 @@ class CopyToDeviceTest(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for i in range(5):
self.assertEqual(i, sess.run(next_element))
@@ -920,7 +920,7 @@ class CopyToDeviceTest(test.TestCase):
elem_has_value_t = next_elem.has_value()
elem_value_t = next_elem.get_value()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Before initializing the iterator, evaluating the optional fails with
# a FailedPreconditionError.
with self.assertRaises(errors.FailedPreconditionError):
diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
index 592642da0c..db8fe6aa1b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
@@ -43,7 +43,7 @@ class RangeDatasetTest(test.TestCase):
self.assertEqual([tensor_shape.TensorShape([])] * 3,
[t.shape for t in get_next[1]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
self.assertEqual((20, (b"a", 1, 37.0)), sess.run(get_next))
self.assertEqual((21, (b"b", 2, 38.0)), sess.run(get_next))
@@ -63,7 +63,7 @@ class RangeDatasetTest(test.TestCase):
.make_one_shot_iterator())
negative_get_next = negative_iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(3, sess.run(get_next))
self.assertEqual(3 + 4, sess.run(get_next))
self.assertEqual(3 + 2 * 4, sess.run(get_next))
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
index 64fe6dae24..ed75b27a44 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
@@ -47,22 +47,50 @@ class ReadBatchFeaturesTest(
# Basic test: read from file 0.
self.outputs = self.make_batch_feature(
filenames=self.test_filenames[0],
+ label_key="label",
num_epochs=num_epochs,
batch_size=batch_size).make_one_shot_iterator().get_next()
- self.verify_records(sess, batch_size, 0, num_epochs=num_epochs)
+ self.verify_records(
+ sess,
+ batch_size,
+ 0,
+ num_epochs=num_epochs,
+ label_key_provided=True)
with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess)
+ self._next_actual_batch(sess, label_key_provided=True)
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
# Basic test: read from file 1.
self.outputs = self.make_batch_feature(
filenames=self.test_filenames[1],
+ label_key="label",
num_epochs=num_epochs,
batch_size=batch_size).make_one_shot_iterator().get_next()
- self.verify_records(sess, batch_size, 1, num_epochs=num_epochs)
+ self.verify_records(
+ sess,
+ batch_size,
+ 1,
+ num_epochs=num_epochs,
+ label_key_provided=True)
with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess)
+ self._next_actual_batch(sess, label_key_provided=True)
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ # Basic test: read from both files.
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames,
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size).make_one_shot_iterator().get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ num_epochs=num_epochs,
+ label_key_provided=True)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess, label_key_provided=True)
with ops.Graph().as_default() as g:
with self.session(graph=g) as sess:
@@ -88,9 +116,9 @@ class ReadBatchFeaturesTest(
init_op = iterator.initializer
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
- for file_batch, _, _, _, record_batch in self._next_expected_batch(
+ for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
range(self._num_files), 2, 10):
actual_batch = sess.run(next_element)
self.assertAllEqual(file_batch, actual_batch["file"])
@@ -155,6 +183,25 @@ class ReadBatchFeaturesTest(
with self.session(graph=g) as sess:
self.outputs = self.make_batch_feature(
filenames=self.test_filenames,
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ reader_num_threads=reader_num_threads,
+ parser_num_threads=parser_num_threads).make_one_shot_iterator(
+ ).get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ num_epochs=num_epochs,
+ label_key_provided=True,
+ interleave_cycle_length=reader_num_threads)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess, label_key_provided=True)
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames,
num_epochs=num_epochs,
batch_size=batch_size,
reader_num_threads=reader_num_threads,
@@ -175,16 +222,20 @@ class ReadBatchFeaturesTest(
# Basic test: read from file 0.
outputs = self.make_batch_feature(
filenames=self.test_filenames[0],
+ label_key="label",
num_epochs=num_epochs,
batch_size=batch_size,
drop_final_batch=True).make_one_shot_iterator().get_next()
- for _, tensor in outputs.items():
+ for tensor in nest.flatten(outputs):
if isinstance(tensor, ops.Tensor): # Guard against SparseTensor.
self.assertEqual(tensor.shape[0], batch_size)
def testIndefiniteRepeatShapeInference(self):
dataset = self.make_batch_feature(
- filenames=self.test_filenames[0], num_epochs=None, batch_size=32)
+ filenames=self.test_filenames[0],
+ label_key="label",
+ num_epochs=None,
+ batch_size=32)
for shape, clazz in zip(nest.flatten(dataset.output_shapes),
nest.flatten(dataset.output_classes)):
if issubclass(clazz, ops.Tensor):
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
index e63bc4c720..08b9f03816 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
@@ -76,6 +76,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
filenames,
num_epochs,
batch_size,
+ label_key=None,
reader_num_threads=1,
parser_num_threads=1,
shuffle=False,
@@ -91,8 +92,10 @@ class ReadBatchFeaturesTestBase(test.TestCase):
features={
"file": parsing_ops.FixedLenFeature([], dtypes.int64),
"record": parsing_ops.FixedLenFeature([], dtypes.int64),
- "keywords": parsing_ops.VarLenFeature(dtypes.string)
+ "keywords": parsing_ops.VarLenFeature(dtypes.string),
+ "label": parsing_ops.FixedLenFeature([], dtypes.string),
},
+ label_key=label_key,
reader=core_readers.TFRecordDataset,
num_epochs=self.num_epochs,
shuffle=shuffle,
@@ -101,7 +104,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
parser_num_threads=parser_num_threads,
drop_final_batch=drop_final_batch)
- def _record(self, f, r):
+ def _record(self, f, r, l):
example = example_pb2.Example(
features=feature_pb2.Features(
feature={
@@ -114,7 +117,11 @@ class ReadBatchFeaturesTestBase(test.TestCase):
"keywords":
feature_pb2.Feature(
bytes_list=feature_pb2.BytesList(
- value=self._get_keywords(f, r)))
+ value=self._get_keywords(f, r))),
+ "label":
+ feature_pb2.Feature(
+ bytes_list=feature_pb2.BytesList(
+ value=[compat.as_bytes(l)]))
}))
return example.SerializeToString()
@@ -139,23 +146,30 @@ class ReadBatchFeaturesTestBase(test.TestCase):
filenames.append(fn)
writer = python_io.TFRecordWriter(fn)
for j in range(self._num_records):
- writer.write(self._record(i, j))
+ writer.write(self._record(i, j, "fake-label"))
writer.close()
return filenames
- def _run_actual_batch(self, outputs, sess):
- file_op = outputs["file"]
- keywords_indices_op = outputs["keywords"].indices
- keywords_values_op = outputs["keywords"].values
- keywords_dense_shape_op = outputs["keywords"].dense_shape
- record_op = outputs["record"]
+ def _run_actual_batch(self, outputs, sess, label_key_provided=False):
+ if label_key_provided:
+ # outputs would be a tuple of (feature dict, label)
+ label_op = outputs[1]
+ features_op = outputs[0]
+ else:
+ features_op = outputs
+ label_op = features_op["label"]
+ file_op = features_op["file"]
+ keywords_indices_op = features_op["keywords"].indices
+ keywords_values_op = features_op["keywords"].values
+ keywords_dense_shape_op = features_op["keywords"].dense_shape
+ record_op = features_op["record"]
return sess.run([
file_op, keywords_indices_op, keywords_values_op,
- keywords_dense_shape_op, record_op
+ keywords_dense_shape_op, record_op, label_op
])
- def _next_actual_batch(self, sess):
- return self._run_actual_batch(self.outputs, sess)
+ def _next_actual_batch(self, sess, label_key_provided=False):
+ return self._run_actual_batch(self.outputs, sess, label_key_provided)
def _interleave(self, iterators, cycle_length):
pending_iterators = iterators
@@ -188,7 +202,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
def _next_record(file_indices):
for j in file_indices:
for i in range(self._num_records):
- yield j, i
+ yield j, i, compat.as_bytes("fake-label")
def _next_record_interleaved(file_indices, cycle_length):
return self._interleave([_next_record([i]) for i in file_indices],
@@ -200,6 +214,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
keywords_batch_max_len = 0
record_batch = []
batch_index = 0
+ label_batch = []
for _ in range(num_epochs):
if cycle_length == 1:
next_records = _next_record(file_indices)
@@ -208,6 +223,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
for record in next_records:
f = record[0]
r = record[1]
+ label_batch.append(record[2])
file_batch.append(f)
record_batch.append(r)
keywords = self._get_keywords(f, r)
@@ -219,7 +235,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
if len(file_batch) == batch_size:
yield [
file_batch, keywords_batch_indices, keywords_batch_values,
- [batch_size, keywords_batch_max_len], record_batch
+ [batch_size, keywords_batch_max_len], record_batch, label_batch
]
file_batch = []
keywords_batch_indices = []
@@ -227,10 +243,11 @@ class ReadBatchFeaturesTestBase(test.TestCase):
keywords_batch_max_len = 0
record_batch = []
batch_index = 0
+ label_batch = []
if file_batch:
yield [
file_batch, keywords_batch_indices, keywords_batch_values,
- [len(file_batch), keywords_batch_max_len], record_batch
+ [len(file_batch), keywords_batch_max_len], record_batch, label_batch
]
def verify_records(self,
@@ -238,6 +255,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
batch_size,
file_index=None,
num_epochs=1,
+ label_key_provided=False,
interleave_cycle_length=1):
if file_index is not None:
file_indices = [file_index]
@@ -245,8 +263,12 @@ class ReadBatchFeaturesTestBase(test.TestCase):
file_indices = range(self._num_files)
for expected_batch in self._next_expected_batch(
- file_indices, batch_size, num_epochs, interleave_cycle_length):
- actual_batch = self._next_actual_batch(sess)
+ file_indices,
+ batch_size,
+ num_epochs,
+ cycle_length=interleave_cycle_length):
+ actual_batch = self._next_actual_batch(
+ sess, label_key_provided=label_key_provided)
for i in range(len(expected_batch)):
self.assertAllEqual(expected_batch[i], actual_batch[i])
diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
index c5cfddb72b..16b1441baa 100644
--- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
@@ -77,7 +77,7 @@ class ResampleTest(test.TestCase, parameterized.TestCase):
class_func=lambda c, _: c,
seed=27)).make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
returned = []
while len(returned) < 4000:
returned.append(sess.run(get_next))
@@ -115,7 +115,7 @@ class ResampleTest(test.TestCase, parameterized.TestCase):
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
returned = []
with self.assertRaises(errors.OutOfRangeError):
while True:
@@ -146,7 +146,7 @@ class ResampleTest(test.TestCase, parameterized.TestCase):
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
returned = []
with self.assertRaises(errors.OutOfRangeError):
while True:
diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
index 42cada0b97..dde678bd54 100644
--- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
@@ -50,7 +50,7 @@ class ScanDatasetTest(test.TestCase):
start, make_scan_fn(step)).take(take).make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10),
(10, 2, 10), (10, -1, 10),
@@ -100,7 +100,7 @@ class ScanDatasetTest(test.TestCase):
make_scan_fn(step)).take(take).make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10),
(10, 2, 10), (10, -1, 10),
@@ -133,7 +133,7 @@ class ScanDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(5):
(longer_vector_val, larger_rank_val), _ = sess.run(next_element)
self.assertAllEqual([0] * (2**i), longer_vector_val)
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
index 7b9ea191a4..aa89674c6e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
@@ -210,6 +210,7 @@ py_test(
"//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
@@ -318,6 +319,19 @@ py_test(
)
py_test(
+ name = "parse_example_dataset_serialization_test",
+ size = "medium",
+ srcs = ["parse_example_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
name = "prefetch_dataset_serialization_test",
size = "small",
srcs = ["prefetch_dataset_serialization_test.py"],
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py
index 9fdbcb66bf..595cecef4d 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py
@@ -510,7 +510,6 @@ class DatasetSerializationTestBase(test.TestCase):
else:
init_op, get_next_op, saver = self._build_graph(
ds_fn, sparse_tensors=sparse_tensors)
- get_next_op = remove_variants(get_next_op)
return init_op, get_next_op, saver
for i in range(len(break_points) + 1):
@@ -616,29 +615,40 @@ class DatasetSerializationTestBase(test.TestCase):
# `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections
# do not support tuples we flatten the tensors and restore the shape in
# `_get_iterator_ops_from_collection`.
-
- # TODO(shivaniagrwal): `output_classes` is a nested structure of classes,
- # this base class is specific to current test cases. Update when tests are
- # added with `output_classes` as a nested structure with at least one of the
- # component being `tf.SparseTensor`.
- if (sparse_tensors or
- self._get_output_classes(ds_fn) is sparse_tensor.SparseTensor):
+ if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
ops.add_to_collection("iterator_ops", get_next.indices)
ops.add_to_collection("iterator_ops", get_next.values)
ops.add_to_collection("iterator_ops", get_next.dense_shape)
- else:
- for el in nest.flatten(get_next):
- ops.add_to_collection("iterator_ops", el)
+ return
+
+ get_next_list = nest.flatten(get_next)
+ for i, output_class in enumerate(
+ nest.flatten(self._get_output_classes(ds_fn))):
+ if output_class is sparse_tensor.SparseTensor:
+ ops.add_to_collection("iterator_ops", get_next_list[i].indices)
+ ops.add_to_collection("iterator_ops", get_next_list[i].values)
+ ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape)
+ else:
+ ops.add_to_collection("iterator_ops", get_next_list[i])
def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False):
all_ops = ops.get_collection("iterator_ops")
- if (sparse_tensors or
- self._get_output_classes(ds_fn) is sparse_tensor.SparseTensor):
+ if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
init_op, indices, values, dense_shape = all_ops
return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape)
- else:
- return all_ops[0], nest.pack_sequence_as(
- self._get_output_types(ds_fn), all_ops[1:])
+ get_next_list = []
+ i = 1
+ for output_class in nest.flatten(self._get_output_classes(ds_fn)):
+ if output_class is sparse_tensor.SparseTensor:
+ indices, values, dense_shape = all_ops[i:i + 3]
+ i += 3
+ get_next_list.append(
+ sparse_tensor.SparseTensor(indices, values, dense_shape))
+ else:
+ get_next_list.append(all_ops[i])
+ i += 1
+ return all_ops[0], nest.pack_sequence_as(
+ self._get_output_types(ds_fn), get_next_list)
def _get_output_types(self, ds_fn):
with ops.Graph().as_default():
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
index ac3892fe81..243f6405a1 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
@@ -27,42 +28,38 @@ from tensorflow.python.platform import test
class InterleaveDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
+ dataset_serialization_test_base.DatasetSerializationTestBase,
+ parameterized.TestCase):
- def _build_iterator_graph(self, input_values, cycle_length, block_length):
+ def _build_iterator_graph(self, input_values, cycle_length, block_length,
+ num_parallel_calls):
repeat_count = 2
return dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
repeat_count).interleave(
lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
- cycle_length, block_length)
+ cycle_length, block_length, num_parallel_calls)
- def testSerializationCore(self):
+ @parameterized.named_parameters(
+ ("1", 2, 3, None),
+ ("2", 2, 3, 1),
+ ("3", 2, 3, 2),
+ ("4", 1, 3, None),
+ ("5", 1, 3, 1),
+ ("6", 2, 1, None),
+ ("7", 2, 1, 1),
+ ("8", 2, 1, 2),
+ )
+ def testSerializationCore(self, cycle_length, block_length,
+ num_parallel_calls):
input_values = np.array([4, 5, 6], dtype=np.int64)
num_outputs = np.sum(input_values) * 2
- # cycle_length > 1, block_length > 1
- cycle_length = 2
- block_length = 3
# pylint: disable=g-long-lambda
self.run_core_tests(
lambda: self._build_iterator_graph(
- input_values, cycle_length, block_length),
+ input_values, cycle_length, block_length, num_parallel_calls),
lambda: self._build_iterator_graph(
- input_values, cycle_length * 2, block_length * 1),
+ input_values, cycle_length * 2, block_length, num_parallel_calls),
num_outputs)
- # cycle_length = 1
- cycle_length = 1
- block_length = 3
- self.run_core_tests(
- lambda: self._build_iterator_graph(
- input_values, cycle_length, block_length),
- None, num_outputs)
- # block_length = 1
- cycle_length = 2
- block_length = 1
- self.run_core_tests(
- lambda: self._build_iterator_graph(
- input_values, cycle_length, block_length),
- None, num_outputs)
# pylint: enable=g-long-lambda
def testSparseCore(self):
@@ -82,5 +79,5 @@ class InterleaveDatasetSerializationTest(
self.run_core_tests(_build_dataset, None, 20)
-if __name__ == '__main__':
+if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py
new file mode 100644
index 0000000000..d3fa84e74c
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py
@@ -0,0 +1,50 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the ParseExampleDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.platform import test
+
+
+class ParseExampleDatasetSerializationTest(
+ reader_dataset_ops_test_base.ReadBatchFeaturesTestBase,
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def ParseExampleDataset(self, num_repeat, batch_size):
+ return self.make_batch_feature(
+ filenames=self.test_filenames,
+ num_epochs=num_repeat,
+ batch_size=batch_size,
+ reader_num_threads=5,
+ parser_num_threads=10)
+
+ def testSerializationCore(self):
+ num_repeat = 5
+ batch_size = 2
+ num_outputs = self._num_records * self._num_files * num_repeat // batch_size
+ # pylint: disable=g-long-lambda
+ self.run_core_tests(
+ lambda: self.ParseExampleDataset(
+ num_repeat=num_repeat, batch_size=batch_size),
+ lambda: self.ParseExampleDataset(num_repeat=10, batch_size=4),
+ num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
index 077abd6b30..440e48db30 100644
--- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
@@ -35,7 +35,7 @@ class ShuffleAndRepeatTest(test.TestCase):
def _gen_outputs(self, ds_fn, num_outputs, verify_exhausted=True):
get_next = ds_fn().make_one_shot_iterator().get_next()
outputs = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(num_outputs):
outputs.append(sess.run(get_next))
if verify_exhausted:
diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
index 8b2f846494..90d18dca2a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
@@ -32,18 +32,18 @@ from tensorflow.python.platform import test
class SlideDatasetTest(test.TestCase, parameterized.TestCase):
- @parameterized.parameters(
- (20, 14, 7, 1),
- (20, 17, 9, 1),
- (20, 14, 14, 1),
- (20, 10, 14, 1),
- (20, 14, 19, 1),
- (20, 4, 1, 2),
- (20, 2, 1, 6),
- (20, 4, 7, 2),
- (20, 2, 7, 6),
- (1, 10, 4, 1),
- (0, 10, 4, 1),
+ @parameterized.named_parameters(
+ ("1", 20, 14, 7, 1),
+ ("2", 20, 17, 9, 1),
+ ("3", 20, 14, 14, 1),
+ ("4", 20, 10, 14, 1),
+ ("5", 20, 14, 19, 1),
+ ("6", 20, 4, 1, 2),
+ ("7", 20, 2, 1, 6),
+ ("8", 20, 4, 7, 2),
+ ("9", 20, 2, 7, 6),
+ ("10", 1, 10, 4, 1),
+ ("11", 0, 10, 4, 1),
)
def testSlideDataset(self, count, window_size, window_shift, window_stride):
"""Tests a dataset that slides a window its input elements."""
@@ -75,7 +75,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
self.assertEqual([[None] + list(c.shape[1:]) for c in components],
[t.shape.as_list() for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -96,18 +96,18 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- @parameterized.parameters(
- (20, 14, 7, 1),
- (20, 17, 9, 1),
- (20, 14, 14, 1),
- (20, 10, 14, 1),
- (20, 14, 19, 1),
- (20, 4, 1, 2),
- (20, 2, 1, 6),
- (20, 4, 7, 2),
- (20, 2, 7, 6),
- (1, 10, 4, 1),
- (0, 10, 4, 1),
+ @parameterized.named_parameters(
+ ("1", 20, 14, 7, 1),
+ ("2", 20, 17, 9, 1),
+ ("3", 20, 14, 14, 1),
+ ("4", 20, 10, 14, 1),
+ ("5", 20, 14, 19, 1),
+ ("6", 20, 4, 1, 2),
+ ("7", 20, 2, 1, 6),
+ ("8", 20, 4, 7, 2),
+ ("9", 20, 2, 7, 6),
+ ("10", 1, 10, 4, 1),
+ ("11", 0, 10, 4, 1),
)
def testSlideDatasetDeprecated(self, count, window_size, stride,
window_stride):
@@ -139,7 +139,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
self.assertEqual([[None] + list(c.shape[1:]) for c in components],
[t.shape.as_list() for t in get_next])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -160,10 +160,10 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- @parameterized.parameters(
- (14, 0, 3, 1),
- (14, 3, 0, 1),
- (14, 3, 3, 0),
+ @parameterized.named_parameters(
+ ("1", 14, 0, 3, 1),
+ ("2", 14, 3, 0, 1),
+ ("3", 14, 3, 3, 0),
)
def testSlideDatasetInvalid(self, count, window_size, window_shift,
window_stride):
@@ -180,7 +180,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
window_stride=window_stride_t)).make_initializable_iterator())
init_op = iterator.initializer
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(
init_op,
@@ -214,7 +214,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
num_batches = (10 - 5) // 3 + 1
for i in range(num_batches):
@@ -243,7 +243,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
num_batches = (10 - 5) // 3 + 1
for i in range(num_batches):
@@ -277,7 +277,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
# Slide: 1st batch.
actual = sess.run(get_next)
@@ -316,7 +316,7 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
.make_initializable_iterator())
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
index 2c2cfbebff..52823d3fca 100644
--- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
@@ -30,7 +30,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSet(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string), 2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(2): # Run twice to verify statelessness of db operations.
sess.run(
init_op,
@@ -48,7 +48,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetJoinQuery(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -67,7 +67,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetNullTerminator(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -86,7 +86,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetReuseSqlDataset(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -114,7 +114,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadEmptyResultSet(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -128,7 +128,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetWithInvalidDriverName(self):
init_op = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))[0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(
init_op,
@@ -142,7 +142,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetWithInvalidColumnName(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -157,7 +157,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetOfQueryWithSyntaxError(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -173,7 +173,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetWithMismatchBetweenColumnsAndOutputTypes(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -190,7 +190,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetOfInsertQuery(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.string))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -205,7 +205,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# place it in an `int8` tensor.
def testReadResultSetInt8(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -222,7 +222,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetInt8NegativeAndZero(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8,
dtypes.int8))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -238,7 +238,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# a SQLite database table and place it in an `int8` tensor.
def testReadResultSetInt8MaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.int8, dtypes.int8))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -256,7 +256,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# place it in an `int16` tensor.
def testReadResultSetInt16(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -273,7 +273,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetInt16NegativeAndZero(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16,
dtypes.int16))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -289,7 +289,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# a SQLite database table and place it in an `int16` tensor.
def testReadResultSetInt16MaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -307,7 +307,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# place it in an `int32` tensor.
def testReadResultSetInt32(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -321,7 +321,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# SQLite database table and place it in an `int32` tensor.
def testReadResultSetInt32NegativeAndZero(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -337,7 +337,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# a SQLite database table and place it in an `int32` tensor.
def testReadResultSetInt32MaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -355,7 +355,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# table and place it in an `int32` tensor.
def testReadResultSetInt32VarCharColumnAsInt(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -371,7 +371,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# and place it in an `int64` tensor.
def testReadResultSetInt64(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -387,7 +387,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# SQLite database table and place it in an `int64` tensor.
def testReadResultSetInt64NegativeAndZero(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -403,7 +403,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# a SQLite database table and place it in an `int64` tensor.
def testReadResultSetInt64MaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -422,7 +422,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# place it in a `uint8` tensor.
def testReadResultSetUInt8(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -438,7 +438,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# SQLite database table and place them in `uint8` tensors.
def testReadResultSetUInt8MinAndMaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -456,7 +456,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# and place it in a `uint16` tensor.
def testReadResultSetUInt16(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -472,7 +472,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# SQLite database table and place them in `uint16` tensors.
def testReadResultSetUInt16MinAndMaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -491,7 +491,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# in `bool` tensors.
def testReadResultSetBool(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -508,7 +508,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
# from a SQLite database table and place it as `True` in a `bool` tensor.
def testReadResultSetBoolNotZeroOrOne(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -525,7 +525,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetFloat64(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.float64))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -544,7 +544,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetFloat64OverlyPrecise(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.float64))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
@@ -570,7 +570,7 @@ class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.float64))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
init_op,
feed_dict={
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
index a41d21f8c1..e25570c5ad 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
@@ -19,7 +19,6 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base
from tensorflow.contrib.data.python.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
@@ -76,6 +75,31 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
sess.run(next_element)
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
+ def testPrefetchBufferUtilization(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(100).map(
+ lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(
+ -1).apply(stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(100):
+ self.assertAllEqual(
+ np.array([i] * i, dtype=np.int64), sess.run(next_element))
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
+ float(i + 1))
+ self._assertSummaryHasRange(summary_str, "Prefetch::buffer_utilization",
+ 0, 1)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
+ 100)
+
def testReinitialize(self):
stats_aggregator = stats_ops.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
@@ -175,44 +199,5 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
-class FeatureStatsDatasetTest(
- stats_dataset_test_base.StatsDatasetTestBase,
- reader_dataset_ops_test_base.ReadBatchFeaturesTestBase):
-
- def testFeaturesStats(self):
- num_epochs = 5
- total_records = num_epochs * self._num_records
- batch_size = 2
- stats_aggregator = stats_ops.StatsAggregator()
- dataset = self.make_batch_feature(
- filenames=self.test_filenames[0],
- num_epochs=num_epochs,
- batch_size=batch_size,
- shuffle=True,
- shuffle_seed=5,
- drop_final_batch=True).apply(
- stats_ops.set_stats_aggregator(stats_aggregator))
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
- summary_t = stats_aggregator.get_summary()
-
- with self.test_session() as sess:
- sess.run(iterator.initializer)
- for _ in range(total_records // batch_size):
- sess.run(next_element)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
- self._assertSummaryHasCount(
- sess.run(summary_t), "record_stats:features", total_records)
- self._assertSummaryHasCount(
- sess.run(summary_t), "record_stats:feature-values", total_records)
- self._assertSummaryHasSum(
- sess.run(summary_t), "record_stats:features", total_records * 3)
- self._assertSummaryHasSum(
- sess.run(summary_t), "record_stats:feature-values",
- self._sum_keywords(1) * num_epochs + 2 * total_records)
-
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
index 9a13acf8f0..2f5a44408f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
+++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
@@ -34,6 +34,16 @@ class StatsDatasetTestBase(test.TestCase):
return
self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+ def _assertSummaryHasRange(self, summary_str, tag, min_value, max_value):
+ summary_proto = summary_pb2.Summary()
+ summary_proto.ParseFromString(summary_str)
+ for value in summary_proto.value:
+ if tag == value.tag:
+ self.assertLessEqual(min_value, value.histo.min)
+ self.assertGreaterEqual(max_value, value.histo.max)
+ return
+ self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+
def _assertSummaryHasSum(self, summary_str, tag, expected_value):
summary_proto = summary_pb2.Summary()
summary_proto.ParseFromString(summary_str)
diff --git a/tensorflow/contrib/data/python/kernel_tests/test_utils.py b/tensorflow/contrib/data/python/kernel_tests/test_utils.py
index 1b962b3418..4c3353fe40 100644
--- a/tensorflow/contrib/data/python/kernel_tests/test_utils.py
+++ b/tensorflow/contrib/data/python/kernel_tests/test_utils.py
@@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import re
+
from tensorflow.python.data.util import nest
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
@@ -29,7 +31,7 @@ class DatasetTestBase(test.TestCase):
# TODO(rachelim): support sparse tensor outputs
next1 = dataset1.make_one_shot_iterator().get_next()
next2 = dataset2.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
while True:
try:
op1 = sess.run(next1)
@@ -45,16 +47,27 @@ class DatasetTestBase(test.TestCase):
for i in range(len(op1)):
self.assertAllEqual(op1[i], op2[i])
- def _assert_datasets_raise_same_error(self, dataset1, dataset2, exc_class):
- next1 = dataset1.make_one_shot_iterator().get_next()
- next2 = dataset2.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ def _assert_datasets_raise_same_error(self,
+ dataset1,
+ dataset2,
+ exception_class,
+ replacements=None):
+ # We are defining next1 and next2 in the same line so that we get identical
+ # file:line_number in the error messages
+ # pylint: disable=line-too-long
+ next1, next2 = dataset1.make_one_shot_iterator().get_next(), dataset2.make_one_shot_iterator().get_next()
+ # pylint: enable=line-too-long
+ with self.cached_session() as sess:
try:
sess.run(next1)
raise ValueError(
"Expected dataset to raise an error of type %s, but it did not." %
- repr(exc_class))
- except exc_class as e:
+ repr(exception_class))
+ except exception_class as e:
+ expected_message = e.message
+ for old, new, count in replacements:
+ expected_message = expected_message.replace(old, new, count)
# Check that the first segment of the error messages are the same.
- with self.assertRaisesRegexp(exc_class, e.message.split(". ")[0]):
+ with self.assertRaisesRegexp(exception_class,
+ re.escape(expected_message)):
sess.run(next2)
diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
index 0486e2bce2..8d335e87d5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
@@ -33,8 +33,17 @@ from tensorflow.python.platform import test
class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase):
- @parameterized.parameters((1, None), (2, None), (4, None), (8, None),
- (16, None), (4, -1), (4, 0), (4, 1), (4, 4))
+ @parameterized.named_parameters(
+ ("1", 1, None),
+ ("2", 2, None),
+ ("3", 4, None),
+ ("4", 8, None),
+ ("5", 16, None),
+ ("6", 4, -1),
+ ("7", 4, 0),
+ ("8", 4, 1),
+ ("9", 4, 4),
+ )
def testNumThreads(self, num_threads, max_intra_op_parallelism):
def get_thread_id(_):
@@ -60,7 +69,7 @@ class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase):
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
thread_ids = []
try:
diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
index d79a842e7a..f994c8563f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
@@ -45,7 +45,7 @@ class UniqueDatasetTest(test.TestCase):
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for test_case, expected in test_cases:
current_test_case = test_case
sess.run(iterator.initializer)
diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
index 33d95d6754..6eaa0b1959 100644
--- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
@@ -64,15 +64,15 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
else:
self.assertEqual(xs, ys)
- @parameterized.parameters(
- (None, np.int32([]), dtypes.bool),
- (None, np.int32([]), dtypes.int32),
- (None, np.int32([]), dtypes.float32),
- (None, np.int32([]), dtypes.string),
- (None, np.int32([2]), dtypes.int32),
- (None, np.int32([2, 2]), dtypes.int32),
- ((None, None, None), np.int32([]), dtypes.int32),
- ((None, (None, None)), np.int32([]), dtypes.int32),
+ @parameterized.named_parameters(
+ ("1", None, np.int32([]), dtypes.bool),
+ ("2", None, np.int32([]), dtypes.int32),
+ ("3", None, np.int32([]), dtypes.float32),
+ ("4", None, np.int32([]), dtypes.string),
+ ("5", None, np.int32([2]), dtypes.int32),
+ ("6", None, np.int32([2, 2]), dtypes.int32),
+ ("7", (None, None, None), np.int32([]), dtypes.int32),
+ ("8", (None, (None, None)), np.int32([]), dtypes.int32),
)
def testWindowDatasetFlatMap(self, structure, shape, dtype):
"""Tests windowing by chaining it with flat map.
@@ -92,20 +92,20 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
dataset = self._structuredDataset(structure, shape, dtype).apply(
grouping.window_dataset(5)).flat_map(fn)
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
expected = sess.run(self._structuredElement(structure, shape, dtype))
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (None, np.int32([]), dtypes.bool),
- (None, np.int32([]), dtypes.int32),
- (None, np.int32([]), dtypes.float32),
- (None, np.int32([]), dtypes.string),
- (None, np.int32([2]), dtypes.int32),
- (None, np.int32([2, 2]), dtypes.int32),
- ((None, None, None), np.int32([]), dtypes.int32),
- ((None, (None, None)), np.int32([]), dtypes.int32),
+ @parameterized.named_parameters(
+ ("1", None, np.int32([]), dtypes.bool),
+ ("2", None, np.int32([]), dtypes.int32),
+ ("3", None, np.int32([]), dtypes.float32),
+ ("4", None, np.int32([]), dtypes.string),
+ ("5", None, np.int32([2]), dtypes.int32),
+ ("6", None, np.int32([2, 2]), dtypes.int32),
+ ("7", (None, None, None), np.int32([]), dtypes.int32),
+ ("8", (None, (None, None)), np.int32([]), dtypes.int32),
)
def testWindowDatasetBatchDense(self, structure, shape, dtype):
"""Tests batching of dense tensor windows.
@@ -128,17 +128,17 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply(
grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
expected = sess.run(
self._structuredElement(structure, np.concatenate(
([5], shape), axis=0), dtype))
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int32([]),),
- (np.int32([1]),),
- (np.int32([1, 2, 3]),),
+ @parameterized.named_parameters(
+ ("1", np.int32([])),
+ ("2", np.int32([1])),
+ ("3", np.int32([1, 2, 3])),
)
def testWindowDatasetBatchDenseDynamicShape(self, shape):
"""Tests batching of dynamically shaped dense tensor windows.
@@ -155,7 +155,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, {shape_t: shape})
expected = sess.run(
self._structuredElement(None, np.concatenate(([5], shape), axis=0),
@@ -203,15 +203,15 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
for substructure in structure
])
- @parameterized.parameters(
- (None, np.int32([]), dtypes.bool),
- (None, np.int32([]), dtypes.int32),
- (None, np.int32([]), dtypes.float32),
- (None, np.int32([]), dtypes.string),
- (None, np.int32([2]), dtypes.int32),
- (None, np.int32([2, 2]), dtypes.int32),
- ((None, None, None), np.int32([]), dtypes.int32),
- ((None, (None, None)), np.int32([]), dtypes.int32),
+ @parameterized.named_parameters(
+ ("1", None, np.int32([]), dtypes.bool),
+ ("2", None, np.int32([]), dtypes.int32),
+ ("3", None, np.int32([]), dtypes.float32),
+ ("4", None, np.int32([]), dtypes.string),
+ ("5", None, np.int32([2]), dtypes.int32),
+ ("6", None, np.int32([2, 2]), dtypes.int32),
+ ("7", (None, None, None), np.int32([]), dtypes.int32),
+ ("8", (None, (None, None)), np.int32([]), dtypes.int32),
)
def testWindowDatasetBatchSparse(self, structure, shape, dtype):
"""Tests batching of sparse tensor windows.
@@ -235,7 +235,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
structure, shape, dtype).repeat(5).apply(
grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
expected = sess.run(
self._structuredSparseElement(structure,
np.concatenate(([5], shape), axis=0),
@@ -243,10 +243,10 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int32([]),),
- (np.int32([1]),),
- (np.int32([1, 2, 3]),),
+ @parameterized.named_parameters(
+ ("1", np.int32([])),
+ ("2", np.int32([1])),
+ ("3", np.int32([1, 2, 3])),
)
def testWindowDatasetBatchSparseDynamicShape(self, shape):
"""Tests batching of dynamically shaped sparse tensor windows.
@@ -263,7 +263,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, {shape_t: shape})
expected = sess.run(
self._structuredSparseElement(None,
@@ -284,17 +284,18 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
for substructure in structure
]))
- @parameterized.parameters(
- (None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]),
- (None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]),
- (None, np.int32([[1], [2], [3]]), dtypes.string, [-1]),
- (None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
- (None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]),
- ((None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- ((None, (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])),
+ @parameterized.named_parameters(
+ ("1", None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]),
+ ("2", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("3", None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]),
+ ("4", None, np.int32([[1], [2], [3]]), dtypes.string, [-1]),
+ ("5", None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
+ ("6", None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]),
+ ("7", (None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("8", (None,
+ (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("9", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("10", None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])),
)
def testWindowDatasetPaddedBatchDense(self, structure, shapes, dtype,
padded_shape):
@@ -320,7 +321,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
grouping.window_dataset(len(shapes))).apply(
grouping._map_x_dataset(fn))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape)
expected = sess.run(
self._structuredElement(
@@ -329,10 +330,10 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int32([[1], [2], [3]]), [-1]),
- (np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
- (np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
+ @parameterized.named_parameters(
+ ("1", np.int32([[1], [2], [3]]), [-1]),
+ ("2", np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
+ ("3", np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
)
def testWindowDatasetPaddedBatchDenseDynamicShape(self, shapes, padded_shape):
"""Tests padded batching of dynamically shaped dense tensor windows.
@@ -351,7 +352,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, {shapes_t: shapes})
expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape)
expected = sess.run(
@@ -361,9 +362,9 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int32([[1]]), np.int32([0])),
- (np.int32([[10], [20]]), np.int32([15])),
+ @parameterized.named_parameters(
+ ("1", np.int32([[1]]), np.int32([0])),
+ ("2", np.int32([[10], [20]]), np.int32([15])),
)
def testWindowDatasetPaddedBatchDenseInvalid(self, shapes, padded_shape):
"""Tests invalid padded batching of dense tensor windows.
@@ -379,7 +380,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
grouping._map_x_dataset(
lambda x: batching.padded_batch_window(x, padded_shape)))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
@@ -420,17 +421,18 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
for substructure in structure
])
- @parameterized.parameters(
- (None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]),
- (None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]),
- (None, np.int64([[1], [2], [3]]), dtypes.string, [-1]),
- (None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
- (None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]),
- ((None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- ((None, (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- (None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])),
+ @parameterized.named_parameters(
+ ("1", None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]),
+ ("2", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("3", None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]),
+ ("4", None, np.int64([[1], [2], [3]]), dtypes.string, [-1]),
+ ("5", None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
+ ("6", None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]),
+ ("7", (None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("8", (None,
+ (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("9", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
+ ("10", None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])),
)
def testWindowDatasetPaddedBatchSparse(self, structure, shapes, dtype,
padded_shape):
@@ -456,17 +458,17 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
structure, shapes, dtype).apply(grouping.window_dataset(
len(shapes))).apply(grouping._map_x_dataset(fn))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
expected = sess.run(
self._structuredRaggedSparseElement(structure, shapes, dtype,
padded_shape))
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int64([[1], [2], [3]]), [-1]),
- (np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
- (np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
+ @parameterized.named_parameters(
+ ("1", np.int64([[1], [2], [3]]), [-1]),
+ ("2", np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
+ ("3", np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
)
def testWindowDatasetPaddedBatchSparseDynamicShape(self, shapes,
padded_shape):
@@ -487,7 +489,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op, {shapes_t: shapes})
expected = sess.run(
self._structuredRaggedSparseElement(None, shapes, dtypes.int32,
@@ -495,9 +497,9 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
actual = sess.run(get_next)
self._assertEqual(expected, actual)
- @parameterized.parameters(
- (np.int64([[1]]), [0]),
- (np.int64([[10], [20]]), [15]),
+ @parameterized.named_parameters(
+ ("1", np.int64([[1]]), [0]),
+ ("2", np.int64([[10], [20]]), [15]),
)
def testWindowDatasetPaddedBatchSparseInvalid(self, shapes, padded_shape):
"""Tests invalid padded batching of sparse tensor windows.
@@ -514,7 +516,7 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
grouping._map_x_dataset(
lambda x: batching.padded_batch_window(x, padded_shape)))
get_next = dataset.make_one_shot_iterator().get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
diff --git a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
index c603ecc5ab..867ee2ba37 100644
--- a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
@@ -61,7 +61,7 @@ class TFRecordWriterTest(test.TestCase):
return os.path.join(self.get_temp_dir(), "tf_record.out.txt")
def testWrite(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.writer, feed_dict={
self.filename: self._createFile(),
@@ -71,7 +71,7 @@ class TFRecordWriterTest(test.TestCase):
def testWriteZLIB(self):
options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.ZLIB)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.writer,
feed_dict={
@@ -84,7 +84,7 @@ class TFRecordWriterTest(test.TestCase):
def testWriteGZIP(self):
options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(
self.writer,
feed_dict={
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index ad9378dfb9..4b45cc7e36 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -80,17 +80,14 @@ py_library(
":batching",
":gen_dataset_ops",
":interleave_ops",
+ ":parsing_ops",
":shuffle_ops",
- ":stats_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:lib",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:parsing_ops",
"//tensorflow/python:platform",
- "//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
"//tensorflow/python/data/ops:dataset_ops",
@@ -211,6 +208,22 @@ py_library(
)
py_library(
+ name = "parsing_ops",
+ srcs = ["parsing_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+py_library(
name = "map_defun",
srcs = ["map_defun.py"],
srcs_version = "PY2AND3",
@@ -331,7 +344,10 @@ py_library(
tf_gen_op_wrapper_py(
name = "gen_dataset_ops",
out = "gen_dataset_ops.py",
- deps = ["//tensorflow/contrib/data:dataset_ops_op_lib"],
+ deps = [
+ "//tensorflow/contrib/data:dataset_ops_op_lib",
+ "//tensorflow/contrib/data:indexed_dataset_ops_op_lib",
+ ],
)
tf_kernel_library(
@@ -349,6 +365,7 @@ tf_custom_op_py_library(
dso = ["//tensorflow/contrib/data:_dataset_ops.so"],
kernels = [
":dataset_ops_kernels",
+ "//tensorflow/contrib/data:indexed_dataset_ops_op_lib",
"//tensorflow/contrib/data:dataset_ops_op_lib",
],
srcs_version = "PY2AND3",
@@ -360,6 +377,19 @@ tf_custom_op_py_library(
)
py_library(
+ name = "indexed_dataset_ops",
+ srcs = ["indexed_dataset_ops.py"],
+ deps = [
+ ":contrib_op_loader",
+ ":gen_dataset_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
name = "prefetching_ops",
srcs = ["prefetching_ops.py"],
deps = [
@@ -380,6 +410,7 @@ py_library(
":error_ops",
":get_single_element",
":grouping",
+ ":indexed_dataset_ops",
":interleave_ops",
":map_defun",
":optimization",
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index 9f059942a6..367c159dc5 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -272,9 +272,9 @@ def _padded_batch_dense_window(dataset, padded_shape, padding_value=None):
padding_value = 0
def batch_init_fn(_):
- return array_ops.fill(
- array_ops.concat([np.array([0], dtype=np.int32), padded_shape], 0),
- constant_op.constant(padding_value, dtype=dataset.output_types))
+ batch_shape = array_ops.concat(
+ [np.array([0], dtype=np.int32), padded_shape], 0)
+ return gen_array_ops.empty(batch_shape, dtype=dataset.output_types)
def batch_reduce_fn(state, value):
return array_ops.concat([state, [value]], 0)
@@ -647,15 +647,17 @@ def assert_element_shape(expected_shapes):
"""Assert the shape of this `Dataset`.
```python
- shapes = [tf.TensorShape([16, 256]), tf.TensorShape(None)]
+ shapes = [tf.TensorShape([16, 256]), tf.TensorShape([None, 2])]
result = dataset.apply(tf.contrib.data.assert_element_shape(shapes))
- print(result.output_shapes) # ==> "((16, 256), <unknown>)"
+ print(result.output_shapes) # ==> "((16, 256), (<unknown>, 2))"
```
If dataset shapes and expected_shape, are fully defined, assert they match.
Otherwise, add assert op that will validate the shapes when tensors are
evaluated, and set shapes on tensors, respectively.
+ Note that unknown dimension in `expected_shapes` will be ignored.
+
Args:
expected_shapes: A nested structure of `tf.TensorShape` objects.
@@ -664,20 +666,31 @@ def assert_element_shape(expected_shapes):
`tf.data.Dataset.apply`
"""
+ def _merge_output_shapes(original_shapes, expected_shapes):
+ flat_original_shapes = nest.flatten(original_shapes)
+ flat_new_shapes = nest.flatten_up_to(original_shapes, expected_shapes)
+ flat_merged_output_shapes = [
+ original_shape.merge_with(new_shape)
+ for original_shape, new_shape in zip(flat_original_shapes,
+ flat_new_shapes)]
+ return nest.pack_sequence_as(original_shapes, flat_merged_output_shapes)
+
def _check_shape(*elements):
flatten_tensors = nest.flatten(elements)
flatten_shapes = nest.flatten(expected_shapes)
checked_tensors = [
- with_shape(shape, tensor)
+ with_shape(shape, tensor) if shape else tensor # Ignore unknown shape
for shape, tensor in zip(flatten_shapes, flatten_tensors)
]
return nest.pack_sequence_as(elements, checked_tensors)
def _apply_fn(dataset):
+ output_shapes = _merge_output_shapes(dataset.output_shapes,
+ expected_shapes)
return _RestructuredDataset(
dataset.map(_check_shape),
dataset.output_types,
- output_shapes=expected_shapes,
+ output_shapes=output_shapes,
output_classes=dataset.output_classes)
return _apply_fn
diff --git a/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py b/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
new file mode 100644
index 0000000000..a0932b4081
--- /dev/null
+++ b/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
@@ -0,0 +1,173 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrappers for indexed datasets."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
+from tensorflow.contrib.data.python.ops import gen_dataset_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+
+
+class MaterializedIndexedDataset(object):
+ """MaterializedIndexedDataset is highly experimental!
+ """
+
+ def __init__(self, materialized_resource, materializer, output_classes,
+ output_types, output_shapes):
+ self._materialized_resource = materialized_resource
+ self._materializer = materializer
+ self._output_classes = output_classes
+ self._output_types = output_types
+ self._output_shapes = output_shapes
+
+ @property
+ def initializer(self):
+ if self._materializer is not None:
+ return self._materializer
+ raise ValueError("MaterializedDataset does not have a materializer")
+
+ def get(self, index):
+ """Get retrieves a value (or set of values) from the IndexedDataset.
+
+ Args:
+ index: A uint64 scalar or vector tensor with the indices to retrieve.
+
+ Returns:
+ A tensor containing the values corresponding to `index`.
+ """
+ # TODO(saeta): nest.pack_sequence_as(...)
+ return gen_dataset_ops.indexed_dataset_get(
+ self._materialized_resource,
+ index,
+ output_types=nest.flatten(
+ sparse.as_dense_types(self._output_types, self._output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_types(self._output_shapes, self._output_classes)))
+
+
+class IndexedDataset(dataset_ops.Dataset):
+ """IndexedDataset is highly experimental!
+ """
+
+ def __init__(self):
+ pass
+
+ def materialize(self, shared_name=None, container=None):
+ """Materialize creates a MaterializedIndexedDataset.
+
+ IndexedDatasets can be combined through operations such as TBD. Therefore,
+ they are only materialized when absolutely required.
+
+ Args:
+ shared_name: a string for the shared name to use for the resource.
+ container: a string for the container to store the resource.
+
+ Returns:
+ A MaterializedIndexedDataset.
+ """
+ if container is None:
+ container = ""
+ if shared_name is None:
+ shared_name = ""
+ materialized_resource = gen_dataset_ops.materialized_index_dataset_handle(
+ container=container,
+ shared_name=shared_name,
+ output_types=nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_types(self.output_shapes, self.output_classes)))
+
+ with ops.colocate_with(materialized_resource):
+ materializer = gen_dataset_ops.indexed_dataset_materialize(
+ self._as_variant_tensor(), materialized_resource)
+ return MaterializedIndexedDataset(materialized_resource, materializer,
+ self.output_classes, self.output_types,
+ self.output_shapes)
+
+ @abc.abstractproperty
+ def output_types(self):
+ """Returns the type of each component of an element of this IndexedDataset.
+
+ Returns:
+ A nested structure of `tf.DType` objects corresponding to each component
+ of an element of this IndexedDataset.
+ """
+ raise NotImplementedError("IndexedDataset.output_types")
+
+ @abc.abstractproperty
+ def output_classes(self):
+ """Returns the class of each component of an element of this IndexedDataset.
+
+ The expected values are `tf.Tensor` and `tf.SparseTensor`.
+
+ Returns:
+ A nested structure of Python `type` objects corresponding to each
+ component of an element of this IndexedDataset.
+ """
+ raise NotImplementedError("IndexedDataset.output_classes")
+
+ @abc.abstractproperty
+ def output_shapes(self):
+ """Returns the shape of each component of an element of this IndexedDataset.
+
+ Returns:
+ A nested structure of `tf.TensorShape` objects corresponding to each
+ component of an element of this IndexedDataset.
+ """
+ raise NotImplementedError("IndexedDataset.output_shapes")
+
+ @abc.abstractmethod
+ def _as_variant_tensor(self):
+ """Creates a `tf.variant` `tf.Tensor` representing this IndexedDataset.
+
+ Returns:
+ A scalar `tf.Tensor` of `tf.variant` type, which represents this
+ IndexedDataset.
+ """
+ raise NotImplementedError("IndexedDataset._as_variant_tensor")
+
+
+class IdentityIndexedDataset(IndexedDataset):
+ """IdentityIndexedDataset is a trivial indexed dataset used for testing.
+ """
+
+ def __init__(self, size):
+ super(IdentityIndexedDataset, self).__init__()
+ # TODO(saeta): Verify _size is a scalar!
+ self._size = ops.convert_to_tensor(size, dtype=dtypes.uint64, name="size")
+
+ @property
+ def output_types(self):
+ return dtypes.uint64
+
+ @property
+ def output_classes(self):
+ return ops.Tensor
+
+ @property
+ def output_shapes(self):
+ return tensor_shape.scalar()
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.identity_indexed_dataset(self._size)
diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py
index 5a1a35199a..92d4251a86 100644
--- a/tensorflow/contrib/data/python/ops/interleave_ops.py
+++ b/tensorflow/contrib/data/python/ops/interleave_ops.py
@@ -163,7 +163,7 @@ class _DirectedInterleaveDataset(dataset_ops.Dataset):
for data_input in data_inputs[1:]:
if (data_input.output_types != data_inputs[0].output_types or
data_input.output_classes != data_inputs[0].output_classes):
- raise TypeError("All datasets must have the same type.")
+ raise TypeError("All datasets must have the same type and class.")
def _as_variant_tensor(self):
# pylint: disable=protected-access
@@ -216,25 +216,59 @@ def sample_from_datasets(datasets, weights=None, seed=None):
length of the `datasets` element.
"""
num_datasets = len(datasets)
- if weights is None:
- weights = dataset_ops.Dataset.from_tensors([1.0] * num_datasets).repeat()
- elif not isinstance(weights, dataset_ops.Dataset):
- weights = ops.convert_to_tensor(weights, name="weights")
- if weights.dtype not in (dtypes.float32, dtypes.float64):
- raise TypeError("`weights` must be convertible to a tensor of "
- "`tf.float32` or `tf.float64` elements.")
- if not weights.shape.is_compatible_with([num_datasets]):
- raise ValueError("`weights` must be a vector of length `len(datasets)`.")
- weights = dataset_ops.Dataset.from_tensors(weights).repeat()
-
- # The `stateless_multinomial()` op expects log-probabilities, as opposed to
- # weights.
- logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits"))
- def select_dataset(logits, seed):
- return array_ops.squeeze(
- stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1])
- selector_input = dataset_ops.Dataset.zip(
- (logits_ds, random_ops.RandomDataset(seed).batch(2))).map(select_dataset)
+ if not isinstance(weights, dataset_ops.Dataset):
+ if weights is None:
+ # Select inputs with uniform probability.
+ logits = [[1.0] * num_datasets]
+
+ else:
+ # Use the given `weights` as the probability of choosing the respective
+ # input.
+ weights = ops.convert_to_tensor(weights, name="weights")
+ if weights.dtype not in (dtypes.float32, dtypes.float64):
+ raise TypeError("`weights` must be convertible to a tensor of "
+ "`tf.float32` or `tf.float64` elements.")
+ if not weights.shape.is_compatible_with([num_datasets]):
+ raise ValueError(
+ "`weights` must be a vector of length `len(datasets)`.")
+
+ # The `stateless_multinomial()` op expects log-probabilities, as opposed
+ # to weights.
+ logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0)
+
+ # NOTE(mrry): We only specialize when `weights` is not a `Dataset`. When it
+ # is a `Dataset`, it is possible that evaluating it has a side effect the
+ # user depends on.
+ if len(datasets) == 1:
+ return datasets[0]
+
+ def select_dataset_constant_logits(seed):
+ return array_ops.squeeze(
+ stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1])
+
+ selector_input = dataset_ops.MapDataset(
+ random_ops.RandomDataset(seed).batch(2),
+ select_dataset_constant_logits,
+ use_inter_op_parallelism=False)
+
+ else:
+ # Use each element of the given `weights` dataset as the probability of
+ # choosing the respective input.
+
+ # The `stateless_multinomial()` op expects log-probabilities, as opposed to
+ # weights.
+ logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits"))
+
+ def select_dataset_varying_logits(logits, seed):
+ return array_ops.squeeze(
+ stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1])
+
+ logits_and_seeds = dataset_ops.Dataset.zip(
+ (logits_ds, random_ops.RandomDataset(seed).batch(2)))
+ selector_input = dataset_ops.MapDataset(
+ logits_and_seeds,
+ select_dataset_varying_logits,
+ use_inter_op_parallelism=False)
return _DirectedInterleaveDataset(selector_input, datasets)
diff --git a/tensorflow/contrib/data/python/ops/map_defun.py b/tensorflow/contrib/data/python/ops/map_defun.py
index 54d5cd6da0..3d0d0993c9 100644
--- a/tensorflow/contrib/data/python/ops/map_defun.py
+++ b/tensorflow/contrib/data/python/ops/map_defun.py
@@ -53,6 +53,4 @@ def map_defun(fn, elems, output_dtypes, output_shapes):
elems = [ops.convert_to_tensor(e) for e in elems]
output_shapes = [tensor_shape.TensorShape(s) for s in output_shapes]
- if not all(s.is_fully_defined() for s in output_shapes):
- raise ValueError("All fn output shapes must be fully defined.")
return gen_dataset_ops.map_defun(elems, output_dtypes, output_shapes, fn)
diff --git a/tensorflow/contrib/data/python/ops/parsing_ops.py b/tensorflow/contrib/data/python/ops/parsing_ops.py
new file mode 100644
index 0000000000..2701605e64
--- /dev/null
+++ b/tensorflow/contrib/data/python/ops/parsing_ops.py
@@ -0,0 +1,150 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental `dataset` API for parsing example."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import parsing_ops
+
+
+class _ParseExampleDataset(dataset_ops.Dataset):
+ """A `Dataset` that parses `example` dataset into a `dict` dataset."""
+
+ def __init__(self, input_dataset, features, num_parallel_calls):
+ super(_ParseExampleDataset, self).__init__()
+ self._input_dataset = input_dataset
+ if not all(types == dtypes.string
+ for types in nest.flatten(input_dataset.output_types)):
+ raise TypeError("Input dataset should be a dataset of vectors of strings")
+ self._num_parallel_calls = num_parallel_calls
+ # pylint: disable=protected-access
+ self._features = parsing_ops._prepend_none_dimension(features)
+ # sparse_keys and dense_keys come back sorted here.
+ (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults,
+ dense_shapes) = parsing_ops._features_to_raw_params(
+ self._features, [
+ parsing_ops.VarLenFeature, parsing_ops.SparseFeature,
+ parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature
+ ])
+ # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature.
+ (_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes,
+ dense_shape_as_shape) = parsing_ops._process_raw_parameters(
+ None, dense_defaults, sparse_keys, sparse_types, dense_keys,
+ dense_types, dense_shapes)
+ # pylint: enable=protected-access
+ self._sparse_keys = sparse_keys
+ self._sparse_types = sparse_types
+ self._dense_keys = dense_keys
+ self._dense_defaults = dense_defaults_vec
+ self._dense_shapes = dense_shapes
+ self._dense_types = dense_types
+ dense_output_shapes = [
+ self._input_dataset.output_shapes.concatenate(shape)
+ for shape in dense_shape_as_shape
+ ]
+ sparse_output_shapes = [
+ self._input_dataset.output_shapes.concatenate([None])
+ for _ in range(len(sparse_keys))
+ ]
+
+ self._output_shapes = dict(
+ zip(self._dense_keys + self._sparse_keys,
+ dense_output_shapes + sparse_output_shapes))
+ self._output_types = dict(
+ zip(self._dense_keys + self._sparse_keys,
+ self._dense_types + self._sparse_types))
+ self._output_classes = dict(
+ zip(self._dense_keys + self._sparse_keys,
+ [ops.Tensor for _ in range(len(self._dense_defaults))] +
+ [sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys))
+ ]))
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.parse_example_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._num_parallel_calls,
+ self._dense_defaults,
+ self._sparse_keys,
+ self._dense_keys,
+ self._sparse_types,
+ self._dense_shapes,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+
+# TODO(b/111553342): add arguments names and example names as well.
+def parse_example_dataset(features, num_parallel_calls=1):
+ """A transformation that parses `Example` protos into a `dict` of tensors.
+
+ Parses a number of serialized `Example` protos given in `serialized`. We refer
+ to `serialized` as a batch with `batch_size` many entries of individual
+ `Example` protos.
+
+ This op parses serialized examples into a dictionary mapping keys to `Tensor`
+ and `SparseTensor` objects. `features` is a dict from keys to `VarLenFeature`,
+ `SparseFeature`, and `FixedLenFeature` objects. Each `VarLenFeature`
+ and `SparseFeature` is mapped to a `SparseTensor`, and each
+ `FixedLenFeature` is mapped to a `Tensor`. See `tf.parse_example` for more
+ details about feature dictionaries.
+
+ Args:
+ features: A `dict` mapping feature keys to `FixedLenFeature`,
+ `VarLenFeature`, and `SparseFeature` values.
+ num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
+ representing the number of parsing processes to call in parallel.
+
+ Returns:
+ A dataset transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+
+ Raises:
+ ValueError: if features argument is None.
+ """
+ if features is None:
+ raise ValueError("Missing: features was %s." % features)
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ out_dataset = _ParseExampleDataset(dataset, features, num_parallel_calls)
+ if any([
+ isinstance(feature, parsing_ops.SparseFeature)
+ for _, feature in features.items()
+ ]):
+ # pylint: disable=protected-access
+ # pylint: disable=g-long-lambda
+ out_dataset = out_dataset.map(
+ lambda x: parsing_ops._construct_sparse_tensors_for_sparse_features(
+ features, x), num_parallel_calls=num_parallel_calls)
+ return out_dataset
+
+ return _apply_fn
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index 3882d4bfdb..4c466781f7 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -25,8 +25,8 @@ import numpy as np
from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops
from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.contrib.data.python.ops import parsing_ops
from tensorflow.contrib.data.python.ops import shuffle_ops
-from tensorflow.contrib.data.python.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.data.util import convert
@@ -37,7 +37,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import gfile
from tensorflow.python.util import deprecation
@@ -326,7 +325,6 @@ def make_csv_dataset(
shuffle_seed=None,
prefetch_buffer_size=1,
num_parallel_reads=1,
- num_parallel_parser_calls=2,
sloppy=False,
num_rows_for_inference=100,
compression_type=None,
@@ -393,8 +391,6 @@ def make_csv_dataset(
batches consumed per training step.
num_parallel_reads: Number of threads used to read CSV records from files.
If >1, the results will be interleaved.
- num_parallel_parser_calls: Number of parallel invocations of the CSV parsing
- function on CSV records.
sloppy: If `True`, reading performance will be improved at
the cost of non-deterministic ordering. If `False`, the order of elements
produced is deterministic prior to shuffling (elements are still
@@ -503,7 +499,8 @@ def make_csv_dataset(
# indefinitely, and all batches will be full-sized.
dataset = dataset.batch(batch_size=batch_size,
drop_remainder=num_epochs is None)
- dataset = dataset.map(map_fn, num_parallel_calls=num_parallel_parser_calls)
+ dataset = dataset_ops.MapDataset(
+ dataset, map_fn, use_inter_op_parallelism=False)
dataset = dataset.prefetch(prefetch_buffer_size)
return dataset
@@ -663,6 +660,7 @@ def make_batched_features_dataset(file_pattern,
batch_size,
features,
reader=core_readers.TFRecordDataset,
+ label_key=None,
reader_args=None,
num_epochs=None,
shuffle=True,
@@ -675,6 +673,9 @@ def make_batched_features_dataset(file_pattern,
drop_final_batch=False):
"""Returns a `Dataset` of feature dictionaries from `Example` protos.
+ If label_key argument is provided, returns a `Dataset` of tuple
+ comprising of feature dictionaries and label.
+
Example:
```
@@ -725,6 +726,9 @@ def make_batched_features_dataset(file_pattern,
reader: A function or class that can be
called with a `filenames` tensor and (optional) `reader_args` and returns
a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`.
+ label_key: (Optional) A string corresponding to the key labels are stored in
+ `tf.Examples`. If provided, it must be one of the `features` key,
+ otherwise results in `ValueError`.
reader_args: Additional arguments to pass to the reader class.
num_epochs: Integer specifying the number of times to read through the
dataset. If None, cycles through the dataset forever. Defaults to `None`.
@@ -750,8 +754,11 @@ def make_batched_features_dataset(file_pattern,
`False`.
Returns:
- A dataset of `dict` elements. Each `dict` maps feature keys to
- `Tensor` or `SparseTensor` objects.
+ A dataset of `dict` elements, (or a tuple of `dict` elements and label).
+ Each `dict` maps feature keys to `Tensor` or `SparseTensor` objects.
+
+ Raises:
+ ValueError: If `label_key` is not one of the `features` keys.
"""
# Create dataset of all matching filenames
filenames = _get_file_names(file_pattern, False)
@@ -772,14 +779,13 @@ def make_batched_features_dataset(file_pattern,
# Extract values if the `Example` tensors are stored as key-value tuples.
if dataset.output_types == (dtypes.string, dtypes.string):
- dataset = dataset.map(lambda _, v: v)
+ dataset = dataset_ops.MapDataset(
+ dataset, lambda _, v: v, use_inter_op_parallelism=False)
# Apply dataset repeat and shuffle transformations.
dataset = _maybe_shuffle_and_repeat(
dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
- dataset = dataset.apply(stats_ops.feature_stats("record_stats"))
-
# NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
# improve the shape inference, because it makes the batch dimension static.
# It is safe to do this because in that case we are repeating the input
@@ -788,13 +794,17 @@ def make_batched_features_dataset(file_pattern,
batch_size, drop_remainder=drop_final_batch or num_epochs is None)
# Parse `Example` tensors to a dictionary of `Feature` tensors.
- dataset = dataset.map(
- lambda x: parsing_ops.parse_example(x, features),
- num_parallel_calls=parser_num_threads)
+ dataset = dataset.apply(
+ parsing_ops.parse_example_dataset(
+ features, num_parallel_calls=parser_num_threads))
+
+ if label_key:
+ if label_key not in features:
+ raise ValueError(
+ "The `label_key` provided (%r) must be one of the `features` keys." %
+ label_key)
+ dataset = dataset.map(lambda x: (x, x.pop(label_key)))
- # TODO(rachelim): Add an optional label_name argument for extracting the label
- # from the features dictionary, to comply with the type expected by the
- # input_fn to a `tf.Estimator.train` or `tf.Estimator.evaluate` function.
dataset = dataset.prefetch(prefetch_buffer_size)
return dataset
@@ -974,3 +984,49 @@ class SqlDataset(dataset_ops.Dataset):
@property
def output_types(self):
return self._output_types
+
+
+class LMDBDataset(dataset_ops.Dataset):
+ """A LMDB Dataset that reads the lmdb file."""
+
+ def __init__(self, filenames):
+ """Create a `LMDBDataset`.
+
+ `LMDBDataset` allows a user to read data from a mdb file as
+ (key value) pairs sequentially.
+ For example:
+ ```python
+ dataset = tf.contrib.lmdb.LMDBDataset("/foo/bar.mdb")
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+ # Prints the (key, value) pairs inside a lmdb file.
+ while True:
+ try:
+ print(sess.run(next_element))
+ except tf.errors.OutOfRangeError:
+ break
+ ```
+ Args:
+ filenames: A `tf.string` tensor containing one or more filenames.
+ """
+ super(LMDBDataset, self).__init__()
+ self._filenames = ops.convert_to_tensor(
+ filenames, dtype=dtypes.string, name="filenames")
+
+ def _as_variant_tensor(self):
+ return contrib_gen_dataset_ops.lmdb_dataset(
+ self._filenames,
+ output_types=nest.flatten(self.output_types),
+ output_shapes=nest.flatten(self.output_shapes))
+
+ @property
+ def output_classes(self):
+ return ops.Tensor, ops.Tensor
+
+ @property
+ def output_shapes(self):
+ return (tensor_shape.TensorShape([]), tensor_shape.TensorShape([]))
+
+ @property
+ def output_types(self):
+ return dtypes.string, dtypes.string
diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py
index 3b4e981402..8426228992 100644
--- a/tensorflow/contrib/data/python/ops/stats_ops.py
+++ b/tensorflow/contrib/data/python/ops/stats_ops.py
@@ -178,29 +178,6 @@ def latency_stats(tag):
return _apply_fn
-# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
-# or make private / remove.
-def feature_stats(tag):
- """Records the features stats from `Example` records of the input dataset.
-
- To consume the statistics, associate a `StatsAggregator` with the output
- dataset.
-
- Args:
- tag: String. All statistics recorded by the returned transformation will be
- associated with the given `tag`.
-
- Returns:
- A `Dataset` transformation function, which can be passed to
- `tf.data.Dataset.apply`.
- """
-
- def _apply_fn(dataset):
- return _StatsDataset(dataset, gen_dataset_ops.feature_stats_dataset, tag)
-
- return _apply_fn
-
-
class _StatsDataset(dataset_ops.Dataset):
"""A `Dataset` that acts as an identity, and also records statistics."""
diff --git a/tensorflow/contrib/distribute/BUILD b/tensorflow/contrib/distribute/BUILD
index c16f1d6035..a87a5624c8 100644
--- a/tensorflow/contrib/distribute/BUILD
+++ b/tensorflow/contrib/distribute/BUILD
@@ -35,5 +35,7 @@ py_library(
"//tensorflow/contrib/distribute/python:tpu_strategy",
"//tensorflow/python:training",
"//tensorflow/python:util",
+ "//tensorflow/python/distribute:distribute_config",
+ "//tensorflow/python/distribute:distribute_coordinator",
],
)
diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md
index 2f5dd10550..30e1992c01 100644
--- a/tensorflow/contrib/distribute/README.md
+++ b/tensorflow/contrib/distribute/README.md
@@ -1,6 +1,6 @@
# Distribution Strategy
-> *NOTE*: This is a experimental feature. The API and performance
+> *NOTE*: This is an experimental feature. The API and performance
> characteristics are subject to change.
## Overview
@@ -9,29 +9,111 @@
API is an easy way to distribute your training
across multiple devices/machines. Our goal is to allow users to use existing
models and training code with minimal changes to enable distributed training.
-Moreover, we've design the API in such a way that it works with both eager and
+Moreover, we've designed the API in such a way that it works with both eager and
graph execution.
-Currently we support one type of strategy, called
-[`MirroredStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/MirroredStrategy).
-It does in-graph replication with synchronous training
+Currently we support several types of strategies:
+
+* [`MirroredStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/MirroredStrategy):
+This does in-graph replication with synchronous training
on many GPUs on one machine. Essentially, we create copies of all variables in
the model's layers on each device. We then use all-reduce to combine gradients
across the devices before applying them to the variables to keep them in sync.
-In the future, we intend to support other kinds of training configurations such
-as multi-node, synchronous,
-[asynchronous](https://www.tensorflow.org/deploy/distributed#putting_it_all_together_example_trainer_program),
-parameter servers and model parallelism.
+* [`CollectiveAllReduceStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/CollectiveAllReduceStrategy):
+This is a version of `MirroredStrategy` for multi-working training. It uses
+a collective op to do all-reduce. This supports between-graph communication and
+synchronization, and delegates the specifics of the all-reduce implementation to
+the runtime (as opposed to encoding it in the graph). This allows it to perform
+optimizations like batching and switch between plugins that support different
+hardware or algorithms. In the future, this strategy will implement
+fault-tolerance to allow training to continue when there is worker failure.
+
+* [`ParameterServerStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/ParameterServerStrategy):
+This strategy supports using parameter servers either for multi-GPU local
+training or asynchronous multi-machine training. When used to train locally,
+variables are not mirrored, instead they placed on the CPU and operations are
+replicated across all local GPUs. In a multi-machine setting, some are
+designated as workers and some as parameter servers. Each variable is placed on
+one parameter server. Computation operations are replicated across all GPUs of
+the workers.
+
+## Multi-GPU Training
+
+## Example with Keras API
+
+Let's see how to scale to multiple GPUs on one machine using `MirroredStrategy` with [tf.keras] (https://www.tensorflow.org/guide/keras).
+
+Take a very simple model consisting of a single layer:
+
+```python
+inputs = tf.keras.layers.Input(shape=(1,))
+predictions = tf.keras.layers.Dense(1)(inputs)
+model = tf.keras.models.Model(inputs=inputs, outputs=predictions)
+```
-## Example
+Let's also define a simple input dataset for training this model. Note that currently we require using
+[`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset)
+with `DistributionStrategy`.
+
+```python
+features = tf.data.Dataset.from_tensors([1.]).repeat(10000).batch(10)
+labels = tf.data.Dataset.from_tensors([1.]).repeat(10000).batch(10)
+train_dataset = tf.data.Dataset.zip((features, labels))
+```
-Let's demonstrate how to use this API with a simple example. We will use the
-[`Estimator`](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator)
-approach, and show you how to scale your model to run on multiple GPUs on one
-machine using `MirroredStrategy`.
-Let's consider a very simple model function which tries to learn a simple
-function.
+To distribute this Keras model on multiple GPUs using `MirroredStrategy` we
+first instantiate a `MirroredStrategy` object.
+
+```python
+distribution = tf.contrib.distribute.MirroredStrategy()
+```
+
+We then compile the Keras model and pass the `MirroredStrategy` object in the
+`distribute` argument (apart from other usual arguments like `loss` and
+`optimizer`).
+
+```python
+model.compile(loss='mean_squared_error',
+ optimizer=tf.train.GradientDescentOptimizer(learning_rate=0.2),
+ distribute=strategy)
+```
+
+To train the model we call Keras `fit` API using the input dataset that we
+created earlier, same as how we would in a non-distributed case.
+
+```python
+model.fit(train_dataset, epochs=5, steps_per_epoch=10)
+```
+
+Similarly, we can also call `evaluate` and `predict` as before using appropriate
+datasets.
+
+```python
+model.evaluate(eval_dataset)
+model.predict(predict_dataset)
+```
+
+That's all you need to train your model with Keras on multiple GPUs with
+`MirroredStrategy`. It will take care of splitting up
+the input dataset, replicating layers and variables on each device, and
+combining and applying gradients.
+
+The model and input code does not have to change because we have changed the
+underlying components of TensorFlow (such as
+optimizer, batch norm and summaries) to become distribution-aware.
+That means those components know how to
+combine their state across devices. Further, saving and checkpointing works
+seamlessly, so you can save with one or no distribution strategy and resume with
+another.
+
+
+## Example with Estimator API
+
+You can also use Distribution Strategy API with [`Estimator`](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator). Let's see a simple example of it's usage with `MirroredStrategy`.
+
+
+Consider a very simple model function which tries to learn a simple function.
```python
def model_fn(features, labels, mode):
@@ -53,17 +135,14 @@ def model_fn(features, labels, mode):
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
```
-Let's also define a simple input function to feed data for training this model.
-Note that we require using
-[`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset)
-with `DistributionStrategy`.
+Again, let's define a simple input function to feed data for training this model.
```python
def input_fn():
features = tf.data.Dataset.from_tensors([[1.]]).repeat(100)
labels = tf.data.Dataset.from_tensors(1.).repeat(100)
- return dataset_ops.Dataset.zip((features, labels))
+ return tf.data.Dataset.zip((features, labels))
```
Now that we have a model function and input function defined, we can define the
@@ -80,20 +159,14 @@ distribution = tf.contrib.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(train_distribute=distribution)
classifier = tf.estimator.Estimator(model_fn=model_fn, config=config)
classifier.train(input_fn=input_fn)
+classifier.evaluate(input_fn=input_fn)
```
That's it! This change will now configure estimator to run on all GPUs on your
-machine, with the `MirroredStrategy` approach. It will take care of distributing
-the input dataset, replicating layers and variables on each device, and
-combining and applying gradients.
+machine.
-The model and input functions do not have to change because we have changed the
-underlying components of TensorFlow (such as
-optimizer, batch norm and summaries) to become distribution-aware.
-That means those components know how to
-combine their state across devices. Further, saving and checkpointing works
-seamlessly, so you can save with one or no distribution strategy and resume with
-another.
+
+## Customization and Performance Tips
Above, we showed the easiest way to use [`MirroredStrategy`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/distribute/MirroredStrategy#__init__).
There are few things you can customize in practice:
@@ -103,8 +176,6 @@ of GPUs (using param `num_gpus`), in case you don't want auto detection.
* You can specify various parameters for all reduce with the `cross_tower_ops`
param, such as the all reduce algorithm to use, and gradient repacking.
-## Performance Tips
-
We've tried to make it such that you get the best performance for your existing
model. We also recommend you follow the tips from
[Input Pipeline Performance Guide](https://www.tensorflow.org/performance/datasets_performance).
@@ -113,15 +184,177 @@ and [`dataset.prefetch`](https://www.tensorflow.org/performance/datasets_perform
in the input function gives a solid boost in performance. When using
`dataset.prefetch`, use `buffer_size=None` to let it detect optimal buffer size.
+## Multi-worker Training
+### Overview
+
+For multi-worker training, no code change is required to the `Estimator` code.
+You can run the same model code for all tasks in your cluster including
+parameter servers and the evaluator. But you need to use
+`tf.estimator.train_and_evaluator`, explicitly specify `num_gpus_per_workers`
+for your strategy object, and set "TF\_CONFIG" environment variables for each
+binary running in your cluster. We'll provide a Kubernetes template in the
+[tensorflow/ecosystem](https://github.com/tensorflow/ecosystem) repo which sets
+"TF\_CONFIG" for your training tasks.
+
+### TF\_CONFIG environment variable
+
+The "TF\_CONFIG" environment variables is a JSON string which specifies what
+tasks constitute a cluster, their addresses and each task's role in the cluster.
+One example of "TF\_CONFIG" is:
+
+```python
+TF_CONFIG='{
+ "cluster": {
+ "worker": ["host1:port", "host2:port", "host3:port"],
+ "ps": ["host4:port", "host5:port"]
+ },
+ "task": {"type": "worker", "index": 1}
+}'
+```
+
+This "TF\_CONFIG" specifies that there are three workers and two ps tasks in the
+cluster along with their hosts and ports. The "task" part specifies that the
+role of the current task in the cluster, worker 1. Valid roles in a cluster is
+"chief", "worker", "ps" and "evaluator". There should be no "ps" job for
+`CollectiveAllReduceStrategy` and `MirroredStrategy`. The "evaluator" job is
+optional and can have at most one task. It does single machine evaluation and if
+you don't want to do evaluation, you can pass in a dummy `input_fn` to the
+`tf.estimator.EvalSpec` of `tf.estimator.train_and_evaluate`.
+
+### Dataset
+
+The `input_fn` you provide to estimator code is for one worker. So remember to
+scale up your batch if you have multiple GPUs on each worker.
+
+The same `input_fn` will be used for all workers if you use
+`CollectiveAllReduceStrategy` and `ParameterServerStrategy`. Therefore it is
+important to shuffle your dataset in your `input_fn`.
+
+`MirroredStrategy` will insert a `tf.dataset.Dataset.shard` call in you
+`input_fn`. As a result, each worker gets a fraction of your input data.
+
+### Performance Tips
+
+We have been actively working on multi-worker performance. Currently, prefer
+`CollectiveAllReduceStrategy` for synchronous multi-worker training.
+
+### Example
+
+Let's use the same example for multi-worker. We'll start a cluster with 3
+workers doing synchronous all-reduce training. In the following code snippet, we
+start multi-worker training using `tf.estimator.train_and_evaluate`:
+
+
+```python
+def model_main():
+ estimator = ...
+ distribution = tf.contrib.distribute.CollectiveAllReduceStrategy(
+ num_gpus_per_worker=2)
+ config = tf.estimator.RunConfig(train_distribute=distribution)
+ train_spec = tf.estimator.TrainSpec(input_fn=input_fn)
+ eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
+ tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
+```
+
+
+**Note**: You don't have to set "TF\_CONFIG" manually if you use our provided
+Kubernetes template.
+
+You'll then need 3 machines, find out their host addresses and one available
+port on each machine. Then set "TF\_CONFIG" in each binary and run the above
+model code.
+
+In your worker 0, run:
+
+```python
+os.environ["TF_CONFIG"] = json.dumps({
+ "cluster": {
+ "worker": ["host1:port", "host2:port", "host3:port"]
+ },
+ "task": {"type": "worker", "index": 0}
+})
+
+# Call the model_main function defined above.
+model_main()
+```
+
+In your worker 1, run:
+
+```python
+os.environ["TF_CONFIG"] = json.dumps({
+ "cluster": {
+ "worker": ["host1:port", "host2:port", "host3:port"]
+ },
+ "task": {"type": "worker", "index": 1}
+})
+
+# Call the model_main function defined above.
+model_main()
+```
+
+In your worker 2, run:
+
+```python
+os.environ["TF_CONFIG"] = json.dumps({
+ "cluster": {
+ "worker": ["host1:port", "host2:port", "host3:port"]
+ },
+ "task": {"type": "worker", "index": 2}
+})
+
+# Call the model_main function defined above.
+model_main()
+```
+
+Then you'll find your cluster has started training! You can inspect the logs of
+workers or start a tensorboard.
+
+### Standalone client mode
+
+We have a new way to run distributed training. You can bring up standard
+tensorflow servers in your cluster and run your model code anywhere such as on
+your laptop.
+
+In the above example, instead of calling `model_main`, you can call
+`tf.contrib.distribute.run_standard_tensorflow_server().join()`. This will bring
+up a cluster running standard tensorflow servers which wait for your request to
+start training.
+
+On your laptop, you can run
+
+```python
+estimator = ...
+distribution = tf.contrib.distribute.CollectiveAllReduceStrategy(
+ num_gpus_per_worker=2)
+config = tf.estimator.RunConfig(
+ experimental_distribute=tf.contrib.distribute.DistributeConfig(
+ train_distribute=distribution,
+ remote_cluster={"worker": ["host1:port", "host2:port", "host3:port"]}))
+train_spec = tf.estimator.TrainSpec(input_fn=input_fn)
+eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
+tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
+```
+
+Then you will see the training logs on your laptop. You can terminate the
+training by terminating your process on your laptop. You can also modify your
+code and run a new model against the same cluster.
+
+We've been optimizing the performance of standalone client mode. If you notice
+high latency between your laptop and your cluster, you can reduce that latency
+by running your model binary in the cluster.
+
## Caveats
+
This feature is in early stages and there are a lot of improvements forthcoming:
* Summaries are only computed in the first tower in `MirroredStrategy`.
-* Evaluation is not yet distributed.
* Eager support is in the works; performance can be more challenging with eager
execution.
-* As mentioned earlier, multi-node and other distributed strategies will be
-introduced in the future.
+* We currently support the following predefined Keras callbacks:
+`ModelCheckpointCallback`, `TensorBoardCallback`. We will soon be adding support for
+some of the other callbacks such as `EarlyStopping`, `ReduceLROnPlateau`, etc. If you
+create your own callback, you will not have access to all model properties and
+validation data.
* If you are [`batching`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch)
your input data, we will place one batch on each GPU in each step. So your
effective batch size will be `num_gpus * batch_size`. Therefore, consider
diff --git a/tensorflow/contrib/distribute/__init__.py b/tensorflow/contrib/distribute/__init__.py
index 588a4f2898..350f81f60f 100644
--- a/tensorflow/contrib/distribute/__init__.py
+++ b/tensorflow/contrib/distribute/__init__.py
@@ -27,6 +27,8 @@ from tensorflow.contrib.distribute.python.one_device_strategy import OneDeviceSt
from tensorflow.contrib.distribute.python.parameter_server_strategy import ParameterServerStrategy
from tensorflow.contrib.distribute.python.step_fn import *
from tensorflow.contrib.distribute.python.tpu_strategy import TPUStrategy
+from tensorflow.python.distribute.distribute_config import DistributeConfig
+from tensorflow.python.distribute.distribute_coordinator import run_standard_tensorflow_server
from tensorflow.python.training.distribute import *
from tensorflow.python.training.distribution_strategy_context import *
@@ -37,6 +39,7 @@ _allowed_symbols = [
'AllReduceCrossTowerOps',
'CollectiveAllReduceStrategy',
'CrossTowerOps',
+ 'DistributeConfig',
'DistributionStrategy',
'MirroredStrategy',
'Monitor',
@@ -54,6 +57,7 @@ _allowed_symbols = [
'get_tower_context',
'has_distribution_strategy',
'require_tower_context',
+ 'run_standard_tensorflow_server',
'UpdateContext',
]
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index 59efd17746..87f76eaa94 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -23,8 +23,6 @@ py_library(
deps = [
":input_ops",
":prefetching_ops_v2",
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/contrib/eager/python:datasets",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:device_util",
@@ -85,6 +83,7 @@ py_library(
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
+ "//tensorflow/python/distribute:multi_worker_util",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:tape",
],
@@ -105,6 +104,38 @@ py_library(
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python/distribute:multi_worker_util",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
+cuda_py_test(
+ name = "parameter_server_strategy_test",
+ srcs = ["parameter_server_strategy_test.py"],
+ additional_deps = [
+ ":combinations",
+ ":multi_worker_test_base",
+ ":parameter_server_strategy",
+ ":values",
+ "@absl_py//absl/testing:parameterized",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:session",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/distribute:multi_worker_util",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/estimator:estimator_py",
+ ],
+ tags = [
+ "multi_and_single_gpu",
+ "no_pip",
],
)
@@ -138,6 +169,7 @@ py_library(
"//tensorflow/python:collective_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:training",
+ "//tensorflow/python/distribute:multi_worker_util",
"//tensorflow/python/eager:context",
],
)
@@ -237,48 +269,21 @@ py_test(
],
)
-py_test(
- name = "parameter_server_strategy_test",
- srcs = ["parameter_server_strategy_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_pip",
- ],
- deps = [
- ":combinations",
- ":multi_worker_test_base",
- ":parameter_server_strategy",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:gradients",
- "//tensorflow/python:layers",
- "//tensorflow/python:session",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- "//tensorflow/python/eager:context",
- "//tensorflow/python/estimator:estimator_py",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
cuda_py_test(
name = "mirrored_strategy_multigpu_test",
srcs = ["mirrored_strategy_multigpu_test.py"],
additional_deps = [
":mirrored_strategy",
+ ":multi_worker_test_base",
":values",
":strategy_test_lib",
"//tensorflow/python:distribute",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:layers",
+ "//tensorflow/python:state_ops",
"//tensorflow/python:variable_scope",
- "//tensorflow/python:array_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
@@ -339,19 +344,17 @@ py_library(
],
)
-py_test(
+cuda_py_test(
name = "collective_all_reduce_strategy_test",
srcs = ["collective_all_reduce_strategy_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_pip",
- ],
- deps = [
+ additional_deps = [
":collective_all_reduce_strategy",
":combinations",
":cross_tower_utils",
":multi_worker_test_base",
":strategy_test_lib",
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@@ -365,8 +368,10 @@ py_test(
"//tensorflow/python:variables",
"//tensorflow/python/eager:context",
"//tensorflow/python/estimator:estimator_py",
- "//third_party/py/numpy",
- "@absl_py//absl/testing:parameterized",
+ ],
+ tags = [
+ "multi_and_single_gpu",
+ "no_pip",
],
)
@@ -446,6 +451,35 @@ cuda_py_test(
],
)
+cuda_py_test(
+ name = "estimator_training_test",
+ size = "large",
+ srcs = ["estimator_training_test.py"],
+ additional_deps = [
+ ":combinations",
+ ":mirrored_strategy",
+ ":multi_worker_test_base",
+ ":parameter_server_strategy",
+ "//third_party/py/numpy",
+ "//tensorflow/contrib/optimizer_v2:training",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/distribute",
+ "//tensorflow/python/eager:test",
+ "//tensorflow/python/estimator:estimator_py",
+ "//tensorflow/python/feature_column",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:summary",
+ ],
+ tags = [
+ "manual",
+ "multi_and_single_gpu",
+ "no_pip",
+ "nogpu",
+ "notap",
+ ],
+)
+
py_library(
name = "single_loss_example",
srcs = ["single_loss_example.py"],
@@ -601,6 +635,7 @@ cuda_py_test(
":combinations",
":cross_tower_ops",
":multi_worker_test_base",
+ ":mirrored_strategy",
":values",
"@absl_py//absl/testing:parameterized",
"//tensorflow/python:array_ops",
@@ -673,19 +708,32 @@ cuda_py_test(
],
)
-cuda_py_test(
- name = "keras_test",
+py_library(
+ name = "keras_test_lib",
+ testonly = 1,
srcs = ["keras_test.py"],
- additional_deps = [
- "//third_party/py/numpy",
+ deps = [
+ ":combinations",
"//tensorflow/contrib/distribute/python:mirrored_strategy",
+ "//tensorflow/contrib/distribute/python:tpu_strategy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:training",
"//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/keras",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+cuda_py_test(
+ name = "keras_test",
+ srcs = ["keras_test.py"],
+ additional_deps = [
+ ":keras_test_lib",
],
tags = [
"multi_and_single_gpu",
+ "no_pip",
"no_windows_gpu",
"notsan",
],
diff --git a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py
index 95b824e51a..865dba803f 100644
--- a/tensorflow/contrib/distribute/python/checkpoint_utils_test.py
+++ b/tensorflow/contrib/distribute/python/checkpoint_utils_test.py
@@ -48,7 +48,7 @@ class CheckpointUtilsWithDistributionStrategyTest(
mode=["graph"]))
def testInitFromCheckpoint(self, distribution, in_tower_mode):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1_value, v2_value, _, _ = checkpoint_utils_test._create_checkpoints(
session, checkpoint_dir)
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
index 9afcaecf78..77079d0df9 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
@@ -18,96 +18,96 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import json
-import os
-
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
from tensorflow.contrib.distribute.python import cross_tower_utils
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import values
-from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import collective_ops
-from tensorflow.python.training import server_lib
-
+from tensorflow.python.platform import tf_logging as logging
-# TODO(yuefengz): move this function to a common util file.
-def _normalize_cluster_spec(cluster_spec):
- if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
- return server_lib.ClusterSpec(cluster_spec)
- elif not isinstance(cluster_spec, server_lib.ClusterSpec):
- raise ValueError(
- "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
- "`tf.train.ClusterDef` object")
- return cluster_spec
-
-# TODO(yuefengz): shard the dataset.
# TODO(yuefengz): support in-graph replication.
-# TODO(yuefengz): it only works with a cluster without a chief node, maybe
-# support chief node?
class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
"""Distribution strategy that uses collective ops for all-reduce.
It is similar to the MirroredStrategy but it uses collective ops for
- reduction. It currently only works for between-graph replication and its
- reduction will reduce across all workers.
+ reduction.
+
+ When `cluster_spec` is given by the `configure` method, it turns into the
+ mulit-worker version that works on multiple workers with between-graph
+ replication.
+
+ Note: `configure` will be called by higher-level APIs if running in
+ distributed environment.
"""
- def __init__(self,
- num_gpus_per_worker=0,
- cluster_spec=None,
- task_type="worker",
- task_id=0):
+ def __init__(self, num_gpus_per_worker=0):
"""Initializes the object.
Args:
- num_gpus_per_worker: number of local GPUs or GPUs per worker.
- cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
- cluster configurations.
- task_type: the current task type, such as "worker".
- task_id: the current task id.
-
- Raises:
- ValueError: if `task_type` is not in the `cluster_spec`.
+ num_gpus_per_worker: number of local GPUs or GPUs per worker, the default
+ is 0 meaning CPU only.
"""
self._num_gpus_per_worker = num_gpus_per_worker
- self._initialize(cluster_spec, task_type, task_id)
+ self._initialize_local_worker(num_gpus_per_worker)
+
+ def _initialize_local_worker(self, num_gpus_per_worker):
+ """Initializes the object for local training."""
+ self._is_chief = True
+ self._num_workers = 1
- def _initialize(self, cluster_spec, task_type, task_id):
+ if num_gpus_per_worker:
+ local_devices = [
+ "/device:GPU:%d" % i for i in range(num_gpus_per_worker)
+ ]
+ else:
+ local_devices = ["/device:CPU:0"]
+
+ self._collective_keys = cross_tower_utils.CollectiveKeys()
+ super(CollectiveAllReduceStrategy, self).__init__(
+ devices=local_devices,
+ cross_tower_ops=cross_tower_ops_lib.CollectiveAllReduce(
+ num_workers=1,
+ num_gpus_per_worker=num_gpus_per_worker,
+ collective_keys=self._collective_keys))
+
+ self._cluster_spec = None
+ self._task_type = None
+ self._task_id = None
+
+ logging.info("CollectiveAllReduceStrategy with local_devices = %r",
+ local_devices)
+
+ def _initialize_multi_worker(self, num_gpus_per_worker, cluster_spec,
+ task_type, task_id):
+ """Initializes the object for multi-worker training."""
+ if task_type is None or task_id is None:
+ raise ValueError("When `cluster_spec` is given, you must also specify "
+ "`task_type` and `task_id`")
if task_type not in ["chief", "worker"]:
raise ValueError(
"Unrecognized task_type: %r, valid task types are: \"chief\", "
"\"worker\"." % task_type)
- if cluster_spec:
- self._cluster_spec = _normalize_cluster_spec(cluster_spec)
- worker_device = "/job:%s/task:%d" % (task_type, task_id)
- num_workers = len(self._cluster_spec.as_dict().get(task_type, []))
- if "chief" in self._cluster_spec.as_dict():
- num_workers += 1
- if not num_workers:
- raise ValueError("`task_type` shoud be in `cluster_spec`.")
-
- # TODO(yuefengz): create a utility to infer chief.
- if "chief" in self._cluster_spec.as_dict() and task_type == "chief":
- assert task_id == 0
- self._is_chief = True
- else:
- assert task_type == "worker"
- self._is_chief = task_id == 0
- else:
- self._cluster_spec = None
- self._is_chief = True
- worker_device = ""
- num_workers = 1
- self._num_workers = num_workers
+ cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
+ self._num_workers = len(cluster_spec.as_dict().get("worker", [])) + len(
+ cluster_spec.as_dict().get("chief", []))
+ if not self._num_workers:
+ raise ValueError("No `worker` or `chief` tasks can be found in "
+ "`cluster_spec`.")
+
+ self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
+ task_id)
- if self._num_gpus_per_worker:
+ worker_device = "/job:%s/task:%d" % (task_type, task_id)
+ if num_gpus_per_worker:
local_devices = [
"%s/device:GPU:%d" % (worker_device, i)
- for i in range(self._num_gpus_per_worker)
+ for i in range(num_gpus_per_worker)
]
else:
local_devices = [worker_device]
@@ -116,14 +116,23 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
super(CollectiveAllReduceStrategy, self).__init__(
devices=local_devices,
cross_tower_ops=cross_tower_ops_lib.CollectiveAllReduce(
- num_workers=num_workers,
- num_gpus_per_worker=self._num_gpus_per_worker,
+ num_workers=self._num_workers,
+ num_gpus_per_worker=num_gpus_per_worker,
collective_keys=self._collective_keys))
# Add a default device so that ops without specified devices will not end up
# on other workers.
- if cluster_spec:
- self._default_device = "/job:%s/replica:0/task:%d" % (task_type, task_id)
+ self._default_device = "/job:%s/task:%d" % (task_type, task_id)
+
+ self._cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
+ self._task_type = task_type
+ self._task_id = task_id
+
+ logging.info(
+ "Multi-worker CollectiveAllReduceStrategy with "
+ "cluster_spec = %r, task_type = %r, task_id = %r, "
+ "num_workers = %r, local_devices = %r", cluster_spec.as_dict(),
+ task_type, task_id, self._num_workers, local_devices)
def _create_variable(self, next_creator, *args, **kwargs):
colocate_with = kwargs.pop("colocate_with", None)
@@ -187,19 +196,81 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
return mirrored_strategy._create_mirrored_variable(
devices, _real_mirrored_creator, *args, **kwargs)
- def configure(self, session_config=None):
- # Use TF_CONFIG to get the cluster spec and the current job.
- if not self._cluster_spec:
- tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
- cluster_spec = _normalize_cluster_spec(tf_config.get("cluster", {}))
+ def distribute_dataset(self, dataset_fn):
+ """Distributes the dataset to each local GPU."""
+ # TODO(yuefengz): shard the dataset.
+ return values.PerDeviceDataset(
+ self._call_dataset_fn(dataset_fn), self._devices, True)
- task_env = tf_config.get("task", {})
- if task_env:
- task_type = task_env.get("type", "worker")
- task_id = int(task_env.get("index", "0"))
- else:
- task_type = "worker"
- task_id = 0
+ def configure(self,
+ session_config=None,
+ cluster_spec=None,
+ task_type=None,
+ task_id=None):
+ """Configures the object.
+
+ Args:
+ session_config: a @{tf.ConfigProto}
+ cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
+ cluster configurations.
+ task_type: the current task type, such as "worker".
+ task_id: the current task id.
+
+ Raises:
+ ValueError: if `task_type` is not in the `cluster_spec`.
+ """
+ if not self._cluster_spec and cluster_spec:
+ # If a `cluster_spec` is already passed in, do nothing here.
+ # TODO(yuefengz): check `cluster_spec` is the same if this object has
+ # already been initialized with a `cluster_spec`.
+ self._initialize_multi_worker(self._num_gpus_per_worker, cluster_spec,
+ task_type, task_id)
+
+ if not session_config or not self._cluster_spec:
+ return
+
+ session_config.isolate_session_state = True
+
+ assert self._task_type
+ assert self._task_id is not None
+
+ # Collective group leader is needed for collective ops to coordinate
+ # workers.
+ if "chief" in self._cluster_spec.jobs:
+ session_config.experimental.collective_group_leader = (
+ "/job:chief/replica:0/task:0")
+ else:
+ if "worker" not in self._cluster_spec.jobs:
+ raise ValueError(
+ "You must have `chief` or `worker` jobs in the `cluster_spec`.")
+ session_config.experimental.collective_group_leader = (
+ "/job:worker/replica:0/task:0")
+
+ # The device filters prevent communication between workers.
+ del session_config.device_filters[:]
+ session_config.device_filters.append(
+ "/job:%s/task:%d" % (self._task_type, self._task_id))
+
+ # The scoped_allocator_optimization is to optimize graphs for collective
+ # ops.
+ rewrite_options = session_config.graph_options.rewrite_options
+ rewrite_options.scoped_allocator_optimization = (
+ rewriter_config_pb2.RewriterConfig.ON)
+ del rewrite_options.scoped_allocator_opts.enable_op[:]
+ rewrite_options.scoped_allocator_opts.enable_op.append("CollectiveReduce")
+
+ @property
+ def between_graph(self):
+ return True
+
+ @property
+ def should_init(self):
+ return True
+
+ @property
+ def should_checkpoint(self):
+ return self._is_chief
- if cluster_spec:
- self._initialize(cluster_spec, task_type, task_id)
+ @property
+ def should_save_summary(self):
+ return self._is_chief
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
index b5e54e3b7d..36e9761073 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
@@ -25,10 +25,8 @@ from tensorflow.contrib.distribute.python import collective_all_reduce_strategy
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import cross_tower_utils
from tensorflow.contrib.distribute.python import multi_worker_test_base
-from tensorflow.contrib.distribute.python import strategy_test_lib
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import context
-from tensorflow.python.estimator import run_config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -41,53 +39,46 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-class DistributedCollectiveAllReduceStrategyTest(
- multi_worker_test_base.MultiWorkerTestBase, parameterized.TestCase):
+class CollectiveAllReduceStrategyTestBase(
+ multi_worker_test_base.MultiWorkerTestBase):
collective_key_base = 0
- @classmethod
- def setUpClass(cls):
- """Create a local cluster with 2 workers."""
- cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster(
- num_workers=3, num_ps=0)
- cls._cluster_spec = {
- run_config.TaskType.WORKER: [
- 'fake_worker_0', 'fake_worker_1', 'fake_worker_2'
- ]
- }
-
def setUp(self):
self._run_options = config_pb2.RunOptions()
self._run_options.experimental.collective_graph_key = 6
self._sess_config = config_pb2.ConfigProto()
- self._sess_config.experimental.collective_group_leader = (
- '/job:worker/replica:0/task:0')
# We use a different key_base for each test so that collective keys won't be
# reused.
# TODO(yuefengz, tucker): enable it to reuse collective keys in different
# tests.
- DistributedCollectiveAllReduceStrategyTest.collective_key_base += 100000
- super(DistributedCollectiveAllReduceStrategyTest, self).setUp()
+ CollectiveAllReduceStrategyTestBase.collective_key_base += 100000
+ super(CollectiveAllReduceStrategyTestBase, self).setUp()
def _get_test_object(self, task_type, task_id, num_gpus=0):
distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
- num_gpus_per_worker=num_gpus,
- cluster_spec=self._cluster_spec,
- task_type=task_type,
- task_id=task_id)
+ num_gpus_per_worker=num_gpus)
+ if task_type and task_id is not None:
+ distribution.configure(
+ session_config=self._sess_config,
+ cluster_spec=self._cluster_spec,
+ task_type=task_type,
+ task_id=task_id)
collective_keys = cross_tower_utils.CollectiveKeys(
group_key_start=10 * num_gpus +
- DistributedCollectiveAllReduceStrategyTest.collective_key_base,
+ CollectiveAllReduceStrategyTestBase.collective_key_base,
instance_key_start=num_gpus * 100 +
- DistributedCollectiveAllReduceStrategyTest.collective_key_base,
+ CollectiveAllReduceStrategyTestBase.collective_key_base,
instance_key_with_id_start=num_gpus * 10000 +
- DistributedCollectiveAllReduceStrategyTest.collective_key_base)
+ CollectiveAllReduceStrategyTestBase.collective_key_base)
distribution._collective_keys = collective_keys
distribution._cross_tower_ops._collective_keys = collective_keys
- return distribution, self._workers[task_id].target
+ if task_type and task_id is not None:
+ return distribution, 'grpc://' + self._cluster_spec[task_type][task_id]
+ else:
+ return distribution, ''
def _test_minimize_loss_graph(self, task_type, task_id, num_gpus):
d, master_target = self._get_test_object(task_type, task_id, num_gpus)
@@ -155,12 +146,6 @@ class DistributedCollectiveAllReduceStrategyTest(
self.assertLess(error_after, error_before)
return error_after < error_before
- @combinations.generate(
- combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
- def testMinimizeLossGraph(self, num_gpus):
- self._run_between_graph_clients(self._test_minimize_loss_graph,
- self._cluster_spec, num_gpus)
-
def _test_variable_initialization(self, task_type, task_id, num_gpus):
distribution, master_target = self._get_test_object(task_type, task_id,
num_gpus)
@@ -182,16 +167,37 @@ class DistributedCollectiveAllReduceStrategyTest(
distribution.reduce(
variable_scope.VariableAggregation.MEAN, x,
destinations='/cpu:0'))[0]
+ x = distribution.unwrap(x)[0]
sess.run(
variables.global_variables_initializer(), options=self._run_options)
+
x_value, reduced_x_value = sess.run(
[x, reduced_x], options=self._run_options)
- self.assertTrue(np.array_equal(x_value, reduced_x_value))
- return np.array_equal(x_value, reduced_x_value)
+ self.assertTrue(
+ np.allclose(x_value, reduced_x_value, atol=1e-5),
+ msg=('x_value = %r, reduced_x_value = %r' % (x_value,
+ reduced_x_value)))
+ return np.allclose(x_value, reduced_x_value, atol=1e-5)
+
+
+class DistributedCollectiveAllReduceStrategyTest(
+ CollectiveAllReduceStrategyTestBase, parameterized.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ """Create a local cluster with 3 workers."""
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
+ num_workers=3, num_ps=0)
+
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
+ def testMinimizeLossGraph(self, num_gpus):
+ self._run_between_graph_clients(self._test_minimize_loss_graph,
+ self._cluster_spec, num_gpus)
@combinations.generate(
- combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
def testVariableInitialization(self, num_gpus):
if context.num_gpus() < num_gpus:
return
@@ -201,16 +207,44 @@ class DistributedCollectiveAllReduceStrategyTest(
num_gpus=num_gpus)
-class LocalCollectiveAllReduceStrategy(strategy_test_lib.DistributionTestBase,
- parameterized.TestCase):
+class DistributedCollectiveAllReduceStrategyTestWithChief(
+ CollectiveAllReduceStrategyTestBase, parameterized.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ """Create a local cluster with 3 workers and 1 chief."""
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
+ num_workers=3, num_ps=0, has_chief=True)
+
+ def setUp(self):
+ super(DistributedCollectiveAllReduceStrategyTestWithChief, self).setUp()
+ self._run_options.experimental.collective_graph_key = 7
+
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
+ def testMinimizeLossGraph(self, num_gpus):
+ self._run_between_graph_clients(self._test_minimize_loss_graph,
+ self._cluster_spec, num_gpus)
+
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
+ def testVariableInitialization(self, num_gpus):
+ if context.num_gpus() < num_gpus:
+ return
+ self._run_between_graph_clients(
+ self._test_variable_initialization,
+ self._cluster_spec,
+ num_gpus=num_gpus)
+
+
+class LocalCollectiveAllReduceStrategy(
+ CollectiveAllReduceStrategyTestBase, parameterized.TestCase):
def testMinimizeLossGraph(self, num_gpus=2):
# Collective ops doesn't support strategy with one device.
if context.num_gpus() < num_gpus:
return
- distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
- num_gpus_per_worker=num_gpus)
- self._test_minimize_loss_graph(distribution)
+ self._test_minimize_loss_graph(None, None, num_gpus)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index 2fbadfe0f5..1133be6d0b 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -328,6 +328,10 @@ tpu_strategy = NamedDistribution(
"TPU", lambda: tpu_lib.TPUStrategy(
TPUClusterResolver(""), steps_per_run=5),
required_tpu=True)
+tpu_strategy_one_step = NamedDistribution(
+ "TPU", lambda: tpu_lib.TPUStrategy(
+ TPUClusterResolver(""), steps_per_run=1),
+ required_tpu=True)
# Note that we disable prefetching for testing since prefetching makes
# the input non-deterministic.
mirrored_strategy_with_gpu_and_cpu = NamedDistribution(
@@ -341,33 +345,6 @@ mirrored_strategy_with_two_gpus = NamedDistribution(
["/gpu:0", "/gpu:1"], prefetch_on_device=False),
required_gpus=2)
-multi_worker_strategy_with_cpu = NamedDistribution(
- "MultiWorkerCPU",
- lambda: mirrored_lib.MirroredStrategy(
- cluster_spec={
- "worker": [
- "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"
- ]
- },
- num_gpus=0), 0)
-multi_worker_strategy_with_one_gpu = NamedDistribution(
- "MultiWorker1GPU",
- lambda: mirrored_lib.MirroredStrategy(
- cluster_spec={
- "worker": [
- "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"
- ]
- },
- num_gpus=1), 1)
-multi_worker_strategy_with_two_gpus = NamedDistribution(
- "MultiWorker2GPUs",
- lambda: mirrored_lib.MirroredStrategy(
- cluster_spec={
- "worker": [
- "/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"
- ]
- },
- num_gpus=2), 2)
adam_optimizer_v1_fn = NamedObject(
"AdamV1", lambda: adam.AdamOptimizer(0.2, epsilon=1))
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops.py b/tensorflow/contrib/distribute/python/cross_tower_ops.py
index 163559587d..e08ba9c2a6 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops.py
@@ -35,13 +35,13 @@ from tensorflow.python.training import device_util
def check_destinations(destinations):
- """Checks whether `destinations` is not None and not empty.
+ """Checks whether `destinations` is not empty.
Args:
destinations: a DistributedValues, Variable, string or a list of strings.
Returns:
- Boolean indicating whether `destinations` is not None and not empty.
+ Boolean which is True if `destinations` is not empty.
"""
# Calling bool() on a ResourceVariable is not allowed.
if isinstance(destinations, resource_variable_ops.ResourceVariable):
@@ -56,13 +56,50 @@ def validate_destinations(destinations):
value_lib.AggregatingVariable, six.string_types, list)):
raise ValueError("destinations must be one of a `DistributedValues` object,"
" a tf.Variable object, a device string, a list of device "
- "strings or None")
+ "strings")
if not check_destinations(destinations):
raise ValueError("destinations can not be empty")
+def _make_tensor_into_per_device(input_tensor):
+ """Converts a single tensor into a PerDevice object."""
+ if isinstance(input_tensor, (tuple, list)):
+ raise ValueError("Cannot convert `input_tensor` to a `PerDevice` object, "
+ "got %r but expected a object that is not a tuple or list."
+ % (input_tensor,))
+ if isinstance(input_tensor, value_lib.PerDevice):
+ return input_tensor
+
+ try:
+ device = input_tensor.device
+ except AttributeError:
+ raise ValueError("Cannot convert `input_tensor` to a `PerDevice` object "
+ "because it doesn't have device set.")
+
+ return value_lib.PerDevice({device: input_tensor})
+
+
+def _normalize_value_destination_pairs(value_destination_pairs):
+ """Converts each tensor into a PerDevice object in the input list."""
+ result = []
+ if not isinstance(value_destination_pairs, (list, tuple)):
+ raise ValueError("`value_destination_pairs` should be a list or tuple")
+ for pair in value_destination_pairs:
+ if not isinstance(pair, tuple):
+ raise ValueError(
+ "Each element of `value_destination_pairs` should be a tuple.")
+ if len(pair) != 2:
+ raise ValueError("Each element of `value_destination_pairs` should be a "
+ "tuple of size 2.")
+
+ per_device = _make_tensor_into_per_device(pair[0])
+ result.append((per_device, pair[1]))
+ return result
+
+
def _validate_value_destination_pairs(value_destination_pairs):
+ # TODO(yuefengz): raise exceptions instead of returning False.
# pylint: disable=g-missing-docstring
if not value_destination_pairs: return False
if not isinstance(value_destination_pairs, (list, tuple)): return False
@@ -83,8 +120,10 @@ def get_devices_from(destinations):
return [destinations.device]
elif isinstance(destinations, six.string_types):
return [device_util.resolve(destinations)]
- else:
+ elif isinstance(destinations, (list, tuple)):
return [device_util.resolve(destination) for destination in destinations]
+ else:
+ return [destinations.device]
def _devices_match(left, right):
@@ -92,8 +131,7 @@ def _devices_match(left, right):
def _all_devices_match(value_destination_pairs):
- if not all([d is None or _devices_match(v, d)
- for v, d in value_destination_pairs]):
+ if not all([_devices_match(v, d) for v, d in value_destination_pairs]):
return False
if not all([_devices_match(v, value_destination_pairs[0][0])
for v, _ in value_destination_pairs[1:]]):
@@ -150,7 +188,7 @@ class CrossTowerOps(object):
def __init__(self):
pass
- def reduce(self, aggregation, per_device_value, destinations=None):
+ def reduce(self, aggregation, per_device_value, destinations):
"""Reduce `per_device_value` to `destinations`.
It runs the reduction operation defined by `aggregation` and put the
@@ -159,7 +197,7 @@ class CrossTowerOps(object):
Args:
aggregation: Indicates how a variable will be aggregated. Accepted values
are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`.
- per_device_value: a PerDevice object.
+ per_device_value: a PerDevice object or a tensor with device set.
destinations: the reduction destinations.
Returns:
@@ -169,9 +207,9 @@ class CrossTowerOps(object):
ValueError: if per_device_value is not a PerDevice object.
"""
if not isinstance(per_device_value, value_lib.PerDevice):
- raise ValueError("`per_device_value` must be a `PerDevice` object.")
- if destinations is not None:
- validate_destinations(destinations)
+ per_device_value = _make_tensor_into_per_device(per_device_value)
+
+ validate_destinations(destinations)
return self._reduce(aggregation, per_device_value, destinations)
def batch_reduce(self, aggregation, value_destination_pairs):
@@ -184,8 +222,7 @@ class CrossTowerOps(object):
aggregation: Indicates how a variable will be aggregated. Accepted values
are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`.
value_destination_pairs: a list or a tuple of tuples of PerDevice objects
- and destinations. If a destination is None, then the destinations
- are set to match the devices of the input PerDevice object.
+ (or tensors with device set if there is one tower) and destinations.
Returns:
a list of Mirrored objects.
@@ -195,11 +232,13 @@ class CrossTowerOps(object):
tuples of PerDevice objects and destinations
"""
if not _validate_value_destination_pairs(value_destination_pairs):
- raise ValueError("`value_destination_pairs` must be a list or a tuple of "
- "tuples of PerDevice objects and destinations")
+ # If the first element of each pair is a tensor, we try to turn it into a
+ # PerDevice object.
+ value_destination_pairs = _normalize_value_destination_pairs(
+ value_destination_pairs)
+
for _, d in value_destination_pairs:
- if d is not None:
- validate_destinations(d)
+ validate_destinations(d)
return self._batch_reduce(aggregation, value_destination_pairs)
@@ -529,7 +568,7 @@ class AllReduceCrossTowerOps(CrossTowerOps):
def _reduce(self, aggregation, per_device_value, destinations):
contains_indexed_slices = cross_tower_utils.contains_indexed_slices(
per_device_value)
- if ((destinations is None or _devices_match(per_device_value, destinations))
+ if (_devices_match(per_device_value, destinations)
and not context.executing_eagerly()
and not contains_indexed_slices):
return self._batch_all_reduce(aggregation, [per_device_value])[0]
@@ -558,8 +597,10 @@ class AllReduceCrossTowerOps(CrossTowerOps):
[v[0] for v in value_destination_pairs])
else:
if not all_devices_match:
- logging.warning("Efficient batch_reduce is not supported if "
- "destinations are different.")
+ logging.log_first_n(logging.WARN,
+ "Efficient batch_reduce is not supported if "
+ "destinations are different.",
+ 10)
return [
self._reduce(aggregation, t, destinations=v)
@@ -738,7 +779,7 @@ class CollectiveAllReduce(CrossTowerOps):
def __init__(self,
num_workers=1,
num_gpus_per_worker=0,
- all_reduce_merge_scope=1,
+ all_reduce_merge_scope=32,
collective_keys=None):
"""Initializes the object.
@@ -759,8 +800,15 @@ class CollectiveAllReduce(CrossTowerOps):
# TODO(yuefengz, tucker): is indexed slices supported by collective ops?
def _reduce(self, aggregation, per_device_value, destinations):
+ if cross_tower_utils.contains_indexed_slices(per_device_value):
+ raise ValueError(
+ "`IndexSlices` is not supported for Collective All-Reduce.")
+ if context.executing_eagerly():
+ raise ValueError(
+ "Eager execution is not supported for Collective All-Reduce")
+
all_reduced = self._batch_all_reduce(aggregation, [per_device_value])[0]
- if destinations is None or _devices_match(per_device_value, destinations):
+ if _devices_match(per_device_value, destinations):
return all_reduced
else:
index = {}
@@ -776,15 +824,33 @@ class CollectiveAllReduce(CrossTowerOps):
return value_lib.Mirrored(index)
def _batch_reduce(self, aggregation, value_destination_pairs):
- return [
- self._reduce(aggregation, t, destinations=v)
- for t, v in value_destination_pairs
- ]
+ if cross_tower_utils.contains_indexed_slices(value_destination_pairs):
+ raise ValueError(
+ "`IndexSlices` is not supported for Collective All-Reduce.")
+ if context.executing_eagerly():
+ raise ValueError(
+ "Eager execution is not supported for Collective All-Reduce")
+
+ all_devices_match = _all_devices_match(value_destination_pairs)
+ if all_devices_match:
+ return self._batch_all_reduce(aggregation,
+ [v[0] for v in value_destination_pairs])
+ else:
+ if not all_devices_match:
+ logging.log_first_n(
+ logging.WARN, "Efficient batch_reduce is not supported if "
+ "destinations are different.", 10)
+
+ return [
+ self._reduce(aggregation, t, destinations=v)
+ for t, v in value_destination_pairs
+ ]
def _batch_all_reduce(self, aggregation, per_device_values):
"""All-reduce across all workers in a batch."""
if context.executing_eagerly():
- raise ValueError("Eager mode with collective ops is not supported yet.")
+ raise ValueError(
+ "Eager execution with collective ops is not supported yet.")
logging.log_first_n(
logging.INFO, "Collective All-reduce invoked with batches size = %d, "
diff --git a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
index 3508c9d599..490371477a 100644
--- a/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
+++ b/tensorflow/contrib/distribute/python/cross_tower_ops_test.py
@@ -26,12 +26,12 @@ import numpy as np
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
from tensorflow.contrib.distribute.python import cross_tower_utils
+from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import values as value_lib
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import context
from tensorflow.python.eager import test
-from tensorflow.python.estimator import run_config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -40,9 +40,17 @@ from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.training import device_util
-def _make_per_device(values, devices):
+def _make_per_device(values, devices, regroup=False):
devices = cross_tower_ops_lib.get_devices_from(devices)
assert len(values) == len(devices)
+
+ # We simulate the result of regroup called on PerDevice which strips the
+ # PerDevice wrapper if it has only one value.
+ if len(values) == 1 and regroup:
+ with ops.device(devices[0]):
+ placed_v = array_ops.identity(values[0])
+ return placed_v
+
index = {}
for d, v in zip(devices, values):
with ops.device(d):
@@ -127,7 +135,7 @@ class CrossTowerOpsTestBase(test.TestCase, parameterized.TestCase):
destination_list = devices
all_destinations = [
- None, destination_mirrored, destination_different, destination_str,
+ destination_mirrored, destination_different, destination_str,
destination_list
]
@@ -138,24 +146,24 @@ class CrossTowerOpsTestBase(test.TestCase, parameterized.TestCase):
vs.VariableAggregation.MEAN,
per_device,
destinations=destinations),
- _fake_mirrored(mean, destinations or per_device))
+ _fake_mirrored(mean, destinations))
self._assert_values_equal(
cross_tower_ops.reduce(
vs.VariableAggregation.MEAN,
per_device_2,
destinations=destinations),
- _fake_mirrored(mean_2, destinations or per_device))
+ _fake_mirrored(mean_2, destinations))
self._assert_values_equal(
cross_tower_ops.reduce(
vs.VariableAggregation.SUM, per_device,
destinations=destinations),
- _fake_mirrored(mean * len(devices), destinations or per_device))
+ _fake_mirrored(mean * len(devices), destinations))
self._assert_values_equal(
cross_tower_ops.reduce(
vs.VariableAggregation.SUM,
per_device_2,
destinations=destinations),
- _fake_mirrored(mean_2 * len(devices), destinations or per_device))
+ _fake_mirrored(mean_2 * len(devices), destinations))
# test batch_reduce()
for d1, d2 in itertools.product(all_destinations, all_destinations):
@@ -163,25 +171,22 @@ class CrossTowerOpsTestBase(test.TestCase, parameterized.TestCase):
cross_tower_ops.batch_reduce(vs.VariableAggregation.MEAN,
[(per_device, d1), (per_device_2, d2)]),
[
- _fake_mirrored(mean, d1 or per_device),
- _fake_mirrored(mean_2, d2 or per_device_2)
+ _fake_mirrored(mean, d1),
+ _fake_mirrored(mean_2, d2)
])
self._assert_values_equal(
cross_tower_ops.batch_reduce(vs.VariableAggregation.SUM,
[(per_device, d1), (per_device_2, d2)]),
[
- _fake_mirrored(mean * len(devices), d1 or per_device),
- _fake_mirrored(mean_2 * len(devices), d2 or per_device_2)
+ _fake_mirrored(mean * len(devices), d1),
+ _fake_mirrored(mean_2 * len(devices), d2)
])
# test broadcast()
for destinations in all_destinations:
- if destinations is None:
- continue
- else:
- self._assert_values_equal(
- cross_tower_ops.broadcast(constant_op.constant(1.), destinations),
- _fake_mirrored(1., destinations))
+ self._assert_values_equal(
+ cross_tower_ops.broadcast(constant_op.constant(1.), destinations),
+ _fake_mirrored(1., destinations))
class SingleWorkerCrossTowerOpsTest(CrossTowerOpsTestBase):
@@ -368,14 +373,27 @@ class MultiWorkerCrossTowerOpsTest(multi_worker_test_base.MultiWorkerTestBase,
("xring", 2, -1)], 0, 0, 0)),
],
distribution=[
- combinations.multi_worker_strategy_with_cpu,
- combinations.multi_worker_strategy_with_one_gpu,
- combinations.multi_worker_strategy_with_two_gpus
+ combinations.NamedDistribution(
+ "MirroredCPU",
+ lambda: mirrored_strategy.MirroredStrategy(num_gpus=0),
+ required_gpus=0),
+ combinations.NamedDistribution(
+ "Mirrored1GPU",
+ lambda: mirrored_strategy.MirroredStrategy(num_gpus=1),
+ required_gpus=1),
+ combinations.NamedDistribution(
+ "Mirrored2GPUs",
+ lambda: mirrored_strategy.MirroredStrategy(num_gpus=2),
+ required_gpus=2),
],
mode=["graph"])
@combinations.generate(multi_worker_allreduce_combinations)
def testReductionAndBroadcast(self, cross_tower_ops, distribution):
+ distribution.configure(cluster_spec={
+ "worker":
+ ["/job:worker/replica:0/task:0", "/job:worker/replica:0/task:1"]
+ })
with distribution.scope():
self._testReductionAndBroadcast(cross_tower_ops, distribution)
@@ -388,13 +406,8 @@ class MultiWorkerCollectiveAllReduceTest(
@classmethod
def setUpClass(cls):
"""Create a local cluster with 2 workers."""
- cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster(
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
num_workers=3, num_ps=0)
- cls._cluster_spec = {
- run_config.TaskType.WORKER: [
- "fake_worker_0", "fake_worker_1", "fake_worker_2"
- ]
- }
def setUp(self):
super(MultiWorkerCollectiveAllReduceTest, self).setUp()
@@ -428,7 +441,8 @@ class MultiWorkerCollectiveAllReduceTest(
]
else:
devices = ["/job:%s/task:%d" % (task_type, task_id)]
- return collective_all_reduce_ops, devices, self._workers[task_id].target
+ return (collective_all_reduce_ops, devices,
+ "grpc://" + self._cluster_spec[task_type][task_id])
def _assert_values_equal(self, left, right, sess):
if isinstance(left, list):
@@ -455,7 +469,8 @@ class MultiWorkerCollectiveAllReduceTest(
num_workers = 1
worker_device = None
else:
- num_workers = len(self._workers)
+ num_workers = len(self._cluster_spec.get("chief", [])) + len(
+ self._cluster_spec.get("worker", []))
worker_device = "/job:%s/task:%d" % (task_type, task_id)
with ops.Graph().as_default(), \
ops.device(worker_device), \
@@ -463,7 +478,7 @@ class MultiWorkerCollectiveAllReduceTest(
# Collective ops doesn't support scalar tensors, so we have to construct
# 1-d tensors.
values = [constant_op.constant([float(d)]) for d in range(len(devices))]
- per_device = _make_per_device(values, devices)
+ per_device = _make_per_device(values, devices, regroup=True)
mean = np.array([(len(devices) - 1.) / 2.])
values_2 = [constant_op.constant([d + 1.0]) for d in range(len(devices))]
@@ -476,7 +491,7 @@ class MultiWorkerCollectiveAllReduceTest(
destination_list = devices
all_destinations = [
- destination_different, None, destination_mirrored, destination_str,
+ destination_different, destination_mirrored, destination_str,
destination_list
]
@@ -487,27 +502,27 @@ class MultiWorkerCollectiveAllReduceTest(
vs.VariableAggregation.MEAN,
per_device,
destinations=destinations),
- _fake_mirrored(mean, destinations or per_device), sess)
+ _fake_mirrored(mean, destinations), sess)
self._assert_values_equal(
collective_all_reduce.reduce(
vs.VariableAggregation.MEAN,
per_device_2,
destinations=destinations),
- _fake_mirrored(mean_2, destinations or per_device), sess)
+ _fake_mirrored(mean_2, destinations), sess)
self._assert_values_equal(
collective_all_reduce.reduce(
vs.VariableAggregation.SUM,
per_device,
destinations=destinations),
- _fake_mirrored(mean * len(devices) * num_workers, destinations or
- per_device), sess)
+ _fake_mirrored(mean * len(devices) * num_workers, destinations),
+ sess)
self._assert_values_equal(
collective_all_reduce.reduce(
vs.VariableAggregation.SUM,
per_device_2,
destinations=destinations),
- _fake_mirrored(mean_2 * len(devices) * num_workers, destinations or
- per_device), sess)
+ _fake_mirrored(mean_2 * len(devices) * num_workers, destinations),
+ sess)
# test batch_reduce()
for d1, d2 in itertools.product(all_destinations, all_destinations):
@@ -516,24 +531,22 @@ class MultiWorkerCollectiveAllReduceTest(
[(per_device, d1),
(per_device_2, d2)]),
[
- _fake_mirrored(mean, d1 or per_device),
- _fake_mirrored(mean_2, d2 or per_device_2)
+ _fake_mirrored(mean, d1),
+ _fake_mirrored(mean_2, d2)
], sess)
self._assert_values_equal(
collective_all_reduce.batch_reduce(vs.VariableAggregation.SUM,
[(per_device, d1),
(per_device_2, d2)]),
[
- _fake_mirrored(mean * len(devices) * num_workers, d1 or
- per_device),
- _fake_mirrored(mean_2 * len(devices) * num_workers, d2 or
- per_device_2)
+ _fake_mirrored(mean * len(devices) * num_workers, d1),
+ _fake_mirrored(mean_2 * len(devices) * num_workers, d2)
], sess)
return True
@combinations.generate(
- combinations.combine(mode=["graph"], num_gpus=[0, 1, 2]))
+ combinations.combine(mode=["graph"], num_gpus=[0, 1, 2], required_gpus=1))
def testReductionDistributed(self, num_gpus):
if context.num_gpus() < num_gpus:
return
diff --git a/tensorflow/contrib/distribute/python/estimator_training_test.py b/tensorflow/contrib/distribute/python/estimator_training_test.py
new file mode 100644
index 0000000000..5348512016
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/estimator_training_test.py
@@ -0,0 +1,659 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests that show Distribute Coordinator works with Estimator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import glob
+import json
+import os
+import sys
+import tempfile
+import threading
+from absl.testing import parameterized
+import numpy as np
+import six
+
+_portpicker_import_error = None
+try:
+ import portpicker # pylint: disable=g-import-not-at-top
+except ImportError as _error: # pylint: disable=invalid-name
+ _portpicker_import_error = _error
+ portpicker = None
+
+# pylint: disable=g-import-not-at-top
+from tensorflow.contrib.distribute.python import combinations
+from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python import parameter_server_strategy
+from tensorflow.contrib.optimizer_v2 import adagrad
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.distribute import distribute_coordinator as dc
+from tensorflow.python.distribute import estimator_training as dc_training
+from tensorflow.python.distribute.distribute_config import DistributeConfig
+from tensorflow.python.eager import context
+from tensorflow.python.estimator import exporter as exporter_lib
+from tensorflow.python.estimator import run_config as run_config_lib
+from tensorflow.python.estimator import training as estimator_training
+from tensorflow.python.estimator.canned import dnn_linear_combined
+from tensorflow.python.estimator.canned import prediction_keys
+from tensorflow.python.estimator.export import export as export_lib
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.summary import summary_iterator
+from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import server_lib
+
+BATCH_SIZE = 10
+LABEL_DIMENSION = 2
+DATA = np.linspace(
+ 0., 2., BATCH_SIZE * LABEL_DIMENSION, dtype=np.float32).reshape(
+ BATCH_SIZE, LABEL_DIMENSION)
+EVAL_NAME = "foo"
+EXPORTER_NAME = "saved_model_exporter"
+MAX_STEPS = 10
+
+CHIEF = dc._TaskType.CHIEF
+EVALUATOR = dc._TaskType.EVALUATOR
+WORKER = dc._TaskType.WORKER
+PS = dc._TaskType.PS
+
+original_run_distribute_coordinator = dc.run_distribute_coordinator
+
+
+# TODO(yuefengz): merge this method back to test_util.
+def _create_local_cluster(num_workers,
+ num_ps,
+ has_eval=False,
+ protocol="grpc",
+ worker_config=None,
+ ps_config=None):
+ if _portpicker_import_error:
+ raise _portpicker_import_error # pylint: disable=raising-bad-type
+ worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
+ ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
+
+ cluster_dict = {
+ "worker": ["localhost:%s" % port for port in worker_ports],
+ "ps": ["localhost:%s" % port for port in ps_ports]
+ }
+ if has_eval:
+ cluster_dict["evaluator"] = ["localhost:%s" % portpicker.pick_unused_port()]
+
+ cs = server_lib.ClusterSpec(cluster_dict)
+
+ workers = [
+ server_lib.Server(
+ cs,
+ job_name="worker",
+ protocol=protocol,
+ task_index=ix,
+ config=worker_config,
+ start=True) for ix in range(num_workers)
+ ]
+ ps_servers = [
+ server_lib.Server(
+ cs,
+ job_name="ps",
+ protocol=protocol,
+ task_index=ix,
+ config=ps_config,
+ start=True) for ix in range(num_ps)
+ ]
+ if has_eval:
+ evals = [
+ server_lib.Server(
+ cs,
+ job_name="evaluator",
+ protocol=protocol,
+ task_index=0,
+ config=worker_config,
+ start=True)
+ ]
+ else:
+ evals = []
+
+ return workers, ps_servers, evals
+
+
+def _create_in_process_cluster(num_workers, num_ps, has_eval=False):
+ """Create an in-process cluster that consists of only standard server."""
+ # Leave some memory for cuda runtime.
+ if has_eval:
+ gpu_mem_frac = 0.7 / (num_workers + 1)
+ else:
+ gpu_mem_frac = 0.7 / num_workers
+
+ worker_config = config_pb2.ConfigProto()
+ worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
+
+ # Enable collective ops which has no impact on non-collective ops.
+ # TODO(yuefengz, tucker): removing this after we move the initialization of
+ # collective mgr to the session level.
+ worker_config.experimental.collective_group_leader = (
+ "/job:worker/replica:0/task:0")
+
+ ps_config = config_pb2.ConfigProto()
+ ps_config.device_count["GPU"] = 0
+
+ return _create_local_cluster(
+ num_workers,
+ num_ps=num_ps,
+ has_eval=has_eval,
+ worker_config=worker_config,
+ ps_config=ps_config,
+ protocol="grpc")
+
+
+def _create_cluster_spec(has_chief=False,
+ num_workers=1,
+ num_ps=0,
+ has_eval=False):
+ if _portpicker_import_error:
+ raise _portpicker_import_error # pylint: disable=raising-bad-type
+
+ cluster_spec = {}
+ if has_chief:
+ cluster_spec[CHIEF] = ["localhost:%s" % portpicker.pick_unused_port()]
+ if num_workers:
+ cluster_spec[WORKER] = [
+ "localhost:%s" % portpicker.pick_unused_port()
+ for _ in range(num_workers)
+ ]
+ if num_ps:
+ cluster_spec[PS] = [
+ "localhost:%s" % portpicker.pick_unused_port() for _ in range(num_ps)
+ ]
+ if has_eval:
+ cluster_spec[EVALUATOR] = ["localhost:%s" % portpicker.pick_unused_port()]
+ return cluster_spec
+
+
+def _bytes_to_str(maybe_bytes):
+ if isinstance(maybe_bytes, six.string_types):
+ return maybe_bytes
+ else:
+ return str(maybe_bytes, "utf-8")
+
+
+def _strip_protocol(target):
+ # cluster_spec expects "host:port" strings.
+ if "//" in target:
+ return target.split("//")[1]
+ else:
+ return target
+
+
+class DistributeCoordinatorIntegrationTest(test.TestCase,
+ parameterized.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ """Create a local cluster with 2 workers."""
+ cls._workers, cls._ps, cls._evals = _create_in_process_cluster(
+ num_workers=3, num_ps=2, has_eval=True)
+ cls._cluster_spec = {
+ "worker": [
+ _strip_protocol(_bytes_to_str(w.target)) for w in cls._workers
+ ],
+ "ps": [_strip_protocol(_bytes_to_str(ps.target)) for ps in cls._ps],
+ "evaluator": [
+ _strip_protocol(_bytes_to_str(e.target)) for e in cls._evals
+ ]
+ }
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+ self._event = threading.Event()
+ super(DistributeCoordinatorIntegrationTest, self).setUp()
+
+ def dataset_input_fn(self, x, y, batch_size, shuffle):
+
+ def input_fn():
+ dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
+ if shuffle:
+ dataset = dataset.shuffle(batch_size)
+ dataset = dataset.repeat(100).batch(batch_size)
+ return dataset
+
+ return input_fn
+
+ def _get_exporter(self, name, fc):
+ feature_spec = feature_column.make_parse_example_spec(fc)
+ serving_input_receiver_fn = (
+ export_lib.build_parsing_serving_input_receiver_fn(feature_spec))
+ return exporter_lib.LatestExporter(
+ name, serving_input_receiver_fn=serving_input_receiver_fn)
+
+ def _extract_loss_and_global_step(self, event_folder):
+ """Returns the loss and global step in last event."""
+ event_paths = glob.glob(os.path.join(event_folder, "events*"))
+
+ loss = None
+ global_step_count = None
+
+ for e in summary_iterator.summary_iterator(event_paths[-1]):
+ current_loss = None
+ for v in e.summary.value:
+ if v.tag == "loss":
+ current_loss = v.simple_value
+
+ # If loss is not found, global step is meaningless.
+ if current_loss is None:
+ continue
+
+ current_global_step = e.step
+ if global_step_count is None or current_global_step > global_step_count:
+ global_step_count = current_global_step
+ loss = current_loss
+
+ return (loss, global_step_count)
+
+ def _get_estimator(self,
+ train_distribute,
+ eval_distribute,
+ remote_cluster=None):
+ input_dimension = LABEL_DIMENSION
+ linear_feature_columns = [
+ feature_column.numeric_column("x", shape=(input_dimension,))
+ ]
+ dnn_feature_columns = [
+ feature_column.numeric_column("x", shape=(input_dimension,))
+ ]
+
+ return dnn_linear_combined.DNNLinearCombinedRegressor(
+ linear_feature_columns=linear_feature_columns,
+ dnn_hidden_units=(2, 2),
+ dnn_feature_columns=dnn_feature_columns,
+ label_dimension=LABEL_DIMENSION,
+ model_dir=self._model_dir,
+ dnn_optimizer=adagrad.AdagradOptimizer(0.001),
+ linear_optimizer=adagrad.AdagradOptimizer(0.001),
+ config=run_config_lib.RunConfig(
+ experimental_distribute=DistributeConfig(
+ train_distribute=train_distribute,
+ eval_distribute=eval_distribute,
+ remote_cluster=remote_cluster)))
+
+ def _complete_flow(self,
+ train_distribute,
+ eval_distribute,
+ remote_cluster=None):
+ estimator = self._get_estimator(train_distribute, eval_distribute,
+ remote_cluster)
+
+ input_dimension = LABEL_DIMENSION
+ train_input_fn = self.dataset_input_fn(
+ x={"x": DATA},
+ y=DATA,
+ batch_size=BATCH_SIZE // len(train_distribute.worker_devices),
+ shuffle=True)
+ if eval_distribute:
+ eval_batch_size = BATCH_SIZE // len(eval_distribute.worker_devices)
+ else:
+ eval_batch_size = BATCH_SIZE
+ eval_input_fn = self.dataset_input_fn(
+ x={"x": DATA}, y=DATA, batch_size=eval_batch_size, shuffle=False)
+
+ linear_feature_columns = [
+ feature_column.numeric_column("x", shape=(input_dimension,))
+ ]
+ dnn_feature_columns = [
+ feature_column.numeric_column("x", shape=(input_dimension,))
+ ]
+ feature_columns = linear_feature_columns + dnn_feature_columns
+
+ estimator_training.train_and_evaluate(
+ estimator,
+ estimator_training.TrainSpec(train_input_fn, max_steps=MAX_STEPS),
+ estimator_training.EvalSpec(
+ name=EVAL_NAME,
+ input_fn=eval_input_fn,
+ steps=None,
+ exporters=self._get_exporter(EXPORTER_NAME, feature_columns),
+ start_delay_secs=0,
+ throttle_secs=1))
+ return estimator
+
+ def _inspect_train_and_eval_events(self, estimator):
+ # Make sure nothing is stuck in limbo.
+ writer_cache.FileWriterCache.clear()
+
+ # Examine the training events. Use a range to check global step to avoid
+ # flakyness due to global step race condition.
+ training_loss, _ = self._extract_loss_and_global_step(self._model_dir)
+ self.assertIsNotNone(training_loss)
+
+ # Examine the eval events. The global step should be accurate.
+ eval_dir = os.path.join(self._model_dir, "eval_" + EVAL_NAME)
+ eval_loss, eval_global_step = self._extract_loss_and_global_step(
+ event_folder=eval_dir)
+ self.assertIsNotNone(eval_loss)
+ self.assertGreaterEqual(eval_global_step, MAX_STEPS)
+
+ # Examine the export folder.
+ export_dir = os.path.join(
+ os.path.join(self._model_dir, "export"), EXPORTER_NAME)
+ self.assertTrue(gfile.Exists(export_dir))
+
+ # Examine the ckpt for predict.
+ def predict_input_fn():
+ return dataset_ops.Dataset.from_tensor_slices({
+ "x": DATA
+ }).batch(BATCH_SIZE)
+
+ predicted_proba = np.array([
+ x[prediction_keys.PredictionKeys.PREDICTIONS]
+ for x in estimator.predict(predict_input_fn)
+ ])
+ self.assertAllEqual((BATCH_SIZE, LABEL_DIMENSION), predicted_proba.shape)
+
+ @combinations.generate(
+ combinations.combine(
+ mode=["graph"],
+ train_distribute_cls=[
+ mirrored_strategy.MirroredStrategy,
+ parameter_server_strategy.ParameterServerStrategy
+ ],
+ eval_distribute_cls=[
+ None, mirrored_strategy.MirroredStrategy,
+ parameter_server_strategy.ParameterServerStrategy
+ ],
+ required_gpus=1))
+ def test_complete_flow_standalone_client(self, train_distribute_cls,
+ eval_distribute_cls):
+ try:
+ train_distribute = train_distribute_cls(num_gpus=context.num_gpus())
+ except TypeError:
+ train_distribute = train_distribute_cls(num_gpus_per_worker=2)
+
+ if eval_distribute_cls:
+ eval_distribute = eval_distribute_cls()
+ else:
+ eval_distribute = None
+
+ estimator = self._complete_flow(
+ train_distribute, eval_distribute, remote_cluster=self._cluster_spec)
+ self._inspect_train_and_eval_events(estimator)
+
+ def _mock_run_distribute_coordinator(
+ self,
+ worker_fn,
+ strategy,
+ eval_fn,
+ eval_strategy,
+ mode=dc.CoordinatorMode.STANDALONE_CLIENT,
+ cluster_spec=None,
+ session_config=None):
+ # Calls the origial `run_distribute_coordinator` method but gets task config
+ # from environment variables and then signals the caller.
+ task_type = None
+ task_id = None
+ if not cluster_spec:
+ cluster_spec = None
+ tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
+ if not cluster_spec:
+ cluster_spec = tf_config.get("cluster", {})
+ task_env = tf_config.get("task", {})
+ if task_env:
+ task_type = task_env.get("type", task_type)
+ task_id = int(task_env.get("index", task_id))
+ self._event.set()
+ original_run_distribute_coordinator(
+ worker_fn,
+ strategy,
+ eval_fn,
+ eval_strategy,
+ mode=mode,
+ cluster_spec=cluster_spec,
+ task_type=task_type,
+ task_id=task_id,
+ session_config=session_config)
+
+ def _task_thread(self, train_distribute, eval_distribute):
+ with test.mock.patch.object(dc, "run_distribute_coordinator",
+ self._mock_run_distribute_coordinator):
+ self._complete_flow(train_distribute, eval_distribute)
+
+ def _run_task_in_thread(self, cluster_spec, task_type, task_id,
+ train_distribute, eval_distribute):
+ if task_type:
+ tf_config = {
+ "cluster": cluster_spec,
+ "task": {
+ "type": task_type,
+ "index": task_id
+ }
+ }
+ else:
+ tf_config = {
+ "cluster": cluster_spec,
+ "task": {
+ "type": task_type,
+ "index": task_id
+ }
+ }
+ self._event.clear()
+ t = threading.Thread(
+ target=self._task_thread, args=(train_distribute, eval_distribute))
+ with test.mock.patch.dict("os.environ",
+ {"TF_CONFIG": json.dumps(tf_config)}):
+ t.start()
+ self._event.wait()
+ return t
+
+ def _run_multiple_tasks_in_threads(self, cluster_spec, train_distribute,
+ eval_distribute):
+ threads = {}
+ for task_type in cluster_spec.keys():
+ threads[task_type] = []
+ for task_id in range(len(cluster_spec[task_type])):
+ t = self._run_task_in_thread(cluster_spec, task_type, task_id,
+ train_distribute, eval_distribute)
+ threads[task_type].append(t)
+ return threads
+
+ @combinations.generate(
+ combinations.combine(
+ mode=["graph"],
+ train_distribute_cls=[
+ parameter_server_strategy.ParameterServerStrategy,
+ ],
+ eval_distribute_cls=[
+ None, mirrored_strategy.MirroredStrategy,
+ parameter_server_strategy.ParameterServerStrategy
+ ],
+ required_gpus=1))
+ def test_complete_flow_indepedent_worker_between_graph(
+ self, train_distribute_cls, eval_distribute_cls):
+ train_distribute = train_distribute_cls(
+ num_gpus_per_worker=context.num_gpus())
+
+ if eval_distribute_cls:
+ eval_distribute = eval_distribute_cls()
+ else:
+ eval_distribute = None
+
+ cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True)
+ threads = self._run_multiple_tasks_in_threads(
+ cluster_spec, train_distribute, eval_distribute)
+ for task_type, ts in threads.items():
+ if task_type == PS:
+ continue
+ for t in ts:
+ t.join()
+
+ estimator = self._get_estimator(train_distribute, eval_distribute)
+ self._inspect_train_and_eval_events(estimator)
+
+ @combinations.generate(
+ combinations.combine(
+ mode=["graph"],
+ train_distribute_cls=[mirrored_strategy.MirroredStrategy],
+ eval_distribute_cls=[None, mirrored_strategy.MirroredStrategy],
+ required_gpus=1))
+ def test_complete_flow_indepedent_worker_in_graph(self, train_distribute_cls,
+ eval_distribute_cls):
+ train_distribute = train_distribute_cls(num_gpus=context.num_gpus())
+
+ if eval_distribute_cls:
+ eval_distribute = eval_distribute_cls()
+ else:
+ eval_distribute = None
+
+ cluster_spec = _create_cluster_spec(num_workers=3, num_ps=2, has_eval=True)
+ threads = self._run_multiple_tasks_in_threads(
+ cluster_spec, train_distribute, eval_distribute)
+ threads[WORKER][0].join()
+ threads[EVALUATOR][0].join()
+
+ estimator = self._get_estimator(train_distribute, eval_distribute)
+ self._inspect_train_and_eval_events(estimator)
+
+
+TF_CONFIG_WITH_CHIEF = {
+ "cluster": {
+ "chief": ["fake_chief"],
+ },
+ "task": {
+ "type": "chief",
+ "index": 0
+ }
+}
+
+TF_CONFIG_WITH_MASTER = {
+ "cluster": {
+ "master": ["fake_master"],
+ },
+ "task": {
+ "type": "master",
+ "index": 0
+ }
+}
+
+TF_CONFIG_WITHOUT_TASK = {"cluster": {"chief": ["fake_worker"]}}
+
+
+class RunConfigTest(test.TestCase):
+
+ def test_previously_unexpected_cluster_spec(self):
+ with test.mock.patch.dict(
+ "os.environ", {"TF_CONFIG": json.dumps(TF_CONFIG_WITHOUT_TASK)}):
+ run_config_lib.RunConfig(
+ experimental_distribute=DistributeConfig(
+ train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2)))
+
+ def test_should_run_distribute_coordinator(self):
+ """Tests that should_run_distribute_coordinator return a correct value."""
+ # We don't use distribute coordinator for local training.
+ self.assertFalse(
+ dc_training.should_run_distribute_coordinator(
+ run_config_lib.RunConfig()))
+
+ # When `train_distribute` is not specified, don't use distribute
+ # coordinator.
+ with test.mock.patch.dict("os.environ",
+ {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
+ self.assertFalse(
+ dc_training.should_run_distribute_coordinator(
+ run_config_lib.RunConfig()))
+
+ # When `train_distribute` is specified and TF_CONFIG is detected, use
+ # distribute coordinator.
+ with test.mock.patch.dict("os.environ",
+ {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
+ config_with_train_distribute = run_config_lib.RunConfig(
+ experimental_distribute=DistributeConfig(
+ train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2)))
+ config_with_eval_distribute = run_config_lib.RunConfig(
+ experimental_distribute=DistributeConfig(
+ eval_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2)))
+ self.assertTrue(
+ dc_training.should_run_distribute_coordinator(
+ config_with_train_distribute))
+ self.assertFalse(
+ dc_training.should_run_distribute_coordinator(
+ config_with_eval_distribute))
+
+ # With a master in the cluster, don't run distribute coordinator.
+ with test.mock.patch.dict("os.environ",
+ {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}):
+ config = run_config_lib.RunConfig(
+ experimental_distribute=DistributeConfig(
+ train_distribute=mirrored_strategy.MirroredStrategy(num_gpus=2)))
+ self.assertFalse(dc_training.should_run_distribute_coordinator(config))
+
+ def test_init_run_config_duplicate_distribute(self):
+ with self.assertRaises(ValueError):
+ run_config_lib.RunConfig(
+ train_distribute=mirrored_strategy.MirroredStrategy(),
+ experimental_distribute=DistributeConfig(
+ train_distribute=mirrored_strategy.MirroredStrategy()))
+
+ with self.assertRaises(ValueError):
+ run_config_lib.RunConfig(
+ eval_distribute=mirrored_strategy.MirroredStrategy(),
+ experimental_distribute=DistributeConfig(
+ eval_distribute=mirrored_strategy.MirroredStrategy()))
+
+ def test_init_run_config_none_distribute_coordinator_mode(self):
+ # We don't use distribute coordinator for local training.
+ config = run_config_lib.RunConfig(
+ train_distribute=mirrored_strategy.MirroredStrategy())
+ dc_training.init_run_config(config, {})
+ self.assertIsNone(config._distribute_coordinator_mode)
+
+ # With a master in the cluster, don't run distribute coordinator.
+ with test.mock.patch.dict("os.environ",
+ {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_MASTER)}):
+ config = run_config_lib.RunConfig(
+ train_distribute=mirrored_strategy.MirroredStrategy())
+ self.assertIsNone(config._distribute_coordinator_mode)
+
+ # When `train_distribute` is not specified, don't use distribute
+ # coordinator.
+ with test.mock.patch.dict("os.environ",
+ {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
+ config = run_config_lib.RunConfig()
+ self.assertFalse(hasattr(config, "_distribute_coordinator_mode"))
+
+ def test_init_run_config_independent_worker(self):
+ # When `train_distribute` is specified and TF_CONFIG is detected, use
+ # distribute coordinator with INDEPENDENT_WORKER mode.
+ with test.mock.patch.dict("os.environ",
+ {"TF_CONFIG": json.dumps(TF_CONFIG_WITH_CHIEF)}):
+ config = run_config_lib.RunConfig(
+ train_distribute=mirrored_strategy.MirroredStrategy())
+ self.assertEqual(config._distribute_coordinator_mode,
+ dc.CoordinatorMode.INDEPENDENT_WORKER)
+
+ def test_init_run_config_standalone_client(self):
+ # When `train_distribute` is specified, TF_CONFIG is detected and
+ # `experimental.remote_cluster` is set use distribute coordinator with
+ # STANDALONE_CLIENT mode.
+ config = run_config_lib.RunConfig(
+ train_distribute=mirrored_strategy.MirroredStrategy(),
+ experimental_distribute=DistributeConfig(
+ remote_cluster={"chief": ["fake_worker"]}))
+ self.assertEqual(config._distribute_coordinator_mode,
+ dc.CoordinatorMode.STANDALONE_CLIENT)
+
+
+if __name__ == "__main__":
+ with test.mock.patch.object(sys, "exit", os._exit):
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/examples/BUILD b/tensorflow/contrib/distribute/python/examples/BUILD
index cbfd178502..84b106545e 100644
--- a/tensorflow/contrib/distribute/python/examples/BUILD
+++ b/tensorflow/contrib/distribute/python/examples/BUILD
@@ -19,9 +19,20 @@ py_binary(
)
py_binary(
- name = "simple_tfkeras_example",
+ name = "keras_model_with_estimator",
srcs = [
- "simple_tfkeras_example.py",
+ "keras_model_with_estimator.py",
+ ],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_binary(
+ name = "keras_mnist",
+ srcs = [
+ "keras_mnist.py",
],
deps = [
"//tensorflow:tensorflow_py",
diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py
new file mode 100644
index 0000000000..a84ef04196
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py
@@ -0,0 +1,125 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""An example training a Keras Model using MirroredStrategy and native APIs."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+NUM_CLASSES = 10
+
+
+def get_input_datasets():
+ """Downloads the MNIST dataset and creates train and eval dataset objects.
+
+ Returns:
+ Train dataset, eval dataset and input shape.
+
+ """
+ # input image dimensions
+ img_rows, img_cols = 28, 28
+
+ # the data, split between train and test sets
+ (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
+
+ if tf.keras.backend.image_data_format() == 'channels_first':
+ x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
+ x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
+ input_shape = (1, img_rows, img_cols)
+ else:
+ x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
+ x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
+ input_shape = (img_rows, img_cols, 1)
+
+ x_train = x_train.astype('float32')
+ x_test = x_test.astype('float32')
+ x_train /= 255
+ x_test /= 255
+
+ # convert class vectors to binary class matrices
+ y_train = tf.keras.utils.to_categorical(y_train, NUM_CLASSES)
+ y_test = tf.keras.utils.to_categorical(y_test, NUM_CLASSES)
+
+ # train dataset
+ train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
+ train_ds = train_ds.repeat()
+ train_ds = train_ds.shuffle(100)
+ train_ds = train_ds.batch(64, drop_remainder=True)
+
+ # eval dataset
+ eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
+ eval_ds = eval_ds.repeat()
+ eval_ds = eval_ds.batch(64, drop_remainder=True)
+
+ return train_ds, eval_ds, input_shape
+
+
+def get_model(input_shape):
+ """Builds a Sequential CNN model to recognize MNIST digits.
+
+ Args:
+ input_shape: Shape of the input depending on the `image_data_format`.
+
+ Returns:
+ a Keras model
+
+ """
+ # Define a CNN model to recognize MNIST digits.
+ model = tf.keras.models.Sequential()
+ model.add(tf.keras.layers.Conv2D(32, kernel_size=(3, 3),
+ activation='relu',
+ input_shape=input_shape))
+ model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))
+ model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))
+ model.add(tf.keras.layers.Dropout(0.25))
+ model.add(tf.keras.layers.Flatten())
+ model.add(tf.keras.layers.Dense(128, activation='relu'))
+ model.add(tf.keras.layers.Dropout(0.5))
+ model.add(tf.keras.layers.Dense(NUM_CLASSES, activation='softmax'))
+ return model
+
+
+def main(_):
+ # Build the train and eval datasets from the MNIST data. Also return the
+ # input shape which is constructed based on the `image_data_format`
+ # i.e channels_first or channels_last.
+ train_ds, eval_ds, input_shape = get_input_datasets()
+ model = get_model(input_shape)
+
+ # Instantiate the MirroredStrategy object. If we don't specify `num_gpus` or
+ # the `devices` argument then all the GPUs available on the machine are used.
+ strategy = tf.contrib.distribute.MirroredStrategy()
+
+ # Compile the model by passing the distribution strategy object to the
+ # `distribute` argument. `fit`, `evaluate` and `predict` will be distributed
+ # based on the strategy instantiated.
+ model.compile(loss=tf.keras.losses.categorical_crossentropy,
+ optimizer=tf.train.RMSPropOptimizer(learning_rate=0.001),
+ metrics=['accuracy'],
+ distribute=strategy)
+
+ # Train the model with the train dataset.
+ model.fit(x=train_ds, epochs=20, steps_per_epoch=310)
+
+ # Evaluate the model with the eval dataset.
+ score = model.evaluate(eval_ds, steps=10, verbose=0)
+ print('Test loss:', score[0])
+ print('Test accuracy:', score[1])
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py b/tensorflow/contrib/distribute/python/examples/keras_model_with_estimator.py
index 518ec9c423..8d117eb7e8 100644
--- a/tensorflow/contrib/distribute/python/examples/simple_tfkeras_example.py
+++ b/tensorflow/contrib/distribute/python/examples/keras_model_with_estimator.py
@@ -42,19 +42,19 @@ def main(args):
model_dir = args[1]
print('Using %s to store checkpoints.' % model_dir)
- # Define tf.keras Model.
+ # Define a Keras Model.
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(16, activation='relu', input_shape=(10,)))
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
- # Compile tf.keras Model.
+ # Compile the model.
optimizer = tf.train.GradientDescentOptimizer(0.2)
model.compile(loss='binary_crossentropy', optimizer=optimizer)
model.summary()
tf.keras.backend.set_learning_phase(True)
- # Define a DistributionStrategy and convert the tf.keras Model to a
- # tf.Estimator that utilizes the DistributionStrategy.
+ # Define a DistributionStrategy and convert the Keras Model to an
+ # Estimator that utilizes the DistributionStrategy.
strategy = tf.contrib.distribute.MirroredStrategy(
['/device:GPU:0', '/device:GPU:1'])
config = tf.estimator.RunConfig(
@@ -62,7 +62,7 @@ def main(args):
keras_estimator = tf.keras.estimator.model_to_estimator(
keras_model=model, config=config, model_dir=model_dir)
- # Train and evaluate the tf.Estimator.
+ # Train and evaluate the model.
keras_estimator.train(input_fn=input_fn, steps=10)
eval_result = keras_estimator.evaluate(input_fn=input_fn)
print('Eval result: {}'.format(eval_result))
diff --git a/tensorflow/contrib/distribute/python/input_ops.py b/tensorflow/contrib/distribute/python/input_ops.py
index 1f24f62947..f07ec8234d 100644
--- a/tensorflow/contrib/distribute/python/input_ops.py
+++ b/tensorflow/contrib/distribute/python/input_ops.py
@@ -47,11 +47,8 @@ def auto_shard_dataset(dataset, num_shards, index):
Returns:
A modified `Dataset` obtained by updating the pipeline sharded by the
- files.
-
- Raises:
- NotImplementedError: If we cannot automatically determine a good way to
- shard the input dataset.
+ files. The input dataset will be returned if we cannot automatically
+ determine a good way to shard the input dataset.
"""
# TODO(priyag): Clone datasets instead of updating in place, similar to the
@@ -127,8 +124,10 @@ def auto_shard_dataset(dataset, num_shards, index):
tf_logging.warn(
"Could not find a standard reader in the input pipeline"
"(one of TextLineDataset, TFRecordDataset, FixedLengthRecordDataset)."
- "Falling back to sharding the dataset anyway. Please verify"
- "correctness of auto-sharding for your input.")
+ "So auto-sharding is not done. Please verify correctness of "
+ "auto-sharding for your input.")
+ # TODO(yuefengz): maybe still shard it?
+ return dataset
# TODO(priyag): What do we want to do if the number of filenames is
# uneven in the number of shards? By default, this will just return as
diff --git a/tensorflow/contrib/distribute/python/input_ops_test.py b/tensorflow/contrib/distribute/python/input_ops_test.py
index 16179c3a49..c5acb7ced4 100644
--- a/tensorflow/contrib/distribute/python/input_ops_test.py
+++ b/tensorflow/contrib/distribute/python/input_ops_test.py
@@ -91,7 +91,7 @@ class AutoShardDatasetTest(test.TestCase):
def _verifySimpleShardingOutput(self, dataset, record_fn):
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for f in range(self._shard_index, self._num_files, self._num_shards):
for r in range(self._num_records):
self.assertAllEqual(record_fn(r, f), sess.run(next_element))
@@ -150,7 +150,7 @@ class AutoShardDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
actual, expected = [], []
for f in range(self._shard_index, self._num_files, self._num_shards):
for r in range(self._num_records):
@@ -182,7 +182,7 @@ class AutoShardDatasetTest(test.TestCase):
# Verify output.
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
actual = []
num_iterations = (self._num_files * self._num_records * num_epochs) // (
self._num_shards * batch_size)
@@ -218,7 +218,7 @@ class AutoShardDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for f in range(self._shard_index, self._num_files, self._num_shards):
for r in range(self._num_records):
self.assertAllEqual(self._record(r, f), sess.run(next_element))
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index a262d7666e..9e1762d92c 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -18,9 +18,12 @@ from __future__ import division
from __future__ import print_function
import os
+from absl.testing import parameterized
import numpy as np
+from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python import tpu_strategy
from tensorflow.contrib.distribute.python import values
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
@@ -91,6 +94,25 @@ def get_ds_test_input_fn():
return dataset
+def batch_wrapper(dataset, batch_size, distribution):
+ # TPUs currently require fully defined input shapes, drop_remainder ensures
+ # the input will have fully defined shapes.
+ if isinstance(distribution, tpu_strategy.TPUStrategy):
+ return dataset.batch(batch_size, drop_remainder=True)
+ else:
+ return dataset.batch(batch_size)
+
+
+def all_combinations():
+ return combinations.combine(
+ distribution=[combinations.default_strategy,
+ combinations.one_device_strategy,
+ combinations.mirrored_strategy_with_gpu_and_cpu,
+ combinations.mirrored_strategy_with_two_gpus,
+ combinations.tpu_strategy_one_step],
+ mode=['graph'])
+
+
class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
def setUp(self):
@@ -116,7 +138,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
model_dir=self._base_dir,
train_distribute=dist,
eval_distribute=dist)
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model, config=config)
before_eval_results = est_keras.evaluate(
@@ -139,7 +161,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED,
model_dir=self._base_dir,
train_distribute=dist)
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(
keras_model=keras_model, config=config)
before_eval_results = est_keras.evaluate(
@@ -163,7 +185,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
config = run_config_lib.RunConfig(tf_random_seed=_RANDOM_SEED,
model_dir=self._base_dir,
train_distribute=dist)
- with self.test_session():
+ with self.cached_session():
est_keras = keras_lib.model_to_estimator(keras_model=keras_model,
config=config)
with self.assertRaisesRegexp(ValueError,
@@ -175,10 +197,10 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
gfile.DeleteRecursively(self._config.model_dir)
-class TestWithDistributionStrategy(test.TestCase):
+class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
def test_validating_dataset_input_tensors_with_shape_mismatch(self):
- with self.test_session():
+ with self.cached_session():
strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0',
'/device:CPU:0'])
a = constant_op.constant([1, 2], shape=(1, 2))
@@ -197,7 +219,7 @@ class TestWithDistributionStrategy(test.TestCase):
strategy, x, y)
def test_validating_dataset_input_tensors_with_dtype_mismatch(self):
- with self.test_session():
+ with self.cached_session():
strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0',
'/device:CPU:0'])
a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32)
@@ -215,8 +237,8 @@ class TestWithDistributionStrategy(test.TestCase):
distributed_training_utils.validate_distributed_dataset_inputs(
strategy, x, y)
- def test_calling_model_on_same_dataset(self):
- with self.test_session():
+ def test_calling_model_with_numpy_arrays(self):
+ with self.cached_session():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
model = keras.Model(x, y)
@@ -228,11 +250,44 @@ class TestWithDistributionStrategy(test.TestCase):
'/device:GPU:0'])
model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
+ inputs = np.zeros((64, 3), dtype=np.float32)
+ targets = np.zeros((64, 4), dtype=np.float32)
+
+ # Call fit with validation data
+ model.fit(inputs, targets, epochs=1, batch_size=2, verbose=0,
+ validation_data=(inputs, targets))
+
+ # TODO(anjalisridhar): We need tests for when the batch size and steps are
+ # smaller and results in a 0 batch_size and steps value.
+ model.evaluate(inputs, targets)
+ # with steps
+ model.evaluate(inputs, targets, steps=2)
+ # with batch_size
+ model.evaluate(inputs, targets, batch_size=8)
+
+ model.predict(inputs)
+ # with steps
+ model.predict(inputs, steps=2)
+ # with batch_size
+ model.predict(inputs, batch_size=8)
+
+ @combinations.generate(all_combinations())
+ def test_calling_model_on_same_dataset(self, distribution):
+ with self.cached_session():
+ x = keras.layers.Input(shape=(3,), name='input')
+ y = keras.layers.Dense(4, name='dense')(x)
+ model = keras.Model(x, y)
+
+ optimizer = gradient_descent.GradientDescentOptimizer(0.001)
+ loss = 'mse'
+ metrics = ['mae']
+ model.compile(optimizer, loss, metrics=metrics, distribute=distribution)
+
inputs = np.zeros((10, 3), dtype=np.float32)
targets = np.zeros((10, 4), dtype=np.float32)
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
+ dataset = batch_wrapper(dataset, 10, distribution)
# Call fit with validation data
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
@@ -241,8 +296,11 @@ class TestWithDistributionStrategy(test.TestCase):
validation_data=dataset, validation_steps=2)
model.predict(dataset, steps=2)
+ # TODO(priyag): Enable this test for TPU. Currently tuples/dict don't work
+ # as clone_model's input_tensors argument only seems to accept list and not
+ # tuples or dict.
def test_fit_with_tuple_and_dict_dataset_inputs(self):
- with self.test_session():
+ with self.cached_session():
a = keras.layers.Input(shape=(3,), name='input_a')
b = keras.layers.Input(shape=(3,), name='input_b')
@@ -282,8 +340,9 @@ class TestWithDistributionStrategy(test.TestCase):
model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1)
- def test_fit_eval_and_predict_methods_on_dataset(self):
- with self.test_session():
+ @combinations.generate(all_combinations())
+ def test_fit_eval_and_predict_methods_on_dataset(self, distribution):
+ with self.cached_session():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
model = keras.Model(x, y)
@@ -291,16 +350,13 @@ class TestWithDistributionStrategy(test.TestCase):
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
metrics = ['mae']
- strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0',
- '/device:CPU:0'])
-
- model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
+ model.compile(optimizer, loss, metrics=metrics, distribute=distribution)
inputs = np.zeros((10, 3), dtype=np.float32)
targets = np.zeros((10, 4), dtype=np.float32)
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
+ dataset = batch_wrapper(dataset, 10, distribution)
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
model.evaluate(dataset, steps=2, verbose=1)
@@ -320,7 +376,7 @@ class TestWithDistributionStrategy(test.TestCase):
def __call__(self, y_true, y_pred):
return y_pred - y_true
- with self.test_session():
+ with self.cached_session():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
model = keras.Model(x, y)
@@ -336,7 +392,7 @@ class TestWithDistributionStrategy(test.TestCase):
model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
def test_unsupported_features(self):
- with self.test_session():
+ with self.cached_session():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
model = keras.Model(x, y)
@@ -389,7 +445,7 @@ class TestWithDistributionStrategy(test.TestCase):
model.predict(dataset, verbose=0)
def test_calling_with_unsupported_predefined_callbacks(self):
- with self.test_session():
+ with self.cached_session():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
model = keras.Model(x, y)
@@ -428,7 +484,7 @@ class TestWithDistributionStrategy(test.TestCase):
callbacks=[keras.callbacks.TensorBoard(histogram_freq=10)])
def test_dataset_input_shape_validation(self):
- with self.test_session():
+ with self.cached_session():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
model = keras.Model(x, y)
@@ -446,8 +502,7 @@ class TestWithDistributionStrategy(test.TestCase):
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
- with self.assertRaisesRegexp(ValueError,
- 'expected input to have 2 dimensions'):
+ with self.assertRaisesRegexp(ValueError, 'expected input to have shape'):
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0)
# Wrong input shape
@@ -465,7 +520,7 @@ class TestWithDistributionStrategy(test.TestCase):
# TODO(anjalisridhar): Modify this test to use Lambdas since we can compare
# meaningful values. Currently we don't pass the learning phase if the
# Lambda layer uses the learning phase.
- with self.test_session():
+ with self.cached_session():
x = keras.layers.Input(shape=(16,), name='input')
y = keras.layers.Dense(16)(x)
z = keras.layers.Dropout(0.9999)(y)
@@ -497,8 +552,10 @@ class TestWithDistributionStrategy(test.TestCase):
class LossMaskingWithDistributionStrategyTest(test.TestCase):
+ # TODO(priyag): Enable all strategies for this test. Currently it does not
+ # work for TPU due to some invalid datatype.
def test_masking(self):
- with self.test_session():
+ with self.cached_session():
np.random.seed(1337)
x = np.array([[[1], [1]], [[0], [0]]])
model = keras.models.Sequential()
@@ -520,24 +577,25 @@ class LossMaskingWithDistributionStrategyTest(test.TestCase):
self.assertEqual(hist.history['loss'][0], 0)
-class NormalizationLayerWithDistributionStrategyTest(test.TestCase):
+class NormalizationLayerWithDistributionStrategyTest(
+ test.TestCase, parameterized.TestCase):
- def test_batchnorm_correctness(self):
- with self.test_session():
+ @combinations.generate(all_combinations())
+ def test_batchnorm_correctness(self, distribution):
+ with self.cached_session():
model = keras.models.Sequential()
norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8)
model.add(norm)
- strategy = mirrored_strategy.MirroredStrategy(['/device:CPU:0',
- '/device:GPU:0'])
model.compile(loss='mse',
optimizer=gradient_descent.GradientDescentOptimizer(0.01),
- distribute=strategy)
+ distribute=distribution)
# centered on 5.0, variance 10.0
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10))
+ x = x.astype('float32')
dataset = dataset_ops.Dataset.from_tensor_slices((x, x))
dataset = dataset.repeat(100)
- dataset = dataset.batch(32)
+ dataset = batch_wrapper(dataset, 32, distribution)
model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10)
out = model.predict(dataset, steps=2)
@@ -547,10 +605,12 @@ class NormalizationLayerWithDistributionStrategyTest(test.TestCase):
np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
-class CorrectnessWithDistributionStrategyTest(test.TestCase):
+class CorrectnessWithDistributionStrategyTest(test.TestCase,
+ parameterized.TestCase):
- def test_correctness(self):
- with self.test_session():
+ @combinations.generate(all_combinations())
+ def test_correctness(self, distribution):
+ with self.cached_session():
keras.backend.set_image_data_format('channels_last')
num_samples = 10000
x_train = np.random.rand(num_samples, 1)
@@ -558,44 +618,43 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase):
x_train = x_train.astype('float32')
y_train = y_train.astype('float32')
- model = keras.Sequential()
- model.add(keras.layers.Dense(1, input_shape=(1,)))
-
- # With DistributionStrategy
- dataset_with = dataset_ops.Dataset.from_tensor_slices((x_train, y_train))
- dataset_with = dataset_with.batch(32)
- strategy = mirrored_strategy.MirroredStrategy(devices=['/device:CPU:0',
- '/device:GPU:0'],
- prefetch_on_device=False)
-
- model.compile(loss=keras.losses.mean_squared_error,
- optimizer=gradient_descent.GradientDescentOptimizer(0.5),
- distribute=strategy)
- model.fit(x=dataset_with, epochs=1, steps_per_epoch=310)
- wts_with_ds = model.get_weights()
-
- x_predict = [[1], [2], [3], [4]]
- predict_dataset_with = dataset_ops.Dataset.from_tensor_slices((x_predict,
- x_predict))
- predict_dataset_with = predict_dataset_with.batch(2)
- predict_with_ds = model.predict(predict_dataset_with, steps=1)
- predict_with_ds = np.reshape(predict_with_ds, (4, 1))
-
- # Without DistributionStrategy
- dataset_without = dataset_ops.Dataset.from_tensor_slices((x_train,
+ def fit_and_predict(with_distribution=None):
+ model = keras.Sequential()
+ model.add(keras.layers.Dense(1, input_shape=(1,)))
+ model.compile(
+ loss=keras.losses.mean_squared_error,
+ optimizer=gradient_descent.GradientDescentOptimizer(0.5),
+ distribute=with_distribution)
+
+ batch_size = 64
+ if with_distribution:
+ batch_size //= with_distribution.num_towers
+ train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train,
y_train))
- dataset_without = dataset_without.batch(64)
-
- model.compile(loss=keras.losses.mean_squared_error,
- optimizer=gradient_descent.GradientDescentOptimizer(0.5))
- model.fit(x=dataset_without, epochs=1, steps_per_epoch=310)
- wts_without_ds = model.get_weights()
-
- x_predict = [[1], [2], [3], [4]]
- predict_dataset_without = dataset_ops.Dataset.from_tensor_slices((
- x_predict, x_predict))
- predict_dataset_without = predict_dataset_without.batch(4)
- predict_without_ds = model.predict(predict_dataset_without, steps=1)
+ train_dataset = batch_wrapper(train_dataset, batch_size, distribution)
+ # Running only 100 steps instead of the full dataset to keep test
+ # duration small.
+ model.fit(x=train_dataset, epochs=1, steps_per_epoch=100)
+
+ weights = model.get_weights()
+
+ x_predict = [[1.], [2.], [3.], [4.]]
+ predict_batch_size = 4
+ if with_distribution:
+ predict_batch_size //= with_distribution.num_towers
+ predict_dataset = dataset_ops.Dataset.from_tensor_slices((x_predict,
+ x_predict))
+ predict_dataset = batch_wrapper(predict_dataset,
+ predict_batch_size, distribution)
+ predict_result = model.predict(predict_dataset, steps=1)
+ predict_result = np.reshape(predict_result, (4, 1))
+
+ return weights, predict_result
+
+ wts_with_ds, predict_with_ds = fit_and_predict(
+ with_distribution=distribution)
+ wts_without_ds, predict_without_ds = fit_and_predict(
+ with_distribution=None)
# Verify that the weights are the same within some limits of tolerance.
np.testing.assert_allclose(wts_with_ds[0], wts_without_ds[0], rtol=1e-3)
@@ -604,5 +663,8 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase):
np.testing.assert_allclose(predict_with_ds, predict_without_ds, rtol=1e-3)
+# TODO(priyag): Add a test for TPUStrategy with steps_per_run > 1.
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py
index 516ede7ade..bdac4fb58c 100644
--- a/tensorflow/contrib/distribute/python/minimize_loss_test.py
+++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py
@@ -71,7 +71,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
self.evaluate(distribution.initialize())
if not context.executing_eagerly():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
run_step = sess.make_callable(run_step())
self.evaluate(variables_lib.global_variables_initializer())
@@ -108,7 +108,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
model_fn, iterator.get_next(), run_concurrently=layer.built))
if not context.executing_eagerly():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
run_step = sess.make_callable(run_step())
self.evaluate(variables_lib.global_variables_initializer())
@@ -168,7 +168,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
self.evaluate(distribution.initialize())
if not context.executing_eagerly():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
run_step = sess.make_callable(run_step())
self.evaluate(variables_lib.global_variables_initializer())
@@ -249,7 +249,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
self.evaluate(distribution.initialize())
if not context.executing_eagerly():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
run_step = sess.make_callable(run_step())
self.evaluate(variables_lib.global_variables_initializer())
@@ -343,7 +343,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
self.evaluate(distribution.initialize())
if not context.executing_eagerly():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
run_step = sess.make_callable(run_step())
self.evaluate(variables_lib.global_variables_initializer())
@@ -466,7 +466,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
self.evaluate(distribution.initialize())
if not context.executing_eagerly():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
run_step = sess.make_callable(run_step())
self.evaluate(variables_lib.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 6981449a4c..0c6805d682 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -25,8 +25,8 @@ import threading
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
from tensorflow.contrib.distribute.python import shared_variable_creator
from tensorflow.contrib.distribute.python import values
-from tensorflow.core.protobuf import cluster_pb2
from tensorflow.python import pywrap_tensorflow
+from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
from tensorflow.python.framework import constant_op
@@ -39,7 +39,6 @@ from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.training import coordinator
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
-from tensorflow.python.training import server_lib
from tensorflow.python.util import nest
@@ -66,7 +65,7 @@ class _RequestedStop(Exception):
pass
-# Make _call_for_each_tower and _reduce_non_distributed_value not members of
+# _call_for_each_tower and _reduce_non_distributed_value are not members of
# MirroredStrategy so that they are generally not allowed to use anything
# specific to MirroredStrategy and thus can be shared with other distribution
# strategies.
@@ -198,10 +197,12 @@ def _reduce_non_distributed_value(distribution, aggregation, value,
# and equal to 0.
if value == 0:
return 0
- # If the aggregation type is MEAN, then this essentially means that the same
- # value should be on all destinations.
- if aggregation == variable_scope.VariableAggregation.MEAN:
- return distribution.broadcast(value, destinations)
+ # If the aggregation type is MEAN or ONLY_FIRST_TOWER, then this
+ # essentially means that the same value should be on all destinations.
+ if aggregation in (
+ variable_scope.VariableAggregation.MEAN,
+ variable_scope.VariableAggregation.ONLY_FIRST_TOWER):
+ return value
cross_tower_ops_lib.validate_destinations(destinations)
# We do not support an aggregation type of SUM if the value is the same across
@@ -209,8 +210,8 @@ def _reduce_non_distributed_value(distribution, aggregation, value,
# and summing up identical values across towers is not clearly defined.
if (len(distribution.worker_devices) != 1 or
not cross_tower_ops_lib.check_destinations(destinations)):
- raise ValueError("A non-DistributedValues value cannot be reduced with the "
- "given aggregation.")
+ raise ValueError("A non-DistributedValues value %s cannot be reduced with "
+ "the given aggregation %s." % (value, aggregation))
# TODO(anjalisridhar): Moves these methods to a device utility file?
devices = cross_tower_ops_lib.get_devices_from(destinations)
if len(devices) == 1:
@@ -255,11 +256,12 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs):
# Get aggregation value
aggregation = kwargs.pop("aggregation",
variable_scope.VariableAggregation.NONE)
- if aggregation not in [
+ if aggregation not in (
variable_scope.VariableAggregation.NONE,
variable_scope.VariableAggregation.SUM,
- variable_scope.VariableAggregation.MEAN
- ]:
+ variable_scope.VariableAggregation.MEAN,
+ variable_scope.VariableAggregation.ONLY_FIRST_TOWER
+ ):
raise ValueError("Invalid variable aggregation mode: " + aggregation +
" for variable: " + kwargs["name"])
@@ -277,6 +279,9 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs):
else:
result = values.MirroredVariable(index, index[devices[0]], aggregation)
+ # Add the wrapped variable to the requested collections.
+ # The handling of eager mode and the global step matches
+ # ResourceVariable._init_from_args().
if not context.executing_eagerly():
g = ops.get_default_graph()
# If "trainable" is True, next_creator() will add the member variables
@@ -290,6 +295,9 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs):
for v in index.values():
l.remove(v)
g.add_to_collections(collections, result)
+ elif ops.GraphKeys.GLOBAL_STEP in collections:
+ ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, result)
+
return result
@@ -299,8 +307,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
This strategy uses one tower per device and sync replication for its multi-GPU
version.
- When `cluster_spec` is given, it turns into the mulit-worker version that
- works on multiple workers with in-graph replication.
+ When `cluster_spec` is given by the `configure` method., it turns into the
+ mulit-worker version that works on multiple workers with in-graph replication.
+ Note: `configure` will be called by higher-level APIs if running in
+ distributed environment.
There are several important concepts for distributed TensorFlow, e.g.
`client`, `job`, 'task', `cluster`, `in-graph replication` and
@@ -330,8 +340,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
num_gpus: number of GPUs. For local training, either specify `devices` or
`num_gpus`. In distributed training, this must be specified as number of
GPUs on each worker.
- cluster_spec: if this is set, it turns into the multi-worker version and
- `devices` must not be set but `num_gpus` must be set.
+ num_gpus_per_worker: number of GPUs per worker. This is the same as
+ `num_gpus` and only one of `num_gpus` and `num_gpus_per_worker` can be
+ specified.
cross_tower_ops: optional, a descedant of `CrossTowerOps`. If this is not
set, the `configure` method will try to find the best one.
prefetch_on_device: optional boolean to specify whether to prefetch input
@@ -341,65 +352,83 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def __init__(self,
devices=None,
num_gpus=None,
- cluster_spec=None,
+ num_gpus_per_worker=None,
cross_tower_ops=None,
prefetch_on_device=None):
super(MirroredStrategy, self).__init__()
- if cluster_spec:
- if devices is not None:
- raise ValueError("Specifying devices when `cluster_spec` is also given "
- "is not supported in MirroredStrategy.")
-
- # TODO(yuefengz): use the utility method to normalize cluster_spec.
- if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
- cluster_spec = server_lib.ClusterSpec(cluster_spec)
- elif not isinstance(cluster_spec, server_lib.ClusterSpec):
- raise ValueError(
- "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
- "`tf.train.ClusterDef` object")
- self._cluster_spec = cluster_spec
-
- self._workers = []
- for job in sorted(cluster_spec.jobs):
- for task in range(cluster_spec.num_tasks(job)):
- self._workers.append("/job:%s/task:%d" % (job, task))
+ self._cross_tower_ops = cross_tower_ops
+ self._prefetch_on_device = prefetch_on_device
+ # Rememeber num GPUs which might be needed by `configure` method.
+ if num_gpus is not None and num_gpus_per_worker is not None:
+ raise ValueError(
+ "You cannot specify both `num_gpus` and `num_gpus_per_worker`.")
+ if num_gpus is not None:
+ self._num_gpus = num_gpus
+ else:
+ self._num_gpus = num_gpus_per_worker
+ self._initialize_local(self._num_gpus, devices)
+
+ def _initialize_local(self, num_gpus, devices):
+ """Initializes the object for local training."""
+ self._cluster_spec = None
+ # Convert `num_gpus` into `devices`, shouldn't specify both.
+ if devices is None:
if num_gpus is None:
- raise ValueError("`num_gpus` is required if `cluster_spec` is given.")
- self._num_gpus = num_gpus
- if num_gpus > 0:
- self._worker_device_map = {
- worker: [
- device_util.canonicalize(worker + "/device:GPU:%d" % gpu)
- for gpu in range(num_gpus)
- ] for worker in self._workers
- }
+ num_gpus = context.num_gpus()
+ if num_gpus == 0:
+ devices = ["/device:CPU:0"]
else:
- self._worker_device_map = {
- worker: [device_util.canonicalize(worker, "/device:CPU:0")]
- for worker in self._workers
- }
- devices = nest.flatten(self._worker_device_map)
-
- # Setting `_default_device` will add a device scope in the
- # distribution.scope. We set the default device to the first worker. When
- # users specify device under distribution.scope by
- # with tf.device("/cpu:0"):
- # ...
- # their ops will end up on the cpu device of its first worker, e.g.
- # "/job:worker/task:0/device:CPU:0". Note this is not used in tower mode.
- self._default_device = self._workers[0]
- else:
- self._cluster_spec = None
- # Convert `num_gpus` into `devices`, shouldn't specify both.
- if devices is None:
- if num_gpus is None:
- num_gpus = context.num_gpus()
devices = ["/device:GPU:%d" % d for d in range(num_gpus)]
- elif num_gpus is not None:
- raise ValueError("Must only specify one of `devices` and `num_gpus`.")
- # TODO(yuefengz): consider setting the default device.
+ elif num_gpus is not None:
+ raise ValueError("Must only specify one of `devices` and `num_gpus`.")
+ self._num_gpus = num_gpus
+ # TODO(yuefengz): consider setting the default device.
+
+ assert devices, "Must specify at least one device."
+ assert len(set(devices)) == len(devices), (
+ "No duplicates allowed in `devices` argument.")
+ # TODO(josh11b): Require at least 2 devices?
+ self._devices = [device_util.resolve(d) for d in devices]
+ self._canonical_device_set = set(self._devices)
+ self._device_index = values.PerDevice({d: i for i, d in enumerate(devices)})
+
+ def _initialize_multi_worker(self, num_gpus, cluster_spec):
+ """Initializes the object for multi-worker training."""
+ cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
+ self._cluster_spec = cluster_spec
+
+ self._workers = []
+ for job in ["chief", "worker"]:
+ for task in range(len(cluster_spec.as_dict().get(job, []))):
+ self._workers.append("/job:%s/task:%d" % (job, task))
+
+ if num_gpus is None:
+ raise ValueError("`num_gpus` is required if `cluster_spec` is given.")
+ if num_gpus > 0:
+ self._worker_device_map = {
+ worker: [
+ device_util.canonicalize(worker + "/device:GPU:%d" % gpu)
+ for gpu in range(num_gpus)
+ ] for worker in self._workers
+ }
+ else:
+ self._worker_device_map = {
+ worker: [device_util.canonicalize(worker, "/device:CPU:0")]
+ for worker in self._workers
+ }
+
+ devices = nest.flatten(self._worker_device_map)
+
+ # Setting `_default_device` will add a device scope in the
+ # distribution.scope. We set the default device to the first worker. When
+ # users specify device under distribution.scope by
+ # with tf.device("/cpu:0"):
+ # ...
+ # their ops will end up on the cpu device of its first worker, e.g.
+ # "/job:worker/task:0/device:CPU:0". Note this is not used in tower mode.
+ self._default_device = self._workers[0]
assert devices, "Must specify at least one device."
assert len(set(devices)) == len(devices), (
@@ -409,8 +438,6 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
self._canonical_device_set = set(self._devices)
self._device_index = values.PerDevice(
{d: i for i, d in enumerate(devices)})
- self._cross_tower_ops = cross_tower_ops
- self._prefetch_on_device = prefetch_on_device
def _create_variable(self, next_creator, *args, **kwargs):
"""Create a mirrored variable. See `DistributionStrategy.scope`."""
@@ -544,11 +571,25 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
cluster_spec=None,
task_type=None,
task_id=None):
- del cluster_spec, task_type, task_id
+ del task_type, task_id
+
+ if session_config:
+ session_config.isolate_session_state = True
+
+ if cluster_spec:
+ self._initialize_multi_worker(self._num_gpus, cluster_spec)
+
if self._cross_tower_ops is None:
if self._cluster_spec:
- self._cross_tower_ops = cross_tower_ops_lib.MultiWorkerAllReduce(
- self._workers, self._num_gpus)
+ # It currently cannot detect the toplogy of remote workers. So we
+ # hard-code the multi-worker all-reduce algorithm for now.
+ if len(self._workers) == 1:
+ # The default is "nccl".
+ self._cross_tower_ops = cross_tower_ops_lib.AllReduceCrossTowerOps()
+ else:
+ # The default is hierarchical reduce and broadcast.
+ self._cross_tower_ops = cross_tower_ops_lib.MultiWorkerAllReduce(
+ self._workers, self._num_gpus)
else:
self._cross_tower_ops = cross_tower_ops_lib.choose_the_best(
self._devices, session_config=session_config)
@@ -567,10 +608,18 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
# which case `value` would be a single value or value could be 0.
return _reduce_non_distributed_value(self, aggregation, value,
destinations)
+ if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_TOWER:
+ value = value.get(self._devices[0])
+ if isinstance(value, (int, float)):
+ return value
+ return self.broadcast(value, destinations)
return self._get_cross_tower_ops().reduce(
aggregation, value, destinations=destinations)
def _batch_reduce(self, aggregation, value_destination_pairs):
+ if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_TOWER:
+ return [self.broadcast(v.get(self._devices[0]), d)
+ for v, d in value_destination_pairs]
return self._get_cross_tower_ops().batch_reduce(aggregation,
value_destination_pairs)
@@ -636,6 +685,22 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def parameter_devices(self):
return list(self._devices)
+ @property
+ def between_graph(self):
+ return False
+
+ @property
+ def should_init(self):
+ return True
+
+ @property
+ def should_checkpoint(self):
+ return True
+
+ @property
+ def should_save_summary(self):
+ return True
+
def non_slot_devices(self, var_list):
del var_list
return list(self._devices)
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index 9a4cc0a897..c6894e9013 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import sys
from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import strategy_test_lib
from tensorflow.contrib.distribute.python import values
from tensorflow.core.protobuf import config_pb2
@@ -37,10 +38,12 @@ from tensorflow.python.layers import core
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell_impl
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import device_util
from tensorflow.python.training import distribution_strategy_context
+from tensorflow.python.training import server_lib
GPU_TEST = "test_gpu" in sys.argv[0]
@@ -126,6 +129,25 @@ class MirroredTwoDeviceDistributionTest(strategy_test_lib.DistributionTestBase):
expected = sum(range(len(dist.worker_devices)))
self.assertEqual(expected, self.evaluate(unwrapped[0]))
+ @test_util.run_in_graph_and_eager_modes
+ def testReduceOnlyFirstTowerUpdates(self):
+ if not GPU_TEST:
+ self.skipTest("Not GPU test")
+
+ def run_fn(device_id):
+ return constant_op.constant(3 + 5 * device_id)
+
+ dist = self._get_distribution_strategy()
+ with dist.scope():
+ result = dist.call_for_each_tower(run_fn, dist.worker_device_index)
+ reduced = dist.reduce(
+ variable_scope.VariableAggregation.ONLY_FIRST_TOWER,
+ result,
+ destinations="/device:CPU:0")
+ unwrapped = dist.unwrap(reduced)
+ self.assertEqual(1, len(unwrapped))
+ self.assertEqual(3, self.evaluate(unwrapped[0]))
+
@test_util.run_in_graph_and_eager_modes()
def testReduceToMultipleDestinations(self):
if not GPU_TEST:
@@ -382,6 +404,84 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
v3.aggregation)
@test_util.run_in_graph_and_eager_modes(config=config)
+ def testOnlyFirstTowerUpdatesVariables(self):
+ self._skip_eager_if_gpus_less_than(1)
+
+ def create_fn():
+ aggregation = variable_scope.VariableAggregation.ONLY_FIRST_TOWER
+ v0 = variable_scope.variable(
+ 2.0,
+ name="on_read",
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=aggregation)
+ v1 = variable_scope.variable(
+ 3.0,
+ name="on_write",
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE,
+ aggregation=aggregation)
+ return v0, v1
+
+ devices = ["/device:GPU:0", "/device:CPU:0"]
+ dist = mirrored_strategy.MirroredStrategy(devices)
+ with dist.scope():
+ v0, v1 = dist.call_for_each_tower(create_fn, run_concurrently=False)
+ self.evaluate(v0.initializer)
+ self.assertEqual(2.0, self.evaluate(v0.get(devices[0])))
+ self.assertEqual(2.0, self.evaluate(v0.get(devices[1])))
+ self.assertEqual(2.0, self.evaluate(dist.read_var(v0)))
+ self.evaluate(v1.initializer)
+ self.assertEqual(3.0, self.evaluate(v1.get(devices[0])))
+ self.assertEqual(3.0, self.evaluate(v1.get(devices[1])))
+ self.assertEqual(3.0, self.evaluate(dist.read_var(v1)))
+
+ # Update using the assign_add member function.
+ def update_member_fn(device_id):
+ update0 = v0.assign_add(5.0 * (device_id + 1))
+ update1 = v1.assign_add(7.0 * (device_id + 1))
+ return update0, update1
+
+ update0a, update1a = dist.call_for_each_tower(
+ update_member_fn, dist.worker_device_index, run_concurrently=False)
+
+ # Update "sync on read" variable.
+ self.evaluate(dist.group(update0a))
+ self.assertEqual(2.0 + 5.0, self.evaluate(v0.get(devices[0])))
+ # Writes are not synchronized for "sync on read" variables,
+ # so device[1] can end up with a different value.
+ self.assertEqual(2.0 + 2*5.0, self.evaluate(v0.get(devices[1])))
+ # Always reads from device 0.
+ self.assertEqual(2.0 + 5.0, self.evaluate(dist.read_var(v0)))
+
+ # Update "sync on write" variable.
+ self.evaluate(dist.group(update1a))
+ self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[0])))
+ # Writes are synchronized for v1, only the argument to assign_add on
+ # device[0] is used.
+ self.assertEqual(3.0 + 7.0, self.evaluate(v1.get(devices[1])))
+ self.assertEqual(3.0 + 7.0, self.evaluate(dist.read_var(v1)))
+
+ # Update using state_ops.assign_add global function.
+ def update_state_ops_fn(device_id):
+ update0 = state_ops.assign_add(v0, 11.0 * (device_id + 1))
+ update1 = state_ops.assign_add(v1, 13.0 * (device_id + 1))
+ return update0, update1
+
+ update0b, update1b = dist.call_for_each_tower(
+ update_state_ops_fn, dist.worker_device_index, run_concurrently=False)
+ self.evaluate(dist.group(update0b))
+
+ # Update "sync on read" variable.
+ self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(v0.get(devices[0])))
+ self.assertEqual(2.0 + 2*5.0 + 2*11.0, self.evaluate(v0.get(devices[1])))
+ self.assertEqual(2.0 + 5.0 + 11.0, self.evaluate(dist.read_var(v0)))
+
+ # Update "sync on write" variable.
+ self.evaluate(dist.group(update1b))
+ self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[0])))
+ self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(v1.get(devices[1])))
+ self.assertEqual(3.0 + 7.0 + 13.0, self.evaluate(dist.read_var(v1)))
+
+ @test_util.run_in_graph_and_eager_modes(config=config)
def testNoneSynchronizationWithGetVariable(self):
self._skip_eager_if_gpus_less_than(1)
devices = ["/device:CPU:0", "/device:GPU:0"]
@@ -802,8 +902,8 @@ class MirroredVariableUpdateTest(test.TestCase):
return mirrored_var.assign(5.0)
with self.assertRaisesRegexp(
- ValueError, "A non-DistributedValues value cannot be reduced with "
- "the given aggregation."):
+ ValueError, "A non-DistributedValues value 5.0 cannot be reduced "
+ "with the given aggregation VariableAggregation.SUM."):
self.evaluate(dist.unwrap(dist.call_for_each_tower(model_fn)))
@test_util.run_in_graph_and_eager_modes(config=config)
@@ -886,8 +986,18 @@ class MirroredVariableUpdateTest(test.TestCase):
self.assertIsInstance(mirrored_var, values.MirroredVariable)
self.evaluate(variables.global_variables_initializer())
self.assertEquals(1.0, self.evaluate(mirrored_var))
- mirrored_var_result = self.evaluate(mirrored_var.assign_add(6.0))
+
+ # read_value == True
+ mirrored_var_result = self.evaluate(
+ mirrored_var.assign_add(6.0, read_value=True))
self.assertEquals(7.0, mirrored_var_result)
+ self.assertEquals(7.0, self.evaluate(mirrored_var.get("/device:CPU:0")))
+ self.assertEquals(7.0, self.evaluate(mirrored_var.get("/device:GPU:0")))
+
+ # read_value == False
+ self.evaluate(mirrored_var.assign_add(2.0, read_value=False))
+ self.assertEquals(9.0, self.evaluate(mirrored_var.get("/device:CPU:0")))
+ self.assertEquals(9.0, self.evaluate(mirrored_var.get("/device:GPU:0")))
@test_util.run_in_graph_and_eager_modes(config=config)
def testAssignAddMirroredVarTowerContext(self):
@@ -954,6 +1064,8 @@ class MirroredVariableUpdateTest(test.TestCase):
self.assertEquals(5.0, self.evaluate(mirrored_var))
mirrored_var_result = self.evaluate(mirrored_var.assign_sub(2.0))
self.assertEquals(3.0, mirrored_var_result)
+ self.assertEquals(3.0, self.evaluate(mirrored_var.get("/device:GPU:0")))
+ self.assertEquals(3.0, self.evaluate(mirrored_var.get("/device:CPU:0")))
@test_util.run_in_graph_and_eager_modes(config=config)
def testAssignSubMirroredVarTowerContext(self):
@@ -1244,5 +1356,40 @@ class MirroredStrategyDefunTest(test.TestCase):
self._call_and_check(fn1, [factors], expected_result, [fn1])
+class MultiWorkerMirroredStrategyTest(
+ multi_worker_test_base.MultiWorkerTestBase,
+ strategy_test_lib.DistributionTestBase):
+
+ def _get_distribution_strategy(self):
+ cluster_spec = server_lib.ClusterSpec({
+ "worker": ["/job:worker/task:0", "/job:worker/task:1"]
+ })
+ strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus())
+ strategy.configure(cluster_spec=cluster_spec)
+ return strategy
+
+ def testMinimizeLossGraph(self):
+ self._test_minimize_loss_graph(self._get_distribution_strategy(),
+ learning_rate=0.05)
+
+
+class MultiWorkerMirroredStrategyTestWithChief(
+ multi_worker_test_base.MultiWorkerTestBase,
+ strategy_test_lib.DistributionTestBase):
+
+ @classmethod
+ def setUpClass(cls):
+ """Create a local cluster with 2 workers and 1 chief."""
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
+ num_workers=2, num_ps=0, has_chief=True)
+ cls._default_target = "grpc://" + cls._cluster_spec["chief"][0]
+
+ def testMinimizeLossGraph(self):
+ strategy = mirrored_strategy.MirroredStrategy(
+ num_gpus_per_worker=context.num_gpus())
+ strategy.configure(cluster_spec=self._cluster_spec)
+ self._test_minimize_loss_graph(strategy, learning_rate=0.05)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
index 55d59adc07..969e126956 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_test.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distribute.python import mirrored_strategy
-from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import strategy_test_lib
from tensorflow.python.eager import context
from tensorflow.python.eager import test
@@ -28,7 +27,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import distribution_strategy_context
-from tensorflow.python.training import server_lib
class MirroredOneCPUDistributionTest(strategy_test_lib.DistributionTestBase):
@@ -64,6 +62,7 @@ class VariableCreatorStackTest(test.TestCase):
def model_fn(device_id):
assert isinstance(device_id, int)
+
def thread_creator_fn(next_creator, *args, **kwargs):
return next_creator(*args, **kwargs) + ":thread_" + str(device_id)
@@ -90,32 +89,20 @@ class VariableCreatorStackTest(test.TestCase):
self.assertEquals(expected, result)
-class MultiWorkerMirroredStrategyTest(
- multi_worker_test_base.MultiWorkerTestBase,
- strategy_test_lib.DistributionTestBase):
-
- def _get_distribution_strategy(self):
- return mirrored_strategy.MirroredStrategy(
- cluster_spec=server_lib.ClusterSpec({
- 'worker': ['/job:worker/task:0', '/job:worker/task:1']
- }),
- num_gpus=context.num_gpus())
-
- def testMinimizeLossGraph(self):
- self._test_minimize_loss_graph(self._get_distribution_strategy())
+class MultiWorkerMirroredStrategyTest(test.TestCase):
def testDeviceScope(self):
"""Test the device scope of multi-worker MirroredStrategy."""
with context.graph_mode():
- strategy = mirrored_strategy.MirroredStrategy(
- cluster_spec={'worker': ['/job:worker/task:0', '/job:worker/task:1']},
- num_gpus=context.num_gpus())
+ strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus())
+ strategy.configure(
+ cluster_spec={"worker": ["/job:worker/task:0", "/job:worker/task:1"]})
with strategy.scope():
a = constant_op.constant(1.)
- with ops.device('/cpu:0'):
+ with ops.device("/cpu:0"):
b = constant_op.constant(1.)
- self.assertEqual(a.device, '/job:worker/task:0')
- self.assertEqual(b.device, '/job:worker/task:0/device:CPU:0')
+ self.assertEqual(a.device, "/job:worker/task:0")
+ self.assertEqual(b.device, "/job:worker/task:0/device:CPU:0")
if __name__ == "__main__":
diff --git a/tensorflow/contrib/distribute/python/monitor_test.py b/tensorflow/contrib/distribute/python/monitor_test.py
index 2892ce4394..16be839e1d 100644
--- a/tensorflow/contrib/distribute/python/monitor_test.py
+++ b/tensorflow/contrib/distribute/python/monitor_test.py
@@ -45,7 +45,7 @@ class MonitorTest(test.TestCase, parameterized.TestCase):
if context.executing_eagerly():
monitor = monitor_lib.Monitor(single_loss_step, None)
else:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
monitor = monitor_lib.Monitor(single_loss_step, sess)
monitor.run_steps(1)
diff --git a/tensorflow/contrib/distribute/python/multi_worker_test_base.py b/tensorflow/contrib/distribute/python/multi_worker_test_base.py
index 249de01f08..18b4503eff 100644
--- a/tensorflow/contrib/distribute/python/multi_worker_test_base.py
+++ b/tensorflow/contrib/distribute/python/multi_worker_test_base.py
@@ -23,26 +23,105 @@ import copy
import threading
import numpy as np
+_portpicker_import_error = None
+try:
+ import portpicker # pylint: disable=g-import-not-at-top
+except ImportError as _error: # pylint: disable=invalid-name
+ _portpicker_import_error = _error
+ portpicker = None
+
+# pylint: disable=g-import-not-at-top
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.estimator import run_config
from tensorflow.python.platform import test
-from tensorflow.python.framework import test_util
-
-
-def create_in_process_cluster(num_workers, num_ps):
+from tensorflow.python.training import server_lib
+
+
+def _create_cluster(num_workers,
+ num_ps,
+ has_chief=False,
+ has_eval=False,
+ protocol='grpc',
+ worker_config=None,
+ ps_config=None):
+ """Creates and starts local servers and returns the cluster_spec dict."""
+ if _portpicker_import_error:
+ raise _portpicker_import_error # pylint: disable=raising-bad-type
+ worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
+ ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
+
+ cluster_dict = {}
+ if num_workers > 0:
+ cluster_dict['worker'] = ['localhost:%s' % port for port in worker_ports]
+ if num_ps > 0:
+ cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports]
+ if has_eval:
+ cluster_dict['evaluator'] = ['localhost:%s' % portpicker.pick_unused_port()]
+ if has_chief:
+ cluster_dict['chief'] = ['localhost:%s' % portpicker.pick_unused_port()]
+
+ cs = server_lib.ClusterSpec(cluster_dict)
+
+ for i in range(num_workers):
+ server_lib.Server(
+ cs,
+ job_name='worker',
+ protocol=protocol,
+ task_index=i,
+ config=worker_config,
+ start=True)
+
+ for i in range(num_ps):
+ server_lib.Server(
+ cs,
+ job_name='ps',
+ protocol=protocol,
+ task_index=i,
+ config=ps_config,
+ start=True)
+
+ if has_chief:
+ server_lib.Server(
+ cs,
+ job_name='chief',
+ protocol=protocol,
+ task_index=0,
+ config=worker_config,
+ start=True)
+
+ if has_eval:
+ server_lib.Server(
+ cs,
+ job_name='evaluator',
+ protocol=protocol,
+ task_index=0,
+ config=worker_config,
+ start=True)
+
+ return cluster_dict
+
+
+def create_in_process_cluster(num_workers,
+ num_ps,
+ has_chief=False,
+ has_eval=False):
"""Create an in-process cluster that consists of only standard server."""
# Leave some memory for cuda runtime.
- gpu_mem_frac = 0.7 / num_workers
+ gpu_mem_frac = 0.7 / (num_workers + int(has_chief) + int(has_eval))
worker_config = config_pb2.ConfigProto()
worker_config.gpu_options.per_process_gpu_memory_fraction = gpu_mem_frac
# Enable collective ops which has no impact on non-collective ops.
# TODO(yuefengz, tucker): removing this after we move the initialization of
# collective mgr to the session level.
- worker_config.experimental.collective_group_leader = (
- '/job:worker/replica:0/task:0')
+ if has_chief:
+ worker_config.experimental.collective_group_leader = (
+ '/job:chief/replica:0/task:0')
+ else:
+ worker_config.experimental.collective_group_leader = (
+ '/job:worker/replica:0/task:0')
ps_config = config_pb2.ConfigProto()
ps_config.device_count['GPU'] = 0
@@ -56,9 +135,10 @@ def create_in_process_cluster(num_workers, num_ps):
# 2) there is something global in CUDA such that if we initialize CUDA in the
# parent process, the child process cannot initialize it again and thus cannot
# use GPUs (https://stackoverflow.com/questions/22950047).
- return test_util.create_local_cluster(
+ return _create_cluster(
num_workers,
num_ps=num_ps,
+ has_chief=has_chief,
worker_config=worker_config,
ps_config=ps_config,
protocol='grpc')
@@ -70,7 +150,8 @@ class MultiWorkerTestBase(test.TestCase):
@classmethod
def setUpClass(cls):
"""Create a local cluster with 2 workers."""
- cls._workers, cls._ps = create_in_process_cluster(num_workers=2, num_ps=0)
+ cls._cluster_spec = create_in_process_cluster(num_workers=2, num_ps=0)
+ cls._default_target = 'grpc://' + cls._cluster_spec['worker'][0]
def setUp(self):
# We only cache the session in one test because another test may have a
@@ -111,17 +192,17 @@ class MultiWorkerTestBase(test.TestCase):
config.graph_options.rewrite_options.constant_folding = (
rewriter_config_pb2.RewriterConfig.OFF)
+ if target is None:
+ target = self._default_target
if graph is None:
if getattr(self._thread_local, 'cached_session', None) is None:
self._thread_local.cached_session = session.Session(
- graph=None, config=config, target=target or self._workers[0].target)
+ graph=None, config=config, target=target)
sess = self._thread_local.cached_session
with sess.graph.as_default(), sess.as_default():
yield sess
else:
- with session.Session(
- graph=graph, config=config, target=target or
- self._workers[0].target) as sess:
+ with session.Session(graph=graph, config=config, target=target) as sess:
yield sess
def _run_client(self, client_fn, task_type, task_id, num_gpus, *args,
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py
index 68561b5bbf..23b220f64b 100644
--- a/tensorflow/contrib/distribute/python/one_device_strategy.py
+++ b/tensorflow/contrib/distribute/python/one_device_strategy.py
@@ -67,6 +67,7 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
self._prefetch_on_device)
def _broadcast(self, tensor, destinations):
+ del destinations
return tensor
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
@@ -127,6 +128,7 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
return values.MapOutput([fn(m, *args, **kwargs) for m in map_over])
def _reduce(self, aggregation, value, destinations):
+ del destinations
if not isinstance(value, values.MapOutput):
return value
l = value.get()
diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
index a2d736e422..6e9ba37a19 100644
--- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py
+++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
@@ -51,7 +51,7 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
model_fn, iterator.get_next(), run_concurrently=layer.built)))
if not context.executing_eagerly():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
run_step = sess.make_callable(run_step())
self.evaluate(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
index 96b6519bc4..1125d027f6 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
@@ -22,11 +22,13 @@ from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_
from tensorflow.contrib.distribute.python import mirrored_strategy
from tensorflow.contrib.distribute.python import values
from tensorflow.python.distribute import multi_worker_util
+from tensorflow.python.eager import context
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope as vs
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import device_setter
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
@@ -81,37 +83,29 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
create conflicts of device assignment.
"""
- def __init__(self,
- num_gpus_per_worker=0,
- cluster_spec=None,
- task_type=None,
- task_id=None):
+ def __init__(self, num_gpus_per_worker=0):
"""Initializes this strategy.
Args:
- num_gpus_per_worker: number of local GPUs or GPUs per worker.
- cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
- cluster configurations.
- task_type: the current task type.
- task_id: the current task id.
+ num_gpus_per_worker: number of local GPUs or GPUs per worker, the default
+ is 0 meaning CPU only.
+
+ Raises:
+ ValueError: if `cluster_spec` is given but `task_type` or `task_id` is
+ not.
"""
super(ParameterServerStrategy, self).__init__()
self._num_gpus_per_worker = num_gpus_per_worker
- if cluster_spec:
- cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
- self._cluster_spec = cluster_spec
+ self._initialize_local(num_gpus_per_worker)
# We typically don't need to do all-reduce in this strategy.
self._cross_tower_ops = (
cross_tower_ops_lib.ReductionToOneDeviceCrossTowerOps(
reduce_to_device=_LOCAL_CPU))
- self._initialize_devices(num_gpus_per_worker, cluster_spec, task_type,
- task_id)
-
- def _initialize_devices(self, num_gpus_per_worker, cluster_spec, task_type,
- task_id):
- """Initialize internal devices.
+ def _initialize_multi_worker(self, num_gpus_per_worker, cluster_spec,
+ task_type, task_id):
+ """Initialize devices for multiple workers.
It creates variable devices and compute devices. Variables and operations
will be assigned to them respectively. We have one compute device per tower.
@@ -129,85 +123,103 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
Raises:
ValueError: if the cluster_spec doesn't have ps jobs.
"""
- self._task_type = task_type or "worker"
- self._task_id = task_id or 0
- self._worker_device = "/job:%s/task:%d" % (self._task_type, self._task_id)
+ assert cluster_spec
+ if not task_type or task_id is None:
+ raise ValueError("When `cluster_spec` is given, you must also specify "
+ "`task_type` and `task_id`")
+ cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
- # TODO(yuefengz): maybe clearer to split it into two classes, one for
- # the distribuetd case and one for the local case, once we have the factory
- # class/method.
+ self._worker_device = "/job:%s/task:%d" % (self._task_type, self._task_id)
# Define compute devices which is a list of device strings and one for each
# tower. When there are GPUs, replicate operations on these GPUs. Otherwise,
# place operations on CPU.
- if cluster_spec is None:
- # Local mode.
- if num_gpus_per_worker > 0:
- self._compute_devices = list(
- map("/device:GPU:{}".format, range(num_gpus_per_worker)))
- else:
- self._compute_devices = [_LOCAL_CPU]
+ if num_gpus_per_worker > 0:
+ self._compute_devices = [
+ "%s/device:GPU:%d" % (self._worker_device, i)
+ for i in range(num_gpus_per_worker)
+ ]
else:
- # Distributed mode.
- if num_gpus_per_worker > 0:
- self._compute_devices = [
- "%s/device:GPU:%d" % (self._worker_device, i)
- for i in range(num_gpus_per_worker)
- ]
- else:
- self._compute_devices = [self._worker_device]
+ self._compute_devices = [self._worker_device]
self._compute_devices = list(
map(device_util.resolve, self._compute_devices))
self._canonical_compute_device_set = set(self._compute_devices)
- # Define variable device which is a device string in the local case and a
- # device function in the distributed case. It is used to open a device scope
- # where varibles are defined.
+ # In distributed mode, place variables on ps jobs in a round-robin fashion.
+ # Note that devices returned from `replica_device_setter` are not
+ # canonical and therefore we don't canonicalize all variable devices to
+ # make them consistent.
+ # TODO(yuefengz): support passing a strategy object to control variable
+ # assignment.
+ # TODO(yuefengz): merge the logic of replica_device_setter into this
+ # class.
+ num_ps_replicas = len(cluster_spec.as_dict().get("ps", []))
+ if num_ps_replicas == 0:
+ raise ValueError("The cluster spec needs to have `ps` jobs.")
+ self._variable_device = device_setter.replica_device_setter(
+ ps_tasks=num_ps_replicas,
+ worker_device=self._worker_device,
+ merge_devices=True,
+ cluster=cluster_spec)
+
# The `_parameter_devices` is needed for the `parameter_devices` property
- # and is a list of all variable devices.
- if cluster_spec is None:
- # Local mode. If there is only one GPU, put everything on that GPU.
- # Otherwise, place variables on CPU.
- if num_gpus_per_worker == 1:
- assert len(list(self._compute_devices)) == 1
- self._variable_device = _LOCAL_GPU_0
- self._parameter_devices = [_LOCAL_GPU_0]
- else:
- self._variable_device = _LOCAL_CPU
- self._parameter_devices = [_LOCAL_CPU]
+ # and is a list of all variable devices. Here parameter devices are all
+ # tasks of the "ps" job.
+ self._parameter_devices = map("/job:ps/task:{}".format,
+ range(num_ps_replicas))
+
+ # Add a default device so that ops without specified devices will not end up
+ # on other workers.
+ self._default_device = self._worker_device
+
+ self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
+ task_id)
+ self._cluster_spec = cluster_spec
+ self._task_type = task_type
+ self._task_id = task_id
+
+ logging.info(
+ "Multi-worker ParameterServerStrategy with "
+ "cluster_spec = %r, task_type = %r, task_id = %r, "
+ "num_ps_replicas = %r, is_chief = %r, compute_devices = %r, "
+ "variable_device = %r", cluster_spec.as_dict(), task_type, task_id,
+ num_ps_replicas, self._is_chief, self._compute_devices,
+ self._variable_device)
+
+ def _initialize_local(self, num_gpus_per_worker):
+ """Initialize internal devices for local training."""
+ # Define compute devices which is a list of device strings and one for each
+ # tower. When there are GPUs, replicate operations on these GPUs. Otherwise,
+ # place operations on CPU.
+ if num_gpus_per_worker > 0:
+ self._compute_devices = list(
+ map("/device:GPU:{}".format, range(num_gpus_per_worker)))
else:
- # Distributed mode. Place variables on ps jobs in a round-robin fashion.
- # Note that devices returned from `replica_device_setter` are not
- # canonical and therefore we don't canonicalize all variable devices to
- # make them consistent.
- # TODO(yuefengz): support passing a strategy object to control variable
- # assignment.
- # TODO(yuefengz): merge the logic of replica_device_setter into this
- # class.
- num_ps_replicas = len(cluster_spec.as_dict().get("ps", []))
- if num_ps_replicas == 0:
- raise ValueError("The cluster spec needs to have `ps` jobs.")
- self._variable_device = device_setter.replica_device_setter(
- ps_tasks=num_ps_replicas,
- worker_device=self._worker_device,
- merge_devices=True,
- cluster=cluster_spec)
-
- # Parameter devices are all tasks of the "ps" job.
- self._parameter_devices = map("/job:ps/task:{}".format,
- range(num_ps_replicas))
-
- # Define the default device in cross-tower mode. In the distributed case, we
- # set the default device to the corresponding worker to prevent these ops
- # from being placed on other workers.
- if cluster_spec is None:
- self._default_device = None
+ self._compute_devices = [_LOCAL_CPU]
+
+ self._compute_devices = list(
+ map(device_util.resolve, self._compute_devices))
+ self._canonical_compute_device_set = set(self._compute_devices)
+
+ # If there is only one GPU, put everything on that GPU. Otherwise, place
+ # variables on CPU.
+ if num_gpus_per_worker == 1:
+ assert len(list(self._compute_devices)) == 1
+ self._variable_device = _LOCAL_GPU_0
+ self._parameter_devices = [_LOCAL_GPU_0]
else:
- self._default_device = self._worker_device
+ self._variable_device = _LOCAL_CPU
+ self._parameter_devices = [_LOCAL_CPU]
- self._is_chief = cluster_spec is None or multi_worker_util.is_chief(
- cluster_spec, task_type, task_id)
+ self._is_chief = True
+ self._cluster_spec = None
+ self._task_type = None
+ self._task_id = None
+
+ logging.info(
+ "ParameterServerStrategy with compute_devices = %r, "
+ "variable_device = %r", self._compute_devices, self._variable_device)
def distribute_dataset(self, dataset_fn):
"""Distributes the dataset to each local GPU."""
@@ -227,14 +239,42 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
if aggregation not in (
vs.VariableAggregation.NONE,
vs.VariableAggregation.SUM,
- vs.VariableAggregation.MEAN
+ vs.VariableAggregation.MEAN,
+ vs.VariableAggregation.ONLY_FIRST_TOWER
):
raise ValueError("Invalid variable aggregation mode: " + aggregation +
" for variable: " + kwargs["name"])
def var_creator(*args, **kwargs):
+ # Record what collections this variable should be added to.
+ collections = kwargs.pop("collections", None)
+ if collections is None:
+ collections = [ops.GraphKeys.GLOBAL_VARIABLES]
+ kwargs["collections"] = []
+
+ # Create and wrap the variable.
v = next_creator(*args, **kwargs)
- return values.AggregatingVariable(v, aggregation)
+ wrapped = values.AggregatingVariable(v, aggregation)
+
+ # Add the wrapped variable to the requested collections.
+ # The handling of eager mode and the global step matches
+ # ResourceVariable._init_from_args().
+ if not context.executing_eagerly():
+ g = ops.get_default_graph()
+ # If "trainable" is True, next_creator() will add the contained
+ # variable to the TRAINABLE_VARIABLES collection, so we manually
+ # remove it and replace with the wrapper. We can't set "trainable"
+ # to False for next_creator() since that causes functions like
+ # implicit_gradients to skip those variables.
+ if kwargs.get("trainable", True):
+ collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
+ l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
+ l.remove(v)
+ g.add_to_collections(collections, wrapped)
+ elif ops.GraphKeys.GLOBAL_STEP in collections:
+ ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped)
+
+ return wrapped
else:
var_creator = next_creator
@@ -267,10 +307,15 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
# pylint: disable=protected-access
return mirrored_strategy._reduce_non_distributed_value(
self, aggregation, value, destinations)
+ if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER:
+ return self.broadcast(value.get(self._compute_devices[0]), destinations)
return self._cross_tower_ops.reduce(
aggregation, value, destinations=destinations)
def _batch_reduce(self, aggregation, value_destination_pairs):
+ if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER:
+ return [self.broadcast(v.get(self._compute_devices[0]), d)
+ for v, d in value_destination_pairs]
for _, destinations in value_destination_pairs:
self._verify_destinations_not_different_worker(destinations)
return self._cross_tower_ops.batch_reduce(aggregation,
@@ -345,16 +390,40 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
cluster configurations.
task_type: the current task type.
task_id: the current task id.
- """
- del session_config
- # Set the devices if cluster_spec is defined in TF_CONFIG but not passed in
- # the constructor.
+ Raises:
+ ValueError: if `cluster_spec` is given but `task_type` or `task_id` is
+ not.
+ """
if not self._cluster_spec and cluster_spec:
+ # If a `cluster_spec` is already passed in, do nothing here.
+ # TODO(yuefengz): check `cluster_spec` is the same if this object has
+ # already been initialized with a `cluster_spec`.
+ if task_type is None or task_id is None:
+ raise ValueError("When `cluster_spec` is given, must also specify "
+ "`task_type` and `task_id`.")
self._cluster_spec = multi_worker_util.normalize_cluster_spec(
cluster_spec)
- self._initialize_devices(self._num_gpus_per_worker, self._cluster_spec,
- task_type, task_id)
+ self._task_type = task_type
+ self._task_id = task_id
+ self._initialize_multi_worker(self._num_gpus_per_worker,
+ self._cluster_spec, task_type, task_id)
+
+ if not session_config or not self._cluster_spec:
+ return
+
+ session_config.isolate_session_state = False
+
+ assert self._cluster_spec
+ assert self._task_type
+ assert self._task_id is not None
+
+ # The device filters prevent communication between workers.
+ if self._task_type not in ["chief", "worker"]:
+ return
+ del session_config.device_filters[:]
+ session_config.device_filters.extend(
+ ["/job:%s/task:%d" % (self._task_type, self._task_id), "/job:ps"])
@property
def num_towers(self):
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
index adfe3e8b02..12789e0bc9 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
@@ -18,12 +18,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import copy
import threading
from absl.testing import parameterized
from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import multi_worker_test_base
from tensorflow.contrib.distribute.python import parameter_server_strategy
+from tensorflow.contrib.distribute.python import values
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.eager import context
from tensorflow.python.estimator import run_config
from tensorflow.python.framework import constant_op
@@ -37,21 +41,15 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import device_util
from tensorflow.python.training import distribution_strategy_context
+from tensorflow.python.training import training_util
+CHIEF = run_config.TaskType.CHIEF
+WORKER = run_config.TaskType.WORKER
+PS = run_config.TaskType.PS
-class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
- parameterized.TestCase):
- @classmethod
- def setUpClass(cls):
- cls._workers, cls._ps = multi_worker_test_base.create_in_process_cluster(
- num_workers=3, num_ps=2)
- cls._cluster_spec = {
- run_config.TaskType.WORKER: [
- 'fake_worker_0', 'fake_worker_1', 'fake_worker_2'
- ],
- run_config.TaskType.PS: ['fake_ps_0', 'fake_ps_1']
- }
+class ParameterServerStrategyTestBase(
+ multi_worker_test_base.MultiWorkerTestBase):
def setUp(self):
self._result = 0
@@ -60,23 +58,30 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
self._init_reached = 0
self._finish_condition = threading.Condition()
self._finish_reached = 0
- super(ParameterServerStrategyTest, self).setUp()
+ self._sess_config = config_pb2.ConfigProto(allow_soft_placement=True)
+ super(ParameterServerStrategyTestBase, self).setUp()
def _get_test_objects(self, task_type, task_id, num_gpus):
distribution = parameter_server_strategy.ParameterServerStrategy(
num_gpus_per_worker=num_gpus)
if not task_type:
- return distribution, ''
+ return distribution, '', self._sess_config
+ sess_config = copy.deepcopy(self._sess_config)
distribution.configure(
- cluster_spec=self._cluster_spec, task_type=task_type, task_id=task_id)
- return distribution, self._workers[task_id].target
+ session_config=sess_config,
+ cluster_spec=self._cluster_spec,
+ task_type=task_type,
+ task_id=task_id)
+ return (distribution, 'grpc://' + self._cluster_spec[WORKER][task_id],
+ sess_config)
def _test_device_assignment_distributed(self, task_type, task_id, num_gpus):
worker_device = '/job:%s/replica:0/task:%d' % (task_type, task_id)
- d, _ = self._get_test_objects(task_type, task_id, num_gpus)
+ d, _, sess_config = self._get_test_objects(task_type, task_id, num_gpus)
with ops.Graph().as_default(), \
- self.test_session(target=self._workers[0].target) as sess, \
+ self.test_session(target=self._default_target,
+ config=sess_config) as sess, \
d.scope():
# Define a variable outside the call_for_each_tower scope. This is not
@@ -172,18 +177,14 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
self.assertEqual(z_val, 43.0)
self.assertEqual(f_val, 46.0)
- @combinations.generate(
- combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
- def testDeviceAssignmentDistributed(self, num_gpus):
- self._test_device_assignment_distributed('worker', 1, num_gpus)
-
def _test_device_assignment_local(self,
d,
compute_device='CPU',
variable_device='CPU',
num_gpus=0):
with ops.Graph().as_default(), \
- self.test_session(target=self._workers[0].target) as sess, \
+ self.test_session(target=self._default_target,
+ config=self._sess_config) as sess, \
d.scope():
def model_fn():
@@ -276,33 +277,18 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
self.assertEqual(z_val, 43.0)
self.assertEqual(f_val, 46.0)
- def testDeviceAssignmentLocalCPU(self):
- distribution = parameter_server_strategy.ParameterServerStrategy(
- num_gpus_per_worker=0)
- self._test_device_assignment_local(
- distribution, compute_device='CPU', variable_device='CPU', num_gpus=0)
-
- def testDeviceAssignmentLocalOneGPU(self):
- distribution = parameter_server_strategy.ParameterServerStrategy(
- num_gpus_per_worker=1)
- self._test_device_assignment_local(
- distribution, compute_device='GPU', variable_device='GPU', num_gpus=1)
-
- def testDeviceAssignmentLocalTwoGPUs(self):
- distribution = parameter_server_strategy.ParameterServerStrategy(
- num_gpus_per_worker=2)
- self._test_device_assignment_local(
- distribution, compute_device='GPU', variable_device='CPU', num_gpus=2)
-
def _test_simple_increment(self, task_type, task_id, num_gpus):
- d, master_target = self._get_test_objects(task_type, task_id, num_gpus)
+ d, master_target, sess_config = self._get_test_objects(
+ task_type, task_id, num_gpus)
if hasattr(d, '_cluster_spec') and d._cluster_spec:
- num_workers = len(d._cluster_spec.as_dict().get('worker',
- ['dummy_worker']))
+ num_workers = len(d._cluster_spec.as_dict().get(WORKER))
+ if 'chief' in d._cluster_spec.as_dict():
+ num_workers += 1
else:
num_workers = 1
with ops.Graph().as_default(), \
- self.test_session(target=master_target) as sess, \
+ self.test_session(target=master_target,
+ config=sess_config) as sess, \
d.scope():
def model_fn():
@@ -312,18 +298,22 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
y = variable_scope.get_variable(
'y', initializer=20.0,
aggregation=variable_scope.VariableAggregation.SUM)
+ z = variable_scope.get_variable(
+ 'z', initializer=30.0,
+ aggregation=variable_scope.VariableAggregation.ONLY_FIRST_TOWER)
# We explicitly make a constant tensor here to avoid complaints about
# summing non-distributed values.
one = constant_op.constant(1.0)
x_add = x.assign_add(one, use_locking=True)
y_add = y.assign_add(one, use_locking=True)
+ z_add = z.assign_add(one, use_locking=True)
- train_op = control_flow_ops.group([x_add, y_add])
- return x, y, train_op
+ train_op = control_flow_ops.group(x_add, y_add, z_add)
+ return x, y, z, train_op
- x, y, train_op = d.call_for_each_tower(model_fn)
- train_op = d.group(d.unwrap(train_op))
+ x, y, z, train_op = d.call_for_each_tower(model_fn)
+ train_op = d.group(train_op)
if context.num_gpus() < d._num_gpus_per_worker:
return True
@@ -349,16 +339,25 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
self._finish_condition.notify_all()
self._finish_condition.release()
- x_val, y_val = sess.run([x, y])
+ x_val, y_val, z_val = sess.run([x, y, z])
self.assertEqual(x_val, 10.0 + 1.0 * num_workers * d.num_towers)
self.assertEqual(y_val, 20.0 + 1.0 * num_workers * d.num_towers)
+ self.assertEqual(z_val, 30.0 + 1.0 * num_workers)
return (x_val == 10.0 + 1.0 * num_workers * d.num_towers and
- y_val == 20.0 + 1.0 * num_workers * d.num_towers)
+ y_val == 20.0 + 1.0 * num_workers * d.num_towers and
+ z_val == 30.0 + 1.0 * num_workers)
def _test_minimize_loss_graph(self, task_type, task_id, num_gpus):
- d, master_target = self._get_test_objects(task_type, task_id, num_gpus)
+ d, master_target, sess_config = self._get_test_objects(
+ task_type, task_id, num_gpus)
+ assert hasattr(d, '_cluster_spec') and d._cluster_spec
+ num_workers = len(d._cluster_spec.as_dict().get(WORKER))
+ if CHIEF in d._cluster_spec.as_dict():
+ num_workers += 1
+
with ops.Graph().as_default(), \
- self.test_session(target=master_target) as sess, \
+ self.test_session(target=master_target,
+ config=sess_config) as sess, \
d.scope():
l = core.Dense(1, use_bias=False)
@@ -405,13 +404,13 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
if context.num_gpus() < d._num_gpus_per_worker:
return True
- if task_id == 0:
+ if multi_worker_util.is_chief(d._cluster_spec, task_type, task_id):
variables.global_variables_initializer().run()
# Workers waiting for chief worker's initializing variables.
self._init_condition.acquire()
self._init_reached += 1
- while self._init_reached != 3:
+ while self._init_reached != num_workers:
self._init_condition.wait()
self._init_condition.notify_all()
self._init_condition.release()
@@ -428,9 +427,42 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
self.assertLess(error_after, error_before)
return error_after < error_before
+
+class ParameterServerStrategyTest(ParameterServerStrategyTestBase,
+ parameterized.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
+ num_workers=3, num_ps=2)
+ cls._default_target = 'grpc://' + cls._cluster_spec[WORKER][0]
+
+ def testDeviceAssignmentLocalCPU(self):
+ distribution = parameter_server_strategy.ParameterServerStrategy(
+ num_gpus_per_worker=0)
+ self._test_device_assignment_local(
+ distribution, compute_device='CPU', variable_device='CPU', num_gpus=0)
+
+ def testDeviceAssignmentLocalOneGPU(self):
+ distribution = parameter_server_strategy.ParameterServerStrategy(
+ num_gpus_per_worker=1)
+ self._test_device_assignment_local(
+ distribution, compute_device='GPU', variable_device='GPU', num_gpus=1)
+
+ def testDeviceAssignmentLocalTwoGPUs(self):
+ distribution = parameter_server_strategy.ParameterServerStrategy(
+ num_gpus_per_worker=2)
+ self._test_device_assignment_local(
+ distribution, compute_device='GPU', variable_device='CPU', num_gpus=2)
+
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
+ def testDeviceAssignmentDistributed(self, num_gpus):
+ self._test_device_assignment_distributed('worker', 1, num_gpus)
+
def testSimpleBetweenGraph(self):
self._run_between_graph_clients(self._test_simple_increment,
- self._cluster_spec, 0)
+ self._cluster_spec, context.num_gpus())
@combinations.generate(
combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
@@ -444,5 +476,38 @@ class ParameterServerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
self._cluster_spec, num_gpus)
+class ParameterServerStrategyWithChiefTest(ParameterServerStrategyTestBase,
+ parameterized.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
+ num_workers=3, num_ps=2, has_chief=True)
+ cls._default_target = 'grpc://' + cls._cluster_spec[CHIEF][0]
+
+ def testSimpleBetweenGraph(self):
+ self._run_between_graph_clients(self._test_simple_increment,
+ self._cluster_spec, context.num_gpus())
+
+ @combinations.generate(
+ combinations.combine(mode=['graph'], num_gpus=[0, 1, 2]))
+ def testMinimizeLossGraph(self, num_gpus):
+ self._run_between_graph_clients(self._test_minimize_loss_graph,
+ self._cluster_spec, num_gpus)
+
+ def testGlobalStepIsWrapped(self):
+ distribution = parameter_server_strategy.ParameterServerStrategy(
+ num_gpus_per_worker=2)
+ with ops.Graph().as_default(), distribution.scope():
+ created_step = training_util.create_global_step()
+ get_step = training_util.get_global_step()
+ self.assertEqual(created_step, get_step,
+ msg=('created_step %s type %s vs. get_step %s type %s' %
+ (id(created_step), created_step.__class__.__name__,
+ id(get_step), get_step.__class__.__name__)))
+ self.assertIs(values.AggregatingVariable, type(created_step))
+ self.assertIs(values.AggregatingVariable, type(get_step))
+
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
index a68dbce6c7..bb10b546a1 100644
--- a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
+++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
@@ -37,7 +37,7 @@ class PrefetchingOpsV2Test(test.TestCase):
iterator = device_dataset.make_one_shot_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for i in range(10):
self.assertEqual(i, sess.run(next_element))
with self.assertRaises(errors.OutOfRangeError):
@@ -55,7 +55,7 @@ class PrefetchingOpsV2Test(test.TestCase):
next_element = iterator.get_next()
output = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for _ in range(5):
result = sess.run(next_element)
self.assertEqual(2, len(result))
@@ -75,7 +75,7 @@ class PrefetchingOpsV2Test(test.TestCase):
iterator = device_dataset.make_initializable_iterator()
next_element = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(iterator.initializer)
for _ in range(5):
sess.run(next_element)
diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py
index 8605ab1f7d..f1ada49fa3 100644
--- a/tensorflow/contrib/distribute/python/step_fn_test.py
+++ b/tensorflow/contrib/distribute/python/step_fn_test.py
@@ -49,7 +49,7 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase):
if context.executing_eagerly():
run_step = single_loss_step
else:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
run_step = sess.make_callable(single_loss_step())
self.evaluate(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py
index 371b97ba96..5d498fb629 100644
--- a/tensorflow/contrib/distribute/python/strategy_test_lib.py
+++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py
@@ -130,7 +130,8 @@ class DistributionTestBase(test.TestCase):
# Error should go down
self.assertLess(error_after, error_before)
- def _test_minimize_loss_graph(self, d, soft_placement=False):
+ def _test_minimize_loss_graph(self, d, soft_placement=False,
+ learning_rate=0.2):
config = config_pb2.ConfigProto()
config.allow_soft_placement = soft_placement
config.gpu_options.per_process_gpu_memory_fraction = 0.3
@@ -150,7 +151,7 @@ class DistributionTestBase(test.TestCase):
grad_fn = backprop.implicit_grad(loss)
def update(v, g):
- return v.assign_sub(0.2 * g)
+ return v.assign_sub(learning_rate * g)
one = d.broadcast(constant_op.constant([[1.]]))
@@ -189,7 +190,8 @@ class DistributionTestBase(test.TestCase):
with d.scope():
map_in = [constant_op.constant(i) for i in range(10)]
map_out = d.map(map_in, lambda x, y: x * y, 2)
- observed = d.reduce(variable_scope.VariableAggregation.SUM, map_out)
+ observed = d.reduce(variable_scope.VariableAggregation.SUM, map_out,
+ "/device:CPU:0")
expected = 90 # 2 * (0 + 1 + ... + 9)
self.assertEqual(expected, observed.numpy())
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index a486003076..6ba83976fc 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -59,7 +59,7 @@ def get_tpu_system_metadata(tpu_cluster_resolver):
class TPUStrategy(one_device_strategy.OneDeviceStrategy):
"""Experimental TPU distribution strategy implementation."""
- def __init__(self, tpu_cluster_resolver, steps_per_run):
+ def __init__(self, tpu_cluster_resolver, steps_per_run, num_cores=None):
"""Initializes the TPUStrategy object.
Args:
@@ -70,68 +70,101 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
metrics, summaries etc.
This parameter is only used when Distribution Strategy is used with
estimator or keras.
+ num_cores: Number of cores to use on the TPU. If None specified, then
+ auto-detect the cores and topology of the TPU system.
"""
- # TODO(isaprykin): Generalize the defaults. They are currently tailored for
- # the unit test.
+ # TODO(sourabhbajaj): OneDeviceStrategy should be initialized with the
+ # master node fetched from the cluster resolver.
super(TPUStrategy, self).__init__('/device:CPU:0')
self._tpu_cluster_resolver = tpu_cluster_resolver
self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver)
+ # TODO(sourabhbajaj): Change this from num_cores to metadata_override
+ self._num_cores_override = num_cores
- # TODO(priyag): This should not be hardcoded here.
- self._host = '/device:CPU:0'
# TODO(sourabhbajaj): Remove this once performance of running one step
# at a time is comparable to multiple steps.
self.steps_per_run = steps_per_run
- def distribute_dataset(self, dataset_fn):
- # TODO(priyag): Perhaps distribute across cores here.
- return self._call_dataset_fn(dataset_fn)
+ def _get_enqueue_op_per_host(self, host_id, iterator, input_shapes,
+ iterations):
+ """Create an enqueue op for a single host identified using host_id.
- # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
- # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have
- # a mechanism to infer the outputs of `fn`. Pending b/110550782.
- def _run_steps_on_dataset(self, fn, iterator, iterations,
- initial_loop_values=None):
+ The while_loop op returned will run `iterations` times and in each run
+ enqueue batches for each shard.
- shapes = nest.flatten(iterator.output_shapes)
- if any([not s.is_fully_defined() for s in shapes]):
- raise ValueError(
- 'TPU currently requires fully defined shapes. Either use '
- 'set_shape() on the input tensors or use '
- 'dataset.apply(map_and_batch(..., drop_remainder=True)).')
- types = nest.flatten(iterator.output_types)
+ Args:
+ host_id: integer, id of the host to run the enqueue ops on.
+ iterator: `tf.data` iterator to read the input data.
+ input_shapes: shape of inputs to be enqueue on the queue. This is same as
+ the value of `nest.flatten(iterator.output_shapes)`.
+ iterations: integer, number of iterations to be run; determines the
+ number of batches to be enqueued.
+
+ Returns:
+ while_loop_op running `iterations` times; in each run we enqueue a batch
+ on the infeed queue from the host with id `host_id` for each device shard.
+ """
+ host = self.get_host_cpu_device(host_id)
- def enqueue_ops_fn():
+ def _infeed_enqueue_ops_fn():
"""Enqueue ops for one iteration."""
control_deps = []
sharded_inputs = []
- with ops.device(self._host):
- for _ in range(self.num_towers):
+ enqueue_ops = []
+
+ with ops.device(host):
+ for _ in range(self.num_towers_per_host):
# Use control dependencies to ensure a deterministic ordering.
with ops.control_dependencies(control_deps):
inputs = nest.flatten(iterator.get_next())
control_deps.extend(inputs)
sharded_inputs.append(inputs)
- enqueue_ops = []
for core_id, shard_input in enumerate(sharded_inputs):
enqueue_ops.append(
tpu_ops.infeed_enqueue_tuple(
- inputs=shard_input, shapes=shapes, device_ordinal=core_id))
+ inputs=shard_input,
+ shapes=input_shapes,
+ device_ordinal=core_id))
return enqueue_ops
def enqueue_ops_loop_body(i):
- with ops.control_dependencies(enqueue_ops_fn()):
+ """Callable for the loop body of the while_loop instantiated below."""
+ with ops.control_dependencies(_infeed_enqueue_ops_fn()):
return i + 1
- with ops.device(self._host):
- enqueue_ops = control_flow_ops.while_loop(
+ with ops.device(host):
+ enqueue_op_per_host = control_flow_ops.while_loop(
lambda i: i < iterations,
enqueue_ops_loop_body,
[constant_op.constant(0)],
parallel_iterations=1)
+ return enqueue_op_per_host
+
+ def distribute_dataset(self, dataset_fn):
+ # TODO(priyag): Perhaps distribute across cores here.
+ return self._call_dataset_fn(dataset_fn)
+
+ # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
+ # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have
+ # a mechanism to infer the outputs of `fn`. Pending b/110550782.
+ def _run_steps_on_dataset(self, fn, iterator, iterations,
+ initial_loop_values=None):
+
+ shapes = nest.flatten(iterator.output_shapes)
+ if any([not s.is_fully_defined() for s in shapes]):
+ raise ValueError(
+ 'TPU currently requires fully defined shapes. Either use '
+ 'set_shape() on the input tensors or use '
+ 'dataset.apply(map_and_batch(..., drop_remainder=True)).')
+ types = nest.flatten(iterator.output_types)
+
+ enqueue_ops = [
+ self._get_enqueue_op_per_host(host_id, iterator, shapes, iterations)
+ for host_id in range(self.num_hosts)]
+
def dequeue_fn():
dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes)
return nest.pack_sequence_as(iterator.output_shapes, dequeued)
@@ -142,6 +175,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
initial_loop_values = nest.flatten(initial_loop_values)
ctx = values.MultiStepContext()
def run_fn(*args, **kwargs):
+ """Single step on the TPU device."""
del args, kwargs
fn_inputs = dequeue_fn()
if not isinstance(fn_inputs, tuple):
@@ -233,6 +267,9 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
if aggregation == vs.VariableAggregation.MEAN:
# TODO(jhseu): Revisit once we support model-parallelism.
value *= (1. / self.num_towers)
+ elif aggregation != vs.VariableAggregation.SUM:
+ raise NotImplementedError(
+ 'Currently only support sum & mean in TPUStrategy.')
return tpu_ops.cross_replica_sum(value)
cf_context = cf_context.outer_context
@@ -242,10 +279,12 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
devices = cross_tower_ops_lib.get_devices_from(destinations)
if len(devices) == 1:
assert device_util.canonicalize(devices[0]) == device_util.canonicalize(
- self._host)
+ self.get_host_cpu_device(0))
else:
raise ValueError('Multiple devices are not supported for TPUStrategy')
+ if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER:
+ return value[0]
output = math_ops.add_n(value)
if aggregation == vs.VariableAggregation.MEAN:
return output * (1. / len(value))
@@ -258,4 +297,31 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
@property
def num_towers(self):
+ return self._num_cores_override or self._tpu_metadata.num_cores
+
+ @property
+ def num_hosts(self):
+ return self._tpu_metadata.num_hosts
+
+ @property
+ def num_towers_per_host(self):
return self._tpu_metadata.num_of_cores_per_host
+
+ def get_host_cpu_device(self, host_id):
+ if self._tpu_cluster_resolver.get_master() in ('', 'local'):
+ return '/replica:0/task:0/device:CPU:0'
+ job_name = self._tpu_cluster_resolver.get_job_name() or 'tpu_worker'
+ return '/job:%s/task:%d/device:CPU:0' % (job_name, host_id)
+
+ def configure(self,
+ session_config=None,
+ cluster_spec=None,
+ task_type=None,
+ task_id=None):
+ del cluster_spec, task_type, task_id
+ if session_config:
+ session_config.isolate_session_state = True
+ cluster_spec = self._tpu_cluster_resolver.cluster_spec()
+ if cluster_spec:
+ session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
+
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index a58bb3a849..fafa6384a1 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -183,6 +183,14 @@ class Mirrored(DistributedDelegate):
return self._index[device]
return list(self._index.values())[0]
+ def _as_graph_element(self):
+ obj = self.get()
+ # pylint: disable=protected-access
+ conv_fn = getattr(obj, "_as_graph_element", None)
+ if conv_fn and callable(conv_fn):
+ return conv_fn()
+ return obj
+
def _assign_on_device(device, variable, tensor):
with ops.device(device):
@@ -296,6 +304,10 @@ class DistributedVariable(DistributedDelegate):
self._primary_var.op.type)
return self.get().op
+ @property
+ def _in_graph_mode(self):
+ return self._primary_var._in_graph_mode # pylint: disable=protected-access
+
def read_value(self):
return distribution_strategy_context.get_distribution_strategy().read_var(
self)
@@ -328,10 +340,6 @@ class MirroredVariable(DistributedVariable, Mirrored,
"""Holds a map from device to variables whose values are kept in sync."""
def __init__(self, index, primary_var, aggregation):
- # Use a weakref to make it easy to map from the contained values
- # to the container without introducing a reference cycle.
- for v in six.itervalues(index):
- v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access
self._primary_var = primary_var
self._aggregation = aggregation
super(MirroredVariable, self).__init__(index)
@@ -354,8 +362,19 @@ class MirroredVariable(DistributedVariable, Mirrored,
# We are calling assign on the mirrored variable in cross tower context,
# use update to update the variable.
- return distribution_strategy_context.get_distribution_strategy().update(
- self, f, *args, **kwargs)
+ strategy = distribution_strategy_context.get_distribution_strategy()
+ updates = strategy.update(self, f, *args, **kwargs)
+ grouped = strategy.group(updates)
+ if isinstance(updates, DistributedValues) and updates.is_tensor_like:
+ # Make sure we run all updates. Without this, something like
+ # session.run(mirrored_var.assign*(...)) may only update one tower.
+ index = {}
+ for d in updates.devices:
+ with ops.device(d), ops.control_dependencies([grouped]):
+ index[d] = array_ops.identity(updates.get(d))
+ return Mirrored(index)
+ else:
+ return grouped
else:
_assert_tower_context()
# We are calling an assign function on the mirrored variable in tower
@@ -500,6 +519,8 @@ class TowerLocalVariable(DistributedVariable, PerDevice,
return self._aggregation
def _get_cross_tower(self):
+ if self._aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER:
+ return self._primary_var
all_components = tuple(self._index.values())
# TODO(josh11b): Use a strategy-specific method.
total = math_ops.add_n(all_components)
@@ -1180,6 +1201,10 @@ class AggregatingVariable(checkpointable.CheckpointableBase):
def __repr__(self):
return repr(self._v)
+ def _should_act_as_resource_variable(self):
+ """Pass resource_variable_ops.is_resource_variable check."""
+ pass
+
# Register a conversion function which reads the value of the variable,
# allowing instances of the class to be used as tensors.
diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py
index 3602f4d128..15a85a28f5 100644
--- a/tensorflow/contrib/distribute/python/values_test.py
+++ b/tensorflow/contrib/distribute/python/values_test.py
@@ -521,6 +521,7 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
return worker_device_map, devices
def testDataDistributionOneDevicePerWorker(self):
+ self.skipTest("Temporarily disabled.")
worker_device_map, devices = self._cpu_devices()
with context.graph_mode():
dataset_fn = lambda: dataset_ops.Dataset.range(8)
@@ -528,6 +529,7 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
[[0, 1], [2, 3], [4, 5], [6, 7]])
def testDataDistributionTwoDevicePerWorker(self):
+ self.skipTest("Temporarily disabled.")
if context.num_gpus() < 1:
self.skipTest("A GPU is not available for this test.")
worker_device_map, devices = self._cpu_and_one_gpu_devices()
@@ -537,6 +539,7 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
[[0, 2, 1, 3], [4, 6, 5, 7]])
def testTupleDataset(self):
+ self.skipTest("Temporarily disabled.")
worker_device_map, devices = self._cpu_devices()
with context.graph_mode():
@@ -553,6 +556,7 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
expected_values)
def testInitializableIterator(self):
+ self.skipTest("Temporarily disabled.")
worker_device_map, devices = self._cpu_devices()
with context.graph_mode():
dataset_fn = lambda: dataset_ops.Dataset.range(8)
@@ -570,6 +574,7 @@ class MultiWorkerDatasetTest(multi_worker_test_base.MultiWorkerTestBase):
[[0, 1], [2, 3], [4, 5], [6, 7]])
def testValueErrorForIterator(self):
+ self.skipTest("Temporarily disabled.")
# Incompatiable arguments.
with self.assertRaises(ValueError):
values.MultiWorkerDataIterator({"w1": None}, {"w1": "d1", "w2": "d2"})
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index a8d0d493ab..97c53ae2b9 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -445,7 +445,7 @@ cuda_py_test(
cuda_py_test(
name = "sinh_arcsinh_test",
- size = "small",
+ size = "medium",
srcs = ["python/kernel_tests/sinh_arcsinh_test.py"],
additional_deps = [
":distributions_py",
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py
index 042c8ebd51..372b7e37b7 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py
@@ -31,7 +31,7 @@ class AbsoluteValueTest(test.TestCase):
"""Tests correctness of the absolute value bijector."""
def testBijectorVersusNumpyRewriteOfBasicFunctionsEventNdims0(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
bijector = AbsoluteValue(validate_args=True)
self.assertEqual("absolute_value", bijector.name)
x = array_ops.constant([[0., 1., -1], [0., -5., 3.]]) # Shape [2, 3]
@@ -54,13 +54,13 @@ class AbsoluteValueTest(test.TestCase):
y, event_ndims=0)))
def testNegativeYRaisesForInverseIfValidateArgs(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
bijector = AbsoluteValue(validate_args=True)
with self.assertRaisesOpError("y was negative"):
sess.run(bijector.inverse(-1.))
def testNegativeYRaisesForILDJIfValidateArgs(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
bijector = AbsoluteValue(validate_args=True)
with self.assertRaisesOpError("y was negative"):
sess.run(bijector.inverse_log_det_jacobian(-1., event_ndims=0))
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py
index 1e4ad724d0..a7bd51430e 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py
@@ -28,7 +28,7 @@ from tensorflow.python.platform import test
class AffineLinearOperatorTest(test.TestCase):
def testIdentity(self):
- with self.test_session():
+ with self.cached_session():
affine = AffineLinearOperator(
validate_args=True)
x = np.array([[1, 0, -1], [2, 3, 4]], dtype=np.float32)
@@ -45,7 +45,7 @@ class AffineLinearOperatorTest(test.TestCase):
affine.forward_log_det_jacobian(x, event_ndims=2).eval())
def testDiag(self):
- with self.test_session():
+ with self.cached_session():
shift = np.array([-1, 0, 1], dtype=np.float32)
diag = np.array([[1, 2, 3],
[2, 5, 6]], dtype=np.float32)
@@ -67,7 +67,7 @@ class AffineLinearOperatorTest(test.TestCase):
affine.forward_log_det_jacobian(x, event_ndims=1).eval())
def testTriL(self):
- with self.test_session():
+ with self.cached_session():
shift = np.array([-1, 0, 1], dtype=np.float32)
tril = np.array([[[3, 0, 0],
[2, -1, 0],
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py
index d2533620be..bc6752a69d 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_scalar_test.py
@@ -31,14 +31,14 @@ class AffineScalarBijectorTest(test.TestCase):
"""Tests correctness of the Y = scale @ x + shift transformation."""
def testProperties(self):
- with self.test_session():
+ with self.cached_session():
mu = -1.
# scale corresponds to 1.
bijector = AffineScalar(shift=mu)
self.assertEqual("affine_scalar", bijector.name)
def testNoBatchScalar(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
def static_run(fun, x, **kwargs):
return fun(x, **kwargs).eval()
@@ -60,7 +60,7 @@ class AffineScalarBijectorTest(test.TestCase):
run(bijector.inverse_log_det_jacobian, x, event_ndims=0))
def testOneBatchScalarViaIdentityIn64BitUserProvidesShiftOnly(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
def static_run(fun, x, **kwargs):
return fun(x, **kwargs).eval()
@@ -83,7 +83,7 @@ class AffineScalarBijectorTest(test.TestCase):
run(bijector.inverse_log_det_jacobian, x, event_ndims=0))
def testOneBatchScalarViaIdentityIn64BitUserProvidesScaleOnly(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
def static_run(fun, x, **kwargs):
return fun(x, **kwargs).eval()
@@ -106,7 +106,7 @@ class AffineScalarBijectorTest(test.TestCase):
run(bijector.inverse_log_det_jacobian, x, event_ndims=0))
def testTwoBatchScalarIdentityViaIdentity(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
def static_run(fun, x, **kwargs):
return fun(x, **kwargs).eval()
@@ -129,7 +129,7 @@ class AffineScalarBijectorTest(test.TestCase):
run(bijector.inverse_log_det_jacobian, x, event_ndims=0))
def testTwoBatchScalarIdentityViaScale(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
def static_run(fun, x, **kwargs):
return fun(x, **kwargs).eval()
@@ -152,7 +152,7 @@ class AffineScalarBijectorTest(test.TestCase):
run(bijector.inverse_log_det_jacobian, x, event_ndims=0))
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
bijector = AffineScalar(shift=3.6, scale=0.42)
assert_scalar_congruency(bijector, lower_x=-2., upper_x=2.)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py
index 9e14b9a53e..dc18eb3df6 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py
@@ -32,14 +32,14 @@ class AffineBijectorTest(test.TestCase):
"""Tests correctness of the Y = scale @ x + shift transformation."""
def testProperties(self):
- with self.test_session():
+ with self.cached_session():
mu = -1.
# scale corresponds to 1.
bijector = Affine(shift=mu)
self.assertEqual("affine", bijector.name)
def testNoBatchMultivariateIdentity(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder = array_ops.placeholder(dtypes.float32, name="x")
def static_run(fun, x, **kwargs):
@@ -71,7 +71,7 @@ class AffineBijectorTest(test.TestCase):
0., run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
def testNoBatchMultivariateDiag(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder = array_ops.placeholder(dtypes.float32, name="x")
def static_run(fun, x, **kwargs):
@@ -114,7 +114,7 @@ class AffineBijectorTest(test.TestCase):
run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
def testNoBatchMultivariateFullDynamic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32, name="x")
mu = array_ops.placeholder(dtypes.float32, name="mu")
scale_diag = array_ops.placeholder(dtypes.float32, name="scale_diag")
@@ -137,7 +137,7 @@ class AffineBijectorTest(test.TestCase):
feed_dict))
def testBatchMultivariateIdentity(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder = array_ops.placeholder(dtypes.float32, name="x")
def static_run(fun, x, **kwargs):
@@ -161,7 +161,7 @@ class AffineBijectorTest(test.TestCase):
run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
def testBatchMultivariateDiag(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder = array_ops.placeholder(dtypes.float32, name="x")
def static_run(fun, x, **kwargs):
@@ -185,7 +185,7 @@ class AffineBijectorTest(test.TestCase):
run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
def testBatchMultivariateFullDynamic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32, name="x")
mu = array_ops.placeholder(dtypes.float32, name="mu")
scale_diag = array_ops.placeholder(dtypes.float32, name="scale_diag")
@@ -209,7 +209,7 @@ class AffineBijectorTest(test.TestCase):
x, event_ndims=1), feed_dict))
def testIdentityWithDiagUpdate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder = array_ops.placeholder(dtypes.float32, name="x")
def static_run(fun, x, **kwargs):
@@ -235,7 +235,7 @@ class AffineBijectorTest(test.TestCase):
run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
def testIdentityWithTriL(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder = array_ops.placeholder(dtypes.float32, name="x")
def static_run(fun, x, **kwargs):
@@ -261,7 +261,7 @@ class AffineBijectorTest(test.TestCase):
run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
def testDiagWithTriL(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder = array_ops.placeholder(dtypes.float32, name="x")
def static_run(fun, x, **kwargs):
@@ -285,7 +285,7 @@ class AffineBijectorTest(test.TestCase):
run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
def testIdentityAndDiagWithTriL(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder = array_ops.placeholder(dtypes.float32, name="x")
def static_run(fun, x, **kwargs):
@@ -312,7 +312,7 @@ class AffineBijectorTest(test.TestCase):
run(bijector.inverse_log_det_jacobian, x, event_ndims=1))
def testIdentityWithVDVTUpdate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder = array_ops.placeholder(dtypes.float32, name="x")
def static_run(fun, x, **kwargs):
@@ -349,7 +349,7 @@ class AffineBijectorTest(test.TestCase):
run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1))
def testDiagWithVDVTUpdate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder = array_ops.placeholder(dtypes.float32, name="x")
def static_run(fun, x, **kwargs):
@@ -385,7 +385,7 @@ class AffineBijectorTest(test.TestCase):
run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1))
def testTriLWithVDVTUpdate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder = array_ops.placeholder(dtypes.float32, name="x")
def static_run(fun, x, **kwargs):
@@ -422,7 +422,7 @@ class AffineBijectorTest(test.TestCase):
run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1))
def testTriLWithVDVTUpdateNoDiagonal(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
placeholder = array_ops.placeholder(dtypes.float32, name="x")
def static_run(fun, x, **kwargs):
@@ -459,7 +459,7 @@ class AffineBijectorTest(test.TestCase):
run(bijector_ref.inverse_log_det_jacobian, x, event_ndims=1))
def testNoBatchMultivariateRaisesWhenSingular(self):
- with self.test_session():
+ with self.cached_session():
mu = [1., -1]
bijector = Affine(
shift=mu,
@@ -531,7 +531,7 @@ class AffineBijectorTest(test.TestCase):
itertools.combinations(s, r) for r in range(len(s) + 1))
for args in _powerset(scale_params.items()):
- with self.test_session():
+ with self.cached_session():
args = dict(args)
scale_args = dict({"x": x}, **args)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py
index c832fcaa68..bf61e9f2fe 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/batch_normalization_test.py
@@ -69,7 +69,7 @@ class BatchNormTest(test_util.VectorDistributionTestHelpers,
]
for input_shape, event_dims, training in params:
x_ = np.arange(5 * 4 * 2).astype(np.float32).reshape(input_shape)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = constant_op.constant(x_)
# When training, memorize the exact mean of the last
# minibatch that it normalized (instead of moving average assignment).
@@ -145,7 +145,7 @@ class BatchNormTest(test_util.VectorDistributionTestHelpers,
def testMaximumLikelihoodTraining(self):
# Test Maximum Likelihood training with default bijector.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
base_dist = distributions.MultivariateNormalDiag(loc=[0., 0.])
batch_norm = BatchNormalization(training=True)
dist = transformed_distribution_lib.TransformedDistribution(
@@ -176,7 +176,7 @@ class BatchNormTest(test_util.VectorDistributionTestHelpers,
self.assertAllClose([1., 1.], moving_var_, atol=5e-2)
def testLogProb(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
layer = normalization.BatchNormalization(epsilon=0.)
batch_norm = BatchNormalization(batchnorm_layer=layer, training=False)
base_dist = distributions.MultivariateNormalDiag(loc=[0., 0.])
@@ -196,7 +196,7 @@ class BatchNormTest(test_util.VectorDistributionTestHelpers,
def testMutuallyConsistent(self):
# BatchNorm bijector is only mutually consistent when training=False.
dims = 4
- with self.test_session() as sess:
+ with self.cached_session() as sess:
layer = normalization.BatchNormalization(epsilon=0.)
batch_norm = BatchNormalization(batchnorm_layer=layer, training=False)
dist = transformed_distribution_lib.TransformedDistribution(
@@ -215,7 +215,7 @@ class BatchNormTest(test_util.VectorDistributionTestHelpers,
def testInvertMutuallyConsistent(self):
# BatchNorm bijector is only mutually consistent when training=False.
dims = 4
- with self.test_session() as sess:
+ with self.cached_session() as sess:
layer = normalization.BatchNormalization(epsilon=0.)
batch_norm = Invert(
BatchNormalization(batchnorm_layer=layer, training=False))
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py
index dc45114b1c..ada99ec9c6 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/chain_test.py
@@ -46,7 +46,7 @@ class ChainBijectorTest(test.TestCase):
"""Tests the correctness of the Y = Chain(bij1, bij2, bij3) transformation."""
def testBijector(self):
- with self.test_session():
+ with self.cached_session():
chain = Chain((Exp(), Softplus()))
self.assertEqual("chain_of_exp_of_softplus", chain.name)
x = np.asarray([[[1., 2.],
@@ -61,7 +61,7 @@ class ChainBijectorTest(test.TestCase):
chain.forward_log_det_jacobian(x, event_ndims=1).eval())
def testBijectorIdentity(self):
- with self.test_session():
+ with self.cached_session():
chain = Chain()
self.assertEqual("identity", chain.name)
x = np.asarray([[[1., 2.],
@@ -74,13 +74,13 @@ class ChainBijectorTest(test.TestCase):
0., chain.forward_log_det_jacobian(x, event_ndims=1).eval())
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
chain = Chain((Exp(), Softplus()))
assert_scalar_congruency(
chain, lower_x=1e-3, upper_x=1.5, rtol=0.05)
def testShapeGetters(self):
- with self.test_session():
+ with self.cached_session():
chain = Chain([
SoftmaxCentered(validate_args=True),
SoftmaxCentered(validate_args=True),
@@ -195,7 +195,7 @@ class ChainBijectorTest(test.TestCase):
dtype=np.float32, shape=[None, 10], name="samples")
ildj = chain.inverse_log_det_jacobian(samples, event_ndims=0)
self.assertTrue(ildj is not None)
- with self.test_session():
+ with self.cached_session():
ildj.eval({samples: np.zeros([2, 10], np.float32)})
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py
index d1ce273499..9681b64ced 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/cholesky_outer_product_test.py
@@ -30,7 +30,7 @@ class CholeskyOuterProductBijectorTest(test.TestCase):
"""Tests the correctness of the Y = X @ X.T transformation."""
def testBijectorMatrix(self):
- with self.test_session():
+ with self.cached_session():
bijector = bijectors.CholeskyOuterProduct(validate_args=True)
self.assertEqual("cholesky_outer_product", bijector.name)
x = [[[1., 0], [2, 1]], [[np.sqrt(2.), 0], [np.sqrt(8.), 1]]]
@@ -75,7 +75,7 @@ class CholeskyOuterProductBijectorTest(test.TestCase):
bijector = bijectors.CholeskyOuterProduct()
x_pl = array_ops.placeholder(dtypes.float32)
- with self.test_session():
+ with self.cached_session():
log_det_jacobian = bijector.forward_log_det_jacobian(x_pl, event_ndims=2)
# The Jacobian matrix is 2 * tf.eye(2), which has jacobian determinant 4.
@@ -86,7 +86,7 @@ class CholeskyOuterProductBijectorTest(test.TestCase):
def testNoBatchStatic(self):
x = np.array([[1., 0], [2, 1]]) # np.linalg.cholesky(y)
y = np.array([[1., 2], [2, 5]]) # np.matmul(x, x.T)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_actual = bijectors.CholeskyOuterProduct().forward(x=x)
x_actual = bijectors.CholeskyOuterProduct().inverse(y=y)
[y_actual_, x_actual_] = sess.run([y_actual, x_actual])
@@ -98,7 +98,7 @@ class CholeskyOuterProductBijectorTest(test.TestCase):
def testNoBatchDeferred(self):
x = np.array([[1., 0], [2, 1]]) # np.linalg.cholesky(y)
y = np.array([[1., 2], [2, 5]]) # np.matmul(x, x.T)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x_pl = array_ops.placeholder(dtypes.float32)
y_pl = array_ops.placeholder(dtypes.float32)
y_actual = bijectors.CholeskyOuterProduct().forward(x=x_pl)
@@ -119,7 +119,7 @@ class CholeskyOuterProductBijectorTest(test.TestCase):
[2, 5]],
[[9., 3],
[3, 5]]]) # np.matmul(x, x.T)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_actual = bijectors.CholeskyOuterProduct().forward(x=x)
x_actual = bijectors.CholeskyOuterProduct().inverse(y=y)
[y_actual_, x_actual_] = sess.run([y_actual, x_actual])
@@ -137,7 +137,7 @@ class CholeskyOuterProductBijectorTest(test.TestCase):
[2, 5]],
[[9., 3],
[3, 5]]]) # np.matmul(x, x.T)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x_pl = array_ops.placeholder(dtypes.float32)
y_pl = array_ops.placeholder(dtypes.float32)
y_actual = bijectors.CholeskyOuterProduct().forward(x=x_pl)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py
index 7be939cd27..d2c00865e7 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/exp_test.py
@@ -30,7 +30,7 @@ class ExpBijectorTest(test.TestCase):
"""Tests correctness of the Y = g(X) = exp(X) transformation."""
def testBijector(self):
- with self.test_session():
+ with self.cached_session():
bijector = Exp()
self.assertEqual("exp", bijector.name)
x = [[[1.], [2.]]]
@@ -48,13 +48,13 @@ class ExpBijectorTest(test.TestCase):
x, event_ndims=1).eval())
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
bijector = Exp()
assert_scalar_congruency(
bijector, lower_x=-2., upper_x=1.5, rtol=0.05)
def testBijectiveAndFinite(self):
- with self.test_session():
+ with self.cached_session():
bijector = Exp()
x = np.linspace(-10, 10, num=10).astype(np.float32)
y = np.logspace(-10, 10, num=10).astype(np.float32)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py
index 54e54c3296..b9cdbfb823 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py
@@ -31,7 +31,7 @@ class GumbelBijectorTest(test.TestCase):
"""Tests correctness of the Gumbel bijector."""
def testBijector(self):
- with self.test_session():
+ with self.cached_session():
loc = 0.3
scale = 5.
bijector = Gumbel(loc=loc, scale=scale, validate_args=True)
@@ -52,12 +52,12 @@ class GumbelBijectorTest(test.TestCase):
atol=0.)
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
assert_scalar_congruency(
Gumbel(loc=0.3, scale=20.), lower_x=1., upper_x=100., rtol=0.02)
def testBijectiveAndFinite(self):
- with self.test_session():
+ with self.cached_session():
bijector = Gumbel(loc=0., scale=3.0, validate_args=True)
x = np.linspace(-10., 10., num=10).astype(np.float32)
y = np.linspace(0.01, 0.99, num=10).astype(np.float32)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/inline_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/inline_test.py
index 7d3bd758cd..c9bccb36fc 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/inline_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/inline_test.py
@@ -32,7 +32,7 @@ class InlineBijectorTest(test.TestCase):
"""Tests correctness of the inline constructed bijector."""
def testBijector(self):
- with self.test_session():
+ with self.cached_session():
exp = Exp()
inline = Inline(
forward_fn=math_ops.exp,
@@ -55,7 +55,7 @@ class InlineBijectorTest(test.TestCase):
inline.forward_log_det_jacobian(x, event_ndims=1).eval())
def testShapeGetters(self):
- with self.test_session():
+ with self.cached_session():
bijector = Inline(
forward_event_shape_tensor_fn=lambda x: array_ops.concat((x, [1]), 0),
forward_event_shape_fn=lambda x: x.as_list() + [1],
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py
index 8b14c8327f..7e3340aeb0 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/invert_test.py
@@ -31,7 +31,7 @@ class InvertBijectorTest(test.TestCase):
"""Tests the correctness of the Y = Invert(bij) transformation."""
def testBijector(self):
- with self.test_session():
+ with self.cached_session():
for fwd in [
bijectors.Identity(),
bijectors.Exp(),
@@ -53,13 +53,13 @@ class InvertBijectorTest(test.TestCase):
rev.forward_log_det_jacobian(x, event_ndims=1).eval())
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
bijector = bijectors.Invert(bijectors.Exp())
assert_scalar_congruency(
bijector, lower_x=1e-3, upper_x=1.5, rtol=0.05)
def testShapeGetters(self):
- with self.test_session():
+ with self.cached_session():
bijector = bijectors.Invert(bijectors.SoftmaxCentered(validate_args=True))
x = tensor_shape.TensorShape([2])
y = tensor_shape.TensorShape([1])
@@ -73,7 +73,7 @@ class InvertBijectorTest(test.TestCase):
bijector.inverse_event_shape_tensor(y.as_list()).eval())
def testDocstringExample(self):
- with self.test_session():
+ with self.cached_session():
exp_gamma_distribution = (
transformed_distribution_lib.TransformedDistribution(
distribution=gamma_lib.Gamma(concentration=1., rate=2.),
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py
index a8089881f6..b3fb50005e 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py
@@ -30,7 +30,7 @@ class KumaraswamyBijectorTest(test.TestCase):
"""Tests correctness of the Kumaraswamy bijector."""
def testBijector(self):
- with self.test_session():
+ with self.cached_session():
a = 2.
b = 0.3
bijector = Kumaraswamy(
@@ -54,13 +54,13 @@ class KumaraswamyBijectorTest(test.TestCase):
atol=0.)
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
assert_scalar_congruency(
Kumaraswamy(concentration1=0.5, concentration0=1.1),
lower_x=0., upper_x=1., n=int(10e3), rtol=0.02)
def testBijectiveAndFinite(self):
- with self.test_session():
+ with self.cached_session():
concentration1 = 1.2
concentration0 = 2.
bijector = Kumaraswamy(
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py
index 5ba5a2083b..ad4329d425 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py
@@ -71,7 +71,7 @@ class MaskedAutoregressiveFlowTest(test_util.VectorDistributionTestHelpers,
def testBijector(self):
x_ = np.arange(3 * 4 * 2).astype(np.float32).reshape(3, 4, 2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ma = MaskedAutoregressiveFlow(
validate_args=True,
**self._autoregressive_flow_kwargs)
@@ -102,7 +102,7 @@ class MaskedAutoregressiveFlowTest(test_util.VectorDistributionTestHelpers,
def testMutuallyConsistent(self):
dims = 4
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ma = MaskedAutoregressiveFlow(
validate_args=True,
**self._autoregressive_flow_kwargs)
@@ -121,7 +121,7 @@ class MaskedAutoregressiveFlowTest(test_util.VectorDistributionTestHelpers,
def testInvertMutuallyConsistent(self):
dims = 4
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ma = Invert(MaskedAutoregressiveFlow(
validate_args=True,
**self._autoregressive_flow_kwargs))
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py
index 49a9afe3f6..31ee36f024 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/matrix_inverse_tril_test.py
@@ -26,6 +26,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
+@test_util.run_all_in_graph_and_eager_modes
class MatrixInverseTriLBijectorTest(test.TestCase):
"""Tests the correctness of the Y = inv(tril) transformation."""
@@ -40,7 +41,6 @@ class MatrixInverseTriLBijectorTest(test.TestCase):
y[idx][np.triu_indices(y[idx].shape[-1], 1)] = 0
return y
- @test_util.run_in_graph_and_eager_modes
def testComputesCorrectValues(self):
inv = bijectors.MatrixInverseTriL(validate_args=True)
self.assertEqual("matrix_inverse_tril", inv.name)
@@ -62,7 +62,6 @@ class MatrixInverseTriLBijectorTest(test.TestCase):
self.assertNear(expected_fldj_, fldj_, err=1e-3)
self.assertNear(-expected_fldj_, ildj_, err=1e-3)
- @test_util.run_in_graph_and_eager_modes
def testOneByOneMatrix(self):
inv = bijectors.MatrixInverseTriL(validate_args=True)
x_ = np.array([[5.]], dtype=np.float32)
@@ -81,7 +80,6 @@ class MatrixInverseTriLBijectorTest(test.TestCase):
self.assertNear(expected_fldj_, fldj_, err=1e-3)
self.assertNear(-expected_fldj_, ildj_, err=1e-3)
- @test_util.run_in_graph_and_eager_modes
def testZeroByZeroMatrix(self):
inv = bijectors.MatrixInverseTriL(validate_args=True)
x_ = np.eye(0, dtype=np.float32)
@@ -100,7 +98,6 @@ class MatrixInverseTriLBijectorTest(test.TestCase):
self.assertNear(expected_fldj_, fldj_, err=1e-3)
self.assertNear(-expected_fldj_, ildj_, err=1e-3)
- @test_util.run_in_graph_and_eager_modes
def testBatch(self):
# Test batch computation with input shape (2, 1, 2, 2), i.e. batch shape
# (2, 1).
@@ -125,20 +122,18 @@ class MatrixInverseTriLBijectorTest(test.TestCase):
self.assertAllClose(expected_fldj_, fldj_, atol=0., rtol=1e-3)
self.assertAllClose(-expected_fldj_, ildj_, atol=0., rtol=1e-3)
- @test_util.run_in_graph_and_eager_modes
def testErrorOnInputRankTooLow(self):
inv = bijectors.MatrixInverseTriL(validate_args=True)
x_ = np.array([0.1], dtype=np.float32)
rank_error_msg = "must have rank at least 2"
- with self.test_session():
- with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
- inv.forward(x_).eval()
- with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
- inv.inverse(x_).eval()
- with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
- inv.forward_log_det_jacobian(x_, event_ndims=2).eval()
- with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
- inv.inverse_log_det_jacobian(x_, event_ndims=2).eval()
+ with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
+ self.evaluate(inv.forward(x_))
+ with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
+ self.evaluate(inv.inverse(x_))
+ with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
+ self.evaluate(inv.forward_log_det_jacobian(x_, event_ndims=2))
+ with self.assertRaisesWithPredicateMatch(ValueError, rank_error_msg):
+ self.evaluate(inv.inverse_log_det_jacobian(x_, event_ndims=2))
# TODO(b/80481923): Figure out why these assertions fail, and fix them.
## def testErrorOnInputNonSquare(self):
@@ -146,55 +141,50 @@ class MatrixInverseTriLBijectorTest(test.TestCase):
## x_ = np.array([[1., 2., 3.],
## [4., 5., 6.]], dtype=np.float32)
## square_error_msg = "must be a square matrix"
- ## with self.test_session():
- ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
- ## square_error_msg):
- ## inv.forward(x_).eval()
- ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
- ## square_error_msg):
- ## inv.inverse(x_).eval()
- ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
- ## square_error_msg):
- ## inv.forward_log_det_jacobian(x_, event_ndims=2).eval()
- ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
- ## square_error_msg):
- ## inv.inverse_log_det_jacobian(x_, event_ndims=2).eval()
-
- @test_util.run_in_graph_and_eager_modes
+ ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ ## square_error_msg):
+ ## self.evaluate(inv.forward(x_))
+ ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ ## square_error_msg):
+ ## self.evaluate(inv.inverse(x_))
+ ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ ## square_error_msg):
+ ## self.evaluate(inv.forward_log_det_jacobian(x_, event_ndims=2))
+ ## with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ ## square_error_msg):
+ ## self.evaluate(inv.inverse_log_det_jacobian(x_, event_ndims=2))
+
def testErrorOnInputNotLowerTriangular(self):
inv = bijectors.MatrixInverseTriL(validate_args=True)
x_ = np.array([[1., 2.],
[3., 4.]], dtype=np.float32)
triangular_error_msg = "must be lower triangular"
- with self.test_session():
- with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
- triangular_error_msg):
- inv.forward(x_).eval()
- with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
- triangular_error_msg):
- inv.inverse(x_).eval()
- with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
- triangular_error_msg):
- inv.forward_log_det_jacobian(x_, event_ndims=2).eval()
- with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
- triangular_error_msg):
- inv.inverse_log_det_jacobian(x_, event_ndims=2).eval()
-
- @test_util.run_in_graph_and_eager_modes
+ with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ triangular_error_msg):
+ self.evaluate(inv.forward(x_))
+ with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ triangular_error_msg):
+ self.evaluate(inv.inverse(x_))
+ with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ triangular_error_msg):
+ self.evaluate(inv.forward_log_det_jacobian(x_, event_ndims=2))
+ with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
+ triangular_error_msg):
+ self.evaluate(inv.inverse_log_det_jacobian(x_, event_ndims=2))
+
def testErrorOnInputSingular(self):
inv = bijectors.MatrixInverseTriL(validate_args=True)
x_ = np.array([[1., 0.],
[0., 0.]], dtype=np.float32)
nonsingular_error_msg = "must have all diagonal entries nonzero"
- with self.test_session():
- with self.assertRaisesOpError(nonsingular_error_msg):
- inv.forward(x_).eval()
- with self.assertRaisesOpError(nonsingular_error_msg):
- inv.inverse(x_).eval()
- with self.assertRaisesOpError(nonsingular_error_msg):
- inv.forward_log_det_jacobian(x_, event_ndims=2).eval()
- with self.assertRaisesOpError(nonsingular_error_msg):
- inv.inverse_log_det_jacobian(x_, event_ndims=2).eval()
+ with self.assertRaisesOpError(nonsingular_error_msg):
+ self.evaluate(inv.forward(x_))
+ with self.assertRaisesOpError(nonsingular_error_msg):
+ self.evaluate(inv.inverse(x_))
+ with self.assertRaisesOpError(nonsingular_error_msg):
+ self.evaluate(inv.forward_log_det_jacobian(x_, event_ndims=2))
+ with self.assertRaisesOpError(nonsingular_error_msg):
+ self.evaluate(inv.inverse_log_det_jacobian(x_, event_ndims=2))
if __name__ == "__main__":
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py
index cb42331a21..9a88f8f1bc 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/ordered_test.py
@@ -38,26 +38,25 @@ class OrderedBijectorTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testBijectorVector(self):
- with self.test_session():
- ordered = Ordered()
- self.assertEqual("ordered", ordered.name)
- x = np.asarray([[2., 3, 4], [4., 8, 13]])
- y = [[2., 0, 0], [4., np.log(4.), np.log(5.)]]
- self.assertAllClose(y, self.evaluate(ordered.forward(x)))
- self.assertAllClose(x, self.evaluate(ordered.inverse(y)))
- self.assertAllClose(
- np.sum(np.asarray(y)[..., 1:], axis=-1),
- self.evaluate(ordered.inverse_log_det_jacobian(y, event_ndims=1)),
- atol=0.,
- rtol=1e-7)
- self.assertAllClose(
- self.evaluate(-ordered.inverse_log_det_jacobian(y, event_ndims=1)),
- self.evaluate(ordered.forward_log_det_jacobian(x, event_ndims=1)),
- atol=0.,
- rtol=1e-7)
+ ordered = Ordered()
+ self.assertEqual("ordered", ordered.name)
+ x = np.asarray([[2., 3, 4], [4., 8, 13]])
+ y = [[2., 0, 0], [4., np.log(4.), np.log(5.)]]
+ self.assertAllClose(y, self.evaluate(ordered.forward(x)))
+ self.assertAllClose(x, self.evaluate(ordered.inverse(y)))
+ self.assertAllClose(
+ np.sum(np.asarray(y)[..., 1:], axis=-1),
+ self.evaluate(ordered.inverse_log_det_jacobian(y, event_ndims=1)),
+ atol=0.,
+ rtol=1e-7)
+ self.assertAllClose(
+ self.evaluate(-ordered.inverse_log_det_jacobian(y, event_ndims=1)),
+ self.evaluate(ordered.forward_log_det_jacobian(x, event_ndims=1)),
+ atol=0.,
+ rtol=1e-7)
def testBijectorUnknownShape(self):
- with self.test_session():
+ with self.cached_session():
ordered = Ordered()
self.assertEqual("ordered", ordered.name)
x = array_ops.placeholder(shape=[2, None], dtype=dtypes.float32)
@@ -84,21 +83,20 @@ class OrderedBijectorTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testShapeGetters(self):
- with self.test_session():
- x = tensor_shape.TensorShape([4])
- y = tensor_shape.TensorShape([4])
- bijector = Ordered(validate_args=True)
- self.assertAllEqual(y, bijector.forward_event_shape(x))
- self.assertAllEqual(y.as_list(),
- self.evaluate(bijector.forward_event_shape_tensor(
- x.as_list())))
- self.assertAllEqual(x, bijector.inverse_event_shape(y))
- self.assertAllEqual(x.as_list(),
- self.evaluate(bijector.inverse_event_shape_tensor(
- y.as_list())))
+ x = tensor_shape.TensorShape([4])
+ y = tensor_shape.TensorShape([4])
+ bijector = Ordered(validate_args=True)
+ self.assertAllEqual(y, bijector.forward_event_shape(x))
+ self.assertAllEqual(y.as_list(),
+ self.evaluate(bijector.forward_event_shape_tensor(
+ x.as_list())))
+ self.assertAllEqual(x, bijector.inverse_event_shape(y))
+ self.assertAllEqual(x.as_list(),
+ self.evaluate(bijector.inverse_event_shape_tensor(
+ y.as_list())))
def testBijectiveAndFinite(self):
- with self.test_session():
+ with self.cached_session():
ordered = Ordered()
x = np.sort(self._rng.randn(3, 10), axis=-1).astype(np.float32)
y = (self._rng.randn(3, 10)).astype(np.float32)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py
index 7eef4ab599..e2062ed55d 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py
@@ -38,7 +38,7 @@ class PermuteBijectorTest(test.TestCase):
expected_x = np.random.randn(4, 2, 3)
expected_y = expected_x[..., expected_permutation]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
permutation_ph = array_ops.placeholder(dtype=dtypes.int32)
bijector = Permute(
permutation=permutation_ph,
@@ -64,7 +64,7 @@ class PermuteBijectorTest(test.TestCase):
self.assertAllClose(0., ildj, rtol=1e-6, atol=0)
def testRaisesOpError(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesOpError("Permutation over `d` must contain"):
permutation_ph = array_ops.placeholder(dtype=dtypes.int32)
bijector = Permute(
@@ -77,7 +77,7 @@ class PermuteBijectorTest(test.TestCase):
permutation = np.int32([2, 0, 1])
x = np.random.randn(4, 2, 3)
y = x[..., permutation]
- with self.test_session():
+ with self.cached_session():
bijector = Permute(permutation=permutation, validate_args=True)
assert_bijective_and_finite(
bijector, x, y, event_ndims=1, rtol=1e-6, atol=0)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py
index 85d2283013..ef303ab664 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/power_transform_test.py
@@ -30,7 +30,7 @@ class PowerTransformBijectorTest(test.TestCase):
"""Tests correctness of the power transformation."""
def testBijector(self):
- with self.test_session():
+ with self.cached_session():
c = 0.2
bijector = PowerTransform(power=c, validate_args=True)
self.assertEqual("power_transform", bijector.name)
@@ -48,13 +48,13 @@ class PowerTransformBijectorTest(test.TestCase):
atol=0.)
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
bijector = PowerTransform(power=0.2, validate_args=True)
assert_scalar_congruency(
bijector, lower_x=-2., upper_x=1.5, rtol=0.05)
def testBijectiveAndFinite(self):
- with self.test_session():
+ with self.cached_session():
bijector = PowerTransform(power=0.2, validate_args=True)
x = np.linspace(-4.999, 10, num=10).astype(np.float32)
y = np.logspace(0.001, 10, num=10).astype(np.float32)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py
index 2d52895fbe..b3b7b8535e 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py
@@ -43,7 +43,7 @@ class RealNVPTest(test_util.VectorDistributionTestHelpers, test.TestCase):
def testBijector(self):
x_ = np.arange(3 * 4 * 2).astype(np.float32).reshape(3, 4 * 2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
nvp = RealNVP(
num_masked=4,
validate_args=True,
@@ -78,7 +78,7 @@ class RealNVPTest(test_util.VectorDistributionTestHelpers, test.TestCase):
def testMutuallyConsistent(self):
dims = 4
- with self.test_session() as sess:
+ with self.cached_session() as sess:
nvp = RealNVP(
num_masked=3,
validate_args=True,
@@ -98,7 +98,7 @@ class RealNVPTest(test_util.VectorDistributionTestHelpers, test.TestCase):
def testInvertMutuallyConsistent(self):
dims = 4
- with self.test_session() as sess:
+ with self.cached_session() as sess:
nvp = Invert(RealNVP(
num_masked=3,
validate_args=True,
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py
index d44e49b487..79eadf524b 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py
@@ -50,7 +50,7 @@ class _ReshapeBijectorTest(object):
expected_x = np.random.randn(4, 3, 2)
expected_y = np.reshape(expected_x, [4, 6])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape_in, shape_out, feed_dict = self.build_shapes([3, 2], [6,])
bijector = Reshape(
event_shape_out=shape_out,
@@ -84,7 +84,7 @@ class _ReshapeBijectorTest(object):
# using the _tensor methods, we should always get a fully-specified
# result since these are evaluated at graph runtime.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
(shape_out_,
shape_in_) = sess.run((
bijector.forward_event_shape_tensor(shape_in),
@@ -103,7 +103,7 @@ class _ReshapeBijectorTest(object):
expected_y_scalar = expected_x_scalar[0]
shape_in, shape_out, feed_dict = self.build_shapes([], [1,])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
bijector = Reshape(
event_shape_out=shape_in,
event_shape_in=shape_out, validate_args=True)
@@ -124,7 +124,7 @@ class _ReshapeBijectorTest(object):
def testMultipleUnspecifiedDimensionsOpError(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [4, -1, -1,])
bijector = Reshape(
event_shape_out=shape_out,
@@ -139,7 +139,7 @@ class _ReshapeBijectorTest(object):
# pylint: disable=invalid-name
def _testInvalidDimensionsOpError(self, expected_error_message):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [1, 2, -2,])
bijector = Reshape(
@@ -155,7 +155,7 @@ class _ReshapeBijectorTest(object):
def testValidButNonMatchingInputOpError(self):
x = np.random.randn(4, 3, 2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [1, 6, 1,])
bijector = Reshape(
event_shape_out=shape_out,
@@ -173,7 +173,7 @@ class _ReshapeBijectorTest(object):
def testValidButNonMatchingInputPartiallySpecifiedOpError(self):
x = np.random.randn(4, 3, 2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape_in, shape_out, feed_dict = self.build_shapes([2, -1], [1, 6, 1,])
bijector = Reshape(
event_shape_out=shape_out,
@@ -190,7 +190,7 @@ class _ReshapeBijectorTest(object):
x1 = np.random.randn(4, 2, 3)
x2 = np.random.randn(4, 1, 1, 5)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape_in, shape_out, fd_mismatched = self.build_shapes([2, 3],
[1, 1, 5])
bijector = Reshape(
@@ -208,7 +208,7 @@ class _ReshapeBijectorTest(object):
expected_x = np.random.randn(4, 6)
expected_y = np.reshape(expected_x, [4, 2, 3])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# one of input/output shapes is partially specified
shape_in, shape_out, feed_dict = self.build_shapes([-1,], [2, 3])
bijector = Reshape(
@@ -227,7 +227,7 @@ class _ReshapeBijectorTest(object):
def testBothShapesPartiallySpecified(self):
expected_x = np.random.randn(4, 2, 3)
expected_y = np.reshape(expected_x, [4, 3, 2])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape_in, shape_out, feed_dict = self.build_shapes([-1, 3], [-1, 2])
bijector = Reshape(
event_shape_out=shape_out,
@@ -245,7 +245,7 @@ class _ReshapeBijectorTest(object):
def testDefaultVectorShape(self):
expected_x = np.random.randn(4, 4)
expected_y = np.reshape(expected_x, [4, 2, 2])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_, shape_out, feed_dict = self.build_shapes([-1,], [-1, 2])
bijector = Reshape(shape_out,
validate_args=True)
@@ -292,7 +292,7 @@ class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest):
def testBijectiveAndFinite(self):
x = np.random.randn(4, 2, 3)
y = np.reshape(x, [4, 1, 2, 3])
- with self.test_session():
+ with self.cached_session():
bijector = Reshape(
event_shape_in=[2, 3],
event_shape_out=[1, 2, 3],
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py
index cea4a62c22..a6d432753d 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py
@@ -31,7 +31,7 @@ class SigmoidBijectorTest(test.TestCase):
"""Tests correctness of the Y = g(X) = (1 + exp(-X))^-1 transformation."""
def testBijector(self):
- with self.test_session():
+ with self.cached_session():
self.assertEqual("sigmoid", Sigmoid().name)
x = np.linspace(-10., 10., 100).reshape([2, 5, 10]).astype(np.float32)
y = special.expit(x)
@@ -45,11 +45,11 @@ class SigmoidBijectorTest(test.TestCase):
x, event_ndims=0).eval(), atol=0., rtol=1e-4)
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
assert_scalar_congruency(Sigmoid(), lower_x=-7., upper_x=7.)
def testBijectiveAndFinite(self):
- with self.test_session():
+ with self.cached_session():
x = np.linspace(-7., 7., 100).astype(np.float32)
eps = 1e-3
y = np.linspace(eps, 1. - eps, 100).astype(np.float32)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py
index 795f1993ba..282619a73b 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py
@@ -33,7 +33,7 @@ class SinhArcsinhBijectorTest(test.TestCase):
"""Tests correctness of the power transformation."""
def testBijectorVersusNumpyRewriteOfBasicFunctions(self):
- with self.test_session():
+ with self.cached_session():
skewness = 0.2
tailweight = 2.0
bijector = SinhArcsinh(
@@ -58,7 +58,7 @@ class SinhArcsinhBijectorTest(test.TestCase):
atol=0.)
def testLargerTailWeightPutsMoreWeightInTails(self):
- with self.test_session():
+ with self.cached_session():
# Will broadcast together to shape [3, 2].
x = [-1., 1.]
tailweight = [[0.5], [1.0], [2.0]]
@@ -75,7 +75,7 @@ class SinhArcsinhBijectorTest(test.TestCase):
self.assertLess(forward_1[1], forward_1[2])
def testSkew(self):
- with self.test_session():
+ with self.cached_session():
# Will broadcast together to shape [3, 2].
x = [-1., 1.]
skewness = [[-1.], [0.], [1.]]
@@ -92,24 +92,24 @@ class SinhArcsinhBijectorTest(test.TestCase):
self.assertLess(np.abs(y[2, 0]), np.abs(y[2, 1]))
def testScalarCongruencySkewness1Tailweight0p5(self):
- with self.test_session():
+ with self.cached_session():
bijector = SinhArcsinh(skewness=1.0, tailweight=0.5, validate_args=True)
assert_scalar_congruency(bijector, lower_x=-2., upper_x=2.0, rtol=0.05)
def testScalarCongruencySkewnessNeg1Tailweight1p5(self):
- with self.test_session():
+ with self.cached_session():
bijector = SinhArcsinh(skewness=-1.0, tailweight=1.5, validate_args=True)
assert_scalar_congruency(bijector, lower_x=-2., upper_x=2.0, rtol=0.05)
def testBijectiveAndFiniteSkewnessNeg1Tailweight0p5(self):
- with self.test_session():
+ with self.cached_session():
bijector = SinhArcsinh(skewness=-1., tailweight=0.5, validate_args=True)
x = np.concatenate((-np.logspace(-2, 10, 1000), [0], np.logspace(
-2, 10, 1000))).astype(np.float32)
assert_bijective_and_finite(bijector, x, x, event_ndims=0, rtol=1e-3)
def testBijectiveAndFiniteSkewness1Tailweight3(self):
- with self.test_session():
+ with self.cached_session():
bijector = SinhArcsinh(skewness=1., tailweight=3., validate_args=True)
x = np.concatenate((-np.logspace(-2, 5, 1000), [0], np.logspace(
-2, 5, 1000))).astype(np.float32)
@@ -117,7 +117,7 @@ class SinhArcsinhBijectorTest(test.TestCase):
bijector, x, x, event_ndims=0, rtol=1e-3)
def testBijectorEndpoints(self):
- with self.test_session():
+ with self.cached_session():
for dtype in (np.float32, np.float64):
bijector = SinhArcsinh(
skewness=dtype(0.), tailweight=dtype(1.), validate_args=True)
@@ -129,7 +129,7 @@ class SinhArcsinhBijectorTest(test.TestCase):
bijector, bounds, bounds, event_ndims=0, atol=2e-6)
def testBijectorOverRange(self):
- with self.test_session():
+ with self.cached_session():
for dtype in (np.float32, np.float64):
skewness = np.array([1.2, 5.], dtype=dtype)
tailweight = np.array([2., 10.], dtype=dtype)
@@ -176,12 +176,12 @@ class SinhArcsinhBijectorTest(test.TestCase):
atol=0.)
def testZeroTailweightRaises(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("not positive"):
SinhArcsinh(tailweight=0., validate_args=True).forward(1.0).eval()
def testDefaultDtypeIsFloat32(self):
- with self.test_session():
+ with self.cached_session():
bijector = SinhArcsinh()
self.assertEqual(bijector.tailweight.dtype, np.float32)
self.assertEqual(bijector.skewness.dtype, np.float32)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py
index 0f0a2fa531..8d18400487 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softmax_centered_test.py
@@ -35,7 +35,7 @@ class SoftmaxCenteredBijectorTest(test.TestCase):
"""Tests correctness of the Y = g(X) = exp(X) / sum(exp(X)) transformation."""
def testBijectorVector(self):
- with self.test_session():
+ with self.cached_session():
softmax = SoftmaxCentered()
self.assertEqual("softmax_centered", softmax.name)
x = np.log([[2., 3, 4], [4., 8, 12]])
@@ -54,7 +54,7 @@ class SoftmaxCenteredBijectorTest(test.TestCase):
rtol=1e-7)
def testBijectorUnknownShape(self):
- with self.test_session():
+ with self.cached_session():
softmax = SoftmaxCentered()
self.assertEqual("softmax_centered", softmax.name)
x = array_ops.placeholder(shape=[2, None], dtype=dtypes.float32)
@@ -80,7 +80,7 @@ class SoftmaxCenteredBijectorTest(test.TestCase):
rtol=1e-7)
def testShapeGetters(self):
- with self.test_session():
+ with self.cached_session():
x = tensor_shape.TensorShape([4])
y = tensor_shape.TensorShape([5])
bijector = SoftmaxCentered(validate_args=True)
@@ -94,7 +94,7 @@ class SoftmaxCenteredBijectorTest(test.TestCase):
y.as_list()).eval())
def testBijectiveAndFinite(self):
- with self.test_session():
+ with self.cached_session():
softmax = SoftmaxCentered()
x = np.linspace(-50, 50, num=10).reshape(5, 2).astype(np.float32)
# Make y values on the simplex with a wide range.
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py
index 3d8a0a32bb..e805619041 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softplus_test.py
@@ -42,13 +42,13 @@ class SoftplusBijectorTest(test.TestCase):
return -np.log(1 - np.exp(-y))
def testHingeSoftnessZeroRaises(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softplus(hinge_softness=0., validate_args=True)
with self.assertRaisesOpError("must be non-zero"):
bijector.forward([1., 1.]).eval()
def testBijectorForwardInverseEventDimsZero(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softplus()
self.assertEqual("softplus", bijector.name)
x = 2 * rng.randn(2, 10)
@@ -58,7 +58,7 @@ class SoftplusBijectorTest(test.TestCase):
self.assertAllClose(x, bijector.inverse(y).eval())
def testBijectorForwardInverseWithHingeSoftnessEventDimsZero(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softplus(hinge_softness=1.5)
x = 2 * rng.randn(2, 10)
y = 1.5 * self._softplus(x / 1.5)
@@ -67,7 +67,7 @@ class SoftplusBijectorTest(test.TestCase):
self.assertAllClose(x, bijector.inverse(y).eval())
def testBijectorLogDetJacobianEventDimsZero(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softplus()
y = 2 * rng.rand(2, 10)
# No reduction needed if event_dims = 0.
@@ -77,7 +77,7 @@ class SoftplusBijectorTest(test.TestCase):
y, event_ndims=0).eval())
def testBijectorForwardInverseEventDimsOne(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softplus()
self.assertEqual("softplus", bijector.name)
x = 2 * rng.randn(2, 10)
@@ -87,7 +87,7 @@ class SoftplusBijectorTest(test.TestCase):
self.assertAllClose(x, bijector.inverse(y).eval())
def testBijectorLogDetJacobianEventDimsOne(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softplus()
y = 2 * rng.rand(2, 10)
ildj_before = self._softplus_ildj_before_reduction(y)
@@ -97,25 +97,25 @@ class SoftplusBijectorTest(test.TestCase):
y, event_ndims=1).eval())
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softplus()
assert_scalar_congruency(
bijector, lower_x=-2., upper_x=2.)
def testScalarCongruencyWithPositiveHingeSoftness(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softplus(hinge_softness=1.3)
assert_scalar_congruency(
bijector, lower_x=-2., upper_x=2.)
def testScalarCongruencyWithNegativeHingeSoftness(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softplus(hinge_softness=-1.3)
assert_scalar_congruency(
bijector, lower_x=-2., upper_x=2.)
def testBijectiveAndFinite32bit(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softplus()
x = np.linspace(-20., 20., 100).astype(np.float32)
y = np.logspace(-10, 10, 100).astype(np.float32)
@@ -123,7 +123,7 @@ class SoftplusBijectorTest(test.TestCase):
bijector, x, y, event_ndims=0, rtol=1e-2, atol=1e-2)
def testBijectiveAndFiniteWithPositiveHingeSoftness32Bit(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softplus(hinge_softness=1.23)
x = np.linspace(-20., 20., 100).astype(np.float32)
y = np.logspace(-10, 10, 100).astype(np.float32)
@@ -131,7 +131,7 @@ class SoftplusBijectorTest(test.TestCase):
bijector, x, y, event_ndims=0, rtol=1e-2, atol=1e-2)
def testBijectiveAndFiniteWithNegativeHingeSoftness32Bit(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softplus(hinge_softness=-0.7)
x = np.linspace(-20., 20., 100).astype(np.float32)
y = -np.logspace(-10, 10, 100).astype(np.float32)
@@ -139,7 +139,7 @@ class SoftplusBijectorTest(test.TestCase):
bijector, x, y, event_ndims=0, rtol=1e-2, atol=1e-2)
def testBijectiveAndFinite16bit(self):
- with self.test_session():
+ with self.cached_session():
bijector = Softplus()
# softplus(-20) is zero, so we can't use such a large range as in 32bit.
x = np.linspace(-10., 20., 100).astype(np.float16)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py
index d0098c3c10..8dad80aa64 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/softsign_test.py
@@ -43,16 +43,15 @@ class SoftsignBijectorTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testBijectorBounds(self):
bijector = Softsign(validate_args=True)
- with self.test_session():
- with self.assertRaisesOpError("greater than -1"):
- bijector.inverse(-3.).eval()
- with self.assertRaisesOpError("greater than -1"):
- bijector.inverse_log_det_jacobian(-3., event_ndims=0).eval()
-
- with self.assertRaisesOpError("less than 1"):
- bijector.inverse(3.).eval()
- with self.assertRaisesOpError("less than 1"):
- bijector.inverse_log_det_jacobian(3., event_ndims=0).eval()
+ with self.assertRaisesOpError("greater than -1"):
+ self.evaluate(bijector.inverse(-3.))
+ with self.assertRaisesOpError("greater than -1"):
+ self.evaluate(bijector.inverse_log_det_jacobian(-3., event_ndims=0))
+
+ with self.assertRaisesOpError("less than 1"):
+ self.evaluate(bijector.inverse(3.))
+ with self.assertRaisesOpError("less than 1"):
+ self.evaluate(bijector.inverse_log_det_jacobian(3., event_ndims=0))
@test_util.run_in_graph_and_eager_modes
def testBijectorForwardInverse(self):
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py
index 30c7a738c3..e5550cc830 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/square_test.py
@@ -29,7 +29,7 @@ class SquareBijectorTest(test.TestCase):
"""Tests the correctness of the Y = X ** 2 transformation."""
def testBijectorScalar(self):
- with self.test_session():
+ with self.cached_session():
bijector = bijectors.Square(validate_args=True)
self.assertEqual("square", bijector.name)
x = [[[1., 5],
@@ -50,7 +50,7 @@ class SquareBijectorTest(test.TestCase):
rtol=1e-7)
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
bijector = bijectors.Square(validate_args=True)
assert_scalar_congruency(bijector, lower_x=1e-3, upper_x=1.5, rtol=0.05)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/weibull_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/weibull_test.py
index f57adcda89..424eb58fa0 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/weibull_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/weibull_test.py
@@ -31,7 +31,7 @@ class WeibullBijectorTest(test.TestCase):
"""Tests correctness of the weibull bijector."""
def testBijector(self):
- with self.test_session():
+ with self.cached_session():
scale = 5.
concentration = 0.3
bijector = Weibull(
@@ -54,13 +54,13 @@ class WeibullBijectorTest(test.TestCase):
atol=0.)
def testScalarCongruency(self):
- with self.test_session():
+ with self.cached_session():
assert_scalar_congruency(
Weibull(scale=20., concentration=0.3),
lower_x=1., upper_x=100., rtol=0.02)
def testBijectiveAndFinite(self):
- with self.test_session():
+ with self.cached_session():
bijector = Weibull(
scale=20., concentration=2., validate_args=True)
x = np.linspace(1., 8., num=10).astype(np.float32)
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py
index f7b2efa7bc..05f5d30666 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py
@@ -542,9 +542,9 @@ class PadDynamicTest(_PadTest, test.TestCase):
return False
+@test_util.run_all_in_graph_and_eager_modes
class TestMoveDimension(test.TestCase):
- @test_util.run_in_graph_and_eager_modes
def test_move_dimension_static_shape(self):
x = random_ops.random_normal(shape=[200, 30, 4, 1, 6])
@@ -561,7 +561,6 @@ class TestMoveDimension(test.TestCase):
x_perm = distribution_util.move_dimension(x, 4, 2)
self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 6, 4, 1])
- @test_util.run_in_graph_and_eager_modes
def test_move_dimension_dynamic_shape(self):
x_ = random_ops.random_normal(shape=[200, 30, 4, 1, 6])
diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD
index fa3f1bb7ad..84517b57c7 100644
--- a/tensorflow/contrib/eager/python/BUILD
+++ b/tensorflow/contrib/eager/python/BUILD
@@ -14,6 +14,7 @@ py_library(
":datasets",
":metrics",
":network",
+ ":remote",
":saver",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
@@ -223,11 +224,24 @@ py_test(
],
)
+py_library(
+ name = "remote",
+ srcs = ["remote.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:platform",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
py_test(
name = "remote_test",
srcs = ["remote_test.py"],
srcs_version = "PY2AND3",
deps = [
+ ":remote",
"//tensorflow/contrib/eager/python:tfe",
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
diff --git a/tensorflow/contrib/eager/python/evaluator_test.py b/tensorflow/contrib/eager/python/evaluator_test.py
index 7d2274db9b..48d093e075 100644
--- a/tensorflow/contrib/eager/python/evaluator_test.py
+++ b/tensorflow/contrib/eager/python/evaluator_test.py
@@ -117,7 +117,7 @@ class EvaluatorTest(test.TestCase):
self.assertEqual(6.0, results["mean"].numpy())
def testDatasetGraph(self):
- with context.graph_mode(), ops.Graph().as_default(), self.test_session():
+ with context.graph_mode(), ops.Graph().as_default(), self.cached_session():
e = SimpleEvaluator(IdentityModel())
ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0])
init_op, call_op, results_op = e.evaluate_on_dataset(ds)
@@ -126,7 +126,7 @@ class EvaluatorTest(test.TestCase):
self.assertEqual(6.0, results["mean"])
def testWriteSummariesGraph(self):
- with context.graph_mode(), ops.Graph().as_default(), self.test_session():
+ with context.graph_mode(), ops.Graph().as_default(), self.cached_session():
e = SimpleEvaluator(IdentityModel())
ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0])
training_util.get_or_create_global_step()
diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
index 315d7a4893..529c99b37c 100644
--- a/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
+++ b/tensorflow/contrib/eager/python/examples/generative_examples/image_captioning_with_attention.ipynb
@@ -66,7 +66,7 @@
"\n",
"[Image Source](https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg), License: Public Domain\n",
"\n",
- "Our goal is generate a caption, such as \"a surfer riding on a wave\". Here, we'll use an attention based model. This enables us to see which parts of the image the model focuses on as it generates a caption.\n",
+ "Our goal is to generate a caption, such as \"a surfer riding on a wave\". Here, we'll use an attention-based model. This enables us to see which parts of the image the model focuses on as it generates a caption.\n",
"\n",
"![Prediction](https://tensorflow.org/images/imcap_prediction.png)\n",
"\n",
@@ -128,7 +128,7 @@
"source": [
"## Download and prepare the MS-COCO dataset\n",
"\n",
- "We will use the [MS-COCO dataset](http://cocodataset.org/#home) to train our model. This dataset contains >82,000 images, each of which has been annotated with at least 5 different captions. The code code below will download and extract the dataset automatically. \n",
+ "We will use the [MS-COCO dataset](http://cocodataset.org/#home) to train our model. This dataset contains >82,000 images, each of which has been annotated with at least 5 different captions. The code below will download and extract the dataset automatically. \n",
"\n",
"**Caution: large download ahead**. We'll use the training set, it's a 13GB file."
]
diff --git a/tensorflow/contrib/eager/python/examples/notebooks/README.md b/tensorflow/contrib/eager/python/examples/notebooks/README.md
index 0d5ed84894..2778b228e9 100644
--- a/tensorflow/contrib/eager/python/examples/notebooks/README.md
+++ b/tensorflow/contrib/eager/python/examples/notebooks/README.md
@@ -1,11 +1,3 @@
-## Research and experimentation
-
-Eager execution provides an imperative, define-by-run interface for advanced
-operations. Write custom layers, forward passes, and training loops with auto
-differentiation. Start with these notebooks, then read the
-[eager execution guide](https://www.tensorflow.org/guide/eager).
-
-1. [Eager execution basics](./eager_basics.ipynb)
-2. [Automatic differentiation and gradient tapes](./automatic_differentiation.ipynb)
-3. [Custom training: basics](./custom_training.ipynb)
-4. [Custom layers](./custom_layers.ipynb)
+The notebooks have been moved to the
+[tensorflow/docs](https://github.com/tensorflow/docs/tree/master/site/en/tutorials/eager)
+repository.
diff --git a/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb
index 51b7ffc4de..8fae622e12 100644
--- a/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb
+++ b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb
@@ -15,12 +15,7 @@
"execution_count": 0,
"metadata": {
"cellView": "form",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
+ "colab": {},
"colab_type": "code",
"id": "GCCk8_dHpuNf"
},
@@ -53,308 +48,35 @@
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
- "id": "idv0bPeCp325"
- },
- "source": [
- "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n",
- "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb\"\u003e\n",
- " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
- "\u003c/td\u003e\u003ctd\u003e\n",
- "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text",
- "id": "vDJ4XzMqodTy"
- },
- "source": [
- "In the previous tutorial we introduced `Tensor`s and operations on them. In this tutorial we will cover [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation), a key technique for optimizing machine learning models."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text",
- "id": "GQJysDM__Qb0"
- },
- "source": [
- "## Setup\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "colab_type": "code",
- "id": "OiMPZStlibBv"
- },
- "outputs": [],
- "source": [
- "import tensorflow as tf\n",
- "tf.enable_eager_execution()\n",
- "\n",
- "tfe = tf.contrib.eager # Shorthand for some symbols"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text",
- "id": "1CLWJl0QliB0"
- },
- "source": [
- "## Derivatives of a function\n",
- "\n",
- "TensorFlow provides APIs for automatic differentiation - computing the derivative of a function. The way that more closely mimics the math is to encapsulate the computation in a Python function, say `f`, and use `tfe.gradients_function` to create a function that computes the derivatives of `f` with respect to its arguments. If you're familiar with [autograd](https://github.com/HIPS/autograd) for differentiating numpy functions, this will be familiar. For example: "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "colab_type": "code",
- "id": "9FViq92UX7P8"
- },
- "outputs": [],
- "source": [
- "from math import pi\n",
- "\n",
- "def f(x):\n",
- " return tf.square(tf.sin(x))\n",
- "\n",
- "assert f(pi/2).numpy() == 1.0\n",
- "\n",
- "\n",
- "# grad_f will return a list of derivatives of f\n",
- "# with respect to its arguments. Since f() has a single argument,\n",
- "# grad_f will return a list with a single element.\n",
- "grad_f = tfe.gradients_function(f)\n",
- "assert tf.abs(grad_f(pi/2)[0]).numpy() \u003c 1e-7"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text",
- "id": "v9fPs8RyopCf"
- },
- "source": [
- "### Higher-order gradients\n",
- "\n",
- "The same API can be used to differentiate as many times as you like:\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "colab_type": "code",
- "id": "3D0ZvnGYo0rW"
- },
- "outputs": [],
- "source": [
- "def f(x):\n",
- " return tf.square(tf.sin(x))\n",
- "\n",
- "def grad(f):\n",
- " return lambda x: tfe.gradients_function(f)(x)[0]\n",
- "\n",
- "x = tf.lin_space(-2*pi, 2*pi, 100) # 100 points between -2π and +2π\n",
- "\n",
- "import matplotlib.pyplot as plt\n",
- "\n",
- "plt.plot(x, f(x), label=\"f\")\n",
- "plt.plot(x, grad(f)(x), label=\"first derivative\")\n",
- "plt.plot(x, grad(grad(f))(x), label=\"second derivative\")\n",
- "plt.plot(x, grad(grad(grad(f)))(x), label=\"third derivative\")\n",
- "plt.legend()\n",
- "plt.show()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text",
- "id": "-39gouo7mtgu"
- },
- "source": [
- "## Gradient tapes\n",
- "\n",
- "Every differentiable TensorFlow operation has an associated gradient function. For example, the gradient function of `tf.square(x)` would be a function that returns `2.0 * x`. To compute the gradient of a user-defined function (like `f(x)` in the example above), TensorFlow first \"records\" all the operations applied to compute the output of the function. We call this record a \"tape\". It then uses that tape and the gradients functions associated with each primitive operation to compute the gradients of the user-defined function using [reverse mode differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation).\n",
- "\n",
- "Since operations are recorded as they are executed, Python control flow (using `if`s and `while`s for example) is naturally handled:\n",
- "\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "colab_type": "code",
- "id": "MH0UfjympWf7"
- },
- "outputs": [],
- "source": [
- "def f(x, y):\n",
- " output = 1\n",
- " # Must use range(int(y)) instead of range(y) in Python 3 when\n",
- " # using TensorFlow 1.10 and earlier. Can use range(y) in 1.11+\n",
- " for i in range(int(y)):\n",
- " output = tf.multiply(output, x)\n",
- " return output\n",
- "\n",
- "def g(x, y):\n",
- " # Return the gradient of `f` with respect to it's first parameter\n",
- " return tfe.gradients_function(f)(x, y)[0]\n",
- "\n",
- "assert f(3.0, 2).numpy() == 9.0 # f(x, 2) is essentially x * x\n",
- "assert g(3.0, 2).numpy() == 6.0 # And its gradient will be 2 * x\n",
- "assert f(4.0, 3).numpy() == 64.0 # f(x, 3) is essentially x * x * x\n",
- "assert g(4.0, 3).numpy() == 48.0 # And its gradient will be 3 * x * x"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text",
- "id": "aNmR5-jhpX2t"
- },
- "source": [
- "At times it may be inconvenient to encapsulate computation of interest into a function. For example, if you want the gradient of the output with respect to intermediate values computed in the function. In such cases, the slightly more verbose but explicit [tf.GradientTape](https://www.tensorflow.org/api_docs/python/tf/GradientTape) context is useful. All computation inside the context of a `tf.GradientTape` is \"recorded\".\n",
- "\n",
- "For example:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "colab_type": "code",
- "id": "bAFeIE8EuVIq"
+ "id": "clNGnJ3u8Rl6"
},
- "outputs": [],
"source": [
- "x = tf.ones((2, 2))\n",
- " \n",
- "# TODO(b/78880779): Remove the 'persistent=True' argument and use\n",
- "# a single t.gradient() call when the bug is resolved.\n",
- "with tf.GradientTape(persistent=True) as t:\n",
- " # TODO(ashankar): Explain with \"watch\" argument better?\n",
- " t.watch(x)\n",
- " y = tf.reduce_sum(x)\n",
- " z = tf.multiply(y, y)\n",
- "\n",
- "# Use the same tape to compute the derivative of z with respect to the\n",
- "# intermediate value y.\n",
- "dz_dy = t.gradient(z, y)\n",
- "assert dz_dy.numpy() == 8.0\n",
- "\n",
- "# Derivative of z with respect to the original input tensor x\n",
- "dz_dx = t.gradient(z, x)\n",
- "for i in [0, 1]:\n",
- " for j in [0, 1]:\n",
- " assert dz_dx[i][j].numpy() == 8.0"
+ "This file has moved."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
- "id": "DK05KXrAAld3"
- },
- "source": [
- "### Higher-order gradients\n",
- "\n",
- "Operations inside of the `GradientTape` context manager are recorded for automatic differentiation. If gradients are computed in that context, then the gradient computation is recorded as well. As a result, the exact same API works for higher-order gradients as well. For example:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "colab_type": "code",
- "id": "cPQgthZ7ugRJ"
- },
- "outputs": [],
- "source": [
- "# TODO(ashankar): Should we use the persistent tape here instead? Follow up on Tom and Alex's discussion\n",
- "\n",
- "x = tf.constant(1.0) # Convert the Python 1.0 to a Tensor object\n",
- "\n",
- "with tf.GradientTape() as t:\n",
- " with tf.GradientTape() as t2:\n",
- " t2.watch(x)\n",
- " y = x * x * x\n",
- " # Compute the gradient inside the 't' context manager\n",
- " # which means the gradient computation is differentiable as well.\n",
- " dy_dx = t2.gradient(y, x)\n",
- "d2y_dx2 = t.gradient(dy_dx, x)\n",
- "\n",
- "assert dy_dx.numpy() == 3.0\n",
- "assert d2y_dx2.numpy() == 6.0"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text",
- "id": "4U1KKzUpNl58"
+ "id": "idv0bPeCp325"
},
"source": [
- "## Next Steps\n",
- "\n",
- "In this tutorial we covered gradient computation in TensorFlow. With that we have enough of the primitives required to build an train neural networks, which we will cover in the [next tutorial](https://github.com/tensorflow/models/tree/master/official/contrib/eager/python/examples/notebooks/3_neural_networks.ipynb)."
+ "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n",
+ "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/eager/automatic_differentiation.ipynb\"\u003e\n",
+ " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
+ "\u003c/td\u003e\u003ctd\u003e\n",
+ "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/site/en/tutorials/eager/automatic_differentiation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
- "default_view": {},
"name": "automatic_differentiation.ipynb",
"private_outputs": true,
"provenance": [],
"toc_visible": true,
- "version": "0.3.2",
- "views": {}
+ "version": "0.3.2"
},
"kernelspec": {
"display_name": "Python 3",
diff --git a/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb
index a0bbbb6123..d89774c45e 100644
--- a/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb
+++ b/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb
@@ -1,46 +1,25 @@
{
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "colab": {
- "name": "custom_layers.ipynb",
- "version": "0.3.2",
- "views": {},
- "default_view": {},
- "provenance": [],
- "private_outputs": true,
- "collapsed_sections": [],
- "toc_visible": true
- },
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
- }
- },
"cells": [
{
+ "cell_type": "markdown",
"metadata": {
- "id": "tDnwEv8FtJm7",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "tDnwEv8FtJm7"
},
- "cell_type": "markdown",
"source": [
"##### Copyright 2018 The TensorFlow Authors."
]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "JlknJBWQtKkI",
+ "cellView": "form",
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "cellView": "form"
+ "id": "JlknJBWQtKkI"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
@@ -53,347 +32,57 @@
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "60RdWsg1tETW",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "# Custom layers"
- ]
- },
- {
- "metadata": {
- "id": "BcJg7Enms86w",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "<table class=\"tfo-notebook-buttons\" align=\"left\"><td>\n",
- "<a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb\">\n",
- " <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
- "</td><td>\n",
- "<a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_layers.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a></td></table>"
- ]
- },
- {
- "metadata": {
- "id": "UEu3q4jmpKVT",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "We recommend using `tf.keras` as a high-level API for building neural networks. That said, most TensorFlow APIs are usable with eager execution.\n"
]
},
{
- "metadata": {
- "id": "pwX7Fii1rwsJ",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "import tensorflow as tf\n",
- "tfe = tf.contrib.eager\n",
- "\n",
- "tf.enable_eager_execution()"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "zSFfVVjkrrsI",
- "colab_type": "text"
- },
"cell_type": "markdown",
- "source": [
- "## Layers: common sets of useful operations\n",
- "\n",
- "Most of the time when writing code for machine learning models you want to operate at a higher level of abstraction than individual operations and manipulation of individual variables.\n",
- "\n",
- "Many machine learning models are expressible as the composition and stacking of relatively simple layers, and TensorFlow provides both a set of many common layers as a well as easy ways for you to write your own application-specific layers either from scratch or as the composition of existing layers.\n",
- "\n",
- "TensorFlow includes the full [Keras](https://keras.io) API in the tf.keras package, and the Keras layers are very useful when building your own models.\n"
- ]
- },
- {
"metadata": {
- "id": "8PyXlPl-4TzQ",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "colab_type": "text",
+ "id": "60RdWsg1tETW"
},
- "cell_type": "code",
- "source": [
- "# In the tf.keras.layers package, layers are objects. To construct a layer,\n",
- "# simply construct the object. Most layers take as a first argument the number\n",
- "# of output dimensions / channels.\n",
- "layer = tf.keras.layers.Dense(100)\n",
- "# The number of input dimensions is often unnecessary, as it can be inferred\n",
- "# the first time the layer is used, but it can be provided if you want to \n",
- "# specify it manually, which is useful in some complex models.\n",
- "layer = tf.keras.layers.Dense(10, input_shape=(None, 5))"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "Fn69xxPO5Psr",
- "colab_type": "text"
- },
- "cell_type": "markdown",
"source": [
- "The full list of pre-existing layers can be seen in [the documentation](https://www.tensorflow.org/api_docs/python/tf/keras/layers). It includes Dense (a fully-connected layer),\n",
- "Conv2D, LSTM, BatchNormalization, Dropout, and many others."
+ "# Custom layers"
]
},
{
- "metadata": {
- "id": "E3XKNknP5Mhb",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "# To use a layer, simply call it.\n",
- "layer(tf.zeros([10, 5]))"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "Wt_Nsv-L5t2s",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "# Layers have many useful methods. For example, you can inspect all variables\n",
- "# in a layer by calling layer.variables. In this case a fully-connected layer\n",
- "# will have variables for weights and biases.\n",
- "layer.variables"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "6ilvKjz8_4MQ",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "# The variables are also accessible through nice accessors\n",
- "layer.kernel, layer.bias"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "O0kDbE54-5VS",
- "colab_type": "text"
- },
"cell_type": "markdown",
- "source": [
- "## Implementing custom layers\n",
- "The best way to implement your own layer is extending the tf.keras.Layer class and implementing:\n",
- " * `__init__` , where you can do all input-independent initialization\n",
- " * `build`, where you know the shapes of the input tensors and can do the rest of the initialization\n",
- " * `call`, where you do the forward computation\n",
- "\n",
- "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`. However, the advantage of creating them in `build` is that it enables late variable creation based on the shape of the inputs the layer will operate on. On the other hand, creating variables in `__init__` would mean that shapes required to create the variables will need to be explicitly specified."
- ]
- },
- {
- "metadata": {
- "id": "5Byl3n1k5kIy",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "class MyDenseLayer(tf.keras.layers.Layer):\n",
- " def __init__(self, num_outputs):\n",
- " super(MyDenseLayer, self).__init__()\n",
- " self.num_outputs = num_outputs\n",
- " \n",
- " def build(self, input_shape):\n",
- " self.kernel = self.add_variable(\"kernel\", \n",
- " shape=[input_shape[-1].value, \n",
- " self.num_outputs])\n",
- " \n",
- " def call(self, input):\n",
- " return tf.matmul(input, self.kernel)\n",
- " \n",
- "layer = MyDenseLayer(10)\n",
- "print(layer(tf.zeros([10, 5])))\n",
- "print(layer.variables)"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
"metadata": {
- "id": "tk8E2vY0-z4Z",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "9sFn_RV_8zM-"
},
- "cell_type": "markdown",
"source": [
- "Note that you don't have to wait until `build` is called to create your variables, you can also create them in `__init__`.\n",
- "\n",
- "Overall code is easier to read and maintain if it uses standard layers whenever possible, as other readers will be familiar with the behavior of standard layers. If you want to use a layer which is not present in tf.keras.layers or tf.contrib.layers, consider filing a [github issue](http://github.com/tensorflow/tensorflow/issues/new) or, even better, sending us a pull request!"
+ "This file has moved."
]
},
{
- "metadata": {
- "id": "Qhg4KlbKrs3G",
- "colab_type": "text"
- },
"cell_type": "markdown",
- "source": [
- "## Models: composing layers\n",
- "\n",
- "Many interesting layer-like things in machine learning models are implemented by composing existing layers. For example, each residual block in a resnet is a composition of convolutions, batch normalizations, and a shortcut.\n",
- "\n",
- "The main class used when creating a layer-like thing which contains other layers is tf.keras.Model. Implementing one is done by inheriting from tf.keras.Model."
- ]
- },
- {
- "metadata": {
- "id": "N30DTXiRASlb",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "class ResnetIdentityBlock(tf.keras.Model):\n",
- " def __init__(self, kernel_size, filters):\n",
- " super(ResnetIdentityBlock, self).__init__(name='')\n",
- " filters1, filters2, filters3 = filters\n",
- "\n",
- " self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))\n",
- " self.bn2a = tf.keras.layers.BatchNormalization()\n",
- "\n",
- " self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')\n",
- " self.bn2b = tf.keras.layers.BatchNormalization()\n",
- "\n",
- " self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))\n",
- " self.bn2c = tf.keras.layers.BatchNormalization()\n",
- "\n",
- " def call(self, input_tensor, training=False):\n",
- " x = self.conv2a(input_tensor)\n",
- " x = self.bn2a(x, training=training)\n",
- " x = tf.nn.relu(x)\n",
- "\n",
- " x = self.conv2b(x)\n",
- " x = self.bn2b(x, training=training)\n",
- " x = tf.nn.relu(x)\n",
- "\n",
- " x = self.conv2c(x)\n",
- " x = self.bn2c(x, training=training)\n",
- "\n",
- " x += input_tensor\n",
- " return tf.nn.relu(x)\n",
- "\n",
- " \n",
- "block = ResnetIdentityBlock(1, [1, 2, 3])\n",
- "print(block(tf.zeros([1, 2, 3, 3])))\n",
- "print([x.name for x in block.variables])"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
"metadata": {
- "id": "wYfucVw65PMj",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "BcJg7Enms86w"
},
- "cell_type": "markdown",
"source": [
- "Much of the time, however, models which compose many layers simply call one layer after the other. This can be done in very little code using tf.keras.Sequential"
+ "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n",
+ "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/eager/custom_layers.ipynb\"\u003e\n",
+ " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
+ "\u003c/td\u003e\u003ctd\u003e\n",
+ "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/custom_layers.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e"
]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "collapsed_sections": [],
+ "name": "custom_layers.ipynb",
+ "private_outputs": true,
+ "provenance": [],
+ "toc_visible": true,
+ "version": "0.3.2"
},
- {
- "metadata": {
- "id": "L9frk7Ur4uvJ",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- " my_seq = tf.keras.Sequential([tf.keras.layers.Conv2D(1, (1, 1)),\n",
- " tf.keras.layers.BatchNormalization(),\n",
- " tf.keras.layers.Conv2D(2, 1, \n",
- " padding='same'),\n",
- " tf.keras.layers.BatchNormalization(),\n",
- " tf.keras.layers.Conv2D(3, (1, 1)),\n",
- " tf.keras.layers.BatchNormalization()])\n",
- "my_seq(tf.zeros([1, 2, 3, 3]))"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "c5YwYcnuK-wc",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "# Next steps\n",
- "\n",
- "Now you can go back to the previous notebook and adapt the linear regression example to use layers and models to be better structured."
- ]
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
}
- ]
-} \ No newline at end of file
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb
index 5f1b48fa0d..86dca0b423 100644
--- a/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb
+++ b/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb
@@ -1,46 +1,25 @@
{
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "colab": {
- "name": "Custom training: basics",
- "version": "0.3.2",
- "views": {},
- "default_view": {},
- "provenance": [],
- "private_outputs": true,
- "collapsed_sections": [],
- "toc_visible": true
- },
- "kernelspec": {
- "name": "python3",
- "display_name": "Python 3"
- }
- },
"cells": [
{
+ "cell_type": "markdown",
"metadata": {
- "id": "5rmpybwysXGV",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "5rmpybwysXGV"
},
- "cell_type": "markdown",
"source": [
"##### Copyright 2018 The TensorFlow Authors."
]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "m8y3rGtQsYP2",
+ "cellView": "form",
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "cellView": "form"
+ "id": "m8y3rGtQsYP2"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
@@ -53,425 +32,57 @@
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "hrXv0rU9sIma",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "# Custom training: basics"
- ]
- },
- {
- "metadata": {
- "id": "7S0BwJ_8sLu7",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "<table class=\"tfo-notebook-buttons\" align=\"left\"><td>\n",
- "<a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb\">\n",
- " <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
- "</td><td>\n",
- "<a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/custom_training.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a></td></table>"
- ]
- },
- {
- "metadata": {
- "id": "k2o3TTG4TFpt",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "In the previous tutorial we covered the TensorFlow APIs for automatic differentiation, a basic building block for machine learning.\n",
- "In this tutorial we will use the TensorFlow primitives introduced in the prior tutorials to do some simple machine learning.\n",
- "\n",
- "TensorFlow also includes a higher-level neural networks API (`tf.keras`) which provides useful abstractions to reduce boilerplate. We strongly recommend those higher level APIs for people working with neural networks. However, in this short tutorial we cover neural network training from first principles to establish a strong foundation."
- ]
- },
- {
- "metadata": {
- "id": "3LXMVuV0VhDr",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "## Setup"
- ]
- },
- {
- "metadata": {
- "id": "PJ64L90aVir3",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "import tensorflow as tf\n",
- "\n",
- "tf.enable_eager_execution()"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "eMAWbDJFVmMk",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "## Variables\n",
- "\n",
- "Tensors in TensorFlow are immutable stateless objects. Machine learning models, however, need to have changing state: as your model trains, the same code to compute predictions should behave differently over time (hopefully with a lower loss!). To represent this state which needs to change over the course of your computation, you can choose to rely on the fact that Python is a stateful programming language:\n"
- ]
- },
- {
- "metadata": {
- "id": "VkJwtLS_Jbn8",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "# Using python state\n",
- "x = tf.zeros([10, 10])\n",
- "x += 2 # This is equivalent to x = x + 2, which does not mutate the original\n",
- " # value of x\n",
- "print(x)"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "wfneTXy7JcUz",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "TensorFlow, however, has stateful operations built in, and these are often more pleasant to use than low-level Python representations of your state. To represent weights in a model, for example, it's often convenient and efficient to use TensorFlow variables.\n",
- "\n",
- "A Variable is an object which stores a value and, when used in a TensorFlow computation, will implicitly read from this stored value. There are operations (`tf.assign_sub`, `tf.scatter_update`, etc) which manipulate the value stored in a TensorFlow variable."
]
},
{
- "metadata": {
- "id": "itxmrMil6DQi",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "v = tf.Variable(1.0)\n",
- "assert v.numpy() == 1.0\n",
- "\n",
- "# Re-assign the value\n",
- "v.assign(3.0)\n",
- "assert v.numpy() == 3.0\n",
- "\n",
- "# Use `v` in a TensorFlow operation like tf.square() and reassign\n",
- "v.assign(tf.square(v))\n",
- "assert v.numpy() == 9.0"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "-paSaeq1JzwC",
- "colab_type": "text"
- },
"cell_type": "markdown",
- "source": [
- "Computations using Variables are automatically traced when computing gradients. For Variables representing embeddings TensorFlow will do sparse updates by default, which are more computation and memory efficient.\n",
- "\n",
- "Using Variables is also a way to quickly let a reader of your code know that this piece of state is mutable."
- ]
- },
- {
"metadata": {
- "id": "BMiFcDzE7Qu3",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "hrXv0rU9sIma"
},
- "cell_type": "markdown",
"source": [
- "## Example: Fitting a linear model\n",
- "\n",
- "Let's now put the few concepts we have so far ---`Tensor`, `GradientTape`, `Variable` --- to build and train a simple model. This typically involves a few steps:\n",
- "\n",
- "1. Define the model.\n",
- "2. Define a loss function.\n",
- "3. Obtain training data.\n",
- "4. Run through the training data and use an \"optimizer\" to adjust the variables to fit the data.\n",
- "\n",
- "In this tutorial, we'll walk through a trivial example of a simple linear model: `f(x) = x * W + b`, which has two variables - `W` and `b`. Furthermore, we'll synthesize data such that a well trained model would have `W = 3.0` and `b = 2.0`."
- ]
- },
- {
- "metadata": {
- "id": "gFzH64Jn9PIm",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "### Define the model\n",
- "\n",
- "Let's define a simple class to encapsulate the variables and the computation."
+ "# Custom training: basics"
]
},
{
- "metadata": {
- "id": "_WRu7Pze7wk8",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "class Model(object):\n",
- " def __init__(self):\n",
- " # Initialize variable to (5.0, 0.0)\n",
- " # In practice, these should be initialized to random values.\n",
- " self.W = tf.Variable(5.0)\n",
- " self.b = tf.Variable(0.0)\n",
- " \n",
- " def __call__(self, x):\n",
- " return self.W * x + self.b\n",
- " \n",
- "model = Model()\n",
- "\n",
- "assert model(3.0).numpy() == 15.0"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "xa6j_yXa-j79",
- "colab_type": "text"
- },
"cell_type": "markdown",
- "source": [
- "### Define a loss function\n",
- "\n",
- "A loss function measures how well the output of a model for a given input matches the desired output. Let's use the standard L2 loss."
- ]
- },
- {
- "metadata": {
- "id": "Y0ysUFGY924U",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "def loss(predicted_y, desired_y):\n",
- " return tf.reduce_mean(tf.square(predicted_y - desired_y))"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
"metadata": {
- "id": "qutT_fkl_CBc",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "IGPZTmwn9IT4"
},
- "cell_type": "markdown",
"source": [
- "### Obtain training data\n",
- "\n",
- "Let's synthesize the training data with some noise."
+ "This file has moved."
]
},
{
- "metadata": {
- "id": "gxPTb-kt_N5m",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "TRUE_W = 3.0\n",
- "TRUE_b = 2.0\n",
- "NUM_EXAMPLES = 1000\n",
- "\n",
- "inputs = tf.random_normal(shape=[NUM_EXAMPLES])\n",
- "noise = tf.random_normal(shape=[NUM_EXAMPLES])\n",
- "outputs = inputs * TRUE_W + TRUE_b + noise"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "-50nq-wPBsAW",
- "colab_type": "text"
- },
"cell_type": "markdown",
- "source": [
- "Before we train the model let's visualize where the model stands right now. We'll plot the model's predictions in red and the training data in blue."
- ]
- },
- {
"metadata": {
- "id": "_eb83LtrB4nt",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "colab_type": "text",
+ "id": "7S0BwJ_8sLu7"
},
- "cell_type": "code",
"source": [
- "import matplotlib.pyplot as plt\n",
- "\n",
- "plt.scatter(inputs, outputs, c='b')\n",
- "plt.scatter(inputs, model(inputs), c='r')\n",
- "plt.show()\n",
- "\n",
- "print('Current loss: '),\n",
- "print(loss(model(inputs), outputs).numpy())"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "sSDP-yeq_4jE",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "### Define a training loop\n",
- "\n",
- "We now have our network and our training data. Let's train it, i.e., use the training data to update the model's variables (`W` and `b`) so that the loss goes down using [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent). There are many variants of the gradient descent scheme that are captured in `tf.train.Optimizer` implementations. We'd highly recommend using those implementations, but in the spirit of building from first principles, in this particular example we will implement the basic math ourselves."
+ "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n",
+ "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/eager/custom_training.ipynb\"\u003e\n",
+ " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
+ "\u003c/td\u003e\u003ctd\u003e\n",
+ "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/custom_training.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e"
]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "collapsed_sections": [],
+ "name": "Custom training: basics",
+ "private_outputs": true,
+ "provenance": [],
+ "toc_visible": true,
+ "version": "0.3.2"
},
- {
- "metadata": {
- "id": "MBIACgdnA55X",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "def train(model, inputs, outputs, learning_rate):\n",
- " with tf.GradientTape() as t:\n",
- " current_loss = loss(model(inputs), outputs)\n",
- " dW, db = t.gradient(current_loss, [model.W, model.b])\n",
- " model.W.assign_sub(learning_rate * dW)\n",
- " model.b.assign_sub(learning_rate * db)"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "RwWPaJryD2aN",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "Finally, let's repeatedly run through the training data and see how `W` and `b` evolve."
- ]
- },
- {
- "metadata": {
- "id": "XdfkR223D9dW",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "model = Model()\n",
- "\n",
- "# Collect the history of W-values and b-values to plot later\n",
- "Ws, bs = [], []\n",
- "epochs = range(10)\n",
- "for epoch in epochs:\n",
- " Ws.append(model.W.numpy())\n",
- " bs.append(model.b.numpy())\n",
- " current_loss = loss(model(inputs), outputs)\n",
- "\n",
- " train(model, inputs, outputs, learning_rate=0.1)\n",
- " print('Epoch %2d: W=%1.2f b=%1.2f, loss=%2.5f' %\n",
- " (epoch, Ws[-1], bs[-1], current_loss))\n",
- "\n",
- "# Let's plot it all\n",
- "plt.plot(epochs, Ws, 'r',\n",
- " epochs, bs, 'b')\n",
- "plt.plot([TRUE_W] * len(epochs), 'r--',\n",
- " [TRUE_b] * len(epochs), 'b--')\n",
- "plt.legend(['W', 'b', 'true W', 'true_b'])\n",
- "plt.show()\n",
- " "
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "vPnIVuaSJwWz",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "## Next Steps\n",
- "\n",
- "In this tutorial we covered `Variable`s and built and trained a simple linear model using the TensorFlow primitives discussed so far.\n",
- "\n",
- "In theory, this is pretty much all you need to use TensorFlow for your machine learning research.\n",
- "In practice, particularly for neural networks, the higher level APIs like `tf.keras` will be much more convenient since it provides higher level building blocks (called \"layers\"), utilities to save and restore state, a suite of loss functions, a suite of optimization strategies etc. \n",
- "\n",
- "The [next tutorial](TODO) will cover these higher level APIs."
- ]
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
}
- ]
-} \ No newline at end of file
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb
index f1e13de5de..c6d1a56604 100644
--- a/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb
+++ b/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb
@@ -1,46 +1,25 @@
{
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "colab": {
- "name": "eager_basics.ipynb",
- "version": "0.3.2",
- "views": {},
- "default_view": {},
- "provenance": [],
- "private_outputs": true,
- "collapsed_sections": [],
- "toc_visible": true
- },
- "kernelspec": {
- "name": "python3",
- "display_name": "Python 3"
- }
- },
"cells": [
{
+ "cell_type": "markdown",
"metadata": {
- "id": "iPpI7RaYoZuE",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "iPpI7RaYoZuE"
},
- "cell_type": "markdown",
"source": [
"##### Copyright 2018 The TensorFlow Authors."
]
},
{
+ "cell_type": "code",
+ "execution_count": 0,
"metadata": {
- "id": "hro2InpHobKk",
+ "cellView": "form",
+ "colab": {},
"colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "cellView": "form"
+ "id": "hro2InpHobKk"
},
- "cell_type": "code",
+ "outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
@@ -53,439 +32,47 @@
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "U9i2Dsh-ziXr",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "# Eager execution basics"
- ]
- },
- {
- "metadata": {
- "id": "Hndw-YcxoOJK",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "<table class=\"tfo-notebook-buttons\" align=\"left\"><td>\n",
- "<a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb\">\n",
- " <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
- "</td><td>\n",
- "<a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/notebooks/eager_basics.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a></td></table>"
- ]
- },
- {
- "metadata": {
- "id": "6sILUVbHoSgH",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "This is an introductory tutorial for using TensorFlow. It will cover:\n",
- "\n",
- "* Importing required packages\n",
- "* Creating and using Tensors\n",
- "* Using GPU acceleration\n",
- "* Datasets"
- ]
- },
- {
- "metadata": {
- "id": "z1JcS5iBXMRO",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "## Import TensorFlow\n",
- "\n",
- "To get started, import the `tensorflow` module and enable eager execution.\n",
- "Eager execution enables a more interactive frontend to TensorFlow, the details of which we will discuss much later."
- ]
- },
- {
- "metadata": {
- "id": "RlIWhyeLoYnG",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "cellView": "code"
- },
- "cell_type": "code",
- "source": [
- "import tensorflow as tf\n",
- "\n",
- "tf.enable_eager_execution()"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "H9UySOPLXdaw",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "## Tensors\n",
- "\n",
- "A Tensor is a multi-dimensional array. Similar to NumPy `ndarray` objects, `Tensor` objects have a data type and a shape. Additionally, Tensors can reside in accelerator (like GPU) memory. TensorFlow offers a rich library of operations ([tf.add](https://www.tensorflow.org/api_docs/python/tf/add), [tf.matmul](https://www.tensorflow.org/api_docs/python/tf/matmul), [tf.linalg.inv](https://www.tensorflow.org/api_docs/python/tf/linalg/inv) etc.) that consume and produce Tensors. These operations automatically convert native Python types. For example:\n"
- ]
- },
- {
- "metadata": {
- "id": "ngUe237Wt48W",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "cellView": "code"
- },
- "cell_type": "code",
- "source": [
- "print(tf.add(1, 2))\n",
- "print(tf.add([1, 2], [3, 4]))\n",
- "print(tf.square(5))\n",
- "print(tf.reduce_sum([1, 2, 3]))\n",
- "print(tf.encode_base64(\"hello world\"))\n",
- "\n",
- "# Operator overloading is also supported\n",
- "print(tf.square(2) + tf.square(3))"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "IDY4WsYRhP81",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "Each Tensor has a shape and a datatype"
- ]
- },
- {
- "metadata": {
- "id": "srYWH1MdJNG7",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "x = tf.matmul([[1]], [[2, 3]])\n",
- "print(x.shape)\n",
- "print(x.dtype)"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "eBPw8e8vrsom",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "The most obvious differences between NumPy arrays and TensorFlow Tensors are:\n",
- "\n",
- "1. Tensors can be backed by accelerator memory (like GPU, TPU).\n",
- "2. Tensors are immutable."
- ]
- },
- {
- "metadata": {
- "id": "Dwi1tdW3JBw6",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "### NumPy Compatibility\n",
- "\n",
- "Conversion between TensorFlow Tensors and NumPy ndarrays is quite simple as:\n",
- "* TensorFlow operations automatically convert NumPy ndarrays to Tensors.\n",
- "* NumPy operations automatically convert Tensors to NumPy ndarrays.\n",
- "\n",
- "Tensors can be explicitly converted to NumPy ndarrays by invoking the `.numpy()` method on them.\n",
- "These conversions are typically cheap as the array and Tensor share the underlying memory representation if possible. However, sharing the underlying representation isn't always possible since the Tensor may be hosted in GPU memory while NumPy arrays are always backed by host memory, and the conversion will thus involve a copy from GPU to host memory."
- ]
- },
- {
- "metadata": {
- "id": "lCUWzso6mbqR",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "import numpy as np\n",
- "\n",
- "ndarray = np.ones([3, 3])\n",
- "\n",
- "print(\"TensorFlow operations convert numpy arrays to Tensors automatically\")\n",
- "tensor = tf.multiply(ndarray, 42)\n",
- "print(tensor)\n",
- "\n",
- "\n",
- "print(\"And NumPy operations convert Tensors to numpy arrays automatically\")\n",
- "print(np.add(tensor, 1))\n",
- "\n",
- "print(\"The .numpy() method explicitly converts a Tensor to a numpy array\")\n",
- "print(tensor.numpy())"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "PBNP8yTRfu_X",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "## GPU acceleration\n",
- "\n",
- "Many TensorFlow operations can be accelerated by using the GPU for computation. Without any annotations, TensorFlow automatically decides whether to use the GPU or CPU for an operation (and copies the tensor between CPU and GPU memory if necessary). Tensors produced by an operation are typically backed by the memory of the device on which the operation executed. For example:"
- ]
- },
- {
- "metadata": {
- "id": "3Twf_Rw-gQFM",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- },
- "cellView": "code"
- },
- "cell_type": "code",
- "source": [
- "x = tf.random_uniform([3, 3])\n",
- "\n",
- "print(\"Is there a GPU available: \"),\n",
- "print(tf.test.is_gpu_available())\n",
- "\n",
- "print(\"Is the Tensor on GPU #0: \"),\n",
- "print(x.device.endswith('GPU:0'))"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "vpgYzgVXW2Ud",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "### Device Names\n",
- "\n",
- "The `Tensor.device` property provides a fully qualified string name of the device hosting the contents of the Tensor. This name encodes a bunch of details, such as an identifier of the network address of the host on which this program is executing and the device within that host. This is required for distributed execution of TensorFlow programs, but we'll skip that for now. The string will end with `GPU:<N>` if the tensor is placed on the `N`-th tensor on the host."
]
},
{
- "metadata": {
- "id": "ZWZQCimzuqyP",
- "colab_type": "text"
- },
"cell_type": "markdown",
- "source": [
- "\n",
- "\n",
- "### Explicit Device Placement\n",
- "\n",
- "The term \"placement\" in TensorFlow refers to how individual operations are assigned (placed on) a device for execution. As mentioned above, when there is no explicit guidance provided, TensorFlow automatically decides which device to execute an operation, and copies Tensors to that device if needed. However, TensorFlow operations can be explicitly placed on specific devices using the `tf.device` context manager. For example:"
- ]
- },
- {
- "metadata": {
- "id": "RjkNZTuauy-Q",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "def time_matmul(x):\n",
- " %timeit tf.matmul(x, x)\n",
- "\n",
- "# Force execution on CPU\n",
- "print(\"On CPU:\")\n",
- "with tf.device(\"CPU:0\"):\n",
- " x = tf.random_uniform([1000, 1000])\n",
- " assert x.device.endswith(\"CPU:0\")\n",
- " time_matmul(x)\n",
- "\n",
- "# Force execution on GPU #0 if available\n",
- "if tf.test.is_gpu_available():\n",
- " with tf.device(\"GPU:0\"): # Or GPU:1 for the 2nd GPU, GPU:2 for the 3rd etc.\n",
- " x = tf.random_uniform([1000, 1000])\n",
- " assert x.device.endswith(\"GPU:0\")\n",
- " time_matmul(x)"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
"metadata": {
- "id": "o1K4dlhhHtQj",
- "colab_type": "text"
+ "colab_type": "text",
+ "id": "U9i2Dsh-ziXr"
},
- "cell_type": "markdown",
"source": [
- "## Datasets\n",
- "\n",
- "This section demonstrates the use of the [`tf.data.Dataset` API](https://www.tensorflow.org/guide/datasets) to build pipelines to feed data to your model. It covers:\n",
- "\n",
- "* Creating a `Dataset`.\n",
- "* Iteration over a `Dataset` with eager execution enabled.\n",
- "\n",
- "We recommend using the `Dataset`s API for building performant, complex input pipelines from simple, re-usable pieces that will feed your model's training or evaluation loops.\n",
- "\n",
- "If you're familiar with TensorFlow graphs, the API for constructing the `Dataset` object remains exactly the same when eager execution is enabled, but the process of iterating over elements of the dataset is slightly simpler.\n",
- "You can use Python iteration over the `tf.data.Dataset` object and do not need to explicitly create an `tf.data.Iterator` object.\n",
- "As a result, the discussion on iterators in the [TensorFlow Guide](https://www.tensorflow.org/guide/datasets) is not relevant when eager execution is enabled."
- ]
- },
- {
- "metadata": {
- "id": "zI0fmOynH-Ne",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "### Create a source `Dataset`\n",
- "\n",
- "Create a _source_ dataset using one of the factory functions like [`Dataset.from_tensors`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensors), [`Dataset.from_tensor_slices`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_tensor_slices) or using objects that read from files like [`TextLineDataset`](https://www.tensorflow.org/api_docs/python/tf/data/TextLineDataset) or [`TFRecordDataset`](https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset). See the [TensorFlow Guide](https://www.tensorflow.org/guide/datasets#reading_input_data) for more information."
+ "# Eager execution basics"
]
},
{
- "metadata": {
- "id": "F04fVOHQIBiG",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "ds_tensors = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])\n",
- "\n",
- "# Create a CSV file\n",
- "import tempfile\n",
- "_, filename = tempfile.mkstemp()\n",
- "\n",
- "with open(filename, 'w') as f:\n",
- " f.write(\"\"\"Line 1\n",
- "Line 2\n",
- "Line 3\n",
- " \"\"\")\n",
- "\n",
- "ds_file = tf.data.TextLineDataset(filename)"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "vbxIhC-5IPdf",
- "colab_type": "text"
- },
"cell_type": "markdown",
- "source": [
- "### Apply transformations\n",
- "\n",
- "Use the transformations functions like [`map`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map), [`batch`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch), [`shuffle`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle) etc. to apply transformations to the records of the dataset. See the [API documentation for `tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) for details."
- ]
- },
- {
"metadata": {
- "id": "uXSDZWE-ISsd",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
+ "colab_type": "text",
+ "id": "Hndw-YcxoOJK"
},
- "cell_type": "code",
"source": [
- "ds_tensors = ds_tensors.map(tf.square).shuffle(2).batch(2)\n",
- "\n",
- "ds_file = ds_file.batch(2)"
- ],
- "execution_count": 0,
- "outputs": []
- },
- {
- "metadata": {
- "id": "A8X1GNfoIZKJ",
- "colab_type": "text"
- },
- "cell_type": "markdown",
- "source": [
- "### Iterate\n",
- "\n",
- "When eager execution is enabled `Dataset` objects support iteration.\n",
- "If you're familiar with the use of `Dataset`s in TensorFlow graphs, note that there is no need for calls to `Dataset.make_one_shot_iterator()` or `get_next()` calls."
+ "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n",
+ "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/eager/eager_basics.ipynb\"\u003e\n",
+ " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
+ "\u003c/td\u003e\u003ctd\u003e\n",
+ "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/eager_basics.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e"
]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "collapsed_sections": [],
+ "name": "eager_basics.ipynb",
+ "private_outputs": true,
+ "provenance": [],
+ "toc_visible": true,
+ "version": "0.3.2"
},
- {
- "metadata": {
- "id": "ws-WKRk5Ic6-",
- "colab_type": "code",
- "colab": {
- "autoexec": {
- "startup": false,
- "wait_interval": 0
- }
- }
- },
- "cell_type": "code",
- "source": [
- "print('Elements of ds_tensors:')\n",
- "for x in ds_tensors:\n",
- " print(x)\n",
- "\n",
- "print('\\nElements in ds_file:')\n",
- "for x in ds_file:\n",
- " print(x)"
- ],
- "execution_count": 0,
- "outputs": []
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
}
- ]
-} \ No newline at end of file
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb b/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb
index ee25d25b52..d60ee18586 100644
--- a/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb
+++ b/tensorflow/contrib/eager/python/examples/pix2pix/pix2pix_eager.ipynb
@@ -147,11 +147,12 @@
" # random jittering\n",
" \n",
" # resizing to 286 x 286 x 3\n",
- " # method = 2 indicates using \"ResizeMethod.NEAREST_NEIGHBOR\"\n",
" input_image = tf.image.resize_images(input_image, [286, 286], \n",
- " align_corners=True, method=2)\n",
+ " align_corners=True, \n",
+ " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
" real_image = tf.image.resize_images(real_image, [286, 286], \n",
- " align_corners=True, method=2)\n",
+ " align_corners=True, \n",
+ " method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)\n",
" \n",
" # randomly cropping to 256 x 256 x 3\n",
" stacked_image = tf.stack([input_image, real_image], axis=0)\n",
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py
index a28bc8a43d..9d090e8429 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py
+++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py
@@ -195,12 +195,12 @@ class ResNet50(tf.keras.Model):
def __init__(self,
data_format,
- name=None,
+ name='',
trainable=True,
include_top=True,
pooling=None,
classes=1000):
- super(ResNet50, self).__init__(name='')
+ super(ResNet50, self).__init__(name=name)
valid_channel_values = ('channels_first', 'channels_last')
if data_format not in valid_channel_values:
@@ -272,8 +272,8 @@ class ResNet50(tf.keras.Model):
else:
self.global_pooling = None
- def call(self, input_tensor, training):
- x = self.conv1(input_tensor)
+ def call(self, inputs, training=True):
+ x = self.conv1(inputs)
x = self.bn_conv1(x, training=training)
x = tf.nn.relu(x)
x = self.max_pool(x)
diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py
index 5ee2176154..74ebb1ec77 100644
--- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py
+++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py
@@ -243,8 +243,8 @@ def train_one_epoch(model, optimizer, train_data, log_interval=10):
print("train/batch #%d\tloss: %.6f" % (batch, batch_model_loss()))
-SOURCE_TRAIN_URL = "https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/train.csv"
-SOURCE_TEST_URL = "https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/test.csv"
+SOURCE_TRAIN_URL = "https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/archive/extras/colorbot/data/train.csv"
+SOURCE_TEST_URL = "https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/archive/extras/colorbot/data/test.csv"
def main(_):
diff --git a/tensorflow/contrib/eager/python/examples/scan/BUILD b/tensorflow/contrib/eager/python/examples/scan/BUILD
deleted file mode 100644
index 638c57d1c9..0000000000
--- a/tensorflow/contrib/eager/python/examples/scan/BUILD
+++ /dev/null
@@ -1,25 +0,0 @@
-licenses(["notice"]) # Apache 2.0
-
-package(default_visibility = ["//tensorflow:internal"])
-
-load("//tensorflow:tensorflow.bzl", "cuda_py_test")
-
-cuda_py_test(
- name = "scan_test",
- size = "small",
- srcs = ["scan_test.py"],
- additional_deps = [
- "//third_party/py/numpy",
- "//tensorflow:tensorflow_py",
- ],
-)
-
-cuda_py_test(
- name = "scan_graph_test",
- size = "small",
- srcs = ["scan_graph_test.py"],
- additional_deps = [
- "//third_party/py/numpy",
- "//tensorflow:tensorflow_py",
- ],
-)
diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py
deleted file mode 100644
index d4b8c8941e..0000000000
--- a/tensorflow/contrib/eager/python/examples/scan/scan_graph_test.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Unit test for tf.scan under graph mode execution."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import time
-
-import numpy as np
-import tensorflow as tf
-
-
-class ScanBenchmark(tf.test.Benchmark):
-
- def runScan(self, n):
- elems = np.arange(n)
- start_time = time.time()
- sum_op = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1)
- with tf.Session() as sess:
- sess.run(sum_op)
- wall_time = time.time() - start_time
-
- self.report_benchmark(
- name='scan',
- iters=n,
- wall_time=wall_time)
-
- def benchmarkScan16000(self):
- self.runScan(16000)
-
- def benchmarkScan32000(self):
- self.runScan(32000)
-
- def benchmarkScan64000(self):
- self.runScan(64000)
-
- def benchmarkScan128000(self):
- self.runScan(128000)
-
-if __name__ == '__main__':
- tf.test.main()
diff --git a/tensorflow/contrib/eager/python/examples/scan/scan_test.py b/tensorflow/contrib/eager/python/examples/scan/scan_test.py
deleted file mode 100644
index a02fc24c79..0000000000
--- a/tensorflow/contrib/eager/python/examples/scan/scan_test.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Unit test for tf.scan under eager execution."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import time
-
-import numpy as np
-import tensorflow as tf
-
-
-class ScanBenchmark(tf.test.Benchmark):
-
- def runScan(self, n):
- elems = np.arange(n)
- start_time = time.time()
- _ = tf.scan(lambda a, x: a + x, elems, parallel_iterations=1)
- wall_time = time.time() - start_time
-
- self.report_benchmark(
- name='scan',
- iters=n,
- wall_time=wall_time)
-
- def benchmarkScan16000(self):
- self.runScan(16000)
-
- def benchmarkScan32000(self):
- self.runScan(32000)
-
- def benchmarkScan64000(self):
- self.runScan(64000)
-
- def benchmarkScan128000(self):
- self.runScan(128000)
-
-
-if __name__ == '__main__':
- tf.enable_eager_execution()
- tf.test.main()
diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py
index 6efafccd6b..930e62b680 100644
--- a/tensorflow/contrib/eager/python/metrics_impl.py
+++ b/tensorflow/contrib/eager/python/metrics_impl.py
@@ -336,9 +336,27 @@ class Mean(Metric):
return values
return values, weights
- def result(self):
+ def result(self, write_summary=True):
+ """Returns the result of the Metric.
+
+ Args:
+ write_summary: bool indicating whether to feed the result to the summary
+ before returning.
+ Returns:
+ aggregated metric as float.
+ Raises:
+ ValueError: if the optional argument is not bool
+ """
+ # Convert the boolean to tensor for tf.cond, if it is not.
+ if not isinstance(write_summary, ops.Tensor):
+ write_summary = ops.convert_to_tensor(write_summary)
t = self.numer / self.denom
- summary_ops.scalar(name=self.name, tensor=t)
+ def write_summary_f():
+ summary_ops.scalar(name=self.name, tensor=t)
+ return t
+ control_flow_ops.cond(write_summary,
+ write_summary_f,
+ lambda: t)
return t
diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py
index 20d938d492..9d2d172752 100644
--- a/tensorflow/contrib/eager/python/metrics_test.py
+++ b/tensorflow/contrib/eager/python/metrics_test.py
@@ -25,11 +25,14 @@ from tensorflow.contrib.eager.python import metrics
from tensorflow.contrib.summary import summary_test_util
from tensorflow.python.eager import context
from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import summary_ops_v2 as summary_ops
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import util as checkpointable_utils
@@ -46,6 +49,18 @@ class MetricsTest(test.TestCase):
self.assertEqual(dtypes.float64, m.dtype)
self.assertEqual(dtypes.float64, m.result().dtype)
+ def testSummaryArg(self):
+ m = metrics.Mean()
+ m([1, 10, 100])
+ m(1000)
+ m([10000.0, 100000.0])
+ self.assertEqual(111111.0/6, m.result(write_summary=True).numpy())
+ self.assertEqual(111111.0/6, m.result(write_summary=False).numpy())
+ with self.assertRaises(ValueError):
+ m.result(write_summary=5)
+ with self.assertRaises(ValueError):
+ m.result(write_summary=[True])
+
def testVariableCollections(self):
with context.graph_mode(), ops.Graph().as_default():
m = metrics.Mean()
@@ -93,6 +108,16 @@ class MetricsTest(test.TestCase):
self.assertEqual(len(events), 2)
self.assertEqual(events[1].summary.value[0].simple_value, 37.0)
+ # Get result without saving the summary.
+ logdir = tempfile.mkdtemp()
+ with summary_ops.create_file_writer(
+ logdir, max_queue=0,
+ name="t0").as_default(), summary_ops.always_record_summaries():
+ m.result(write_summary=False) # As a side-effect will write summaries.
+ # events_from_logdir(_) asserts the directory exists.
+ events = summary_test_util.events_from_logdir(logdir)
+ self.assertEqual(len(events), 1)
+
def testWeightedMean(self):
m = metrics.Mean()
m([1, 100, 100000], weights=[1, 0.2, 0.3])
@@ -191,7 +216,7 @@ class MetricsTest(test.TestCase):
self.assertEqual(m1.numer.name, "has_space/numer:0")
def testGraphWithPlaceholder(self):
- with context.graph_mode(), self.test_session() as sess:
+ with context.graph_mode(), self.cached_session() as sess:
m = metrics.Mean()
p = array_ops.placeholder(dtypes.float32)
accumulate = m(p)
@@ -222,6 +247,48 @@ class MetricsTest(test.TestCase):
value = m.value()
self.assertEqual(self.evaluate(value), 2.5)
+ @test_util.run_in_graph_and_eager_modes
+ def testGraphAndEagerTensorGlobalVariables(self):
+ m = metrics.Mean(use_global_variables=True)
+ inputs = ops.convert_to_tensor([1.0, 2.0])
+ accumulate = m(inputs)
+ result = m.result()
+ self.evaluate(m.init_variables())
+ self.evaluate(accumulate)
+ self.assertEqual(self.evaluate(result), 1.5)
+ # Second init resets all the variables.
+ self.evaluate(m.init_variables())
+ inputs = ops.convert_to_tensor([2.0, 3.0])
+ self.evaluate(m(inputs))
+ value = m.value()
+ self.assertEqual(self.evaluate(value), 2.5)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testGraphAndEagerTensorWhileLoopDoubleCall(self):
+ m = metrics.Mean()
+ init_value = constant_op.constant(1)
+ cond = lambda i: math_ops.less(i, 3)
+ def body(x):
+ with ops.control_dependencies([m(x)]):
+ return math_ops.add(x, 1)
+ accumulate = control_flow_ops.while_loop(cond, body, [init_value])
+
+ result = m.result()
+ self.evaluate(m.init_variables())
+ self.evaluate(accumulate)
+ self.assertEqual(self.evaluate(result), 1.5)
+ # Second init resets all the variables.
+ self.evaluate(m.init_variables())
+ inputs = ops.convert_to_tensor([2.0, 3.0])
+ self.evaluate(m(inputs))
+ if ops.context.executing_eagerly():
+ self.evaluate(control_flow_ops.while_loop(cond, body, [init_value]))
+ else:
+ # Reuse the loop operators in graph mode
+ self.evaluate(accumulate)
+ value = m.value()
+ self.assertEqual(self.evaluate(value), 2.0)
+
def testTwoMeansGraph(self):
# Verify two metrics with the same name in the same graph raises a
# ValueError.
@@ -242,7 +309,7 @@ class MetricsTest(test.TestCase):
self.assertTrue(old_numer is m.numer)
def testMetricsChain(self):
- with context.graph_mode(), self.test_session():
+ with context.graph_mode(), self.cached_session():
m1 = metrics.Mean()
m2 = metrics.Mean(name="m2")
update_m2 = m2(3.0)
diff --git a/tensorflow/contrib/eager/python/remote.py b/tensorflow/contrib/eager/python/remote.py
new file mode 100644
index 0000000000..b74cf394f6
--- /dev/null
+++ b/tensorflow/contrib/eager/python/remote.py
@@ -0,0 +1,73 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Helpers to connect to remote servers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.core.protobuf.cluster_pb2 import ClusterDef
+from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
+from tensorflow.python.eager import context
+
+
+def connect_to_remote_host(remote_host=None, job_name="worker"):
+ """Connects to a single machine to enable remote execution on it.
+
+ Will make devices on the remote host available to use. Note that calling this
+ more than once will work, but will invalidate any tensor handles on the old
+ remote devices.
+
+ Using the default job_name of worker, you can schedule ops to run remotely as
+ follows:
+ ```python
+ # Enable eager execution, and connect to the remote host.
+ tf.enable_eager_execution()
+ tf.contrib.eager.connect_to_remote_host("exampleaddr.com:9876")
+
+ with ops.device("job:worker/replica:0/task:1/device:CPU:0"):
+ # The following tensors should be resident on the remote device, and the op
+ # will also execute remotely.
+ x1 = array_ops.ones([2, 2])
+ x2 = array_ops.ones([2, 2])
+ y = math_ops.matmul(x1, x2)
+ ```
+
+ Args:
+ remote_host: The addr of the remote server in host-port format.
+ job_name: The job name under which the new server will be accessible.
+
+ Raises:
+ ValueError: if remote_host is None.
+ """
+ if remote_host is None:
+ raise ValueError("Must provide an remote_host")
+ cluster_def = ClusterDef()
+ job_def = cluster_def.job.add()
+ job_def.name = job_name
+ job_def.tasks[0] = "127.0.0.1:0"
+ job_def.tasks[1] = remote_host
+
+ server_def = ServerDef(
+ cluster=cluster_def,
+ job_name=job_name,
+ task_index=0,
+ protocol="grpc")
+
+ # TODO(nareshmodi): Make this default since it works in more situations.
+ os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1"
+ context.set_server_def(server_def)
diff --git a/tensorflow/contrib/eager/python/remote_test.py b/tensorflow/contrib/eager/python/remote_test.py
index 76f48eeb1c..13029db975 100644
--- a/tensorflow/contrib/eager/python/remote_test.py
+++ b/tensorflow/contrib/eager/python/remote_test.py
@@ -23,6 +23,7 @@ import os
import numpy as np
+from tensorflow.contrib.eager.python import remote
from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.python.eager import backprop
@@ -85,6 +86,7 @@ class RemoteExecutionTest(test.TestCase):
self._cached_server1_target = self._cached_server1.target[len("grpc://"):]
self._cached_server2_target = self._cached_server2.target[len("grpc://"):]
+ def setUp(self):
# Start the local server.
context.set_server_def(
server_def=get_server_def(
@@ -172,6 +174,17 @@ class RemoteExecutionTest(test.TestCase):
y = math_ops.matmul(x1, x1)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
+ @run_sync_and_async
+ def testConnectToRemoteServer(self):
+ """Basic server connection."""
+ remote.connect_to_remote_host(self._cached_server1_target)
+
+ with ops.device("job:worker/replica:0/task:1/device:CPU:0"):
+ x1 = array_ops.ones([2, 2])
+ x2 = array_ops.ones([2, 2])
+ y = math_ops.matmul(x1, x2)
+ np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
+
if __name__ == "__main__":
ops.enable_eager_execution()
diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py
index 4dfd083443..f5b8d95e4f 100644
--- a/tensorflow/contrib/eager/python/tfe.py
+++ b/tensorflow/contrib/eager/python/tfe.py
@@ -74,6 +74,8 @@ To use, at program startup, call `tf.enable_eager_execution()`.
@@TensorSpec
+@@connect_to_remote_host
+
@@DEVICE_PLACEMENT_EXPLICIT
@@DEVICE_PLACEMENT_WARN
@@DEVICE_PLACEMENT_SILENT
@@ -94,6 +96,7 @@ from tensorflow.contrib.eager.python.network import Network
from tensorflow.contrib.eager.python.network import Sequential
from tensorflow.contrib.eager.python.network import save_network_checkpoint
from tensorflow.contrib.eager.python.network import restore_network_checkpoint
+from tensorflow.contrib.eager.python.remote import connect_to_remote_host
from tensorflow.contrib.eager.python.saver import get_optimizer_variables
from tensorflow.contrib.eager.python.saver import restore_variables_on_create
from tensorflow.contrib.eager.python.saver import Saver
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 77f62df99d..437b3d965d 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -446,6 +446,7 @@ py_library(
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:head",
"//tensorflow/python/estimator:optimizers",
+ "//tensorflow/python/ops/losses",
"@six_archive//:six",
],
)
diff --git a/tensorflow/contrib/estimator/python/estimator/baseline_test.py b/tensorflow/contrib/estimator/python/estimator/baseline_test.py
index 505c94e971..513feb03b6 100644
--- a/tensorflow/contrib/estimator/python/estimator/baseline_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/baseline_test.py
@@ -37,13 +37,13 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import checkpoint_utils
-from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import optimizer
from tensorflow.python.training import saver
@@ -339,7 +339,7 @@ class BaselineEstimatorTrainingTest(test.TestCase):
self.assertEquals(0, loss.shape.ndims)
if expected_loss is None:
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
assert_loss = assert_close(
math_ops.to_float(expected_loss, name='expected'),
@@ -347,7 +347,7 @@ class BaselineEstimatorTrainingTest(test.TestCase):
name='assert_loss')
with ops.control_dependencies((assert_loss,)):
if global_step is not None:
- return distribute_lib.increment_var(global_step)
+ return state_ops.assign_add(global_step, 1).op
return control_flow_ops.no_op()
mock_optimizer = test.mock.NonCallableMock(
diff --git a/tensorflow/contrib/estimator/python/estimator/extenders.py b/tensorflow/contrib/estimator/python/estimator/extenders.py
index 26449b4651..e3c44bea66 100644
--- a/tensorflow/contrib/estimator/python/estimator/extenders.py
+++ b/tensorflow/contrib/estimator/python/estimator/extenders.py
@@ -26,6 +26,7 @@ from tensorflow.python.estimator.export.export_output import PredictOutput
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import sparse_ops
from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.util import function_utils
@@ -140,7 +141,7 @@ def clip_gradients_by_norm(optimizer, clip_norm):
name='ClipByNorm' + optimizer.get_name())
-def forward_features(estimator, keys=None):
+def forward_features(estimator, keys=None, sparse_default_values=None):
"""Forward features to predictions dictionary.
In some cases, user wants to see some of the features in estimators prediction
@@ -148,39 +149,36 @@ def forward_features(estimator, keys=None):
runs inference on the users graph and returns the results. Keys are essential
because there is no order guarantee on the outputs so they need to be rejoined
to the inputs via keys or transclusion of the inputs in the outputs.
-
Example:
-
```python
def input_fn():
features, labels = ...
features['unique_example_id'] = ...
features, labels
-
estimator = tf.estimator.LinearClassifier(...)
estimator = tf.contrib.estimator.forward_features(
estimator, 'unique_example_id')
estimator.train(...)
assert 'unique_example_id' in estimator.predict(...)
```
-
Args:
estimator: A `tf.estimator.Estimator` object.
- keys: a `string` or a `list` of `string`. If it is `None`, all of the
+ keys: A `string` or a `list` of `string`. If it is `None`, all of the
`features` in `dict` is forwarded to the `predictions`. If it is a
`string`, only given key is forwarded. If it is a `list` of strings, all
the given `keys` are forwarded.
+ sparse_default_values: A dict of `str` keys mapping the name of the sparse
+ features to be converted to dense, to the default value to use. Only
+ sparse features indicated in the dictionary are converted to dense and the
+ provided default value is used.
Returns:
A new `tf.estimator.Estimator` which forwards features to predictions.
-
Raises:
ValueError:
* if `keys` is already part of `predictions`. We don't allow
override.
* if 'keys' does not exist in `features`.
- * if feature key refers to a `SparseTensor`, since we don't support
- `SparseTensor` in `predictions`. `SparseTensor` is common in `features`.
TypeError: if `keys` type is not one of `string` or list/tuple of `string`.
"""
@@ -231,11 +229,18 @@ def forward_features(estimator, keys=None):
for key in get_keys(features):
feature = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(
features[key])
+ if sparse_default_values and (key in sparse_default_values):
+ if not isinstance(feature, sparse_tensor_lib.SparseTensor):
+ raise ValueError(
+ 'Feature ({}) is expected to be a `SparseTensor`.'.format(key))
+ feature = sparse_ops.sparse_tensor_to_dense(
+ feature, default_value=sparse_default_values[key])
if not isinstance(feature, ops.Tensor):
raise ValueError(
- 'Forwarded feature ({}) should be a Tensor. Please use keys '
- 'argument of forward_features to filter unwanted features. Type of '
- 'features[{}] is {}.'.format(key, key, type(feature)))
+ 'Feature ({}) should be a Tensor. Please use `keys` '
+ 'argument of forward_features to filter unwanted features, or'
+ 'add key to argument `sparse_default_values`.'
+ 'Type of features[{}] is {}.'.format(key, key, type(feature)))
predictions[key] = feature
spec = spec._replace(predictions=predictions)
if spec.export_outputs:
diff --git a/tensorflow/contrib/estimator/python/estimator/extenders_test.py b/tensorflow/contrib/estimator/python/estimator/extenders_test.py
index 407af2deaf..c8fdaa8791 100644
--- a/tensorflow/contrib/estimator/python/estimator/extenders_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/extenders_test.py
@@ -14,6 +14,7 @@
# ==============================================================================
"""extenders tests."""
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -23,6 +24,7 @@ import tempfile
import numpy as np
from tensorflow.contrib.estimator.python.estimator import extenders
+from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.predictor import from_saved_model
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator_lib
@@ -170,19 +172,53 @@ class ClipGradientsByNormTest(test.TestCase):
class ForwardFeaturesTest(test.TestCase):
"""Tests forward_features."""
- def test_forward_single_key(self):
-
- def input_fn():
- return {'x': [[3.], [5.]], 'id': [[101], [102]]}, [[1.], [2.]]
+ def _export_estimator(self, estimator, serving_input_fn):
+ tmpdir = tempfile.mkdtemp()
+ export_dir_base = os.path.join(
+ compat.as_bytes(tmpdir), compat.as_bytes('export'))
+ export_dir = estimator.export_savedmodel(export_dir_base, serving_input_fn)
+ self.assertTrue(gfile.Exists(export_dir))
+ return export_dir, tmpdir
+ def make_dummy_input_fn(self):
+ def _input_fn():
+ dataset = dataset_ops.Dataset.from_tensors({
+ 'x': [[3.], [5.]],
+ 'id': [[101], [102]],
+ 'sparse_id': sparse_tensor.SparseTensor(
+ values=[1, 2, 3],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2]),
+ 'labels': [[1.], [2.]]
+ })
+ def _split(x):
+ labels = x.pop('labels')
+ return x, labels
+ dataset = dataset.map(_split)
+ return dataset
+ return _input_fn
+
+ def test_forward_keys(self):
+
+ input_fn = self.make_dummy_input_fn()
estimator = linear.LinearRegressor([fc.numeric_column('x')])
estimator.train(input_fn=input_fn, steps=1)
- self.assertNotIn('id', next(estimator.predict(input_fn=input_fn)))
- estimator = extenders.forward_features(estimator, 'id')
- predictions = next(estimator.predict(input_fn=input_fn))
- self.assertIn('id', predictions)
- self.assertEqual(101, predictions['id'])
+ forwarded_keys = ['id', 'sparse_id']
+
+ for key in forwarded_keys:
+ self.assertNotIn(key, next(estimator.predict(input_fn=input_fn)))
+
+ estimator = extenders.forward_features(
+ estimator, forwarded_keys, sparse_default_values={'sparse_id': 1})
+
+ expected_results = [101, 2, 102, 5]
+ predictions = estimator.predict(input_fn=input_fn)
+ for _ in range(2):
+ prediction = next(predictions)
+ for key in forwarded_keys:
+ self.assertIn(key, prediction)
+ self.assertEqual(expected_results.pop(0), sum(prediction[key]))
def test_forward_in_exported(self):
@@ -205,11 +241,7 @@ class ForwardFeaturesTest(test.TestCase):
estimator = extenders.forward_features(estimator, 'id')
# export saved model
- tmpdir = tempfile.mkdtemp()
- export_dir_base = os.path.join(
- compat.as_bytes(tmpdir), compat.as_bytes('export'))
- export_dir = estimator.export_savedmodel(export_dir_base, serving_input_fn)
- self.assertTrue(gfile.Exists(export_dir))
+ export_dir, tmpdir = self._export_estimator(estimator, serving_input_fn)
# restore model
predict_fn = from_saved_model(export_dir, signature_def_key='predict')
@@ -222,6 +254,47 @@ class ForwardFeaturesTest(test.TestCase):
# Clean up.
gfile.DeleteRecursively(tmpdir)
+ def test_forward_in_exported_sparse(self):
+ features_columns = [fc.indicator_column(
+ fc.categorical_column_with_vocabulary_list('x', range(10)))]
+
+ classifier = linear.LinearClassifier(feature_columns=features_columns)
+
+ def train_input_fn():
+ dataset = dataset_ops.Dataset.from_tensors({
+ 'x': sparse_tensor.SparseTensor(
+ values=[1, 2, 3],
+ indices=[[0, 0], [1, 0], [1, 1]],
+ dense_shape=[2, 2]),
+ 'labels': [[0], [1]]
+ })
+ def _split(x):
+ labels = x.pop('labels')
+ return x, labels
+ dataset = dataset.map(_split)
+ return dataset
+
+ classifier.train(train_input_fn, max_steps=1)
+
+ classifier = extenders.forward_features(
+ classifier, keys=['x'], sparse_default_values={'x': 0})
+
+ def serving_input_fn():
+ features_ph = array_ops.placeholder(dtype=dtypes.int32, name='x',
+ shape=[None])
+ features = {'x': layers.dense_to_sparse(features_ph)}
+ return estimator_lib.export.ServingInputReceiver(features,
+ {'x': features_ph})
+ export_dir, tmpdir = self._export_estimator(classifier, serving_input_fn)
+ prediction_fn = from_saved_model(export_dir, signature_def_key='predict')
+
+ features = (0, 2)
+ prediction = prediction_fn({'x': features})
+
+ self.assertIn('x', prediction)
+ self.assertEqual(features, tuple(prediction['x']))
+ gfile.DeleteRecursively(tmpdir)
+
def test_forward_list(self):
def input_fn():
@@ -266,7 +339,6 @@ class ForwardFeaturesTest(test.TestCase):
extenders.forward_features(estimator, ['x', estimator])
def test_key_should_be_in_features(self):
-
def input_fn():
return {'x': [[3.], [5.]], 'id': [[101], [102]]}, [[1.], [2.]]
@@ -279,27 +351,36 @@ class ForwardFeaturesTest(test.TestCase):
next(estimator.predict(input_fn=input_fn))
def test_forwarded_feature_should_not_be_a_sparse_tensor(self):
-
def input_fn():
return {
'x': [[3.], [5.]],
- 'id':
- sparse_tensor.SparseTensor(
- values=['1', '2'],
- indices=[[0, 0], [1, 0]],
- dense_shape=[2, 1])
- }, [[1.], [2.]]
+ 'id': sparse_tensor.SparseTensor(
+ values=['1', '2'],
+ indices=[[0, 0], [1, 0]],
+ dense_shape=[2, 1])
+ }, [[1.], [2.]]
estimator = linear.LinearRegressor([fc.numeric_column('x')])
estimator.train(input_fn=input_fn, steps=1)
estimator = extenders.forward_features(estimator)
with self.assertRaisesRegexp(ValueError,
- 'Forwarded feature.* should be a Tensor.'):
+ 'Feature .* should be a Tensor.*'):
next(estimator.predict(input_fn=input_fn))
- def test_predictions_should_be_dict(self):
+ def test_forwarded_feature_should_be_a_sparse_tensor(self):
+ input_fn = self.make_dummy_input_fn()
+
+ estimator = linear.LinearRegressor([fc.numeric_column('x')])
+ estimator.train(input_fn=input_fn, steps=1)
+ estimator = extenders.forward_features(
+ estimator, sparse_default_values={'id': 0, 'sparse_id': 0})
+ with self.assertRaisesRegexp(
+ ValueError, 'Feature .* is expected to be a `SparseTensor`.'):
+ next(estimator.predict(input_fn=input_fn))
+
+ def test_predictions_should_be_dict(self):
def input_fn():
return {'x': [[3.], [5.]], 'id': [[101], [102]]}
diff --git a/tensorflow/contrib/estimator/python/estimator/rnn.py b/tensorflow/contrib/estimator/python/estimator/rnn.py
index 7c49cd00d1..98660bb731 100644
--- a/tensorflow/contrib/estimator/python/estimator/rnn.py
+++ b/tensorflow/contrib/estimator/python/estimator/rnn.py
@@ -37,6 +37,7 @@ from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
from tensorflow.python.training import optimizer as optimizer_lib
from tensorflow.python.training import training_util
@@ -405,6 +406,7 @@ class RNNClassifier(estimator.Estimator):
weight_column=None,
label_vocabulary=None,
optimizer='Adagrad',
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE,
input_layer_partitioner=None,
config=None):
"""Initializes a `RNNClassifier` instance.
@@ -454,6 +456,8 @@ class RNNClassifier(estimator.Estimator):
string.
optimizer: An instance of `tf.Optimizer` or string specifying optimizer
type. Defaults to Adagrad optimizer.
+ loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how
+ to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`.
input_layer_partitioner: Optional. Partitioner for input layer. Defaults
to `min_max_variable_partitioner` with `min_slice_size` 64 << 20.
config: `RunConfig` object to configure the runtime settings.
@@ -467,11 +471,15 @@ class RNNClassifier(estimator.Estimator):
if n_classes == 2:
head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint: disable=protected-access
weight_column=weight_column,
- label_vocabulary=label_vocabulary)
+ label_vocabulary=label_vocabulary,
+ loss_reduction=loss_reduction)
else:
head = head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint: disable=protected-access
- n_classes, weight_column=weight_column,
- label_vocabulary=label_vocabulary)
+ n_classes,
+ weight_column=weight_column,
+ label_vocabulary=label_vocabulary,
+ loss_reduction=loss_reduction)
+
def _model_fn(features, labels, mode, config):
return _rnn_model_fn(
features=features,
diff --git a/tensorflow/contrib/estimator/python/estimator/rnn_test.py b/tensorflow/contrib/estimator/python/estimator/rnn_test.py
index 959b40371a..1aebed348d 100644
--- a/tensorflow/contrib/estimator/python/estimator/rnn_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/rnn_test.py
@@ -713,7 +713,7 @@ class RNNClassifierTrainingTest(test.TestCase):
# Uses same checkpoint and examples as testBinaryClassEvaluationMetrics.
# See that test for loss calculation.
- mock_optimizer = self._mock_optimizer(expected_loss=1.119661)
+ mock_optimizer = self._mock_optimizer(expected_loss=0.559831)
sequence_feature_columns = [
seq_fc.sequence_numeric_column('price', shape=(1,))]
@@ -748,7 +748,7 @@ class RNNClassifierTrainingTest(test.TestCase):
# Uses same checkpoint and examples as testMultiClassEvaluationMetrics.
# See that test for loss calculation.
- mock_optimizer = self._mock_optimizer(expected_loss=2.662932)
+ mock_optimizer = self._mock_optimizer(expected_loss=1.331465)
sequence_feature_columns = [
seq_fc.sequence_numeric_column('price', shape=(1,))]
@@ -812,20 +812,32 @@ class RNNClassifierEvaluationTest(test.TestCase):
# probability = exp(logits) / (1 + exp(logits)) = [[0.353593], [0.504930]]
# loss = -label * ln(p) - (1 - label) * ln(1 - p)
# = [[0.436326], [0.683335]]
+ # sum_over_batch_size = (0.436326 + 0.683335)/2
expected_metrics = {
- ops.GraphKeys.GLOBAL_STEP: global_step,
- metric_keys.MetricKeys.LOSS: 1.119661,
- metric_keys.MetricKeys.LOSS_MEAN: 0.559831,
- metric_keys.MetricKeys.ACCURACY: 1.0,
- metric_keys.MetricKeys.PREDICTION_MEAN: 0.429262,
- metric_keys.MetricKeys.LABEL_MEAN: 0.5,
- metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5,
+ ops.GraphKeys.GLOBAL_STEP:
+ global_step,
+ metric_keys.MetricKeys.LOSS:
+ 0.559831,
+ metric_keys.MetricKeys.LOSS_MEAN:
+ 0.559831,
+ metric_keys.MetricKeys.ACCURACY:
+ 1.0,
+ metric_keys.MetricKeys.PREDICTION_MEAN:
+ 0.429262,
+ metric_keys.MetricKeys.LABEL_MEAN:
+ 0.5,
+ metric_keys.MetricKeys.ACCURACY_BASELINE:
+ 0.5,
# With default threshold of 0.5, the model is a perfect classifier.
- metric_keys.MetricKeys.RECALL: 1.0,
- metric_keys.MetricKeys.PRECISION: 1.0,
+ metric_keys.MetricKeys.RECALL:
+ 1.0,
+ metric_keys.MetricKeys.PRECISION:
+ 1.0,
# Positive example is scored above negative, so AUC = 1.0.
- metric_keys.MetricKeys.AUC: 1.0,
- metric_keys.MetricKeys.AUC_PR: 1.0,
+ metric_keys.MetricKeys.AUC:
+ 1.0,
+ metric_keys.MetricKeys.AUC_PR:
+ 1.0,
}
self.assertAllClose(
sorted_key_dict(expected_metrics), sorted_key_dict(eval_metrics))
@@ -871,9 +883,10 @@ class RNNClassifierEvaluationTest(test.TestCase):
# [0.059494, 0.572639, 0.367866]]
# loss = -1. * log(softmax[label])
# = [[2.105432], [0.557500]]
+ # sum_over_batch_size = (2.105432 + 0.557500)/2
expected_metrics = {
ops.GraphKeys.GLOBAL_STEP: global_step,
- metric_keys.MetricKeys.LOSS: 2.662932,
+ metric_keys.MetricKeys.LOSS: 1.331465,
metric_keys.MetricKeys.LOSS_MEAN: 1.331466,
metric_keys.MetricKeys.ACCURACY: 0.5,
}
diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py
index bb5140aeb3..6aa62fb82e 100644
--- a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py
+++ b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py
@@ -126,7 +126,7 @@ class WalsModelTest(test.TestCase):
observed *= num_rows / 3. if test_rows else num_cols / 2.
want_weight_sum = unobserved + observed
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
wals_model = factorization_ops.WALSModel(
input_rows=num_rows,
input_cols=num_cols,
@@ -161,7 +161,7 @@ class WalsModelTest(test.TestCase):
def _run_test_process_input(self,
use_factors_weights_cache,
compute_loss=False):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
self._wals_inputs = self.sparse_input()
sp_feeder = array_ops.sparse_placeholder(dtypes.float32)
num_rows = 5
@@ -330,7 +330,7 @@ class WalsModelTest(test.TestCase):
def _run_test_process_input_transposed(self,
use_factors_weights_cache,
compute_loss=False):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
self._wals_inputs = self.sparse_input()
sp_feeder = array_ops.sparse_placeholder(dtypes.float32)
num_rows = 5
@@ -505,7 +505,7 @@ class WalsModelTest(test.TestCase):
# trigger the more efficient ALS updates.
# Here we test that those two give identical results.
def _run_test_als(self, use_factors_weights_cache):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
self._wals_inputs = self.sparse_input()
col_init = np.random.rand(7, 3)
als_model = factorization_ops.WALSModel(
@@ -583,7 +583,7 @@ class WalsModelTest(test.TestCase):
atol=1e-2)
def _run_test_als_transposed(self, use_factors_weights_cache):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
self._wals_inputs = self.sparse_input()
col_init = np.random.rand(7, 3)
als_model = factorization_ops.WALSModel(
@@ -673,7 +673,7 @@ class WalsModelTest(test.TestCase):
rows = 15
cols = 11
dims = 3
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
data = np.dot(np.random.rand(rows, 3), np.random.rand(
3, cols)).astype(np.float32) / 3.0
indices = [[i, j] for i in xrange(rows) for j in xrange(cols)]
@@ -703,7 +703,7 @@ class WalsModelTest(test.TestCase):
cols = 11
dims = 3
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
data = np.dot(np.random.rand(rows, 3), np.random.rand(
3, cols)).astype(np.float32) / 3.0
indices = [[i, j] for i in xrange(rows) for j in xrange(cols)]
@@ -736,7 +736,7 @@ class WalsModelTest(test.TestCase):
def keep_index(x):
return not (x[0] + x[1]) % 4
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
row_wts = 0.1 + np.random.rand(rows)
col_wts = 0.1 + np.random.rand(cols)
data = np.dot(np.random.rand(rows, 3), np.random.rand(
diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py b/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py
index 888c3c238c..112e4d289b 100644
--- a/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py
+++ b/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py
@@ -99,7 +99,7 @@ class GmmOpsTest(test.TestCase):
logging.info('Numpy took %f', time.time() - start_time)
start_time = time.time()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
op = gmm_ops._covariance(
constant_op.constant(
data.T, dtype=dtypes.float32), False)
@@ -120,7 +120,7 @@ class GmmOpsTest(test.TestCase):
graph = ops.Graph()
with graph.as_default() as g:
g.seed = 5
- with self.test_session() as sess:
+ with self.cached_session() as sess:
data = constant_op.constant(self.data, dtype=dtypes.float32)
loss_op, scores, assignments, training_op, init_op, _ = gmm_ops.gmm(
data, 'random', num_classes, random_seed=self.seed)
@@ -144,7 +144,7 @@ class GmmOpsTest(test.TestCase):
def testParams(self):
"""Tests that the params work as intended."""
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Experiment 1. Update weights only.
data = constant_op.constant(self.data, dtype=dtypes.float32)
gmm_tool = gmm_ops.GmmAlgorithm([data], num_classes,
diff --git a/tensorflow/contrib/factorization/python/ops/kmeans_test.py b/tensorflow/contrib/factorization/python/ops/kmeans_test.py
index 88eb9cf692..1ab5418fe4 100644
--- a/tensorflow/contrib/factorization/python/ops/kmeans_test.py
+++ b/tensorflow/contrib/factorization/python/ops/kmeans_test.py
@@ -232,7 +232,7 @@ class KMeansTest(KMeansTestBase):
self.assertEqual(features.shape, parsed_feature_dict.shape)
self.assertEqual(features.dtype, parsed_feature_dict.dtype)
# Then check that running the tensor yields the original list of points.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
parsed_points = sess.run(parsed_feature_dict)
self.assertAllEqual(self.points, parsed_points)
diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py
index ca46c39baa..b82bf1188f 100644
--- a/tensorflow/contrib/factorization/python/ops/wals.py
+++ b/tensorflow/contrib/factorization/python/ops/wals.py
@@ -377,64 +377,68 @@ class WALSMatrixFactorization(estimator.Estimator):
WALS (Weighted Alternating Least Squares) is an algorithm for weighted matrix
factorization. It computes a low-rank approximation of a given sparse (n x m)
- matrix A, by a product of two matrices, U * V^T, where U is a (n x k) matrix
- and V is a (m x k) matrix. Here k is the rank of the approximation, also
- called the embedding dimension. We refer to U as the row factors, and V as the
- column factors.
+ matrix `A`, by a product of two matrices, `U * V^T`, where `U` is a (n x k)
+ matrix and `V` is a (m x k) matrix. Here k is the rank of the approximation,
+ also called the embedding dimension. We refer to `U` as the row factors, and
+ `V` as the column factors.
See tensorflow/contrib/factorization/g3doc/wals.md for the precise problem
formulation.
- The training proceeds in sweeps: during a row_sweep, we fix V and solve for U.
- During a column sweep, we fix U and solve for V. Each one of these problems is
- an unconstrained quadratic minimization problem and can be solved exactly (it
- can also be solved in mini-batches, since the solution decouples nicely).
+ The training proceeds in sweeps: during a row_sweep, we fix `V` and solve for
+ `U`. During a column sweep, we fix `U` and solve for `V`. Each one of these
+ problems is an unconstrained quadratic minimization problem and can be solved
+ exactly (it can also be solved in mini-batches, since the solution decouples
+ across rows of each matrix).
The alternating between sweeps is achieved by using a hook during training,
which is responsible for keeping track of the sweeps and running preparation
ops at the beginning of each sweep. It also updates the global_step variable,
which keeps track of the number of batches processed since the beginning of
training.
The current implementation assumes that the training is run on a single
- machine, and will fail if config.num_worker_replicas is not equal to one.
- Training is done by calling self.fit(input_fn=input_fn), where input_fn
+ machine, and will fail if `config.num_worker_replicas` is not equal to one.
+ Training is done by calling `self.fit(input_fn=input_fn)`, where `input_fn`
provides two tensors: one for rows of the input matrix, and one for rows of
the transposed input matrix (i.e. columns of the original matrix). Note that
during a row sweep, only row batches are processed (ignoring column batches)
and vice-versa.
Also note that every row (respectively every column) of the input matrix
must be processed at least once for the sweep to be considered complete. In
- particular, training will not make progress if input_fn does not generate some
- rows.
-
- For prediction, given a new set of input rows A' (e.g. new rows of the A
- matrix), we compute a corresponding set of row factors U', such that U' * V^T
- is a good approximation of A'. We call this operation a row projection. A
- similar operation is defined for columns.
- Projection is done by calling self.get_projections(input_fn=input_fn), where
- input_fn satisfies the constraints given below.
-
- The input functions must satisfy the following constraints: Calling input_fn
- must return a tuple (features, labels) where labels is None, and features is
- a dict containing the following keys:
+ particular, training will not make progress if some rows are not generated by
+ the `input_fn`.
+
+ For prediction, given a new set of input rows `A'`, we compute a corresponding
+ set of row factors `U'`, such that `U' * V^T` is a good approximation of `A'`.
+ We call this operation a row projection. A similar operation is defined for
+ columns. Projection is done by calling
+ `self.get_projections(input_fn=input_fn)`, where `input_fn` satisfies the
+ constraints given below.
+
+ The input functions must satisfy the following constraints: Calling `input_fn`
+ must return a tuple `(features, labels)` where `labels` is None, and
+ `features` is a dict containing the following keys:
+
TRAIN:
- - WALSMatrixFactorization.INPUT_ROWS: float32 SparseTensor (matrix).
+ * `WALSMatrixFactorization.INPUT_ROWS`: float32 SparseTensor (matrix).
Rows of the input matrix to process (or to project).
- - WALSMatrixFactorization.INPUT_COLS: float32 SparseTensor (matrix).
+ * `WALSMatrixFactorization.INPUT_COLS`: float32 SparseTensor (matrix).
Columns of the input matrix to process (or to project), transposed.
+
INFER:
- - WALSMatrixFactorization.INPUT_ROWS: float32 SparseTensor (matrix).
+ * `WALSMatrixFactorization.INPUT_ROWS`: float32 SparseTensor (matrix).
Rows to project.
- - WALSMatrixFactorization.INPUT_COLS: float32 SparseTensor (matrix).
+ * `WALSMatrixFactorization.INPUT_COLS`: float32 SparseTensor (matrix).
Columns to project.
- - WALSMatrixFactorization.PROJECT_ROW: Boolean Tensor. Whether to project
+ * `WALSMatrixFactorization.PROJECT_ROW`: Boolean Tensor. Whether to project
the rows or columns.
- - WALSMatrixFactorization.PROJECTION_WEIGHTS (Optional): float32 Tensor
+ * `WALSMatrixFactorization.PROJECTION_WEIGHTS` (Optional): float32 Tensor
(vector). The weights to use in the projection.
+
EVAL:
- - WALSMatrixFactorization.INPUT_ROWS: float32 SparseTensor (matrix).
+ * `WALSMatrixFactorization.INPUT_ROWS`: float32 SparseTensor (matrix).
Rows to project.
- - WALSMatrixFactorization.INPUT_COLS: float32 SparseTensor (matrix).
+ * `WALSMatrixFactorization.INPUT_COLS`: float32 SparseTensor (matrix).
Columns to project.
- - WALSMatrixFactorization.PROJECT_ROW: Boolean Tensor. Whether to project
+ * `WALSMatrixFactorization.PROJECT_ROW`: Boolean Tensor. Whether to project
the rows or columns.
"""
# Keys to be used in model_fn
@@ -469,7 +473,7 @@ class WALSMatrixFactorization(estimator.Estimator):
max_sweeps=None,
model_dir=None,
config=None):
- """Creates a model for matrix factorization using the WALS method.
+ r"""Creates a model for matrix factorization using the WALS method.
Args:
num_rows: Total number of rows for input matrix.
diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py
index 36b483c6d7..9bdbd05015 100644
--- a/tensorflow/contrib/factorization/python/ops/wals_test.py
+++ b/tensorflow/contrib/factorization/python/ops/wals_test.py
@@ -125,11 +125,13 @@ class WALSMatrixFactorizationTest(test.TestCase):
nz_row_ids = np.arange(np.shape(np_matrix)[0])
nz_col_ids = np.arange(np.shape(np_matrix)[1])
- def extract_features(row_batch, col_batch, shape):
+ def extract_features(row_batch, col_batch, num_rows, num_cols):
row_ids = row_batch[0]
col_ids = col_batch[0]
- rows = self.remap_sparse_tensor_rows(row_batch[1], row_ids, shape)
- cols = self.remap_sparse_tensor_rows(col_batch[1], col_ids, shape)
+ rows = self.remap_sparse_tensor_rows(
+ row_batch[1], row_ids, shape=[num_rows, num_cols])
+ cols = self.remap_sparse_tensor_rows(
+ col_batch[1], col_ids, shape=[num_cols, num_rows])
features = {
wals_lib.WALSMatrixFactorization.INPUT_ROWS: rows,
wals_lib.WALSMatrixFactorization.INPUT_COLS: cols,
@@ -154,7 +156,7 @@ class WALSMatrixFactorizationTest(test.TestCase):
capacity=10,
enqueue_many=True)
- features = extract_features(row_batch, col_batch, sp_mat.dense_shape)
+ features = extract_features(row_batch, col_batch, num_rows, num_cols)
if mode == model_fn.ModeKeys.INFER or mode == model_fn.ModeKeys.EVAL:
self.assertTrue(
@@ -334,7 +336,7 @@ class WALSMatrixFactorizationTest(test.TestCase):
loss = self._model.evaluate(
input_fn=eval_input_fn_row, steps=self._num_rows)['loss']
- with self.test_session():
+ with self.cached_session():
true_loss = self.calculate_loss()
self.assertNear(
@@ -352,7 +354,7 @@ class WALSMatrixFactorizationTest(test.TestCase):
loss = self._model.evaluate(
input_fn=eval_input_fn_col, steps=self._num_cols)['loss']
- with self.test_session():
+ with self.cached_session():
true_loss = self.calculate_loss()
self.assertNear(
@@ -438,7 +440,7 @@ class SweepHookTest(test.TestCase):
math_ops.logical_not(is_row_sweep_var)))
mark_sweep_done = state_ops.assign(is_sweep_done_var, True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sweep_hook = wals_lib._SweepHook(
is_row_sweep_var,
is_sweep_done_var,
@@ -489,7 +491,7 @@ class StopAtSweepHookTest(test.TestCase):
train_op = state_ops.assign_add(completed_sweeps, 1)
hook.begin()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([variables.global_variables_initializer()])
mon_sess = monitored_session._HookedSession(sess, [hook])
mon_sess.run(train_op)
diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py
index b1b5126d9e..45a67acb5b 100644
--- a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py
+++ b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py
@@ -24,11 +24,13 @@ from tensorflow.contrib.ffmpeg.ops import gen_encode_audio_op_py
from tensorflow.contrib.util import loader
from tensorflow.python.framework import ops
from tensorflow.python.platform import resource_loader
+from tensorflow.python.util.deprecation import deprecated
_ffmpeg_so = loader.load_op_library(
resource_loader.get_path_to_datafile('ffmpeg.so'))
+@deprecated('2018-09-04', 'This will be deleted and should not be used.')
def decode_audio(contents, file_format=None, samples_per_second=None,
channel_count=None, stream=None):
"""Create an op that decodes the contents of an audio file.
@@ -69,6 +71,7 @@ def decode_audio(contents, file_format=None, samples_per_second=None,
ops.NotDifferentiable('DecodeAudio')
+@deprecated('2018-09-04', 'This will be deleted and should not be used.')
def encode_audio(audio, file_format=None, samples_per_second=None):
"""Creates an op that encodes an audio file using sampled audio from a tensor.
@@ -95,6 +98,7 @@ def encode_audio(audio, file_format=None, samples_per_second=None):
ops.NotDifferentiable('EncodeAudio')
+@deprecated('2018-09-04', 'This will be deleted and should not be used.')
def decode_video(contents):
"""Create an op that decodes the contents of a video file.
diff --git a/tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py b/tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py
index 4f591367fd..77a424145a 100644
--- a/tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py
+++ b/tensorflow/contrib/framework/python/framework/checkpoint_utils_test.py
@@ -82,7 +82,7 @@ class CheckpointsTest(test.TestCase):
def testNoTensor(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
_, _, _, _ = _create_checkpoints(session, checkpoint_dir)
with self.assertRaises(errors_impl.OpError):
self.assertAllEqual(
@@ -90,7 +90,7 @@ class CheckpointsTest(test.TestCase):
def testGetTensor(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
self.assertAllEqual(
checkpoint_utils.load_variable(checkpoint_dir, "var1"), v1)
@@ -103,7 +103,7 @@ class CheckpointsTest(test.TestCase):
def testGetAllVariables(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
_create_checkpoints(session, checkpoint_dir)
self.assertEqual(
checkpoint_utils.list_variables(checkpoint_dir),
@@ -112,7 +112,7 @@ class CheckpointsTest(test.TestCase):
def testInitFromCheckpoint(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
@@ -146,7 +146,7 @@ class CheckpointsTest(test.TestCase):
def testInitWithScopeDoesNotCaptureSuffixes(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
_, _, _, v4 = _create_checkpoints(session, checkpoint_dir)
with ops.Graph().as_default() as g:
@@ -165,7 +165,7 @@ class CheckpointsTest(test.TestCase):
def testInitFromRootCheckpoint(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
@@ -189,7 +189,7 @@ class CheckpointsTest(test.TestCase):
def testInitToRootCheckpoint(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
@@ -212,7 +212,7 @@ class CheckpointsTest(test.TestCase):
def testInitFromPartitionVar(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
v1 = _create_partition_checkpoints(session, checkpoint_dir)
# New graph and session.
@@ -266,7 +266,7 @@ class CheckpointsTest(test.TestCase):
def testInitFromCheckpointMissing(self):
checkpoint_dir = self.get_temp_dir()
- with self.test_session() as session:
+ with self.cached_session() as session:
_, _, _, _ = _create_checkpoints(session, checkpoint_dir)
# New graph and session.
diff --git a/tensorflow/contrib/framework/python/framework/tensor_util.py b/tensorflow/contrib/framework/python/framework/tensor_util.py
index 4e6eea8884..bdf8aeb2b8 100644
--- a/tensorflow/contrib/framework/python/framework/tensor_util.py
+++ b/tensorflow/contrib/framework/python/framework/tensor_util.py
@@ -23,6 +23,7 @@ import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
@@ -129,10 +130,25 @@ def remove_squeezable_dimensions(predictions, labels, name=None):
return predictions, labels
-def _all_equal(tensor0, tensor1):
- with ops.name_scope('all_equal', values=[tensor0, tensor1]) as scope:
+def _shape_tensor_compatible(expected_shape, actual_shape):
+ """Returns whether actual_shape is compatible with expected_shape.
+
+ Note that -1 in `expected_shape` is recognized as unknown dimension.
+
+ Args:
+ expected_shape: Integer list defining the expected shape, or tensor of same.
+ actual_shape: Shape of the tensor to test.
+ Returns:
+ New tensor.
+ """
+ with ops.name_scope('shape_tensor_equal',
+ values=[expected_shape, actual_shape]) as scope:
return math_ops.reduce_all(
- math_ops.equal(tensor0, tensor1, name='equal'), name=scope)
+ math_ops.logical_or(
+ math_ops.equal(expected_shape, -1),
+ math_ops.equal(expected_shape, actual_shape, 'equal'),
+ name='exclude_partial_shape'),
+ name=scope)
def _is_rank(expected_rank, actual_tensor):
@@ -153,6 +169,8 @@ def _is_rank(expected_rank, actual_tensor):
def _is_shape(expected_shape, actual_tensor, actual_shape=None):
"""Returns whether actual_tensor's shape is expected_shape.
+ Note that -1 in `expected_shape` is recognized as unknown dimension.
+
Args:
expected_shape: Integer list defining the expected shape, or tensor of same.
actual_tensor: Tensor to test.
@@ -164,15 +182,15 @@ def _is_shape(expected_shape, actual_tensor, actual_shape=None):
is_rank = _is_rank(array_ops.size(expected_shape), actual_tensor)
if actual_shape is None:
actual_shape = array_ops.shape(actual_tensor, name='actual')
- shape_equal = _all_equal(
- ops.convert_to_tensor(expected_shape, name='expected'),
- actual_shape)
+ shape_equal = _shape_tensor_compatible(expected_shape, actual_shape)
return math_ops.logical_and(is_rank, shape_equal, name=scope)
def _assert_shape_op(expected_shape, actual_tensor):
"""Asserts actual_tensor's shape is expected_shape.
+ Note that unknown dimension in `expected_shape` will be ignored.
+
Args:
expected_shape: List of integers defining the expected shape, or tensor of
same.
@@ -182,6 +200,9 @@ def _assert_shape_op(expected_shape, actual_tensor):
"""
with ops.name_scope('assert_shape', values=[actual_tensor]) as scope:
actual_shape = array_ops.shape(actual_tensor, name='actual')
+ if (isinstance(expected_shape, tensor_shape.TensorShape)
+ and not expected_shape.is_fully_defined()):
+ expected_shape = [d if d else -1 for d in expected_shape.as_list()]
is_shape = _is_shape(expected_shape, actual_tensor, actual_shape)
return control_flow_ops.Assert(
is_shape, [
diff --git a/tensorflow/contrib/framework/python/framework/tensor_util_test.py b/tensorflow/contrib/framework/python/framework/tensor_util_test.py
index 9db2670304..b1820c10c8 100644
--- a/tensorflow/contrib/framework/python/framework/tensor_util_test.py
+++ b/tensorflow/contrib/framework/python/framework/tensor_util_test.py
@@ -29,7 +29,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables as variables_lib
@@ -39,7 +39,7 @@ from tensorflow.python.platform import test
class LocalVariabletest(test.TestCase):
def test_local_variable(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEquals([], variables_lib.local_variables())
value0 = 42
variables_lib2.local_variable(value0)
@@ -55,7 +55,7 @@ class LocalVariabletest(test.TestCase):
class ReduceSumNTest(test.TestCase):
def test_reduce_sum_n(self):
- with self.test_session():
+ with self.cached_session():
a = constant_op.constant(1)
b = constant_op.constant([2])
c = constant_op.constant([[3, 4], [5, 6]])
@@ -119,13 +119,13 @@ class WithShapeTest(test.TestCase):
}))
def test_with_shape_invalid_expected_shape(self):
- with self.test_session():
+ with self.cached_session():
self.assertRaisesRegexp(ValueError, "Invalid rank",
tensor_util.with_shape, [[1], [2]],
constant_op.constant(1.0))
def test_with_shape_invalid_type(self):
- with self.test_session():
+ with self.cached_session():
self.assertRaisesRegexp(ValueError, "Invalid dtype",
tensor_util.with_shape, [1.1],
constant_op.constant([1.0]))
@@ -138,7 +138,7 @@ class WithShapeTest(test.TestCase):
constant_op.constant(1.0))
def test_with_shape_0(self):
- with self.test_session():
+ with self.cached_session():
value = 42
shape = [0]
unexpected_shapes = [[1], [2], [1, 1]]
@@ -150,7 +150,7 @@ class WithShapeTest(test.TestCase):
unexpected_shapes)
def test_with_shape_1(self):
- with self.test_session():
+ with self.cached_session():
value = [42]
shape = [1]
unexpected_shapes = [[0], [2], [1, 1]]
@@ -162,7 +162,7 @@ class WithShapeTest(test.TestCase):
unexpected_shapes)
def test_with_shape_2(self):
- with self.test_session():
+ with self.cached_session():
value = [42, 43]
shape = [2]
unexpected_shapes = [[0], [1], [2, 1]]
@@ -174,7 +174,7 @@ class WithShapeTest(test.TestCase):
unexpected_shapes)
def test_with_shape_2x2(self):
- with self.test_session():
+ with self.cached_session():
value = [[42, 43], [44, 45]]
shape = [2, 2]
unexpected_shapes = [[0], [1], [2, 1]]
@@ -185,8 +185,18 @@ class WithShapeTest(test.TestCase):
shape,
unexpected_shapes)
- def test_with_shape_none(self):
+ def test_with_shape_2x2_with_partial_expected_shape(self):
with self.test_session():
+ value = [[42, 43], [44, 45]]
+ actual_shape = [2, 2]
+ tensor = constant_op.constant(value, shape=actual_shape)
+ partial_expected_shape = tensor_shape.TensorShape([None, 2])
+ # Won't raise any exception here:
+ tensor_with_shape = tensor_util.with_shape(partial_expected_shape, tensor)
+ np.testing.assert_array_equal(value, tensor_with_shape.eval())
+
+ def test_with_shape_none(self):
+ with self.cached_session():
tensor_no_shape = array_ops.placeholder(dtypes.float32)
compatible_shape = [2, 2]
@@ -210,7 +220,7 @@ class WithShapeTest(test.TestCase):
@test_util.enable_c_shapes
def test_with_shape_partial(self):
- with self.test_session():
+ with self.cached_session():
tensor_partial_shape = array_ops.placeholder(dtypes.float32)
tensor_partial_shape.set_shape([None, 2])
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
index 0ccb4583ab..716bb87e38 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
@@ -174,7 +174,7 @@ class FusedConv2DBiasActivationOp : public OpKernel {
// Input bias is a 1-D tensor, with size matching output depth.
const Tensor& bias = context->input(kBias);
- OP_REQUIRES_OK(context, CheckShape(bias, "conv_input"));
+ OP_REQUIRES_OK(context, CheckShape(bias, "bias"));
const Tensor& conv_input_scale_tensor = context->input(kConvInputScale);
const Tensor& side_input_scale_tensor = context->input(kSideInputScale);
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD
index 9866fccfba..9d0e6e1335 100644
--- a/tensorflow/contrib/gan/BUILD
+++ b/tensorflow/contrib/gan/BUILD
@@ -105,6 +105,7 @@ py_library(
deps = [
":gan_estimator",
":head",
+ ":stargan_estimator",
"//tensorflow/python:util",
],
)
@@ -534,6 +535,57 @@ py_test(
)
py_library(
+ name = "stargan_estimator",
+ srcs = [
+ "python/estimator/python/stargan_estimator.py",
+ "python/estimator/python/stargan_estimator_impl.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":namedtuples",
+ ":summaries",
+ ":train",
+ "//tensorflow/contrib/framework:framework_py",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:metrics",
+ "//tensorflow/python:util",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/estimator:estimator_py",
+ ],
+)
+
+py_test(
+ name = "stargan_estimator_test",
+ srcs = ["python/estimator/python/stargan_estimator_test.py"],
+ shard_count = 1,
+ srcs_version = "PY2AND3",
+ tags = ["notsan"],
+ deps = [
+ ":namedtuples",
+ ":stargan_estimator",
+ ":tuple_losses",
+ "//tensorflow/contrib/layers:layers_py",
+ "//tensorflow/contrib/learn",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:metrics",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:summary",
+ "//tensorflow/python:training",
+ "//tensorflow/python:training_util",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/estimator:estimator_py",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ "@six_archive//:six",
+ ],
+)
+
+py_library(
name = "sliced_wasserstein",
srcs = [
"python/eval/python/sliced_wasserstein.py",
diff --git a/tensorflow/contrib/gan/python/estimator/__init__.py b/tensorflow/contrib/gan/python/estimator/__init__.py
index c9f7bc61b2..99d38011ba 100644
--- a/tensorflow/contrib/gan/python/estimator/__init__.py
+++ b/tensorflow/contrib/gan/python/estimator/__init__.py
@@ -26,15 +26,18 @@ from __future__ import print_function
# pylint: disable=unused-import,wildcard-import
from tensorflow.contrib.gan.python.estimator.python import gan_estimator
from tensorflow.contrib.gan.python.estimator.python import head
+from tensorflow.contrib.gan.python.estimator.python import stargan_estimator
from tensorflow.contrib.gan.python.estimator.python.gan_estimator import *
from tensorflow.contrib.gan.python.estimator.python.head import *
+from tensorflow.contrib.gan.python.estimator.python.stargan_estimator import *
# pylint: enable=unused-import,wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'gan_estimator',
+ 'stargan_estimator',
'head',
-] + gan_estimator.__all__ + head.__all__
+] + gan_estimator.__all__ + stargan_estimator.__all__ + head.__all__
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
index ab9886580d..7243f150ce 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
@@ -184,7 +184,7 @@ class GANEstimator(estimator.Estimator):
return _get_estimator_spec(
mode, gan_model, generator_loss_fn, discriminator_loss_fn,
get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer,
- get_hooks_fn)
+ get_hooks_fn, use_loss_summaries)
super(GANEstimator, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config)
@@ -211,15 +211,17 @@ def _get_gan_model(
def _get_estimator_spec(
mode, gan_model, generator_loss_fn, discriminator_loss_fn,
get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer,
- get_hooks_fn=None):
+ get_hooks_fn=None, use_loss_summaries=True):
"""Get the EstimatorSpec for the current mode."""
if mode == model_fn_lib.ModeKeys.PREDICT:
estimator_spec = model_fn_lib.EstimatorSpec(
mode=mode, predictions=gan_model.generated_data)
else:
gan_loss = tfgan_tuples.GANLoss(
- generator_loss=generator_loss_fn(gan_model),
- discriminator_loss=discriminator_loss_fn(gan_model))
+ generator_loss=generator_loss_fn(
+ gan_model, add_summaries=use_loss_summaries),
+ discriminator_loss=discriminator_loss_fn(
+ gan_model, add_summaries=use_loss_summaries))
if mode == model_fn_lib.ModeKeys.EVAL:
estimator_spec = _get_eval_estimator_spec(
gan_model, gan_loss, get_eval_metric_ops_fn)
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
index 9ac9c6ca9c..83f8dd641f 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
@@ -116,7 +116,7 @@ def get_dummy_gan_model():
discriminator_fn=None)
-def dummy_loss_fn(gan_model):
+def dummy_loss_fn(gan_model, add_summaries=True):
return math_ops.reduce_sum(gan_model.discriminator_real_outputs -
gan_model.discriminator_gen_outputs)
diff --git a/tensorflow/contrib/kfac/python/ops/optimizer_lib.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator.py
index 87d1866e06..341bdf9fbb 100644
--- a/tensorflow/contrib/kfac/python/ops/optimizer_lib.py
+++ b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator.py
@@ -12,19 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""The KFAC optimizer."""
+"""`tf.Learn` components for `GANEstimator`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-# pylint: disable=unused-import,line-too-long,wildcard-import
-from tensorflow.contrib.kfac.python.ops.optimizer import *
+from tensorflow.contrib.gan.python.estimator.python import stargan_estimator_impl
+# pylint: disable=wildcard-import
+from tensorflow.contrib.gan.python.estimator.python.stargan_estimator_impl import *
+# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long,wildcard-import
-_allowed_symbols = [
- "KfacOptimizer",
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
+__all__ = stargan_estimator_impl.__all__
+remove_undocumented(__name__, __all__)
diff --git a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py
new file mode 100644
index 0000000000..f60e16bc04
--- /dev/null
+++ b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_impl.py
@@ -0,0 +1,363 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A TFGAN-backed StarGAN Estimator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import enum
+
+from tensorflow.contrib.framework.python.ops import variables as variable_lib
+from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples
+from tensorflow.contrib.gan.python import train as tfgan_train
+from tensorflow.contrib.gan.python.eval.python import summaries as tfgan_summaries
+from tensorflow.python.estimator import estimator
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import metrics as metrics_lib
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.util import tf_inspect as inspect
+
+__all__ = ['StarGANEstimator', 'SummaryType']
+
+
+class SummaryType(enum.IntEnum):
+ NONE = 0
+ VARIABLES = 1
+ IMAGES = 2
+ IMAGE_COMPARISON = 3
+
+
+_summary_type_map = {
+ SummaryType.VARIABLES: tfgan_summaries.add_gan_model_summaries,
+ SummaryType.IMAGES: tfgan_summaries.add_stargan_image_summaries,
+}
+
+
+class StarGANEstimator(estimator.Estimator):
+ """An estimator for Generative Adversarial Networks (GANs).
+
+ This Estimator is backed by TFGAN. The network functions follow the TFGAN API
+ except for one exception: if either `generator_fn` or `discriminator_fn` have
+ an argument called `mode`, then the tf.Estimator mode is passed in for that
+ argument. This helps with operations like batch normalization, which have
+ different train and evaluation behavior.
+
+ Example:
+
+ ```python
+ import tensorflow as tf
+ tfgan = tf.contrib.gan
+
+ # See TFGAN's `train.py` for a description of the generator and
+ # discriminator API.
+ def generator_fn(generator_inputs):
+ ...
+ return generated_data
+
+ def discriminator_fn(data, conditioning):
+ ...
+ return logits
+
+ # Create GAN estimator.
+ stargan_estimator = tfgan.estimator.StarGANEstimator(
+ model_dir,
+ generator_fn=generator_fn,
+ discriminator_fn=discriminator_fn,
+ loss_fn=loss_fn,
+ generator_optimizer=tf.train.AdamOptimizer(0.1, 0.5),
+ discriminator_optimizer=tf.train.AdamOptimizer(0.1, 0.5))
+
+ # Train estimator.
+ stargan_estimator.train(train_input_fn, steps)
+
+ # Evaluate resulting estimator.
+ stargan_estimator.evaluate(eval_input_fn)
+
+ # Generate samples from generator.
+ stargan_estimator = np.array([
+ x for x in stargan_estimator.predict(predict_input_fn)])
+ ```
+ """
+
+ def __init__(self,
+ model_dir=None,
+ generator_fn=None,
+ discriminator_fn=None,
+ loss_fn=None,
+ generator_optimizer=None,
+ discriminator_optimizer=None,
+ get_hooks_fn=None,
+ get_eval_metric_ops_fn=None,
+ add_summaries=None,
+ use_loss_summaries=True,
+ config=None):
+ """Initializes a StarGANEstimator instance.
+
+ Args:
+ model_dir: Directory to save model parameters, graph and etc. This can
+ also be used to load checkpoints from the directory into a estimator to
+ continue training a previously saved model.
+ generator_fn: A python function that takes a Tensor, Tensor list, or
+ Tensor dictionary as inputs and returns the outputs of the GAN
+ generator. See `TFGAN` for more details and examples. Additionally, if
+ it has an argument called `mode`, the Estimator's `mode` will be passed
+ in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch
+ normalization.
+ discriminator_fn: A python function that takes the output of
+ `generator_fn` or real data in the GAN setup, and `input_data`. Outputs
+ a Tensor in the range [-inf, inf]. See `TFGAN` for more details and
+ examples.
+ loss_fn: The loss function on the generator. Takes a `StarGANModel`
+ namedtuple and return a `GANLoss` namedtuple.
+ generator_optimizer: The optimizer for generator updates, or a function
+ that takes no arguments and returns an optimizer. This function will be
+ called when the default graph is the `StarGANEstimator`'s graph, so
+ utilities like `tf.contrib.framework.get_or_create_global_step` will
+ work.
+ discriminator_optimizer: Same as `generator_optimizer`, but for the
+ discriminator updates.
+ get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a
+ list of hooks. These hooks are run on the generator and discriminator
+ train ops, and can be used to implement the GAN training scheme.
+ Defaults to `train.get_sequential_train_hooks()`.
+ get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a
+ dict of metric results keyed by name. The output of this function is
+ passed into `tf.estimator.EstimatorSpec` during evaluation.
+ add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`.
+ use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
+ If `None`, uses defaults.
+ config: `RunConfig` object to configure the runtime settings.
+
+ Raises:
+ ValueError: If loss functions aren't callable.
+ ValueError: If `use_loss_summaries` isn't boolean or `None`.
+ ValueError: If `get_hooks_fn` isn't callable or `None`.
+ """
+ if not callable(loss_fn):
+ raise ValueError('loss_fn must be callable.')
+ if use_loss_summaries not in [True, False, None]:
+ raise ValueError('use_loss_summaries must be True, False or None.')
+ if get_hooks_fn is not None and not callable(get_hooks_fn):
+ raise TypeError('get_hooks_fn must be callable.')
+
+ def _model_fn(features, labels, mode):
+ """StarGANEstimator model function."""
+ if mode not in [
+ model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL,
+ model_fn_lib.ModeKeys.PREDICT
+ ]:
+ raise ValueError('Mode not recognized: %s' % mode)
+
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ input_data = features[0]
+ input_data_domain_label = features[1]
+ else:
+ input_data = features # rename inputs for clarity
+ input_data_domain_label = labels # rename inputs for clarity
+
+ # Make StarGANModel, which encapsulates the GAN model architectures.
+ gan_model = _get_gan_model(mode, generator_fn, discriminator_fn,
+ input_data, input_data_domain_label,
+ add_summaries)
+
+ # Make the EstimatorSpec, which incorporates the StarGANModel, losses,
+ # eval, metrics, and optimizers (if required).
+ return _get_estimator_spec(mode, gan_model, loss_fn,
+ get_eval_metric_ops_fn, generator_optimizer,
+ discriminator_optimizer, get_hooks_fn)
+
+ super(StarGANEstimator, self).__init__(
+ model_fn=_model_fn, model_dir=model_dir, config=config)
+
+
+def _get_gan_model(mode,
+ generator_fn,
+ discriminator_fn,
+ input_data,
+ input_data_domain_label,
+ add_summaries,
+ generator_scope='Generator'):
+ """Makes the StarGANModel tuple."""
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ gan_model = _make_prediction_gan_model(input_data, input_data_domain_label,
+ generator_fn, generator_scope)
+ else: # model_fn_lib.ModeKeys.TRAIN or model_fn_lib.ModeKeys.EVAL
+ gan_model = _make_gan_model(generator_fn, discriminator_fn, input_data,
+ input_data_domain_label, generator_scope,
+ add_summaries, mode)
+
+ return gan_model
+
+
+def _get_estimator_spec(mode,
+ gan_model,
+ loss_fn,
+ get_eval_metric_ops_fn,
+ generator_optimizer,
+ discriminator_optimizer,
+ get_hooks_fn=None):
+ """Get the EstimatorSpec for the current mode."""
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ estimator_spec = model_fn_lib.EstimatorSpec(
+ mode=mode, predictions=gan_model.generated_data)
+ else:
+ gan_loss = loss_fn(gan_model)
+ if mode == model_fn_lib.ModeKeys.EVAL:
+ estimator_spec = _get_eval_estimator_spec(gan_model, gan_loss,
+ get_eval_metric_ops_fn)
+ else: # model_fn_lib.ModeKeys.TRAIN:
+ gopt = (
+ generator_optimizer()
+ if callable(generator_optimizer) else generator_optimizer)
+ dopt = (
+ discriminator_optimizer()
+ if callable(discriminator_optimizer) else discriminator_optimizer)
+ get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks()
+ estimator_spec = _get_train_estimator_spec(gan_model, gan_loss, gopt,
+ dopt, get_hooks_fn)
+
+ return estimator_spec
+
+
+def _make_gan_model(generator_fn, discriminator_fn, input_data,
+ input_data_domain_label, generator_scope, add_summaries,
+ mode):
+ """Construct a `StarGANModel`, and optionally pass in `mode`."""
+ # If network functions have an argument `mode`, pass mode to it.
+ if 'mode' in inspect.getargspec(generator_fn).args:
+ generator_fn = functools.partial(generator_fn, mode=mode)
+ if 'mode' in inspect.getargspec(discriminator_fn).args:
+ discriminator_fn = functools.partial(discriminator_fn, mode=mode)
+ gan_model = tfgan_train.stargan_model(
+ generator_fn,
+ discriminator_fn,
+ input_data,
+ input_data_domain_label,
+ generator_scope=generator_scope)
+ if add_summaries:
+ if not isinstance(add_summaries, (tuple, list)):
+ add_summaries = [add_summaries]
+ with ops.name_scope(None):
+ for summary_type in add_summaries:
+ _summary_type_map[summary_type](gan_model)
+
+ return gan_model
+
+
+def _make_prediction_gan_model(input_data, input_data_domain_label,
+ generator_fn, generator_scope):
+ """Make a `StarGANModel` from just the generator."""
+ # If `generator_fn` has an argument `mode`, pass mode to it.
+ if 'mode' in inspect.getargspec(generator_fn).args:
+ generator_fn = functools.partial(
+ generator_fn, mode=model_fn_lib.ModeKeys.PREDICT)
+ with variable_scope.variable_scope(generator_scope) as gen_scope:
+ # pylint:disable=protected-access
+ input_data = tfgan_train._convert_tensor_or_l_or_d(input_data)
+ input_data_domain_label = tfgan_train._convert_tensor_or_l_or_d(
+ input_data_domain_label)
+ # pylint:enable=protected-access
+ generated_data = generator_fn(input_data, input_data_domain_label)
+ generator_variables = variable_lib.get_trainable_variables(gen_scope)
+
+ return tfgan_tuples.StarGANModel(
+ input_data=input_data,
+ input_data_domain_label=None,
+ generated_data=generated_data,
+ generated_data_domain_target=input_data_domain_label,
+ reconstructed_data=None,
+ discriminator_input_data_source_predication=None,
+ discriminator_generated_data_source_predication=None,
+ discriminator_input_data_domain_predication=None,
+ discriminator_generated_data_domain_predication=None,
+ generator_variables=generator_variables,
+ generator_scope=generator_scope,
+ generator_fn=generator_fn,
+ discriminator_variables=None,
+ discriminator_scope=None,
+ discriminator_fn=None)
+
+
+def _get_eval_estimator_spec(gan_model,
+ gan_loss,
+ get_eval_metric_ops_fn=None,
+ name=None):
+ """Return an EstimatorSpec for the eval case."""
+ scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss
+ with ops.name_scope(None, 'metrics',
+ [gan_loss.generator_loss, gan_loss.discriminator_loss]):
+
+ def _summary_key(head_name, val):
+ return '%s/%s' % (val, head_name) if head_name else val
+
+ eval_metric_ops = {
+ _summary_key(name, 'generator_loss'):
+ metrics_lib.mean(gan_loss.generator_loss),
+ _summary_key(name, 'discriminator_loss'):
+ metrics_lib.mean(gan_loss.discriminator_loss)
+ }
+ if get_eval_metric_ops_fn is not None:
+ custom_eval_metric_ops = get_eval_metric_ops_fn(gan_model)
+ if not isinstance(custom_eval_metric_ops, dict):
+ raise TypeError('get_eval_metric_ops_fn must return a dict, '
+ 'received: {}'.format(custom_eval_metric_ops))
+ eval_metric_ops.update(custom_eval_metric_ops)
+ return model_fn_lib.EstimatorSpec(
+ mode=model_fn_lib.ModeKeys.EVAL,
+ predictions=gan_model.generated_data,
+ loss=scalar_loss,
+ eval_metric_ops=eval_metric_ops)
+
+
+def _get_train_estimator_spec(gan_model,
+ gan_loss,
+ generator_optimizer,
+ discriminator_optimizer,
+ get_hooks_fn,
+ train_op_fn=tfgan_train.gan_train_ops):
+ """Return an EstimatorSpec for the train case."""
+ scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss
+ train_ops = train_op_fn(gan_model, gan_loss, generator_optimizer,
+ discriminator_optimizer)
+ training_hooks = get_hooks_fn(train_ops)
+ return model_fn_lib.EstimatorSpec(
+ loss=scalar_loss,
+ mode=model_fn_lib.ModeKeys.TRAIN,
+ train_op=train_ops.global_step_inc_op,
+ training_hooks=training_hooks)
+
+
+def stargan_prediction_input_fn_wrapper(fn):
+ """StarGAN Estimator prediction input_fn wrapper.
+
+ Since estimator will disregard the "label" variable pass to the model, we will
+ use a wrapper to pack the (feature, label) tuple as feature passed to the
+ model.
+
+ Args:
+ fn: input_fn for the prediction.
+
+ Returns:
+ A tuple ((feature, label), None) where the second element is the dummy label
+ to be disregarded and the first element is the true input to the estimator.
+ """
+
+ def new_fn():
+ return fn(), None
+
+ return new_fn
diff --git a/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py
new file mode 100644
index 0000000000..2ec7938c7c
--- /dev/null
+++ b/tensorflow/contrib/gan/python/estimator/python/stargan_estimator_test.py
@@ -0,0 +1,306 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for TFGAN's stargan_estimator.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import shutil
+import tempfile
+
+from absl.testing import parameterized
+import numpy as np
+import six
+
+from tensorflow.contrib import layers
+from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples
+from tensorflow.contrib.gan.python.estimator.python import stargan_estimator_impl as estimator
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator.inputs import numpy_io
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import metrics as metrics_lib
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import test
+from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import learning_rate_decay
+from tensorflow.python.training import training
+from tensorflow.python.training import training_util
+
+
+def dummy_generator_fn(input_data, input_data_domain_label, mode):
+ del input_data_domain_label, mode
+
+ return variable_scope.get_variable('dummy_g', initializer=0.5) * input_data
+
+
+def dummy_discriminator_fn(input_data, num_domains, mode):
+ del mode
+
+ hidden = layers.flatten(input_data)
+ output_src = math_ops.reduce_mean(hidden, axis=1)
+ output_cls = layers.fully_connected(
+ inputs=hidden, num_outputs=num_domains, scope='debug')
+
+ return output_src, output_cls
+
+
+class StarGetGANModelTest(test.TestCase, parameterized.TestCase):
+ """Tests that `StarGetGANModel` produces the correct model."""
+
+ @parameterized.named_parameters(('train', model_fn_lib.ModeKeys.TRAIN),
+ ('eval', model_fn_lib.ModeKeys.EVAL),
+ ('predict', model_fn_lib.ModeKeys.PREDICT))
+ def test_get_gan_model(self, mode):
+ with ops.Graph().as_default():
+ input_data = array_ops.ones([6, 4, 4, 3])
+ input_data_domain_label = array_ops.one_hot([0] * 6, 5)
+ gan_model = estimator._get_gan_model(
+ mode,
+ dummy_generator_fn,
+ dummy_discriminator_fn,
+ input_data,
+ input_data_domain_label,
+ add_summaries=False)
+
+ self.assertEqual(input_data, gan_model.input_data)
+ self.assertIsNotNone(gan_model.generated_data)
+ self.assertIsNotNone(gan_model.generated_data_domain_target)
+ self.assertEqual(1, len(gan_model.generator_variables))
+ self.assertIsNotNone(gan_model.generator_scope)
+ self.assertIsNotNone(gan_model.generator_fn)
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ self.assertIsNone(gan_model.input_data_domain_label)
+ self.assertEqual(input_data_domain_label,
+ gan_model.generated_data_domain_target)
+ self.assertIsNone(gan_model.reconstructed_data)
+ self.assertIsNone(gan_model.discriminator_input_data_source_predication)
+ self.assertIsNone(
+ gan_model.discriminator_generated_data_source_predication)
+ self.assertIsNone(gan_model.discriminator_input_data_domain_predication)
+ self.assertIsNone(
+ gan_model.discriminator_generated_data_domain_predication)
+ self.assertIsNone(gan_model.discriminator_variables)
+ self.assertIsNone(gan_model.discriminator_scope)
+ self.assertIsNone(gan_model.discriminator_fn)
+ else:
+ self.assertEqual(input_data_domain_label,
+ gan_model.input_data_domain_label)
+ self.assertIsNotNone(gan_model.reconstructed_data.shape)
+ self.assertIsNotNone(
+ gan_model.discriminator_input_data_source_predication)
+ self.assertIsNotNone(
+ gan_model.discriminator_generated_data_source_predication)
+ self.assertIsNotNone(
+ gan_model.discriminator_input_data_domain_predication)
+ self.assertIsNotNone(
+ gan_model.discriminator_generated_data_domain_predication)
+ self.assertEqual(2, len(gan_model.discriminator_variables)) # 1 FC layer
+ self.assertIsNotNone(gan_model.discriminator_scope)
+ self.assertIsNotNone(gan_model.discriminator_fn)
+
+
+def get_dummy_gan_model():
+ """Similar to get_gan_model()."""
+ # TODO(joelshor): Find a better way of creating a variable scope.
+ with variable_scope.variable_scope('generator') as gen_scope:
+ gen_var = variable_scope.get_variable('dummy_var', initializer=0.0)
+ with variable_scope.variable_scope('discriminator') as dis_scope:
+ dis_var = variable_scope.get_variable('dummy_var', initializer=0.0)
+ return tfgan_tuples.StarGANModel(
+ input_data=array_ops.ones([1, 2, 2, 3]),
+ input_data_domain_label=array_ops.ones([1, 2]),
+ generated_data=array_ops.ones([1, 2, 2, 3]),
+ generated_data_domain_target=array_ops.ones([1, 2]),
+ reconstructed_data=array_ops.ones([1, 2, 2, 3]),
+ discriminator_input_data_source_predication=array_ops.ones([1]) * dis_var,
+ discriminator_generated_data_source_predication=array_ops.ones(
+ [1]) * gen_var * dis_var,
+ discriminator_input_data_domain_predication=array_ops.ones([1, 2
+ ]) * dis_var,
+ discriminator_generated_data_domain_predication=array_ops.ones([1, 2]) *
+ gen_var * dis_var,
+ generator_variables=[gen_var],
+ generator_scope=gen_scope,
+ generator_fn=None,
+ discriminator_variables=[dis_var],
+ discriminator_scope=dis_scope,
+ discriminator_fn=None)
+
+
+def dummy_loss_fn(gan_model):
+ loss = math_ops.reduce_sum(
+ gan_model.discriminator_input_data_domain_predication -
+ gan_model.discriminator_generated_data_domain_predication)
+ loss += math_ops.reduce_sum(gan_model.input_data - gan_model.generated_data)
+ return tfgan_tuples.GANLoss(loss, loss)
+
+
+def get_metrics(gan_model):
+ return {
+ 'mse_custom_metric':
+ metrics_lib.mean_squared_error(gan_model.input_data,
+ gan_model.generated_data)
+ }
+
+
+class GetEstimatorSpecTest(test.TestCase, parameterized.TestCase):
+ """Tests that the EstimatorSpec is constructed appropriately."""
+
+ @classmethod
+ def setUpClass(cls):
+ cls._generator_optimizer = training.GradientDescentOptimizer(1.0)
+ cls._discriminator_optimizer = training.GradientDescentOptimizer(1.0)
+
+ @parameterized.named_parameters(('train', model_fn_lib.ModeKeys.TRAIN),
+ ('eval', model_fn_lib.ModeKeys.EVAL),
+ ('predict', model_fn_lib.ModeKeys.PREDICT))
+ def test_get_estimator_spec(self, mode):
+ with ops.Graph().as_default():
+ self._gan_model = get_dummy_gan_model()
+ spec = estimator._get_estimator_spec(
+ mode,
+ self._gan_model,
+ loss_fn=dummy_loss_fn,
+ get_eval_metric_ops_fn=get_metrics,
+ generator_optimizer=self._generator_optimizer,
+ discriminator_optimizer=self._discriminator_optimizer)
+
+ self.assertEqual(mode, spec.mode)
+ if mode == model_fn_lib.ModeKeys.PREDICT:
+ self.assertEqual(self._gan_model.generated_data, spec.predictions)
+ elif mode == model_fn_lib.ModeKeys.TRAIN:
+ self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar
+ self.assertIsNotNone(spec.train_op)
+ self.assertIsNotNone(spec.training_hooks)
+ elif mode == model_fn_lib.ModeKeys.EVAL:
+ self.assertEqual(self._gan_model.generated_data, spec.predictions)
+ self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar
+ self.assertIsNotNone(spec.eval_metric_ops)
+
+
+# TODO(joelshor): Add pandas test.
+class StarGANEstimatorIntegrationTest(test.TestCase):
+
+ def setUp(self):
+ self._model_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ if self._model_dir:
+ writer_cache.FileWriterCache.clear()
+ shutil.rmtree(self._model_dir)
+
+ def _test_complete_flow(self,
+ train_input_fn,
+ eval_input_fn,
+ predict_input_fn,
+ prediction_size,
+ lr_decay=False):
+
+ def make_opt():
+ gstep = training_util.get_or_create_global_step()
+ lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9)
+ return training.GradientDescentOptimizer(lr)
+
+ gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0)
+ dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0)
+ est = estimator.StarGANEstimator(
+ generator_fn=dummy_generator_fn,
+ discriminator_fn=dummy_discriminator_fn,
+ loss_fn=dummy_loss_fn,
+ generator_optimizer=gopt,
+ discriminator_optimizer=dopt,
+ get_eval_metric_ops_fn=get_metrics,
+ model_dir=self._model_dir)
+
+ # TRAIN
+ num_steps = 10
+ est.train(train_input_fn, steps=num_steps)
+
+ # EVALUTE
+ scores = est.evaluate(eval_input_fn)
+ self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
+ self.assertIn('loss', six.iterkeys(scores))
+ self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'],
+ scores['loss'])
+ self.assertIn('mse_custom_metric', six.iterkeys(scores))
+
+ # PREDICT
+ predictions = np.array([x for x in est.predict(predict_input_fn)])
+
+ self.assertAllEqual(prediction_size, predictions.shape)
+
+ @staticmethod
+ def _numpy_input_fn_wrapper(numpy_input_fn, batch_size, label_size):
+ """Wrapper to remove the dictionary in numpy_input_fn.
+
+ NOTE:
+ We create the domain_label here because the model expect a fully define
+ batch_size from the input.
+
+ Args:
+ numpy_input_fn: input_fn created from numpy_io
+ batch_size: (int) number of items for each batch
+ label_size: (int) number of domains
+
+ Returns:
+ a new input_fn
+ """
+
+ def new_input_fn():
+ features = numpy_input_fn()
+ return features['x'], array_ops.one_hot([0] * batch_size, label_size)
+
+ return new_input_fn
+
+ def test_numpy_input_fn(self):
+ """Tests complete flow with numpy_input_fn."""
+ batch_size = 5
+ img_size = 8
+ channel_size = 3
+ label_size = 3
+ image_data = np.zeros(
+ [batch_size, img_size, img_size, channel_size], dtype=np.float32)
+ train_input_fn = numpy_io.numpy_input_fn(
+ x={'x': image_data},
+ batch_size=batch_size,
+ num_epochs=None,
+ shuffle=True)
+ eval_input_fn = numpy_io.numpy_input_fn(
+ x={'x': image_data}, batch_size=batch_size, shuffle=False)
+ predict_input_fn = numpy_io.numpy_input_fn(
+ x={'x': image_data}, shuffle=False)
+
+ train_input_fn = self._numpy_input_fn_wrapper(train_input_fn, batch_size,
+ label_size)
+ eval_input_fn = self._numpy_input_fn_wrapper(eval_input_fn, batch_size,
+ label_size)
+ predict_input_fn = self._numpy_input_fn_wrapper(predict_input_fn,
+ batch_size, label_size)
+
+ predict_input_fn = estimator.stargan_prediction_input_fn_wrapper(
+ predict_input_fn)
+
+ self._test_complete_flow(
+ train_input_fn=train_input_fn,
+ eval_input_fn=eval_input_fn,
+ predict_input_fn=predict_input_fn,
+ prediction_size=[batch_size, img_size, img_size, channel_size])
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
index 9f5fee4542..e3c780ac1a 100644
--- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
+++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
@@ -51,7 +51,7 @@ class _LossesTest(object):
loss = self._g_loss_fn(self._discriminator_gen_outputs)
self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
self.assertEqual(self._generator_loss_name, loss.op.name)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
def test_discriminator_all_correct(self):
@@ -59,7 +59,7 @@ class _LossesTest(object):
self._discriminator_real_outputs, self._discriminator_gen_outputs)
self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
self.assertEqual(self._discriminator_loss_name, loss.op.name)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
def test_generator_loss_collection(self):
@@ -90,7 +90,7 @@ class _LossesTest(object):
loss = self._g_loss_fn(
array_ops.reshape(self._discriminator_gen_outputs, [2, 2]))
self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
def test_discriminator_patch(self):
@@ -98,7 +98,7 @@ class _LossesTest(object):
array_ops.reshape(self._discriminator_real_outputs, [2, 2]),
array_ops.reshape(self._discriminator_gen_outputs, [2, 2]))
self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
def test_generator_loss_with_placeholder_for_logits(self):
@@ -108,7 +108,7 @@ class _LossesTest(object):
loss = self._g_loss_fn(logits, weights=weights)
self.assertEqual(logits.dtype, loss.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
logits: [[10.0, 4.4, -5.5, 3.6]],
@@ -125,7 +125,7 @@ class _LossesTest(object):
logits, logits2, real_weights=real_weights,
generated_weights=generated_weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
logits: [self._discriminator_real_outputs_np],
@@ -136,7 +136,7 @@ class _LossesTest(object):
def test_generator_with_python_scalar_weight(self):
loss = self._g_loss_fn(
self._discriminator_gen_outputs, weights=self._weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss * self._weights,
loss.eval(), 4)
@@ -144,14 +144,14 @@ class _LossesTest(object):
loss = self._d_loss_fn(
self._discriminator_real_outputs, self._discriminator_gen_outputs,
real_weights=self._weights, generated_weights=self._weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss * self._weights,
loss.eval(), 4)
def test_generator_with_scalar_tensor_weight(self):
loss = self._g_loss_fn(self._discriminator_gen_outputs,
weights=constant_op.constant(self._weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss * self._weights,
loss.eval(), 4)
@@ -160,7 +160,7 @@ class _LossesTest(object):
loss = self._d_loss_fn(
self._discriminator_real_outputs, self._discriminator_gen_outputs,
real_weights=weights, generated_weights=weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss * self._weights,
loss.eval(), 4)
@@ -284,7 +284,7 @@ class ACGANLossTest(test.TestCase):
self.assertEqual(
self._discriminator_gen_classification_logits.dtype, loss.dtype)
self.assertEqual(self._generator_loss_name, loss.op.name)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
def test_discriminator_all_correct(self):
@@ -292,7 +292,7 @@ class ACGANLossTest(test.TestCase):
self.assertEqual(
self._discriminator_gen_classification_logits.dtype, loss.dtype)
self.assertEqual(self._discriminator_loss_name, loss.op.name)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
def test_generator_loss_collection(self):
@@ -319,14 +319,14 @@ class ACGANLossTest(test.TestCase):
patch_args = {x: array_ops.reshape(y, [2, 2, 4]) for x, y in
self._generator_kwargs.items()}
loss = self._g_loss_fn(**patch_args)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
def test_discriminator_patch(self):
patch_args = {x: array_ops.reshape(y, [2, 2, 4]) for x, y in
self._discriminator_kwargs.items()}
loss = self._d_loss_fn(**patch_args)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
def test_generator_loss_with_placeholder_for_logits(self):
@@ -334,7 +334,7 @@ class ACGANLossTest(test.TestCase):
one_hot_labels = array_ops.placeholder(dtypes.int32, shape=(None, 4))
loss = self._g_loss_fn(gen_logits, one_hot_labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(
loss, feed_dict={
gen_logits: self._discriminator_gen_classification_logits_np,
@@ -349,7 +349,7 @@ class ACGANLossTest(test.TestCase):
loss = self._d_loss_fn(gen_logits, real_logits, one_hot_labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(
loss, feed_dict={
gen_logits: self._discriminator_gen_classification_logits_np,
@@ -360,7 +360,7 @@ class ACGANLossTest(test.TestCase):
def test_generator_with_python_scalar_weight(self):
loss = self._g_loss_fn(weights=self._weights, **self._generator_kwargs)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss * self._weights,
loss.eval(), 4)
@@ -368,14 +368,14 @@ class ACGANLossTest(test.TestCase):
loss = self._d_loss_fn(
real_weights=self._weights, generated_weights=self._weights,
**self._discriminator_kwargs)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss * self._weights,
loss.eval(), 4)
def test_generator_with_scalar_tensor_weight(self):
loss = self._g_loss_fn(
weights=constant_op.constant(self._weights), **self._generator_kwargs)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_g_loss * self._weights,
loss.eval(), 4)
@@ -383,7 +383,7 @@ class ACGANLossTest(test.TestCase):
weights = constant_op.constant(self._weights)
loss = self._d_loss_fn(real_weights=weights, generated_weights=weights,
**self._discriminator_kwargs)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(self._expected_d_loss * self._weights,
loss.eval(), 4)
@@ -404,7 +404,7 @@ class _PenaltyTest(object):
loss = self._penalty_fn(**self._kwargs)
self.assertEqual(self._expected_dtype, loss.dtype)
self.assertEqual(self._expected_op_name, loss.op.name)
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAlmostEqual(self._expected_loss, loss.eval(), 6)
@@ -419,13 +419,13 @@ class _PenaltyTest(object):
def test_python_scalar_weight(self):
loss = self._penalty_fn(weights=2.3, **self._kwargs)
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAlmostEqual(self._expected_loss * 2.3, loss.eval(), 3)
def test_scalar_tensor_weight(self):
loss = self._penalty_fn(weights=constant_op.constant(2.3), **self._kwargs)
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
self.assertAlmostEqual(self._expected_loss * 2.3, loss.eval(), 3)
@@ -472,7 +472,7 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest):
self._kwargs['discriminator_scope'])
self.assertEqual(generated_data.dtype, loss.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables.global_variables_initializer().run()
loss = sess.run(loss,
feed_dict={
@@ -494,7 +494,7 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest):
one_sided=True)
self.assertEqual(generated_data.dtype, loss.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables.global_variables_initializer().run()
loss = sess.run(loss,
feed_dict={
@@ -516,7 +516,7 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest):
self._kwargs['discriminator_scope'],
target=2.0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables.global_variables_initializer().run()
loss = sess.run(
loss,
diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
index a559bbfa11..25d74a8c23 100644
--- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
+++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
@@ -118,7 +118,7 @@ def add_loss_consistency_test(test_class, loss_name_str, loss_args):
def consistency_test(self):
self.assertEqual(arg_loss.__name__, tuple_loss.__name__)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(arg_loss(**loss_args).eval(),
tuple_loss(_tuple_from_dict(loss_args)).eval())
@@ -241,7 +241,7 @@ class StarGANLossWrapperTest(test.TestCase):
self.discriminator_generated_data_source_predication)
wrapped_loss_result_tensor = wrapped_loss_fn(self.model)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
loss_result, wrapped_loss_result = sess.run(
[loss_result_tensor, wrapped_loss_result_tensor])
@@ -257,7 +257,7 @@ class StarGANLossWrapperTest(test.TestCase):
self.discriminator_generated_data_source_predication)
wrapped_loss_result_tensor = wrapped_loss_fn(self.model)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
loss_result, wrapped_loss_result = sess.run(
[loss_result_tensor, wrapped_loss_result_tensor])
@@ -282,7 +282,7 @@ class StarGANLossWrapperTest(test.TestCase):
discriminator_scope=self.discriminator_scope)
wrapped_loss_result_tensor = wrapped_loss_fn(self.model)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
loss_result, wrapped_loss_result = sess.run(
[loss_result_tensor, wrapped_loss_result_tensor])
diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc
index 7e6a0f14f6..726f74c7b7 100644
--- a/tensorflow/contrib/gdr/gdr_memory_manager.cc
+++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc
@@ -186,22 +186,22 @@ class GdrMemoryManager : public RemoteMemoryManager {
// TODO(byronyi): remove this class and its registration when the default
// cpu_allocator() returns visitable allocator, or cpu_allocator() is no
// longer in use.
-class BFCRdmaAllocator : public BFCAllocator {
+class BFCGdrAllocator : public BFCAllocator {
public:
- BFCRdmaAllocator()
+ BFCGdrAllocator()
: BFCAllocator(new BasicCPUAllocator(port::kNUMANoAffinity), 1LL << 36,
- true, "cpu_rdma_bfc") {}
+ true, "cpu_gdr_bfc") {}
};
-class BFCRdmaAllocatorFactory : public AllocatorFactory {
+class BFCGdrAllocatorFactory : public AllocatorFactory {
public:
- Allocator* CreateAllocator() override { return new BFCRdmaAllocator; }
+ Allocator* CreateAllocator() override { return new BFCGdrAllocator; }
virtual SubAllocator* CreateSubAllocator(int numa_node) {
return new BasicCPUAllocator(numa_node);
}
};
-REGISTER_MEM_ALLOCATOR("BFCRdmaAllocator", 101, BFCRdmaAllocatorFactory);
+REGISTER_MEM_ALLOCATOR("BFCGdrAllocator", 102, BFCGdrAllocatorFactory);
GdrMemoryManager::GdrMemoryManager(const string& host, const string& port)
: host_(host),
diff --git a/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc b/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc
index 80b2d3e08b..2bf6097d01 100644
--- a/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc
+++ b/tensorflow/contrib/hadoop/kernels/hadoop_dataset_ops.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/platform/file_system.h"
namespace tensorflow {
+namespace data {
namespace {
static const size_t kSyncMarkerSize = 16;
@@ -332,9 +333,10 @@ class SequenceFileDatasetOp : public DatasetOpKernel {
};
DataTypeVector output_types_;
};
-} // namespace
REGISTER_KERNEL_BUILDER(Name("SequenceFileDataset").Device(DEVICE_CPU),
SequenceFileDatasetOp);
+} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc
index 693724b457..370a8caf6a 100644
--- a/tensorflow/contrib/image/kernels/image_ops.cc
+++ b/tensorflow/contrib/image/kernels/image_ops.cc
@@ -71,7 +71,6 @@ class ImageProjectiveTransform : public OpKernel {
void Compute(OpKernelContext* ctx) override {
const Tensor& images_t = ctx->input(0);
const Tensor& transform_t = ctx->input(1);
- const Tensor& shape_t = ctx->input(2);
OP_REQUIRES(ctx, images_t.shape().dims() == 4,
errors::InvalidArgument("Input images must have rank 4"));
OP_REQUIRES(ctx,
@@ -82,17 +81,28 @@ class ImageProjectiveTransform : public OpKernel {
ProjectiveGenerator<Device, T>::kNumParameters),
errors::InvalidArgument(
"Input transform should be num_images x 8 or 1 x 8"));
- OP_REQUIRES(ctx, shape_t.dims() == 1,
- errors::InvalidArgument("output shape must be 1-dimensional",
- shape_t.shape().DebugString()));
- OP_REQUIRES(ctx, shape_t.NumElements() == 2,
- errors::InvalidArgument("output shape must have two elements",
- shape_t.shape().DebugString()));
- auto shape_vec = shape_t.vec<int32>();
- int32 out_height = shape_vec(0);
- int32 out_width = shape_vec(1);
- OP_REQUIRES(ctx, out_height > 0 && out_width > 0,
- errors::InvalidArgument("output dimensions must be positive"));
+
+ int32 out_height, out_width;
+ // Kernel is shared by legacy "ImageProjectiveTransform" op with 2 args.
+ if (ctx->num_inputs() >= 3) {
+ const Tensor& shape_t = ctx->input(2);
+ OP_REQUIRES(ctx, shape_t.dims() == 1,
+ errors::InvalidArgument("output shape must be 1-dimensional",
+ shape_t.shape().DebugString()));
+ OP_REQUIRES(ctx, shape_t.NumElements() == 2,
+ errors::InvalidArgument("output shape must have two elements",
+ shape_t.shape().DebugString()));
+ auto shape_vec = shape_t.vec<int32>();
+ out_height = shape_vec(0);
+ out_width = shape_vec(1);
+ OP_REQUIRES(
+ ctx, out_height > 0 && out_width > 0,
+ errors::InvalidArgument("output dimensions must be positive"));
+ } else {
+ // Shape is N (batch size), H (height), W (width), C (channels).
+ out_height = images_t.shape().dim_size(1);
+ out_width = images_t.shape().dim_size(2);
+ }
Tensor* output_t;
OP_REQUIRES_OK(ctx, ctx->allocate_output(
@@ -109,10 +119,14 @@ class ImageProjectiveTransform : public OpKernel {
}
};
-#define REGISTER(TYPE) \
- REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransform") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<TYPE>("dtype"), \
+#define REGISTER(TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransform") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<TYPE>("dtype"), \
+ ImageProjectiveTransform<CPUDevice, TYPE>); \
+ REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<TYPE>("dtype"), \
ImageProjectiveTransform<CPUDevice, TYPE>)
TF_CALL_uint8(REGISTER);
@@ -147,11 +161,15 @@ TF_CALL_double(DECLARE_FUNCTOR);
} // end namespace functor
-#define REGISTER(TYPE) \
- REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransform") \
- .Device(DEVICE_GPU) \
- .TypeConstraint<TYPE>("dtype") \
- .HostMemory("output_shape"), \
+#define REGISTER(TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransform") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<TYPE>("dtype"), \
+ ImageProjectiveTransform<GPUDevice, TYPE>); \
+ REGISTER_KERNEL_BUILDER(Name("ImageProjectiveTransformV2") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<TYPE>("dtype") \
+ .HostMemory("output_shape"), \
ImageProjectiveTransform<GPUDevice, TYPE>)
TF_CALL_uint8(REGISTER);
diff --git a/tensorflow/contrib/image/ops/image_ops.cc b/tensorflow/contrib/image/ops/image_ops.cc
index 4969ac58f9..6f7c9bb520 100644
--- a/tensorflow/contrib/image/ops/image_ops.cc
+++ b/tensorflow/contrib/image/ops/image_ops.cc
@@ -67,19 +67,7 @@ Status ResizeShapeFn(InferenceContext* c) {
c->Dim(input, 3));
}
-} // namespace
-
-// TODO(ringwalt): Add a "fill_mode" argument with "constant", "mirror", etc.
-// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
-REGISTER_OP("ImageProjectiveTransform")
- .Input("images: dtype")
- .Input("transforms: float32")
- .Input("output_shape: int32")
- .Attr("dtype: {uint8, int32, int64, float16, float32, float64}")
- .Attr("interpolation: string")
- .Output("transformed_images: dtype")
- .SetShapeFn(ResizeShapeFn)
- .Doc(R"doc(
+static const char kImageProjectiveTransformDoc[] = R"doc(
Applies the given transform to each of the images.
Input `image` is a `Tensor` in NHWC format (where the axes are image in batch,
@@ -99,7 +87,35 @@ transforms: 2D `Tensor`, projective transform(s) to apply to the image(s).
transformed_images: 4D `Tensor`, image(s) in NHWC format, generated by applying
the `transforms` to the `images`. Satisfies the description above.
-)doc");
+)doc";
+
+} // namespace
+
+// TODO(ringwalt): Add a "fill_mode" attr with "constant", "mirror", etc.
+// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
+REGISTER_OP("ImageProjectiveTransform")
+ .Input("images: dtype")
+ .Input("transforms: float32")
+ .Attr("dtype: {uint8, int32, int64, float16, float32, float64}")
+ .Attr("interpolation: string")
+ .Output("transformed_images: dtype")
+ // Output shape is identical to input images.
+ .SetShapeFn([](InferenceContext* c) {
+ c->set_output(0, c->input(0));
+ return Status::OK();
+ })
+ .Doc(kImageProjectiveTransformDoc);
+
+// V2 op supports output_shape.
+REGISTER_OP("ImageProjectiveTransformV2")
+ .Input("images: dtype")
+ .Input("transforms: float32")
+ .Input("output_shape: int32")
+ .Attr("dtype: {uint8, int32, int64, float16, float32, float64}")
+ .Attr("interpolation: string")
+ .Output("transformed_images: dtype")
+ .SetShapeFn(ResizeShapeFn)
+ .Doc(kImageProjectiveTransformDoc);
REGISTER_OP("BipartiteMatch")
.Input("distance_mat: float")
diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
index 70339d7612..376c0751ee 100644
--- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
+++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.contrib.image.ops import gen_image_ops
from tensorflow.contrib.image.python.ops import image_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -262,6 +263,15 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
self._test_grad_different_shape([4, 12, 3], [8, 24, 3])
self._test_grad_different_shape([3, 4, 12, 3], [3, 8, 24, 3])
+ def test_projective_transform_v1(self):
+ """The original ImageProjectiveTransform op should take 2 arguments."""
+ image = constant_op.constant([[[[1], [0]], [[0], [1]]]])
+ transform = constant_op.constant([[1., 0., 0., 0., 1., 0., 0., 0.]])
+ result = gen_image_ops.image_projective_transform(
+ image, transform, interpolation="NEAREST")
+ with self.cached_session():
+ self.assertAllEqual([[[[1], [0]], [[0], [1]]]], result.eval())
+
class BipartiteMatchTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py
index e7a09041ad..d4fb99a017 100644
--- a/tensorflow/contrib/image/python/ops/image_ops.py
+++ b/tensorflow/contrib/image/python/ops/image_ops.py
@@ -39,6 +39,7 @@ _IMAGE_DTYPES = set(
ops.RegisterShape("ImageConnectedComponents")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("ImageProjectiveTransformV2")(common_shapes.call_cpp_shape_fn)
# TODO(ringwalt): Support a "reshape" (name used by SciPy) or "expand" (name
@@ -290,7 +291,7 @@ def transform(images,
else:
raise TypeError("Transforms should have rank 1 or 2.")
- output = gen_image_ops.image_projective_transform(
+ output = gen_image_ops.image_projective_transform_v2(
images,
output_shape=output_shape,
transforms=transforms,
@@ -391,7 +392,7 @@ def matrices_to_flat_transforms(transform_matrices):
return transforms[:, :8]
-@ops.RegisterGradient("ImageProjectiveTransform")
+@ops.RegisterGradient("ImageProjectiveTransformV2")
def _image_projective_transform_grad(op, grad):
"""Computes the gradient for ImageProjectiveTransform."""
images = op.inputs[0]
@@ -415,7 +416,7 @@ def _image_projective_transform_grad(op, grad):
transforms = flat_transforms_to_matrices(transforms=transforms)
inverse = linalg_ops.matrix_inverse(transforms)
transforms = matrices_to_flat_transforms(inverse)
- output = gen_image_ops.image_projective_transform(
+ output = gen_image_ops.image_projective_transform_v2(
images=grad,
transforms=transforms,
output_shape=array_ops.shape(image_or_images)[1:3],
diff --git a/tensorflow/contrib/integrate/python/ops/odes_test.py b/tensorflow/contrib/integrate/python/ops/odes_test.py
index c7b4e2faa8..be915ef96f 100644
--- a/tensorflow/contrib/integrate/python/ops/odes_test.py
+++ b/tensorflow/contrib/integrate/python/ops/odes_test.py
@@ -49,7 +49,7 @@ class OdeIntTest(test.TestCase):
y_solved = odes.odeint(func, y0, t)
self.assertIn('odeint', y_solved.name)
self.assertEqual(y_solved.get_shape(), tensor_shape.TensorShape([11]))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_solved = sess.run(y_solved)
y_true = np.exp(t)
self.assertAllClose(y_true, y_solved)
@@ -62,7 +62,7 @@ class OdeIntTest(test.TestCase):
func = lambda y, t: k * y
t = np.linspace(0.0, 1.0, 11)
y_solved = odes.odeint(func, 1.0 + 0.0j, t)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_solved = sess.run(y_solved)
y_true = np.exp(k * t)
self.assertAllClose(y_true, y_solved)
@@ -74,7 +74,7 @@ class OdeIntTest(test.TestCase):
func = lambda t, y: (y - t)**2 + 1.0
t = np.linspace(0.0, 1.0, 11)
y_solved = odes.odeint(func, np.float64(0.5), t)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_solved = sess.run(y_solved)
y_true = 1.0 / (2.0 - t) + t
self.assertAllClose(y_true, y_solved)
@@ -96,7 +96,7 @@ class OdeIntTest(test.TestCase):
t = np.linspace(0.0, 1.0, 11)
y_solved = odes.odeint(func, y0, t)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_solved = sess.run(y_solved)
y_true = np.zeros((len(t), 2, 1))
@@ -113,7 +113,7 @@ class OdeIntTest(test.TestCase):
y_solved = odes.odeint(func, array_ops.reshape(y0, shape), t)
self.assertEqual(y_solved.get_shape(),
tensor_shape.TensorShape(expected_shape))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_solved = sess.run(y_solved)
self.assertEquals(y_solved.shape, expected_shape)
@@ -126,7 +126,7 @@ class OdeIntTest(test.TestCase):
for t_dtype in [dtypes.float32, dtypes.float64]:
y0 = math_ops.cast(1.0, y0_dtype)
y_solved = odes.odeint(func, y0, math_ops.cast(t, t_dtype))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_solved = sess.run(y_solved)
expected = np.asarray(np.exp(t))
self.assertAllClose(y_solved, expected, rtol=1e-5)
@@ -148,13 +148,13 @@ class OdeIntTest(test.TestCase):
self.y0, [0, 1],
method='dopri5',
options={'max_num_steps': 0})
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
'max_num_steps'):
sess.run(y)
y = odes.odeint(self.func, self.y0, [1, 0])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
'monotonic increasing'):
sess.run(y)
@@ -164,7 +164,7 @@ class OdeIntTest(test.TestCase):
times0 = np.linspace(0, 10, num=11, dtype=float)
times1 = np.linspace(0, 10, num=101, dtype=float)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_solved_0, info_0 = sess.run(
odes.odeint(self.func, self.y0, times0, full_output=True))
y_solved_1, info_1 = sess.run(
@@ -179,7 +179,7 @@ class OdeIntTest(test.TestCase):
t = [0, 20]
kwargs = dict(
full_output=True, method='dopri5', options=dict(max_num_steps=2000))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
_, info_0 = sess.run(
odes.odeint(self.func, self.y0, t, rtol=0, atol=1e-6, **kwargs))
_, info_1 = sess.run(
@@ -196,7 +196,7 @@ class StepSizeTest(test.TestCase):
new_step = odes._optimal_step_size(
last_step=constant_op.constant(1.0),
error_ratio=constant_op.constant(1.0))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
new_step = sess.run(new_step)
self.assertAllClose(new_step, 0.9)
@@ -204,7 +204,7 @@ class StepSizeTest(test.TestCase):
new_step = odes._optimal_step_size(
last_step=constant_op.constant(1.0),
error_ratio=constant_op.constant(0.0))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
new_step = sess.run(new_step)
self.assertAllClose(new_step, 10.0)
@@ -212,7 +212,7 @@ class StepSizeTest(test.TestCase):
new_step = odes._optimal_step_size(
last_step=constant_op.constant(1.0),
error_ratio=constant_op.constant(1e6))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
new_step = sess.run(new_step)
self.assertAllClose(new_step, 0.2)
@@ -229,13 +229,13 @@ class InterpolationTest(test.TestCase):
y_fit = array_ops.stack(
[odes._interp_evaluate(coeffs, 0.0, 10.0, t) for t in times])
y_expected = f(times)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_actual = sess.run(y_fit)
self.assertAllClose(y_expected, y_actual)
# attempt interpolation outside bounds
y_invalid = odes._interp_evaluate(coeffs, 0.0, 10.0, 100.0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors_impl.InvalidArgumentError):
sess.run(y_invalid)
@@ -251,7 +251,7 @@ class OdeIntFixedTest(test.TestCase):
y0 = [0., 1.]
y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_grid_array = sess.run(y_grid)
np.testing.assert_allclose(
@@ -265,7 +265,7 @@ class OdeIntFixedTest(test.TestCase):
y0 = [1.]
y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
y_grid_array = sess.run(y_grid)
np.testing.assert_allclose(
diff --git a/tensorflow/contrib/kfac/BUILD b/tensorflow/contrib/kfac/BUILD
deleted file mode 100644
index b719046b37..0000000000
--- a/tensorflow/contrib/kfac/BUILD
+++ /dev/null
@@ -1,26 +0,0 @@
-# Description:
-# Contains KfacOptimizer, an implementation of the K-FAC optimization
-# algorithm in TensorFlow.
-package(default_visibility = ["//visibility:public"])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-py_library(
- name = "kfac",
- srcs = ["__init__.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/kfac/python/ops:curvature_matrix_vector_products_lib",
- "//tensorflow/contrib/kfac/python/ops:fisher_blocks_lib",
- "//tensorflow/contrib/kfac/python/ops:fisher_estimator_lib",
- "//tensorflow/contrib/kfac/python/ops:fisher_factors_lib",
- "//tensorflow/contrib/kfac/python/ops:kfac_optimizer_lib",
- "//tensorflow/contrib/kfac/python/ops:layer_collection_lib",
- "//tensorflow/contrib/kfac/python/ops:loss_functions_lib",
- "//tensorflow/contrib/kfac/python/ops:op_queue_lib",
- "//tensorflow/contrib/kfac/python/ops:utils_lib",
- "//tensorflow/python:util",
- ],
-)
diff --git a/tensorflow/contrib/kfac/README.md b/tensorflow/contrib/kfac/README.md
index 102626925d..42b91d0313 100644
--- a/tensorflow/contrib/kfac/README.md
+++ b/tensorflow/contrib/kfac/README.md
@@ -1,94 +1,3 @@
# K-FAC: Kronecker-Factored Approximate Curvature
-# <font color="red", size=10><u>WARNING: </u></font>
-# ==third_party/tensorflow/contrib/kfac is deprecated. This will be==
-# ==removed on 15-07-2018. <!-- STY:begin_strip_and_replace -->Please import third_party/tensorflow_kfac.==
-# ==<!-- STY:end_strip_and_replace Please check https://github.com/tensorflow/kfac. -->==
-
-**K-FAC in TensorFlow** is an implementation of [K-FAC][kfac-paper], an
-approximate second-order optimization method, in TensorFlow. When applied to
-feedforward and convolutional neural networks, K-FAC can converge `>3.5x`
-faster in `>14x` fewer iterations than SGD with Momentum.
-
-[kfac-paper]: https://arxiv.org/abs/1503.05671
-
-## What is K-FAC?
-
-K-FAC, short for "Kronecker-factored Approximate Curvature", is an approximation
-to the [Natural Gradient][natural_gradient] algorithm designed specifically for
-neural networks. It maintains a block-diagonal approximation to the [Fisher
-Information matrix][fisher_information], whose inverse preconditions the
-gradient.
-
-K-FAC can be used in place of SGD, Adam, and other `Optimizer` implementations.
-Experimentally, K-FAC converges `>3.5x` faster than well-tuned SGD.
-
-Unlike most optimizers, K-FAC exploits structure in the model itself (e.g. "What
-are the weights for layer i?"). As such, you must add some additional code while
-constructing your model to use K-FAC.
-
-[natural_gradient]: http://www.mitpressjournals.org/doi/abs/10.1162/089976698300017746
-[fisher_information]: https://en.wikipedia.org/wiki/Fisher_information#Matrix_form
-
-## Why should I use K-FAC?
-
-K-FAC can take advantage of the curvature of the optimization problem, resulting
-in **faster training**. For an 8-layer Autoencoder, K-FAC converges to the same
-loss as SGD with Momentum in 3.8x fewer seconds and 14.7x fewer updates. See how
-training loss changes as a function of number of epochs, steps, and seconds:
-
-![autoencoder](g3doc/autoencoder.png)
-
-## Is K-FAC for me?
-
-If you have a feedforward or convolutional model for classification that is
-converging too slowly, K-FAC is for you. K-FAC can be used in your model if:
-
-* Your model defines a posterior distribution.
-* Your model uses only fully-connected or convolutional layers (residual
- connections OK).
-* You are training on CPU or GPU.
-* You can modify model code to register layers with K-FAC.
-
-## How do I use K-FAC?
-
-Using K-FAC requires three steps:
-
-1. Registering layer inputs, weights, and pre-activations with a
- `LayerCollection`.
-1. Minimizing the loss with a `KfacOptimizer`.
-1. Keeping K-FAC's preconditioner updated.
-
-```python
-# Build model.
-w = tf.get_variable("w", ...)
-b = tf.get_variable("b", ...)
-logits = tf.matmul(x, w) + b
-loss = tf.reduce_mean(
- tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits))
-
-# Register layers.
-layer_collection = LayerCollection()
-layer_collection.register_fully_connected((w, b), x, logits)
-layer_collection.register_categorical_predictive_distribution(logits)
-
-# Construct training ops.
-optimizer = KfacOptimizer(..., layer_collection=layer_collection)
-train_op = optimizer.minimize(loss)
-
-# Minimize loss.
-with tf.Session() as sess:
- ...
- sess.run([train_op, optimizer.cov_update_op, optimizer.inv_update_op])
-```
-
-See [`examples/`](https://www.tensorflow.org/code/tensorflow/contrib/kfac/examples/) for runnable, end-to-end illustrations.
-
-## Authors
-
-- Alok Aggarwal
-- Daniel Duckworth
-- James Martens
-- Matthew Johnson
-- Olga Wichrowska
-- Roger Grosse
+## KFAC moved to third_party/tensorflow_kfac.
diff --git a/tensorflow/contrib/kfac/__init__.py b/tensorflow/contrib/kfac/__init__.py
deleted file mode 100644
index 1ea354e6cd..0000000000
--- a/tensorflow/contrib/kfac/__init__.py
+++ /dev/null
@@ -1,46 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Kronecker-factored Approximate Curvature Optimizer."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,line-too-long
-from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products_lib as curvature_matrix_vector_products
-from tensorflow.contrib.kfac.python.ops import estimator_lib as estimator
-from tensorflow.contrib.kfac.python.ops import fisher_blocks_lib as fisher_blocks
-from tensorflow.contrib.kfac.python.ops import fisher_factors_lib as fisher_factors
-from tensorflow.contrib.kfac.python.ops import layer_collection_lib as layer_collection
-from tensorflow.contrib.kfac.python.ops import loss_functions_lib as loss_functions
-from tensorflow.contrib.kfac.python.ops import op_queue_lib as op_queue
-from tensorflow.contrib.kfac.python.ops import optimizer_lib as optimizer
-from tensorflow.contrib.kfac.python.ops import utils_lib as utils
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long
-
-_allowed_symbols = [
- "curvature_matrix_vector_products",
- "estimator",
- "fisher_blocks",
- "fisher_factors",
- "layer_collection",
- "loss_functions",
- "op_queue",
- "optimizer",
- "utils",
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/examples/BUILD b/tensorflow/contrib/kfac/examples/BUILD
deleted file mode 100644
index 8186fa1c62..0000000000
--- a/tensorflow/contrib/kfac/examples/BUILD
+++ /dev/null
@@ -1,80 +0,0 @@
-package(default_visibility = [
- "//learning/brain/contrib/kfac/examples:__subpackages__",
- "//tensorflow/contrib/kfac/examples:__subpackages__",
-])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-py_binary(
- name = "mlp_mnist_main",
- srcs = ["mlp_mnist_main.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":mlp",
- "//tensorflow:tensorflow_py",
- ],
-)
-
-py_library(
- name = "mlp",
- srcs = ["mlp.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":mnist",
- "//tensorflow:tensorflow_py",
- ],
-)
-
-py_binary(
- name = "convnet_mnist_single_main",
- srcs = ["convnet_mnist_single_main.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":convnet",
- "//tensorflow:tensorflow_py",
- ],
-)
-
-py_binary(
- name = "convnet_mnist_multi_tower_main",
- srcs = ["convnet_mnist_multi_tower_main.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":convnet",
- "//tensorflow:tensorflow_py",
- ],
-)
-
-py_binary(
- name = "convnet_mnist_distributed_main",
- srcs = ["convnet_mnist_distributed_main.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":convnet",
- "//tensorflow:tensorflow_py",
- ],
-)
-
-py_library(
- name = "convnet",
- srcs = ["convnet.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":mlp",
- ":mnist",
- "//tensorflow:tensorflow_py",
- "//third_party/py/numpy",
- ],
-)
-
-py_library(
- name = "mnist",
- srcs = ["mnist.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow:tensorflow_py",
- "//third_party/py/numpy",
- ],
-)
diff --git a/tensorflow/contrib/kfac/examples/convnet.py b/tensorflow/contrib/kfac/examples/convnet.py
deleted file mode 100644
index 44e01e1aeb..0000000000
--- a/tensorflow/contrib/kfac/examples/convnet.py
+++ /dev/null
@@ -1,667 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-r"""Train a ConvNet on MNIST using K-FAC.
-
-This library fits a 5-layer ConvNet on MNIST using K-FAC. The model has the
-following structure,
-
-- Conv Layer: 5x5 kernel, 16 output channels.
-- Max Pool: 3x3 kernel, stride 2.
-- Conv Layer: 5x5 kernel, 16 output channels.
-- Max Pool: 3x3 kernel, stride 2.
-- Linear: 10 output dims.
-
-After 3k~6k steps, this should reach perfect accuracy on the training set.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-
-import numpy as np
-import tensorflow as tf
-
-from tensorflow.contrib.kfac.examples import mlp
-from tensorflow.contrib.kfac.examples import mnist
-from tensorflow.contrib.kfac.python.ops import optimizer as opt
-
-
-lc = tf.contrib.kfac.layer_collection
-oq = tf.contrib.kfac.op_queue
-opt = tf.contrib.kfac.optimizer
-
-__all__ = [
- "conv_layer",
- "max_pool_layer",
- "linear_layer",
- "build_model",
- "minimize_loss_single_machine",
- "distributed_grads_only_and_ops_chief_worker",
- "distributed_grads_and_ops_dedicated_workers",
- "train_mnist_single_machine",
- "train_mnist_distributed_sync_replicas",
- "train_mnist_multitower"
-]
-
-
-# Inverse update ops will be run every _INVERT_EVRY iterations.
-_INVERT_EVERY = 10
-
-
-def conv_layer(layer_id, inputs, kernel_size, out_channels):
- """Builds a convolutional layer with ReLU non-linearity.
-
- Args:
- layer_id: int. Integer ID for this layer's variables.
- inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row
- corresponds to a single example.
- kernel_size: int. Width and height of the convolution kernel. The kernel is
- assumed to be square.
- out_channels: int. Number of output features per pixel.
-
- Returns:
- preactivations: Tensor of shape [num_examples, width, height, out_channels].
- Values of the layer immediately before the activation function.
- activations: Tensor of shape [num_examples, width, height, out_channels].
- Values of the layer immediately after the activation function.
- params: Tuple of (kernel, bias), parameters for this layer.
- """
- # TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
- layer = tf.layers.Conv2D(
- out_channels,
- kernel_size=[kernel_size, kernel_size],
- kernel_initializer=tf.random_normal_initializer(stddev=0.01),
- padding="SAME",
- name="conv_%d" % layer_id)
- preactivations = layer(inputs)
- activations = tf.nn.relu(preactivations)
-
- # layer.weights is a list. This converts it a (hashable) tuple.
- return preactivations, activations, (layer.kernel, layer.bias)
-
-
-def max_pool_layer(layer_id, inputs, kernel_size, stride):
- """Build a max-pooling layer.
-
- Args:
- layer_id: int. Integer ID for this layer's variables.
- inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row
- corresponds to a single example.
- kernel_size: int. Width and height to pool over per input channel. The
- kernel is assumed to be square.
- stride: int. Step size between pooling operations.
-
- Returns:
- Tensor of shape [num_examples, width/stride, height/stride, out_channels].
- Result of applying max pooling to 'inputs'.
- """
- # TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
- with tf.variable_scope("pool_%d" % layer_id):
- return tf.nn.max_pool(
- inputs, [1, kernel_size, kernel_size, 1], [1, stride, stride, 1],
- padding="SAME",
- name="pool")
-
-
-def linear_layer(layer_id, inputs, output_size):
- """Builds the final linear layer for an MNIST classification problem.
-
- Args:
- layer_id: int. Integer ID for this layer's variables.
- inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row
- corresponds to a single example.
- output_size: int. Number of output dims per example.
-
- Returns:
- activations: Tensor of shape [num_examples, output_size]. Values of the
- layer immediately after the activation function.
- params: Tuple of (weights, bias), parameters for this layer.
- """
- # TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
- pre, _, params = mlp.fc_layer(layer_id, inputs, output_size)
- return pre, params
-
-
-def build_model(examples, labels, num_labels, layer_collection):
- """Builds a ConvNet classification model.
-
- Args:
- examples: Tensor of shape [num_examples, num_features]. Represents inputs of
- model.
- labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted
- by softmax for each example.
- num_labels: int. Number of distinct values 'labels' can take on.
- layer_collection: LayerCollection instance. Layers will be registered here.
-
- Returns:
- loss: 0-D Tensor representing loss to be minimized.
- accuracy: 0-D Tensor representing model's accuracy.
- """
- # Build a ConvNet. For each layer with parameters, we'll keep track of the
- # preactivations, activations, weights, and bias.
- tf.logging.info("Building model.")
- pre0, act0, params0 = conv_layer(
- layer_id=0, inputs=examples, kernel_size=5, out_channels=16)
- act1 = max_pool_layer(layer_id=1, inputs=act0, kernel_size=3, stride=2)
- pre2, act2, params2 = conv_layer(
- layer_id=2, inputs=act1, kernel_size=5, out_channels=16)
- act3 = max_pool_layer(layer_id=3, inputs=act2, kernel_size=3, stride=2)
- flat_act3 = tf.reshape(act3, shape=[-1, int(np.prod(act3.shape[1:4]))])
- logits, params4 = linear_layer(
- layer_id=4, inputs=flat_act3, output_size=num_labels)
- loss = tf.reduce_mean(
- tf.nn.sparse_softmax_cross_entropy_with_logits(
- labels=labels, logits=logits))
- accuracy = tf.reduce_mean(
- tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32))
-
- with tf.device("/cpu:0"):
- tf.summary.scalar("loss", loss)
- tf.summary.scalar("accuracy", accuracy)
-
- # Register parameters. K-FAC needs to know about the inputs, outputs, and
- # parameters of each conv/fully connected layer and the logits powering the
- # posterior probability over classes.
- tf.logging.info("Building LayerCollection.")
- layer_collection.register_conv2d(params0, (1, 1, 1, 1), "SAME", examples,
- pre0)
- layer_collection.register_conv2d(params2, (1, 1, 1, 1), "SAME", act1, pre2)
- layer_collection.register_fully_connected(params4, flat_act3, logits)
- layer_collection.register_categorical_predictive_distribution(
- logits, name="logits")
-
- return loss, accuracy
-
-
-def minimize_loss_single_machine(loss,
- accuracy,
- layer_collection,
- device="/gpu:0",
- session_config=None):
- """Minimize loss with K-FAC on a single machine.
-
- A single Session is responsible for running all of K-FAC's ops. The covariance
- and inverse update ops are placed on `device`. All model variables are on CPU.
-
- Args:
- loss: 0-D Tensor. Loss to be minimized.
- accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
- layer_collection: LayerCollection instance describing model architecture.
- Used by K-FAC to construct preconditioner.
- device: string, Either '/cpu:0' or '/gpu:0'. The covariance and inverse
- update ops are run on this device.
- session_config: None or tf.ConfigProto. Configuration for tf.Session().
-
- Returns:
- final value for 'accuracy'.
- """
- # Train with K-FAC.
- g_step = tf.train.get_or_create_global_step()
- optimizer = opt.KfacOptimizer(
- learning_rate=0.0001,
- cov_ema_decay=0.95,
- damping=0.001,
- layer_collection=layer_collection,
- placement_strategy="round_robin",
- cov_devices=[device],
- inv_devices=[device],
- momentum=0.9)
- (cov_update_thunks,
- inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
-
- def make_update_op(update_thunks):
- update_ops = [thunk() for thunk in update_thunks]
- return tf.group(*update_ops)
-
- cov_update_op = make_update_op(cov_update_thunks)
- with tf.control_dependencies([cov_update_op]):
- inverse_op = tf.cond(
- tf.equal(tf.mod(g_step, _INVERT_EVERY), 0),
- lambda: make_update_op(inv_update_thunks), tf.no_op)
- with tf.control_dependencies([inverse_op]):
- with tf.device(device):
- train_op = optimizer.minimize(loss, global_step=g_step)
-
- tf.logging.info("Starting training.")
- with tf.train.MonitoredTrainingSession(config=session_config) as sess:
- while not sess.should_stop():
- global_step_, loss_, accuracy_, _ = sess.run(
- [g_step, loss, accuracy, train_op])
-
- if global_step_ % _INVERT_EVERY == 0:
- tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
- global_step_, loss_, accuracy_)
-
- return accuracy_
-
-
-def _is_gradient_task(task_id, num_tasks):
- """Returns True if this task should update the weights."""
- if num_tasks < 3:
- return True
- return 0 <= task_id < 0.6 * num_tasks
-
-
-def _is_cov_update_task(task_id, num_tasks):
- """Returns True if this task should update K-FAC's covariance matrices."""
- if num_tasks < 3:
- return False
- return 0.6 * num_tasks <= task_id < num_tasks - 1
-
-
-def _is_inv_update_task(task_id, num_tasks):
- """Returns True if this task should update K-FAC's preconditioner."""
- if num_tasks < 3:
- return False
- return task_id == num_tasks - 1
-
-
-def _num_gradient_tasks(num_tasks):
- """Number of tasks that will update weights."""
- if num_tasks < 3:
- return num_tasks
- return int(np.ceil(0.6 * num_tasks))
-
-
-def _make_distributed_train_op(
- task_id,
- num_worker_tasks,
- num_ps_tasks,
- layer_collection
-):
- """Creates optimizer and distributed training op.
-
- Constructs KFAC optimizer and wraps it in `sync_replicas` optimizer. Makes
- the train op.
-
- Args:
- task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
- num_worker_tasks: int. Number of workers in this distributed training setup.
- num_ps_tasks: int. Number of parameter servers holding variables. If 0,
- parameter servers are not used.
- layer_collection: LayerCollection instance describing model architecture.
- Used by K-FAC to construct preconditioner.
-
- Returns:
- sync_optimizer: `tf.train.SyncReplicasOptimizer` instance which wraps KFAC
- optimizer.
- optimizer: Instance of `opt.KfacOptimizer`.
- global_step: `tensor`, Global step.
- """
- tf.logging.info("Task id : %d", task_id)
- with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
- global_step = tf.train.get_or_create_global_step()
- optimizer = opt.KfacOptimizer(
- learning_rate=0.0001,
- cov_ema_decay=0.95,
- damping=0.001,
- layer_collection=layer_collection,
- momentum=0.9)
- sync_optimizer = tf.train.SyncReplicasOptimizer(
- opt=optimizer,
- replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks),
- total_num_replicas=num_worker_tasks)
- return sync_optimizer, optimizer, global_step
-
-
-def distributed_grads_only_and_ops_chief_worker(
- task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir,
- loss, accuracy, layer_collection, invert_every=10):
- """Minimize loss with a synchronous implementation of K-FAC.
-
- All workers perform gradient computation. Chief worker applies gradient after
- averaging the gradients obtained from all the workers. All workers block
- execution until the update is applied. Chief worker runs covariance and
- inverse update ops. Covariance and inverse matrices are placed on parameter
- servers in a round robin manner. For further details on synchronous
- distributed optimization check `tf.train.SyncReplicasOptimizer`.
-
- Args:
- task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
- is_chief: `boolean`, `True` if the worker is chief worker.
- num_worker_tasks: int. Number of workers in this distributed training setup.
- num_ps_tasks: int. Number of parameter servers holding variables. If 0,
- parameter servers are not used.
- master: string. IP and port of TensorFlow runtime process. Set to empty
- string to run locally.
- checkpoint_dir: string or None. Path to store checkpoints under.
- loss: 0-D Tensor. Loss to be minimized.
- accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to
- run with each step.
- layer_collection: LayerCollection instance describing model architecture.
- Used by K-FAC to construct preconditioner.
- invert_every: `int`, Number of steps between update the inverse.
-
- Returns:
- final value for 'accuracy'.
-
- Raises:
- ValueError: if task_id >= num_worker_tasks.
- """
-
- sync_optimizer, optimizer, global_step = _make_distributed_train_op(
- task_id, num_worker_tasks, num_ps_tasks, layer_collection)
- (cov_update_thunks,
- inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
-
- tf.logging.info("Starting training.")
- hooks = [sync_optimizer.make_session_run_hook(is_chief)]
-
- def make_update_op(update_thunks):
- update_ops = [thunk() for thunk in update_thunks]
- return tf.group(*update_ops)
-
- if is_chief:
- cov_update_op = make_update_op(cov_update_thunks)
- with tf.control_dependencies([cov_update_op]):
- inverse_op = tf.cond(
- tf.equal(tf.mod(global_step, invert_every), 0),
- lambda: make_update_op(inv_update_thunks),
- tf.no_op)
- with tf.control_dependencies([inverse_op]):
- train_op = sync_optimizer.minimize(loss, global_step=global_step)
- else:
- train_op = sync_optimizer.minimize(loss, global_step=global_step)
-
- with tf.train.MonitoredTrainingSession(
- master=master,
- is_chief=is_chief,
- checkpoint_dir=checkpoint_dir,
- hooks=hooks,
- stop_grace_period_secs=0) as sess:
- while not sess.should_stop():
- global_step_, loss_, accuracy_, _ = sess.run(
- [global_step, loss, accuracy, train_op])
- tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_,
- loss_, accuracy_)
- return accuracy_
-
-
-def distributed_grads_and_ops_dedicated_workers(
- task_id, is_chief, num_worker_tasks, num_ps_tasks, master, checkpoint_dir,
- loss, accuracy, layer_collection):
- """Minimize loss with a synchronous implementation of K-FAC.
-
- Different workers are responsible for different parts of K-FAC's Ops. The
- first 60% of tasks compute gradients; the next 20% accumulate covariance
- statistics; the last 20% invert the matrices used to precondition gradients.
- The chief worker applies the gradient .
-
- Args:
- task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
- is_chief: `boolean`, `True` if the worker is chief worker.
- num_worker_tasks: int. Number of workers in this distributed training setup.
- num_ps_tasks: int. Number of parameter servers holding variables. If 0,
- parameter servers are not used.
- master: string. IP and port of TensorFlow runtime process. Set to empty
- string to run locally.
- checkpoint_dir: string or None. Path to store checkpoints under.
- loss: 0-D Tensor. Loss to be minimized.
- accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to
- run with each step.
- layer_collection: LayerCollection instance describing model architecture.
- Used by K-FAC to construct preconditioner.
-
- Returns:
- final value for 'accuracy'.
-
- Raises:
- ValueError: if task_id >= num_worker_tasks.
- """
- sync_optimizer, optimizer, global_step = _make_distributed_train_op(
- task_id, num_worker_tasks, num_ps_tasks, layer_collection)
- _, cov_update_op, inv_update_ops, _, _, _ = optimizer.make_ops_and_vars()
- train_op = sync_optimizer.minimize(loss, global_step=global_step)
- inv_update_queue = oq.OpQueue(inv_update_ops)
-
- tf.logging.info("Starting training.")
- is_chief = (task_id == 0)
- hooks = [sync_optimizer.make_session_run_hook(is_chief)]
- with tf.train.MonitoredTrainingSession(
- master=master,
- is_chief=is_chief,
- checkpoint_dir=checkpoint_dir,
- hooks=hooks,
- stop_grace_period_secs=0) as sess:
- while not sess.should_stop():
- # Choose which op this task is responsible for running.
- if _is_gradient_task(task_id, num_worker_tasks):
- learning_op = train_op
- elif _is_cov_update_task(task_id, num_worker_tasks):
- learning_op = cov_update_op
- elif _is_inv_update_task(task_id, num_worker_tasks):
- # TODO(duckworthd): Running this op before cov_update_op has been run a
- # few times can result in "InvalidArgumentError: Cholesky decomposition
- # was not successful." Delay running this op until cov_update_op has
- # been run a few times.
- learning_op = inv_update_queue.next_op(sess)
- else:
- raise ValueError("Which op should task %d do?" % task_id)
-
- global_step_, loss_, accuracy_, _ = sess.run(
- [global_step, loss, accuracy, learning_op])
- tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_,
- loss_, accuracy_)
-
- return accuracy_
-
-
-def train_mnist_single_machine(data_dir,
- num_epochs,
- use_fake_data=False,
- device="/gpu:0"):
- """Train a ConvNet on MNIST.
-
- Args:
- data_dir: string. Directory to read MNIST examples from.
- num_epochs: int. Number of passes to make over the training set.
- use_fake_data: bool. If True, generate a synthetic dataset.
- device: string, Either '/cpu:0' or '/gpu:0'. The covariance and inverse
- update ops are run on this device.
-
- Returns:
- accuracy of model on the final minibatch of training data.
- """
- # Load a dataset.
- tf.logging.info("Loading MNIST into memory.")
- examples, labels = mnist.load_mnist(
- data_dir,
- num_epochs=num_epochs,
- batch_size=128,
- use_fake_data=use_fake_data,
- flatten_images=False)
-
- # Build a ConvNet.
- layer_collection = lc.LayerCollection()
- loss, accuracy = build_model(
- examples, labels, num_labels=10, layer_collection=layer_collection)
-
- # Fit model.
- return minimize_loss_single_machine(
- loss, accuracy, layer_collection, device=device)
-
-
-def train_mnist_multitower(data_dir, num_epochs, num_towers,
- use_fake_data=True, devices=None):
- """Train a ConvNet on MNIST.
-
- Training data is split equally among the towers. Each tower computes loss on
- its own batch of data and the loss is aggregated on the CPU. The model
- variables are placed on first tower. The covariance and inverse update ops
- and variables are placed on GPUs in a round robin manner.
-
- Args:
- data_dir: string. Directory to read MNIST examples from.
- num_epochs: int. Number of passes to make over the training set.
- num_towers: int. Number of CPUs to split inference across.
- use_fake_data: bool. If True, generate a synthetic dataset.
- devices: string, Either list of CPU or GPU. The covariance and inverse
- update ops are run on this device.
-
- Returns:
- accuracy of model on the final minibatch of training data.
- """
- if devices:
- device_count = {"GPU": num_towers}
- else:
- device_count = {"CPU": num_towers}
-
- devices = devices or [
- "/cpu:{}".format(tower_id) for tower_id in range(num_towers)
- ]
- # Load a dataset.
- tf.logging.info("Loading MNIST into memory.")
- tower_batch_size = 128
- batch_size = tower_batch_size * num_towers
- tf.logging.info(
- ("Loading MNIST into memory. Using batch_size = %d = %d towers * %d "
- "tower batch size.") % (batch_size, num_towers, tower_batch_size))
- examples, labels = mnist.load_mnist(
- data_dir,
- num_epochs=num_epochs,
- batch_size=batch_size,
- use_fake_data=use_fake_data,
- flatten_images=False)
-
- # Split minibatch across towers.
- examples = tf.split(examples, num_towers)
- labels = tf.split(labels, num_towers)
-
- # Build an MLP. Each tower's layers will be added to the LayerCollection.
- layer_collection = lc.LayerCollection()
- tower_results = []
- for tower_id in range(num_towers):
- with tf.device(devices[tower_id]):
- with tf.name_scope("tower%d" % tower_id):
- with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)):
- tf.logging.info("Building tower %d." % tower_id)
- tower_results.append(
- build_model(examples[tower_id], labels[tower_id], 10,
- layer_collection))
- losses, accuracies = zip(*tower_results)
-
- # Average across towers.
- loss = tf.reduce_mean(losses)
- accuracy = tf.reduce_mean(accuracies)
-
- # Fit model.
-
- session_config = tf.ConfigProto(
- allow_soft_placement=False,
- device_count=device_count,
- )
-
- g_step = tf.train.get_or_create_global_step()
- optimizer = opt.KfacOptimizer(
- learning_rate=0.0001,
- cov_ema_decay=0.95,
- damping=0.001,
- layer_collection=layer_collection,
- placement_strategy="round_robin",
- cov_devices=devices,
- inv_devices=devices,
- momentum=0.9)
- (cov_update_thunks,
- inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
-
- def make_update_op(update_thunks):
- update_ops = [thunk() for thunk in update_thunks]
- return tf.group(*update_ops)
-
- cov_update_op = make_update_op(cov_update_thunks)
- with tf.control_dependencies([cov_update_op]):
- inverse_op = tf.cond(
- tf.equal(tf.mod(g_step, _INVERT_EVERY), 0),
- lambda: make_update_op(inv_update_thunks), tf.no_op)
- with tf.control_dependencies([inverse_op]):
- train_op = optimizer.minimize(loss, global_step=g_step)
-
- tf.logging.info("Starting training.")
- with tf.train.MonitoredTrainingSession(config=session_config) as sess:
- while not sess.should_stop():
- global_step_, loss_, accuracy_, _ = sess.run(
- [g_step, loss, accuracy, train_op])
-
- if global_step_ % _INVERT_EVERY == 0:
- tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
- global_step_, loss_, accuracy_)
-
-
-def train_mnist_distributed_sync_replicas(task_id,
- is_chief,
- num_worker_tasks,
- num_ps_tasks,
- master,
- data_dir,
- num_epochs,
- op_strategy,
- use_fake_data=False):
- """Train a ConvNet on MNIST using Sync replicas optimizer.
-
- Args:
- task_id: int. Integer in [0, num_worker_tasks). ID for this worker.
- is_chief: `boolean`, `True` if the worker is chief worker.
- num_worker_tasks: int. Number of workers in this distributed training setup.
- num_ps_tasks: int. Number of parameter servers holding variables.
- master: string. IP and port of TensorFlow runtime process.
- data_dir: string. Directory to read MNIST examples from.
- num_epochs: int. Number of passes to make over the training set.
- op_strategy: `string`, Strategy to run the covariance and inverse
- ops. If op_strategy == `chief_worker` then covariance and inverse
- update ops are run on chief worker otherwise they are run on dedicated
- workers.
-
- use_fake_data: bool. If True, generate a synthetic dataset.
-
- Returns:
- accuracy of model on the final minibatch of training data.
-
- Raises:
- ValueError: If `op_strategy` not in ["chief_worker", "dedicated_workers"].
- """
- # Load a dataset.
- tf.logging.info("Loading MNIST into memory.")
- examples, labels = mnist.load_mnist(
- data_dir,
- num_epochs=num_epochs,
- batch_size=128,
- use_fake_data=use_fake_data,
- flatten_images=False)
-
- # Build a ConvNet.
- layer_collection = lc.LayerCollection()
- with tf.device(tf.train.replica_device_setter(num_ps_tasks)):
- loss, accuracy = build_model(
- examples, labels, num_labels=10, layer_collection=layer_collection)
-
- # Fit model.
- checkpoint_dir = None if data_dir is None else os.path.join(data_dir, "kfac")
- if op_strategy == "chief_worker":
- return distributed_grads_only_and_ops_chief_worker(
- task_id, is_chief, num_worker_tasks, num_ps_tasks, master,
- checkpoint_dir, loss, accuracy, layer_collection)
- elif op_strategy == "dedicated_workers":
- return distributed_grads_and_ops_dedicated_workers(
- task_id, is_chief, num_worker_tasks, num_ps_tasks, master,
- checkpoint_dir, loss, accuracy, layer_collection)
- else:
- raise ValueError("Only supported op strategies are : {}, {}".format(
- "chief_worker", "dedicated_workers"))
-
-
-if __name__ == "__main__":
- tf.app.run()
diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py
deleted file mode 100644
index b4c2d4a9e9..0000000000
--- a/tensorflow/contrib/kfac/examples/convnet_mnist_distributed_main.py
+++ /dev/null
@@ -1,62 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-r"""Train a ConvNet on MNIST using K-FAC.
-
-Distributed training with sync replicas optimizer. See
-`convnet.train_mnist_distributed_sync_replicas` for details.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-
-from absl import flags
-import tensorflow as tf
-
-from tensorflow.contrib.kfac.examples import convnet
-
-FLAGS = flags.FLAGS
-flags.DEFINE_integer("task", -1, "Task identifier")
-flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir")
-flags.DEFINE_string(
- "cov_inv_op_strategy", "chief_worker",
- "In dist training mode run the cov, inv ops on chief or dedicated workers."
-)
-flags.DEFINE_string("master", "local", "Session master.")
-flags.DEFINE_integer("ps_tasks", 2,
- "Number of tasks in the parameter server job.")
-flags.DEFINE_integer("replicas_to_aggregate", 5,
- "Number of replicas to aggregate.")
-flags.DEFINE_integer("worker_replicas", 5, "Number of replicas in worker job.")
-flags.DEFINE_integer("num_epochs", None, "Number of epochs.")
-
-
-def _is_chief():
- """Determines whether a job is the chief worker."""
- if "chief_worker" in FLAGS.brain_jobs:
- return FLAGS.brain_job_name == "chief_worker"
- else:
- return FLAGS.task == 0
-
-
-def main(unused_argv):
- _ = unused_argv
- convnet.train_mnist_distributed_sync_replicas(
- FLAGS.task, _is_chief(), FLAGS.worker_replicas, FLAGS.ps_tasks,
- FLAGS.master, FLAGS.data_dir, FLAGS.num_epochs, FLAGS.cov_inv_op_strategy)
-
-if __name__ == "__main__":
- tf.app.run(main=main)
diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py
deleted file mode 100644
index 4249bf8a8d..0000000000
--- a/tensorflow/contrib/kfac/examples/convnet_mnist_multi_tower_main.py
+++ /dev/null
@@ -1,48 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-r"""Train a ConvNet on MNIST using K-FAC.
-
-Multi tower training mode. See `convnet.train_mnist_multitower` for details.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-
-from absl import flags
-import tensorflow as tf
-
-from tensorflow.contrib.kfac.examples import convnet
-
-FLAGS = flags.FLAGS
-flags.DEFINE_string("data_dir", "/tmp/multitower_1/mnist", "local mnist dir")
-flags.DEFINE_integer("num_towers", 2,
- "Number of towers for multi tower training.")
-
-
-def main(unused_argv):
- _ = unused_argv
- assert FLAGS.num_towers > 1
- devices = ["/gpu:{}".format(tower_id) for tower_id in range(FLAGS.num_towers)]
- convnet.train_mnist_multitower(
- FLAGS.data_dir,
- num_epochs=200,
- num_towers=FLAGS.num_towers,
- devices=devices)
-
-
-if __name__ == "__main__":
- tf.app.run(main=main)
diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py
deleted file mode 100644
index 2c1f099360..0000000000
--- a/tensorflow/contrib/kfac/examples/convnet_mnist_single_main.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-r"""Train a ConvNet on MNIST using K-FAC.
-
-Train on single machine. See `convnet.train_mnist_single_machine` for details.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-
-from absl import flags
-import tensorflow as tf
-
-from tensorflow.contrib.kfac.examples import convnet
-
-FLAGS = flags.FLAGS
-flags.DEFINE_string("data_dir", "/tmp/mnist", "local mnist dir")
-
-
-def main(unused_argv):
- convnet.train_mnist_single_machine(FLAGS.data_dir, num_epochs=200)
-
-
-if __name__ == "__main__":
- tf.app.run(main=main)
diff --git a/tensorflow/contrib/kfac/examples/mlp.py b/tensorflow/contrib/kfac/examples/mlp.py
deleted file mode 100644
index ea2b252a05..0000000000
--- a/tensorflow/contrib/kfac/examples/mlp.py
+++ /dev/null
@@ -1,354 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-r"""Train an MLP on MNIST using K-FAC.
-
-This library fits a 3-layer, tanh-activated MLP on MNIST using K-FAC. After
-~25k steps, this should reach perfect accuracy on the training set.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import tensorflow as tf
-
-from tensorflow.contrib.kfac.examples import mnist
-
-lc = tf.contrib.kfac.layer_collection
-opt = tf.contrib.kfac.optimizer
-
-__all__ = [
- "fc_layer",
- "train_mnist",
- "train_mnist_multitower",
-]
-
-
-def fc_layer(layer_id, inputs, output_size):
- """Builds a fully connected layer.
-
- Args:
- layer_id: int. Integer ID for this layer's variables.
- inputs: Tensor of shape [num_examples, input_size]. Each row corresponds
- to a single example.
- output_size: int. Number of output dimensions after fully connected layer.
-
- Returns:
- preactivations: Tensor of shape [num_examples, output_size]. Values of the
- layer immediately before the activation function.
- activations: Tensor of shape [num_examples, output_size]. Values of the
- layer immediately after the activation function.
- params: Tuple of (weights, bias), parameters for this layer.
- """
- # TODO(b/67004004): Delete this function and rely on tf.layers exclusively.
- layer = tf.layers.Dense(
- output_size,
- kernel_initializer=tf.random_normal_initializer(),
- name="fc_%d" % layer_id)
- preactivations = layer(inputs)
- activations = tf.nn.tanh(preactivations)
-
- # layer.weights is a list. This converts it a (hashable) tuple.
- return preactivations, activations, (layer.kernel, layer.bias)
-
-
-def build_model(examples, labels, num_labels, layer_collection):
- """Builds an MLP classification model.
-
- Args:
- examples: Tensor of shape [num_examples, num_features]. Represents inputs of
- model.
- labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted
- by softmax for each example.
- num_labels: int. Number of distinct values 'labels' can take on.
- layer_collection: LayerCollection instance describing model architecture.
-
- Returns:
- loss: 0-D Tensor representing loss to be minimized.
- accuracy: 0-D Tensor representing model's accuracy.
- """
- # Build an MLP. For each layer, we'll keep track of the preactivations,
- # activations, weights, and bias.
- pre0, act0, params0 = fc_layer(layer_id=0, inputs=examples, output_size=128)
- pre1, act1, params1 = fc_layer(layer_id=1, inputs=act0, output_size=64)
- pre2, act2, params2 = fc_layer(layer_id=2, inputs=act1, output_size=32)
- logits, _, params3 = fc_layer(layer_id=3, inputs=act2, output_size=num_labels)
- loss = tf.reduce_mean(
- tf.nn.sparse_softmax_cross_entropy_with_logits(
- labels=labels, logits=logits))
- accuracy = tf.reduce_mean(
- tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32))
-
- # Register parameters. K-FAC needs to know about the inputs, outputs, and
- # parameters of each layer and the logits powering the posterior probability
- # over classes.
- tf.logging.info("Building LayerCollection.")
- layer_collection.register_fully_connected(params0, examples, pre0)
- layer_collection.register_fully_connected(params1, act0, pre1)
- layer_collection.register_fully_connected(params2, act1, pre2)
- layer_collection.register_fully_connected(params3, act2, logits)
- layer_collection.register_categorical_predictive_distribution(
- logits, name="logits")
-
- return loss, accuracy
-
-
-def minimize(loss, accuracy, layer_collection, num_towers, session_config=None):
- """Minimize 'loss' with KfacOptimizer.
-
- Args:
- loss: 0-D Tensor. Loss to be minimized.
- accuracy: 0-D Tensor. Accuracy of classifier on current minibatch.
- layer_collection: LayerCollection instance. Describes layers in model.
- num_towers: int. Number of CPUs to split minibatch across.
- session_config: tf.ConfigProto. Configuration for tf.Session().
-
- Returns:
- accuracy of classifier on final minibatch.
- """
- devices = tuple("/cpu:%d" % tower_id for tower_id in range(num_towers))
-
- # Train with K-FAC. We'll use a decreasing learning rate that's cut in 1/2
- # every 10k iterations.
- tf.logging.info("Building KFAC Optimizer.")
- global_step = tf.train.get_or_create_global_step()
- optimizer = opt.KfacOptimizer(
- learning_rate=tf.train.exponential_decay(
- 0.00002, global_step, 10000, 0.5, staircase=True),
- cov_ema_decay=0.95,
- damping=0.0005,
- layer_collection=layer_collection,
- momentum=0.99,
- placement_strategy="round_robin",
- cov_devices=devices,
- inv_devices=devices)
-
- (cov_update_thunks,
- inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
-
- def make_update_op(update_thunks):
- update_ops = [thunk() for thunk in update_thunks]
- return tf.group(*update_ops)
-
- # TODO(b/78537047): change (some) examples to use PeriodicInvCovUpdateKfacOpt
- # once that gets moved over? Could still leave more advanced examples as they
- # are (e.g. train_mnist_estimator in this file)
-
- cov_update_op = make_update_op(cov_update_thunks)
- with tf.control_dependencies([cov_update_op]):
- # We update the inverses only every 20 iterations.
- inverse_op = tf.cond(
- tf.equal(tf.mod(global_step, 100), 0),
- lambda: make_update_op(inv_update_thunks), tf.no_op)
- with tf.control_dependencies([inverse_op]):
- train_op = optimizer.minimize(loss, global_step=global_step)
-
- tf.logging.info("Starting training.")
- with tf.train.MonitoredTrainingSession(config=session_config) as sess:
- while not sess.should_stop():
- global_step_, loss_, accuracy_, _ = sess.run(
- [global_step, loss, accuracy, train_op])
-
- if global_step_ % 100 == 0:
- tf.logging.info("global_step: %d | loss: %f | accuracy: %f",
- global_step_, loss_, accuracy_)
-
- return accuracy_
-
-
-def train_mnist(data_dir, num_epochs, use_fake_data=False):
- """Train an MLP on MNIST.
-
- Args:
- data_dir: string. Directory to read MNIST examples from.
- num_epochs: int. Number of passes to make over the training set.
- use_fake_data: bool. If True, generate a synthetic dataset.
-
- Returns:
- accuracy of model on the final minibatch of training data.
- """
- # Load a dataset.
- tf.logging.info("Loading MNIST into memory.")
- examples, labels = mnist.load_mnist(
- data_dir,
- num_epochs=num_epochs,
- batch_size=64,
- flatten_images=True,
- use_fake_data=use_fake_data)
-
- # Build an MLP. The model's layers will be added to the LayerCollection.
- tf.logging.info("Building model.")
- layer_collection = lc.LayerCollection()
- loss, accuracy = build_model(examples, labels, 10, layer_collection)
-
- # Fit model.
- minimize(loss, accuracy, layer_collection, 1)
-
-
-def train_mnist_multitower(data_dir,
- num_epochs,
- num_towers,
- use_fake_data=False):
- """Train an MLP on MNIST, splitting the minibatch across multiple towers.
-
- Args:
- data_dir: string. Directory to read MNIST examples from.
- num_epochs: int. Number of passes to make over the training set.
- num_towers: int. Number of CPUs to split minibatch across.
- use_fake_data: bool. If True, generate a synthetic dataset.
-
- Returns:
- accuracy of model on the final minibatch of training data.
- """
- # Load a dataset.
- tower_batch_size = 64
- batch_size = tower_batch_size * num_towers
- tf.logging.info(
- ("Loading MNIST into memory. Using batch_size = %d = %d towers * %d "
- "tower batch size.") % (batch_size, num_towers, tower_batch_size))
- examples, labels = mnist.load_mnist(
- data_dir,
- num_epochs=num_epochs,
- batch_size=batch_size,
- flatten_images=True,
- use_fake_data=use_fake_data)
-
- # Split minibatch across towers.
- examples = tf.split(examples, num_towers)
- labels = tf.split(labels, num_towers)
-
- # Build an MLP. Each tower's layers will be added to the LayerCollection.
- layer_collection = lc.LayerCollection()
- tower_results = []
- for tower_id in range(num_towers):
- with tf.device("/cpu:%d" % tower_id):
- with tf.name_scope("tower%d" % tower_id):
- with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)):
- tf.logging.info("Building tower %d." % tower_id)
- tower_results.append(
- build_model(examples[tower_id], labels[tower_id], 10,
- layer_collection))
- losses, accuracies = zip(*tower_results)
-
- # Average across towers.
- loss = tf.reduce_mean(losses)
- accuracy = tf.reduce_mean(accuracies)
-
- # Fit model.
- session_config = tf.ConfigProto(
- allow_soft_placement=False, device_count={
- "CPU": num_towers
- })
- return minimize(
- loss, accuracy, layer_collection, num_towers,
- session_config=session_config)
-
-
-def train_mnist_estimator(data_dir, num_epochs, use_fake_data=False):
- """Train an MLP on MNIST using tf.estimator.
-
- Args:
- data_dir: string. Directory to read MNIST examples from.
- num_epochs: int. Number of passes to make over the training set.
- use_fake_data: bool. If True, generate a synthetic dataset.
-
- Returns:
- accuracy of model on the final minibatch of training data.
- """
-
- # Load a dataset.
- def input_fn():
- tf.logging.info("Loading MNIST into memory.")
- return mnist.load_mnist(
- data_dir,
- num_epochs=num_epochs,
- batch_size=64,
- flatten_images=True,
- use_fake_data=use_fake_data)
-
- def model_fn(features, labels, mode, params):
- """Model function for MLP trained with K-FAC.
-
- Args:
- features: Tensor of shape [batch_size, input_size]. Input features.
- labels: Tensor of shape [batch_size]. Target labels for training.
- mode: tf.estimator.ModeKey. Must be TRAIN.
- params: ignored.
-
- Returns:
- EstimatorSpec for training.
-
- Raises:
- ValueError: If 'mode' is anything other than TRAIN.
- """
- del params
-
- if mode != tf.estimator.ModeKeys.TRAIN:
- raise ValueError("Only training is supposed with this API.")
-
- # Build a ConvNet.
- layer_collection = lc.LayerCollection()
- loss, accuracy = build_model(
- features, labels, num_labels=10, layer_collection=layer_collection)
-
- # Train with K-FAC.
- global_step = tf.train.get_or_create_global_step()
- optimizer = opt.KfacOptimizer(
- learning_rate=tf.train.exponential_decay(
- 0.00002, global_step, 10000, 0.5, staircase=True),
- cov_ema_decay=0.95,
- damping=0.0001,
- layer_collection=layer_collection,
- momentum=0.99)
-
- (cov_update_thunks,
- inv_update_thunks) = optimizer.make_vars_and_create_op_thunks()
-
- def make_update_op(update_thunks):
- update_ops = [thunk() for thunk in update_thunks]
- return tf.group(*update_ops)
-
- def make_batch_executed_op(update_thunks, batch_size=1):
- return tf.group(*tf.contrib.kfac.utils.batch_execute(
- global_step, update_thunks, batch_size=batch_size))
-
- # Run cov_update_op every step. Run 1 inv_update_ops per step.
- cov_update_op = make_update_op(cov_update_thunks)
- with tf.control_dependencies([cov_update_op]):
- # But make sure to execute all the inverse ops on the first step
- inverse_op = tf.cond(tf.equal(global_step, 0),
- lambda: make_update_op(inv_update_thunks),
- lambda: make_batch_executed_op(inv_update_thunks))
- with tf.control_dependencies([inverse_op]):
- train_op = optimizer.minimize(loss, global_step=global_step)
-
- # Print metrics every 5 sec.
- hooks = [
- tf.train.LoggingTensorHook(
- {
- "loss": loss,
- "accuracy": accuracy
- }, every_n_secs=5),
- ]
- return tf.estimator.EstimatorSpec(
- mode=mode, loss=loss, train_op=train_op, training_hooks=hooks)
-
- run_config = tf.estimator.RunConfig(
- model_dir="/tmp/mnist", save_checkpoints_steps=1, keep_checkpoint_max=100)
-
- # Train until input_fn() is empty with Estimator. This is a prerequisite for
- # TPU compatibility.
- estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)
- estimator.train(input_fn=input_fn)
diff --git a/tensorflow/contrib/kfac/examples/mlp_mnist_main.py b/tensorflow/contrib/kfac/examples/mlp_mnist_main.py
deleted file mode 100644
index 9c34ade1d2..0000000000
--- a/tensorflow/contrib/kfac/examples/mlp_mnist_main.py
+++ /dev/null
@@ -1,64 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-r"""Train an MLP on MNIST using K-FAC.
-
-See mlp.py for details.
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import argparse
-import sys
-
-import tensorflow as tf
-
-from tensorflow.contrib.kfac.examples import mlp
-
-FLAGS = None
-
-
-def main(argv):
- _ = argv
- if FLAGS.use_estimator:
- if FLAGS.num_towers != 1:
- raise ValueError("Only 1 device supported in tf.estimator example.")
- mlp.train_mnist_estimator(FLAGS.data_dir, num_epochs=200)
- elif FLAGS.num_towers > 1:
- mlp.train_mnist_multitower(
- FLAGS.data_dir, num_epochs=200, num_towers=FLAGS.num_towers)
- else:
- mlp.train_mnist(FLAGS.data_dir, num_epochs=200)
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--data_dir",
- type=str,
- default="/tmp/mnist",
- help="Directory to store dataset in.")
- parser.add_argument(
- "--num_towers",
- type=int,
- default=1,
- help="Number of CPUs to split minibatch across.")
- parser.add_argument(
- "--use_estimator",
- action="store_true",
- help="Use tf.estimator API to train.")
- FLAGS, unparsed = parser.parse_known_args()
- tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
diff --git a/tensorflow/contrib/kfac/examples/mnist.py b/tensorflow/contrib/kfac/examples/mnist.py
deleted file mode 100644
index 547c4ab25d..0000000000
--- a/tensorflow/contrib/kfac/examples/mnist.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Utilities for loading MNIST into TensorFlow."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-import tensorflow as tf
-
-__all__ = [
- 'load_mnist',
-]
-
-
-def load_mnist(data_dir,
- num_epochs,
- batch_size,
- flatten_images=True,
- use_fake_data=False):
- """Loads MNIST dataset into memory.
-
- Args:
- data_dir: string. Directory to read MNIST examples from.
- num_epochs: int. Number of passes to make over the dataset.
- batch_size: int. Number of examples per minibatch.
- flatten_images: bool. If True, [28, 28, 1]-shaped images are flattened into
- [784]-shaped vectors.
- use_fake_data: bool. If True, generate a synthetic dataset rather than
- reading MNIST in.
-
- Returns:
- examples: Tensor of shape [batch_size, 784] if 'flatten_images' is
- True, else [batch_size, 28, 28, 1]. Each row is one example.
- Values in [0, 1].
- labels: Tensor of shape [batch_size]. Indices of integer corresponding to
- each example. Values in {0...9}.
- """
- if use_fake_data:
- rng = np.random.RandomState(42)
- num_examples = batch_size * 4
- images = rng.rand(num_examples, 28 * 28)
- if not flatten_images:
- images = np.reshape(images, [num_examples, 28, 28, 1])
- labels = rng.randint(10, size=num_examples)
- else:
- mnist_data = tf.contrib.learn.datasets.mnist.read_data_sets(
- data_dir, reshape=flatten_images)
- num_examples = len(mnist_data.train.labels)
- images = mnist_data.train.images
- labels = mnist_data.train.labels
-
- dataset = tf.data.Dataset.from_tensor_slices((np.asarray(
- images, dtype=np.float32), np.asarray(labels, dtype=np.int64)))
- return (dataset.repeat(num_epochs).shuffle(num_examples).batch(batch_size)
- .make_one_shot_iterator().get_next())
diff --git a/tensorflow/contrib/kfac/examples/tests/BUILD b/tensorflow/contrib/kfac/examples/tests/BUILD
deleted file mode 100644
index ede7f183fe..0000000000
--- a/tensorflow/contrib/kfac/examples/tests/BUILD
+++ /dev/null
@@ -1,52 +0,0 @@
-package(default_visibility = ["//visibility:private"])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-load("//tensorflow:tensorflow.bzl", "py_test")
-
-py_test(
- name = "mlp_test",
- size = "large",
- srcs = ["mlp_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_pip",
- "notsan",
- ],
- deps = [
- "//tensorflow:tensorflow_py",
- "//tensorflow/contrib/kfac/examples:mlp",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "convnet_test",
- size = "large",
- srcs = ["convnet_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_pip",
- "notsan",
- ],
- deps = [
- "//tensorflow:tensorflow_py",
- "//tensorflow/contrib/kfac",
- "//tensorflow/contrib/kfac/examples:convnet",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "mnist_test",
- srcs = ["mnist_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- "//tensorflow:tensorflow_py",
- "//tensorflow/contrib/kfac/examples:mnist",
- "//third_party/py/numpy",
- ],
-)
diff --git a/tensorflow/contrib/kfac/examples/tests/convnet_test.py b/tensorflow/contrib/kfac/examples/tests/convnet_test.py
deleted file mode 100644
index adecda7166..0000000000
--- a/tensorflow/contrib/kfac/examples/tests/convnet_test.py
+++ /dev/null
@@ -1,166 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for convnet.py."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-import tensorflow as tf
-
-from tensorflow.contrib.kfac import layer_collection as lc
-from tensorflow.contrib.kfac.examples import convnet
-
-
-class ConvNetTest(tf.test.TestCase):
-
- def testConvLayer(self):
- with tf.Graph().as_default():
- pre, act, (w, b) = convnet.conv_layer(
- layer_id=1,
- inputs=tf.zeros([5, 3, 3, 2]),
- kernel_size=3,
- out_channels=5)
- self.assertShapeEqual(np.zeros([5, 3, 3, 5]), pre)
- self.assertShapeEqual(np.zeros([5, 3, 3, 5]), act)
- self.assertShapeEqual(np.zeros([3, 3, 2, 5]), tf.convert_to_tensor(w))
- self.assertShapeEqual(np.zeros([5]), tf.convert_to_tensor(b))
- self.assertIsInstance(w, tf.Variable)
- self.assertIsInstance(b, tf.Variable)
- self.assertIn("conv_1", w.op.name)
- self.assertIn("conv_1", b.op.name)
-
- def testMaxPoolLayer(self):
- with tf.Graph().as_default():
- act = convnet.max_pool_layer(
- layer_id=1, inputs=tf.zeros([5, 6, 6, 2]), kernel_size=5, stride=3)
- self.assertShapeEqual(np.zeros([5, 2, 2, 2]), act)
- self.assertEqual(act.op.name, "pool_1/pool")
-
- def testLinearLayer(self):
- with tf.Graph().as_default():
- act, (w, b) = convnet.linear_layer(
- layer_id=1, inputs=tf.zeros([5, 20]), output_size=5)
- self.assertShapeEqual(np.zeros([5, 5]), act)
- self.assertShapeEqual(np.zeros([20, 5]), tf.convert_to_tensor(w))
- self.assertShapeEqual(np.zeros([5]), tf.convert_to_tensor(b))
- self.assertIsInstance(w, tf.Variable)
- self.assertIsInstance(b, tf.Variable)
- self.assertIn("fc_1", w.op.name)
- self.assertIn("fc_1", b.op.name)
-
- def testBuildModel(self):
- with tf.Graph().as_default():
- x = tf.placeholder(tf.float32, [None, 6, 6, 3])
- y = tf.placeholder(tf.int64, [None])
- layer_collection = lc.LayerCollection()
- loss, accuracy = convnet.build_model(
- x, y, num_labels=5, layer_collection=layer_collection)
-
- # Ensure layers and logits were registered.
- self.assertEqual(len(layer_collection.fisher_blocks), 3)
- self.assertEqual(len(layer_collection.losses), 1)
-
- # Ensure inference doesn't crash.
- with self.test_session() as sess:
- sess.run(tf.global_variables_initializer())
- feed_dict = {
- x: np.random.randn(10, 6, 6, 3).astype(np.float32),
- y: np.random.randint(5, size=10).astype(np.int64),
- }
- sess.run([loss, accuracy], feed_dict=feed_dict)
-
- def _build_toy_problem(self):
- """Construct a toy linear regression problem.
-
- Initial loss should be,
- 2.5 = 0.5 * (1^2 + 2^2)
-
- Returns:
- loss: 0-D Tensor representing loss to be minimized.
- accuracy: 0-D Tensors representing model accuracy.
- layer_collection: LayerCollection instance describing model architecture.
- """
- x = np.asarray([[1.], [2.]]).astype(np.float32)
- y = np.asarray([1., 2.]).astype(np.float32)
- x, y = (tf.data.Dataset.from_tensor_slices((x, y))
- .repeat(100).batch(2).make_one_shot_iterator().get_next())
- w = tf.get_variable("w", shape=[1, 1], initializer=tf.zeros_initializer())
- y_hat = tf.matmul(x, w)
- loss = tf.reduce_mean(0.5 * tf.square(y_hat - y))
- accuracy = loss
-
- layer_collection = lc.LayerCollection()
- layer_collection.register_fully_connected(params=w, inputs=x, outputs=y_hat)
- layer_collection.register_normal_predictive_distribution(y_hat)
-
- return loss, accuracy, layer_collection
-
- def testMinimizeLossSingleMachine(self):
- with tf.Graph().as_default():
- loss, accuracy, layer_collection = self._build_toy_problem()
- accuracy_ = convnet.minimize_loss_single_machine(
- loss, accuracy, layer_collection, device="/cpu:0")
- self.assertLess(accuracy_, 2.0)
-
- def testMinimizeLossDistributed(self):
- with tf.Graph().as_default():
- loss, accuracy, layer_collection = self._build_toy_problem()
- accuracy_ = convnet.distributed_grads_only_and_ops_chief_worker(
- task_id=0,
- is_chief=True,
- num_worker_tasks=1,
- num_ps_tasks=0,
- master="",
- checkpoint_dir=None,
- loss=loss,
- accuracy=accuracy,
- layer_collection=layer_collection)
- self.assertLess(accuracy_, 2.0)
-
- def testTrainMnistSingleMachine(self):
- with tf.Graph().as_default():
- # Ensure model training doesn't crash.
- #
- # Ideally, we should check that accuracy increases as the model converges,
- # but there are too few parameters for the model to effectively memorize
- # the training set the way an MLP can.
- convnet.train_mnist_single_machine(
- data_dir=None, num_epochs=1, use_fake_data=True, device="/cpu:0")
-
- def testTrainMnistMultitower(self):
- with tf.Graph().as_default():
- # Ensure model training doesn't crash.
- convnet.train_mnist_multitower(
- data_dir=None, num_epochs=1, num_towers=2, use_fake_data=True)
-
- def testTrainMnistDistributed(self):
- with tf.Graph().as_default():
- # Ensure model training doesn't crash.
- convnet.train_mnist_distributed_sync_replicas(
- task_id=0,
- is_chief=True,
- num_worker_tasks=1,
- num_ps_tasks=0,
- master="",
- data_dir=None,
- num_epochs=2,
- op_strategy="chief_worker",
- use_fake_data=True)
-
-
-if __name__ == "__main__":
- tf.test.main()
diff --git a/tensorflow/contrib/kfac/examples/tests/mlp_test.py b/tensorflow/contrib/kfac/examples/tests/mlp_test.py
deleted file mode 100644
index 22da6c29f1..0000000000
--- a/tensorflow/contrib/kfac/examples/tests/mlp_test.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for mlp.py."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-import tensorflow as tf
-
-from tensorflow.contrib.kfac.examples import mlp
-
-
-class MlpTest(tf.test.TestCase):
-
- def testFcLayer(self):
- with tf.Graph().as_default():
- pre, act, (w, b) = mlp.fc_layer(
- layer_id=1, inputs=tf.zeros([5, 3]), output_size=10)
- self.assertShapeEqual(np.zeros([5, 10]), pre)
- self.assertShapeEqual(np.zeros([5, 10]), act)
- self.assertShapeEqual(np.zeros([3, 10]), tf.convert_to_tensor(w))
- self.assertShapeEqual(np.zeros([10]), tf.convert_to_tensor(b))
- self.assertIsInstance(w, tf.Variable)
- self.assertIsInstance(b, tf.Variable)
- self.assertIn("fc_1/", w.op.name)
- self.assertIn("fc_1/", b.op.name)
-
- def testTrainMnist(self):
- with tf.Graph().as_default():
- # Ensure model training doesn't crash.
- #
- # Ideally, we should check that accuracy increases as the model converges,
- # but that takes a non-trivial amount of compute.
- mlp.train_mnist(data_dir=None, num_epochs=1, use_fake_data=True)
-
- def testTrainMnistMultitower(self):
- with tf.Graph().as_default():
- # Ensure model training doesn't crash.
- mlp.train_mnist_multitower(
- data_dir=None, num_epochs=1, num_towers=2, use_fake_data=True)
-
- def testTrainMnistEstimator(self):
- with tf.Graph().as_default():
- # Ensure model training doesn't crash.
- mlp.train_mnist_estimator(data_dir=None, num_epochs=1, use_fake_data=True)
-
-
-if __name__ == "__main__":
- tf.test.main()
diff --git a/tensorflow/contrib/kfac/examples/tests/mnist_test.py b/tensorflow/contrib/kfac/examples/tests/mnist_test.py
deleted file mode 100644
index 92f8462357..0000000000
--- a/tensorflow/contrib/kfac/examples/tests/mnist_test.py
+++ /dev/null
@@ -1,72 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for mnist.py."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-import tensorflow as tf
-
-from tensorflow.contrib.kfac.examples import mnist
-
-
-class MnistTest(tf.test.TestCase):
-
- def testValues(self):
- """Ensure values are in their expected range."""
- with tf.Graph().as_default():
- examples, labels = mnist.load_mnist(
- data_dir=None, num_epochs=1, batch_size=64, use_fake_data=True)
-
- with self.test_session() as sess:
- examples_, labels_ = sess.run([examples, labels])
- self.assertTrue(np.all((0 <= examples_) & (examples_ < 1)))
- self.assertTrue(np.all((0 <= labels_) & (labels_ < 10)))
-
- def testFlattenedShapes(self):
- """Ensure images are flattened into their appropriate shape."""
- with tf.Graph().as_default():
- examples, labels = mnist.load_mnist(
- data_dir=None,
- num_epochs=1,
- batch_size=64,
- flatten_images=True,
- use_fake_data=True)
-
- with self.test_session() as sess:
- examples_, labels_ = sess.run([examples, labels])
- self.assertEqual(examples_.shape, (64, 784))
- self.assertEqual(labels_.shape, (64,))
-
- def testNotFlattenedShapes(self):
- """Ensure non-flattened images are their appropriate shape."""
- with tf.Graph().as_default():
- examples, labels = mnist.load_mnist(
- data_dir=None,
- num_epochs=1,
- batch_size=64,
- flatten_images=False,
- use_fake_data=True)
-
- with self.test_session() as sess:
- examples_, labels_ = sess.run([examples, labels])
- self.assertEqual(examples_.shape, (64, 28, 28, 1))
- self.assertEqual(labels_.shape, (64,))
-
-
-if __name__ == '__main__':
- tf.test.main()
diff --git a/tensorflow/contrib/kfac/g3doc/autoencoder.png b/tensorflow/contrib/kfac/g3doc/autoencoder.png
deleted file mode 100644
index 20f93c7703..0000000000
--- a/tensorflow/contrib/kfac/g3doc/autoencoder.png
+++ /dev/null
Binary files differ
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD
deleted file mode 100644
index 6e4a8d71ba..0000000000
--- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD
+++ /dev/null
@@ -1,160 +0,0 @@
-package(default_visibility = ["//visibility:private"])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-load("//tensorflow:tensorflow.bzl", "py_test")
-
-py_test(
- name = "estimator_test",
- srcs = ["estimator_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/kfac/python/ops:fisher_estimator",
- "//tensorflow/contrib/kfac/python/ops:layer_collection",
- "//tensorflow/contrib/kfac/python/ops:utils",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:linalg_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "fisher_factors_test",
- srcs = ["fisher_factors_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/kfac/python/ops:fisher_blocks",
- "//tensorflow/contrib/kfac/python/ops:fisher_factors",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:gradients",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_seed",
- "//tensorflow/python:variables",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "fisher_blocks_test",
- srcs = ["fisher_blocks_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/kfac/python/ops:fisher_blocks",
- "//tensorflow/contrib/kfac/python/ops:layer_collection",
- "//tensorflow/contrib/kfac/python/ops:linear_operator",
- "//tensorflow/contrib/kfac/python/ops:utils",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:random_seed",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:variables",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "layer_collection_test",
- srcs = ["layer_collection_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/kfac/python/ops:fisher_blocks",
- "//tensorflow/contrib/kfac/python/ops:fisher_factors",
- "//tensorflow/contrib/kfac/python/ops:layer_collection",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:linalg_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:random_seed",
- "//tensorflow/python:variable_scope",
- ],
-)
-
-py_test(
- name = "optimizer_test",
- srcs = ["optimizer_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/kfac/python/ops:fisher_factors",
- "//tensorflow/contrib/kfac/python/ops:kfac_optimizer",
- "//tensorflow/contrib/kfac/python/ops:layer_collection",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:nn",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "utils_test",
- srcs = ["utils_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_windows"], # TODO: needs investigation on Windows
- deps = [
- "//tensorflow/contrib/kfac/python/ops:utils",
- "//tensorflow/contrib/tpu",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:linalg_ops",
- "//tensorflow/python:random_seed",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "op_queue_test",
- srcs = ["op_queue_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/kfac/python/ops:op_queue",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- ],
-)
-
-py_test(
- name = "loss_functions_test",
- srcs = ["loss_functions_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/kfac/python/ops:loss_functions",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:random_ops",
- "//third_party/py/numpy",
- ],
-)
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
deleted file mode 100644
index 76b31a5730..0000000000
--- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py
+++ /dev/null
@@ -1,310 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tf.contrib.kfac.estimator."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.kfac.python.ops import estimator
-from tensorflow.contrib.kfac.python.ops import layer_collection as lc
-from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import test
-from tensorflow.python.training import training_util
-
-_ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"]
-
-
-class EstimatorTest(test.TestCase):
-
- def setUp(self):
- self._graph = ops.Graph()
- with self._graph.as_default():
- self.layer_collection = lc.LayerCollection()
-
- self.inputs = random_ops.random_normal((2, 2), dtype=dtypes.float32)
- self.weights = variable_scope.get_variable(
- "w", shape=(2, 2), dtype=dtypes.float32)
- self.bias = variable_scope.get_variable(
- "b", initializer=init_ops.zeros_initializer(), shape=(2, 1))
- self.output = math_ops.matmul(self.inputs, self.weights) + self.bias
-
- # Only register the weights.
- self.layer_collection.register_fully_connected(
- params=(self.weights,), inputs=self.inputs, outputs=self.output)
-
- self.outputs = math_ops.tanh(self.output)
- self.targets = array_ops.zeros_like(self.outputs)
- self.layer_collection.register_categorical_predictive_distribution(
- logits=self.outputs, targets=self.targets)
-
- def testEstimatorInitManualRegistration(self):
- with self._graph.as_default():
- # We should be able to build an estimator for only the registered vars.
- estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- cov_ema_decay=0.1,
- damping=0.2,
- layer_collection=self.layer_collection
- )
-
- # Check that we throw an error if we try to build an estimator for vars
- # that were not manually registered.
- with self.assertRaises(ValueError):
- est = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights, self.bias],
- cov_ema_decay=0.1,
- damping=0.2,
- layer_collection=self.layer_collection
- )
- est.make_vars_and_create_op_thunks()
-
- # Check that we throw an error if we don't include registered variables,
- # i.e. self.weights
- with self.assertRaises(ValueError):
- est = estimator.FisherEstimatorRoundRobin(
- variables=[],
- cov_ema_decay=0.1,
- damping=0.2,
- layer_collection=self.layer_collection)
- est.make_vars_and_create_op_thunks()
-
- @test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42)
- def testVariableWrongNumberOfUses(self, mock_uses):
- with self.assertRaises(ValueError):
- est = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- cov_ema_decay=0.1,
- damping=0.2,
- layer_collection=self.layer_collection)
- est.make_vars_and_create_op_thunks()
-
- def testInvalidEstimationMode(self):
- with self.assertRaises(ValueError):
- est = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- cov_ema_decay=0.1,
- damping=0.2,
- layer_collection=self.layer_collection,
- estimation_mode="not_a_real_mode")
- est.make_vars_and_create_op_thunks()
-
- def testGradientsModeBuild(self):
- with self._graph.as_default():
- est = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- cov_ema_decay=0.1,
- damping=0.2,
- layer_collection=self.layer_collection,
- estimation_mode="gradients")
- est.make_vars_and_create_op_thunks()
-
- def testEmpiricalModeBuild(self):
- with self._graph.as_default():
- est = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- cov_ema_decay=0.1,
- damping=0.2,
- layer_collection=self.layer_collection,
- estimation_mode="empirical")
- est.make_vars_and_create_op_thunks()
-
- def testCurvaturePropModeBuild(self):
- with self._graph.as_default():
- est = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- cov_ema_decay=0.1,
- damping=0.2,
- layer_collection=self.layer_collection,
- estimation_mode="curvature_prop")
- est.make_vars_and_create_op_thunks()
-
- def testExactModeBuild(self):
- with self._graph.as_default():
- est = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- cov_ema_decay=0.1,
- damping=0.2,
- layer_collection=self.layer_collection,
- estimation_mode="exact")
- est.make_vars_and_create_op_thunks()
-
- def test_cov_update_thunks(self):
- """Ensures covariance update ops run once per global_step."""
- with self._graph.as_default(), self.cached_session() as sess:
- fisher_estimator = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- layer_collection=self.layer_collection,
- damping=0.2,
- cov_ema_decay=0.0)
-
- # Construct an op that executes one covariance update per step.
- global_step = training_util.get_or_create_global_step()
- (cov_variable_thunks, cov_update_op_thunks, _,
- _) = fisher_estimator.create_ops_and_vars_thunks()
- for thunk in cov_variable_thunks:
- thunk()
- cov_matrices = [
- fisher_factor.get_cov()
- for fisher_factor in self.layer_collection.get_factors()
- ]
- cov_update_op = control_flow_ops.case(
- [(math_ops.equal(global_step, i), thunk)
- for i, thunk in enumerate(cov_update_op_thunks)])
- increment_global_step = global_step.assign_add(1)
-
- sess.run(variables.global_variables_initializer())
- initial_cov_values = sess.run(cov_matrices)
-
- # Ensure there's one update per covariance matrix.
- self.assertEqual(len(cov_matrices), len(cov_update_op_thunks))
-
- # Test is no-op if only 1 covariance matrix.
- assert len(cov_matrices) > 1
-
- for i in range(len(cov_matrices)):
- # Compare new and old covariance values
- new_cov_values = sess.run(cov_matrices)
- is_cov_equal = [
- np.allclose(initial_cov_value, new_cov_value)
- for (initial_cov_value,
- new_cov_value) in zip(initial_cov_values, new_cov_values)
- ]
- num_cov_equal = sum(is_cov_equal)
-
- # Ensure exactly one covariance matrix changes per step.
- self.assertEqual(num_cov_equal, len(cov_matrices) - i)
-
- # Run all covariance update ops.
- sess.run(cov_update_op)
- sess.run(increment_global_step)
-
- def test_round_robin_placement(self):
- """Check if the ops and variables are placed on devices correctly."""
- with self._graph.as_default():
- fisher_estimator = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- layer_collection=self.layer_collection,
- damping=0.2,
- cov_ema_decay=0.0,
- cov_devices=["/cpu:{}".format(i) for i in range(2)],
- inv_devices=["/cpu:{}".format(i) for i in range(2)])
-
- # Construct an op that executes one covariance update per step.
- (cov_update_thunks,
- inv_update_thunks) = fisher_estimator.make_vars_and_create_op_thunks(
- scope="test")
- cov_update_ops = tuple(thunk() for thunk in cov_update_thunks)
- inv_update_ops = tuple(thunk() for thunk in inv_update_thunks)
- self.assertEqual(cov_update_ops[0].device, "/device:CPU:0")
- self.assertEqual(cov_update_ops[1].device, "/device:CPU:1")
- self.assertEqual(inv_update_ops[0].device, "/device:CPU:0")
- self.assertEqual(inv_update_ops[1].device, "/device:CPU:1")
- cov_matrices = [
- fisher_factor.get_cov()
- for fisher_factor in self.layer_collection.get_factors()
- ]
- inv_matrices = [
- matrix
- for fisher_factor in self.layer_collection.get_factors()
- for matrix in fisher_factor._matpower_by_exp_and_damping.values()
- ]
- self.assertEqual(cov_matrices[0].device, "/device:CPU:0")
- self.assertEqual(cov_matrices[1].device, "/device:CPU:1")
- # Inverse matrices need to be explicitly placed.
- self.assertEqual(inv_matrices[0].device, "")
- self.assertEqual(inv_matrices[1].device, "")
-
- def test_inv_update_thunks(self):
- """Ensures inverse update ops run once per global_step."""
- with self._graph.as_default(), self.cached_session() as sess:
- fisher_estimator = estimator.FisherEstimatorRoundRobin(
- variables=[self.weights],
- layer_collection=self.layer_collection,
- damping=0.2,
- cov_ema_decay=0.0)
-
- # Construct op that updates one inverse per global step.
- global_step = training_util.get_or_create_global_step()
- (cov_variable_thunks, _, inv_variable_thunks,
- inv_update_op_thunks) = fisher_estimator.create_ops_and_vars_thunks()
- for thunk in cov_variable_thunks:
- thunk()
- for thunk in inv_variable_thunks:
- thunk()
- inv_matrices = [
- matrix
- for fisher_factor in self.layer_collection.get_factors()
- for matrix in fisher_factor._matpower_by_exp_and_damping.values()
- ]
- inv_update_op = control_flow_ops.case(
- [(math_ops.equal(global_step, i), thunk)
- for i, thunk in enumerate(inv_update_op_thunks)])
- increment_global_step = global_step.assign_add(1)
-
- sess.run(variables.global_variables_initializer())
- initial_inv_values = sess.run(inv_matrices)
-
- # Ensure there's one update per inverse matrix. This is true as long as
- # there's no fan-in/fan-out or parameter re-use.
- self.assertEqual(len(inv_matrices), len(inv_update_op_thunks))
-
- # Test is no-op if only 1 invariance matrix.
- assert len(inv_matrices) > 1
-
- # Assign each covariance matrix a value other than the identity. This
- # ensures that the inverse matrices are updated to something different as
- # well.
- cov_matrices = [
- fisher_factor.get_cov()
- for fisher_factor in self.layer_collection.get_factors()
- ]
- sess.run([
- cov_matrix.assign(2 * linalg_ops.eye(int(cov_matrix.shape[0])))
- for cov_matrix in cov_matrices
- ])
-
- for i in range(len(inv_matrices)):
- # Compare new and old inverse values
- new_inv_values = sess.run(inv_matrices)
- is_inv_equal = [
- np.allclose(initial_inv_value, new_inv_value)
- for (initial_inv_value,
- new_inv_value) in zip(initial_inv_values, new_inv_values)
- ]
- num_inv_equal = sum(is_inv_equal)
-
- # Ensure exactly one inverse matrix changes per step.
- self.assertEqual(num_inv_equal, len(inv_matrices) - i)
-
- # Run all inverse update ops.
- sess.run(inv_update_op)
- sess.run(increment_global_step)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
deleted file mode 100644
index f845def507..0000000000
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py
+++ /dev/null
@@ -1,1018 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tf.contrib.kfac.fisher_blocks."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
-from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
-from tensorflow.contrib.kfac.python.ops import layer_collection as lc
-from tensorflow.contrib.kfac.python.ops import linear_operator as lo
-from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import random_seed
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import variables as tf_variables
-from tensorflow.python.platform import test
-
-
-# We need to set these constants since the numerical values used in the tests
-# were chosen when these used to be the defaults.
-ff.set_global_constants(init_covariances_at_zero=False,
- zero_debias=False,
- init_inverses_at_zero=False)
-
-# TODO(b/78538100): As far as I can tell, all the tests that say "Make sure our
-# inverse is something other than the identity" are actually broken. They never
-# run the covariance update ops and so the inverse actually is the identity
-# (possible plus the damping term, which would still make it a multiple of the
-# identity).
-
-
-def _make_psd(dim):
- """Constructs a PSD matrix of the given dimension."""
- mat = np.ones((dim, dim), dtype=np.float32)
- mat[np.arange(dim), np.arange(dim)] = 2. + np.arange(dim)
- return array_ops.constant(mat)
-
-
-class UtilsTest(test.TestCase):
-
- def testComputePiTracenorm(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- diag = ops.convert_to_tensor([1., 2., 0., 1.])
- left_factor = lo.LinearOperatorDiag(diag)
- right_factor = lo.LinearOperatorFullMatrix(array_ops.ones([2, 2]))
-
- # pi is the sqrt of the left trace norm divided by the right trace norm
- pi = fb.compute_pi_tracenorm(left_factor, right_factor)
-
- pi_val = sess.run(pi)
- self.assertEqual(1., pi_val)
-
-
-class FullFBTest(test.TestCase):
-
- def testFullFBInitSingleTensor(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.FullFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
-
- self.assertAllEqual(params, block.tensors_to_compute_grads())
-
- def testFullFBInitTensorTuple(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.FullFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
-
- self.assertAllEqual(params, block.tensors_to_compute_grads())
-
- def testInstantiateFactors(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.FullFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
-
- grads = (params[0]**2, math_ops.sqrt(params[1]))
- block.instantiate_factors(grads, 0.5)
-
- def testMultiplyInverseTuple(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.FullFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
- grads = (params[0]**2, math_ops.sqrt(params[1]))
- block.instantiate_factors((grads,), 0.5)
- block._factor.instantiate_cov_variables()
- block.register_inverse()
- block._factor.instantiate_inv_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._factor.make_inverse_update_ops())
-
- vector = array_ops.ones(3,) * 2
- output = block.multiply_inverse(vector)
-
- self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
-
- def testMultiplyInverseNotTuple(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- params = array_ops.constant([[1.], [2.]])
- block = fb.FullFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
- grads = params**2
- block.instantiate_factors((grads,), 0.5)
- block._factor.instantiate_cov_variables()
- block.register_inverse()
- block._factor.instantiate_inv_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._factor.make_inverse_update_ops())
-
- vector = array_ops.ones(2,) * 2
- output = block.multiply_inverse(vector)
-
- self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
-
- def testMultiplyInverseAgainstExplicit(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.FullFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
- grads = (array_ops.constant([2., 3.]), array_ops.constant(4.))
- damping = 0.5
- block.instantiate_factors((grads,), damping)
- block._factor.instantiate_cov_variables()
- block.register_inverse()
- block._factor.instantiate_inv_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(state_ops.assign(block._factor._cov, _make_psd(3)))
- sess.run(block._factor.make_inverse_update_ops())
-
- v_flat = np.array([4., 5., 6.], dtype=np.float32)
- vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
- output = block.multiply_inverse(vector)
- output_flat = sess.run(utils.tensors_to_column(output)).ravel()
-
- full = sess.run(block.full_fisher_block())
- explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat)
-
- self.assertAllClose(output_flat, explicit)
-
-
-class NaiveDiagonalFBTest(test.TestCase):
-
- def testNaiveDiagonalFBInitSingleTensor(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
-
- self.assertAllEqual(params, block.tensors_to_compute_grads())
-
- def testNaiveDiagonalFBInitTensorTuple(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
-
- self.assertAllEqual(params, block.tensors_to_compute_grads())
-
- def testInstantiateFactors(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
-
- grads = (params[0]**2, math_ops.sqrt(params[1]))
- block.instantiate_factors(grads, 0.5)
-
- def testMultiplyInverseTuple(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
- grads = (params[0]**2, math_ops.sqrt(params[1]))
- block.instantiate_factors((grads,), 0.5)
- block._factor.instantiate_cov_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._factor.make_inverse_update_ops())
-
- vector = array_ops.ones(3,) * 2
- output = block.multiply_inverse(vector)
-
- self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
-
- def testMultiplyInverseNotTuple(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- params = array_ops.constant([[1.], [2.]])
- block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
- grads = params**2
- block.instantiate_factors((grads,), 0.5)
- block._factor.instantiate_cov_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._factor.make_inverse_update_ops())
- vector = array_ops.ones(2,) * 2
- output = block.multiply_inverse(vector)
-
- self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output))
-
- def testMultiplyInverseAgainstExplicit(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- params = (array_ops.constant([1., 2.]), array_ops.constant(3.))
- block = fb.NaiveDiagonalFB(lc.LayerCollection(), params)
- block.register_additional_tower(32)
- grads = (params[0]**2, math_ops.sqrt(params[1]))
- damping = 0.5
- block.instantiate_factors((grads,), damping)
- block._factor.instantiate_cov_variables()
-
- cov = array_ops.reshape(array_ops.constant([2., 3., 4.]), [-1, 1])
- sess.run(state_ops.assign(block._factor._cov, cov))
- sess.run(block._factor.make_inverse_update_ops())
-
- v_flat = np.array([4., 5., 6.], dtype=np.float32)
- vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
- output = block.multiply_inverse(vector)
- output_flat = sess.run(utils.tensors_to_column(output)).ravel()
-
- full = sess.run(block.full_fisher_block())
- explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat)
- self.assertAllClose(output_flat, explicit)
-
-
-class FullyConnectedDiagonalFBTest(test.TestCase):
-
- def setUp(self):
- super(FullyConnectedDiagonalFBTest, self).setUp()
-
- self.batch_size = 4
- self.input_size = 6
- self.output_size = 3
-
- self.inputs = np.random.randn(self.batch_size, self.input_size).astype(
- np.float32)
- self.outputs = np.zeros([self.batch_size, self.output_size]).astype(
- np.float32)
- self.output_grads = np.random.randn(self.batch_size,
- self.output_size).astype(np.float32)
- self.w = np.random.randn(self.input_size, self.output_size).astype(
- np.float32)
- self.b = np.random.randn(self.output_size).astype(np.float32)
-
- def fisherApprox(self, has_bias=False):
- """Fisher approximation using default inputs."""
- if has_bias:
- inputs = np.concatenate(
- [self.inputs, np.ones([self.batch_size, 1])], axis=1)
- else:
- inputs = self.inputs
- return self.buildDiagonalFisherApproximation(inputs, self.output_grads)
-
- def buildDiagonalFisherApproximation(self, inputs, output_grads):
- """Builds explicit diagonal Fisher approximation.
-
- Fisher's diagonal is (d loss / d w)'s elements squared for
- d/dw = E[outer(input, output_grad)]
-
- where the expectation is taken over examples.
-
- Args:
- inputs: np.array of shape [batch_size, input_size].
- output_grads: np.array of shape [batch_size, output_size].
-
- Returns:
- Diagonal np.array of shape [num_params, num_params] for num_params =
- input_size * output_size.
- """
- batch_size = inputs.shape[0]
- assert output_grads.shape[0] == batch_size
- input_size = inputs.shape[1]
- output_size = output_grads.shape[1]
- fisher_diag = np.zeros((input_size, output_size))
- for i in range(batch_size):
- fisher_diag += np.square(np.outer(inputs[i], output_grads[i]))
- return np.diag(fisher_diag.flatten()) / batch_size
-
- def testMultiply(self):
- result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
- [self.output_grads])
-
- # Construct Fisher-vector product.
- expected_result = self.fisherApprox().dot(self.w.flatten())
- expected_result = expected_result.reshape(
- [self.input_size, self.output_size])
-
- self.assertAllClose(expected_result, result)
-
- def testMultiplyInverse(self):
- _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
- [self.output_grads])
-
- # Construct inverse Fisher-vector product.
- expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten())
- expected_result = expected_result.reshape(
- [self.input_size, self.output_size])
-
- self.assertAllClose(expected_result, result)
-
- def testRegisterAdditionalTower(self):
- """Ensure 1 big tower and 2 small towers are equivalent."""
- multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps(
- self.w, [self.inputs], [self.outputs], [self.output_grads])
- multiply_result_small, multiply_inverse_result_small = (
- self.runFisherBlockOps(self.w, np.split(self.inputs, 2),
- np.split(self.outputs, 2),
- np.split(self.output_grads, 2)))
-
- self.assertAllClose(multiply_result_big, multiply_result_small)
- self.assertAllClose(multiply_inverse_result_big,
- multiply_inverse_result_small)
-
- def testMultiplyHasBias(self):
- result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs],
- [self.outputs], [self.output_grads])
- expected_result = self.fisherApprox(True).dot(
- np.concatenate([self.w.flatten(), self.b.flatten()]))
- expected_result = expected_result.reshape(
- [self.input_size + 1, self.output_size])
- expected_result = (expected_result[:-1], expected_result[-1])
-
- self.assertEqual(len(result), 2)
- self.assertAllClose(expected_result[0], result[0])
- self.assertAllClose(expected_result[1], result[1])
-
- def runFisherBlockOps(self, params, inputs, outputs, output_grads):
- """Run Ops guaranteed by FisherBlock interface.
-
- Args:
- params: Tensor or 2-tuple of Tensors. Represents weights or weights and
- bias of this layer.
- inputs: list of Tensors of shape [batch_size, input_size]. Inputs to
- layer.
- outputs: list of Tensors of shape [batch_size, output_size].
- Preactivations produced by layer.
- output_grads: list of Tensors of shape [batch_size, output_size].
- Gradient of loss with respect to 'outputs'.
-
- Returns:
- multiply_result: Result of FisherBlock.multiply(params)
- multiply_inverse_result: Result of FisherBlock.multiply_inverse(params)
- """
- with ops.Graph().as_default(), self.cached_session() as sess:
- inputs = as_tensors(inputs)
- outputs = as_tensors(outputs)
- output_grads = as_tensors(output_grads)
- params = as_tensors(params)
-
- block = fb.FullyConnectedDiagonalFB(
- lc.LayerCollection(), has_bias=isinstance(params, (tuple, list)))
- for (i, o) in zip(inputs, outputs):
- block.register_additional_tower(i, o)
-
- block.instantiate_factors((output_grads,), damping=0.0)
- block._factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._factor.make_covariance_update_op(0.0))
- multiply_result = sess.run(block.multiply(params))
- multiply_inverse_result = sess.run(block.multiply_inverse(params))
-
- return multiply_result, multiply_inverse_result
-
-
-class EmbeddingKFACFBTest(test.TestCase):
-
- def testInstantiateFactors(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
-
- # Create a Fisher Block.
- vocab_size = 5
- block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size)
-
- # Add some examples.
- inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]])
- outputs = array_ops.constant([[0.], [1.], [2.]])
- block.register_additional_tower(inputs, outputs)
-
- # Instantiate factor's variables. Ensure it doesn't fail.
- grads = outputs**2.
- damping = array_ops.constant(0.)
- block.instantiate_factors(((grads,),), damping)
-
- def testMultiplyInverse(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
-
- # Create a Fisher Block.
- vocab_size = 5
- block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size)
-
- # Add some examples.
- inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]])
- outputs = array_ops.constant([[0.], [1.], [2.]])
- block.register_additional_tower(inputs, outputs)
-
- # Instantiate factor's variables. Ensure it doesn't fail.
- grads = outputs**2.
- damping = array_ops.constant(0.)
- block.instantiate_factors(((grads,),), damping)
- block._input_factor.instantiate_cov_variables()
- block._output_factor.instantiate_cov_variables()
- block.register_inverse()
- block._input_factor.instantiate_inv_variables()
- block._output_factor.instantiate_inv_variables()
-
- # Create a sparse update.
- indices = array_ops.constant([1, 3, 4])
- values = array_ops.constant([[1.], [1.], [1.]])
- sparse_vector = ops.IndexedSlices(
- values, indices, dense_shape=[vocab_size, 1])
- dense_vector = array_ops.reshape([0., 1., 0., 1., 1.], [vocab_size, 1])
-
- # Compare Fisher-vector product against explicit result.
- result = block.multiply_inverse(sparse_vector)
- expected_result = linalg_ops.matrix_solve(block.full_fisher_block(),
- dense_vector)
-
- sess.run(tf_variables.global_variables_initializer())
- self.assertAlmostEqual(
- sess.run(expected_result[1]), sess.run(result.values[0]))
- self.assertAlmostEqual(
- sess.run(expected_result[3]), sess.run(result.values[1]))
- self.assertAlmostEqual(
- sess.run(expected_result[4]), sess.run(result.values[2]))
-
-
-class FullyConnectedKFACBasicFBTest(test.TestCase):
-
- def testFullyConnectedKFACBasicFBInit(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- inputs = array_ops.constant([1., 2.])
- outputs = array_ops.constant([3., 4.])
- block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection())
- block.register_additional_tower(inputs, outputs)
-
- self.assertAllEqual([outputs], block.tensors_to_compute_grads())
-
- def testInstantiateFactorsHasBias(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- inputs = array_ops.constant([[1., 2.], [3., 4.]])
- outputs = array_ops.constant([[3., 4.], [5., 6.]])
- block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=True)
- block.register_additional_tower(inputs, outputs)
-
- grads = outputs**2
- block.instantiate_factors(((grads,),), 0.5)
-
- def testInstantiateFactorsNoBias(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- inputs = array_ops.constant([[1., 2.], [3., 4.]])
- outputs = array_ops.constant([[3., 4.], [5., 6.]])
- block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
- block.register_additional_tower(inputs, outputs)
-
- grads = outputs**2
- block.instantiate_factors(((grads,),), 0.5)
-
- def testMultiplyInverseTuple(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- inputs = array_ops.constant([[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]])
- outputs = array_ops.constant([[3., 4.], [5., 6.]])
- block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
- block.register_additional_tower(inputs, outputs)
- grads = outputs**2
- block.instantiate_factors(((grads,),), 0.5)
-
- block._input_factor.instantiate_cov_variables()
- block._output_factor.instantiate_cov_variables()
- block.register_inverse()
- block._input_factor.instantiate_inv_variables()
- block._output_factor.instantiate_inv_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._input_factor.make_inverse_update_ops())
- sess.run(block._output_factor.make_inverse_update_ops())
-
- vector = (
- np.arange(2, 6).reshape(2, 2).astype(np.float32), #
- np.arange(1, 3).reshape(2, 1).astype(np.float32))
- output = block.multiply_inverse((array_ops.constant(vector[0]),
- array_ops.constant(vector[1])))
-
- output = sess.run(output)
- self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]],
- output[0])
- self.assertAllClose([0.343146, 0.686291], output[1])
-
- def testMultiplyInverseNotTuple(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- inputs = array_ops.constant([[1., 2.], [3., 4.]])
- outputs = array_ops.constant([[3., 4.], [5., 6.]])
- block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
- block.register_additional_tower(inputs, outputs)
- grads = outputs**2
- block.instantiate_factors(((grads,),), 0.5)
- block._input_factor.instantiate_cov_variables()
- block._output_factor.instantiate_cov_variables()
- block.register_inverse()
- block._input_factor.instantiate_inv_variables()
- block._output_factor.instantiate_inv_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._input_factor.make_inverse_update_ops())
- sess.run(block._output_factor.make_inverse_update_ops())
-
- vector = np.arange(2, 6).reshape(2, 2).astype(np.float32)
- output = block.multiply_inverse(array_ops.constant(vector))
-
- self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]],
- sess.run(output))
-
- def testMultiplyInverseAgainstExplicit(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- input_dim, output_dim = 3, 2
- inputs = array_ops.zeros([32, input_dim])
- outputs = array_ops.zeros([32, output_dim])
- params = array_ops.zeros([input_dim, output_dim])
- block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False)
- block.register_additional_tower(inputs, outputs)
- grads = outputs**2
- damping = 0. # This test is only valid without damping.
- block.instantiate_factors(((grads,),), damping)
- block._input_factor.instantiate_cov_variables()
- block._output_factor.instantiate_cov_variables()
-
- sess.run(state_ops.assign(block._input_factor._cov, _make_psd(3)))
- sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
-
- block.register_inverse()
- block._input_factor.instantiate_inv_variables()
- block._output_factor.instantiate_inv_variables()
-
- sess.run(block._input_factor.make_inverse_update_ops())
- sess.run(block._output_factor.make_inverse_update_ops())
-
- v_flat = np.arange(6, dtype=np.float32)
- vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
- output = block.multiply_inverse(vector)
- output_flat = sess.run(utils.tensors_to_column(output)).ravel()
-
- full = sess.run(block.full_fisher_block())
- explicit = np.dot(np.linalg.inv(full + damping * np.eye(6)), v_flat)
-
- self.assertAllClose(output_flat, explicit)
-
-
-class ConvDiagonalFBTest(test.TestCase):
-
- def setUp(self):
- super(ConvDiagonalFBTest, self).setUp()
-
- self.batch_size = 2
- self.height = 8
- self.width = 4
- self.input_channels = 6
- self.output_channels = 3
- self.kernel_size = 1
-
- self.inputs = np.random.randn(self.batch_size, self.height, self.width,
- self.input_channels).astype(np.float32)
- self.outputs = np.zeros(
- [self.batch_size, self.height, self.width,
- self.output_channels]).astype(np.float32)
- self.output_grads = np.random.randn(
- self.batch_size, self.height, self.width, self.output_channels).astype(
- np.float32)
- self.w = np.random.randn(self.kernel_size, self.kernel_size,
- self.input_channels, self.output_channels).astype(
- np.float32)
- self.b = np.random.randn(self.output_channels).astype(np.float32)
-
- def fisherApprox(self, has_bias=False):
- """Fisher approximation using default inputs."""
- if has_bias:
- inputs = np.concatenate(
- [self.inputs,
- np.ones([self.batch_size, self.height, self.width, 1])],
- axis=-1)
- else:
- inputs = self.inputs
- return self.buildDiagonalFisherApproximation(inputs, self.output_grads,
- self.kernel_size)
-
- def buildDiagonalFisherApproximation(self, inputs, output_grads, kernel_size):
- r"""Builds explicit diagonal Fisher approximation.
-
- Fisher's diagonal is (d loss / d w)'s elements squared for
- d/dw = E[\sum_{loc} outer(input_{loc}, output_grad_{loc})]
-
- where the expectation is taken over examples and the sum over (x, y)
- locations upon which the convolution is applied.
-
- Args:
- inputs: np.array of shape [batch_size, height, width, input_channels].
- output_grads: np.array of shape [batch_size, height, width,
- output_channels].
- kernel_size: int. height and width of kernel.
-
- Returns:
- Diagonal np.array of shape [num_params, num_params] for num_params =
- kernel_size^2 * input_channels * output_channels.
- """
- batch_size, height, width, input_channels = inputs.shape
- assert output_grads.shape[0] == batch_size
- assert output_grads.shape[1] == height
- assert output_grads.shape[2] == width
- output_channels = output_grads.shape[3]
-
- # If kernel_size == 1, then we don't need to worry about capturing context
- # around the pixel upon which a convolution is applied. This makes testing
- # easier.
- assert kernel_size == 1, "kernel_size != 1 isn't supported."
- num_locations = height * width
- inputs = np.reshape(inputs, [batch_size, num_locations, input_channels])
- output_grads = np.reshape(output_grads,
- [batch_size, num_locations, output_channels])
-
- fisher_diag = np.zeros((input_channels, output_channels))
- for i in range(batch_size):
- # Each example's approximation is a square(sum-of-outer-products).
- example_fisher_diag = np.zeros((input_channels, output_channels))
- for j in range(num_locations):
- example_fisher_diag += np.outer(inputs[i, j], output_grads[i, j])
- fisher_diag += np.square(example_fisher_diag)
-
- # Normalize by batch_size (not num_locations).
- return np.diag(fisher_diag.flatten()) / batch_size
-
- def testMultiply(self):
- result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
- [self.output_grads])
-
- # Construct Fisher-vector product.
- expected_result = self.fisherApprox().dot(self.w.flatten())
- expected_result = expected_result.reshape([
- self.kernel_size, self.kernel_size, self.input_channels,
- self.output_channels
- ])
-
- self.assertAllClose(expected_result, result)
-
- def testMultiplyInverse(self):
- _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs],
- [self.output_grads])
-
- # Construct inverse Fisher-vector product.
- expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten())
- expected_result = expected_result.reshape([
- self.kernel_size, self.kernel_size, self.input_channels,
- self.output_channels
- ])
-
- self.assertAllClose(expected_result, result, atol=1e-3)
-
- def testRegisterAdditionalTower(self):
- """Ensure 1 big tower and 2 small towers are equivalent."""
- multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps(
- self.w, [self.inputs], [self.outputs], [self.output_grads])
- multiply_result_small, multiply_inverse_result_small = (
- self.runFisherBlockOps(self.w, np.split(self.inputs, 2),
- np.split(self.outputs, 2),
- np.split(self.output_grads, 2)))
-
- self.assertAllClose(multiply_result_big, multiply_result_small)
- self.assertAllClose(multiply_inverse_result_big,
- multiply_inverse_result_small)
-
- def testMultiplyHasBias(self):
- result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs],
- [self.outputs], [self.output_grads])
- # Clone 'b' along 'input_channels' dimension.
- b_filter = np.tile(
- np.reshape(self.b, [1, 1, 1, self.output_channels]),
- [self.kernel_size, self.kernel_size, 1, 1])
- params = np.concatenate([self.w, b_filter], axis=2)
- expected_result = self.fisherApprox(True).dot(params.flatten())
-
- # Extract 'b' from concatenated parameters.
- expected_result = expected_result.reshape([
- self.kernel_size, self.kernel_size, self.input_channels + 1,
- self.output_channels
- ])
- expected_result = (expected_result[:, :, 0:-1, :],
- np.reshape(expected_result[:, :, -1, :],
- [self.output_channels]))
-
- self.assertEqual(len(result), 2)
- self.assertAllClose(expected_result[0], result[0])
- self.assertAllClose(expected_result[1], result[1])
-
- def runFisherBlockOps(self, params, inputs, outputs, output_grads):
- """Run Ops guaranteed by FisherBlock interface.
-
- Args:
- params: Tensor or 2-tuple of Tensors. Represents weights or weights and
- bias of this layer.
- inputs: list of Tensors of shape [batch_size, input_size]. Inputs to
- layer.
- outputs: list of Tensors of shape [batch_size, output_size].
- Preactivations produced by layer.
- output_grads: list of Tensors of shape [batch_size, output_size].
- Gradient of loss with respect to 'outputs'.
-
- Returns:
- multiply_result: Result of FisherBlock.multiply(params)
- multiply_inverse_result: Result of FisherBlock.multiply_inverse(params)
- """
- with ops.Graph().as_default(), self.cached_session() as sess:
- inputs = as_tensors(inputs)
- outputs = as_tensors(outputs)
- output_grads = as_tensors(output_grads)
- params = as_tensors(params)
-
- block = fb.ConvDiagonalFB(
- lc.LayerCollection(), params, strides=[1, 1, 1, 1], padding='SAME')
- for (i, o) in zip(inputs, outputs):
- block.register_additional_tower(i, o)
-
- block.instantiate_factors((output_grads,), damping=0.0)
- block._factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._factor.make_covariance_update_op(0.0))
- multiply_result = sess.run(block.multiply(params))
- multiply_inverse_result = sess.run(block.multiply_inverse(params))
-
- return multiply_result, multiply_inverse_result
-
-
-class DepthwiseConvKFCBasicFBTest(test.TestCase):
-
- def testInstantiateFactors(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- params = random_ops.random_normal((3, 3, 8, 2))
- inputs = random_ops.random_normal((32, 5, 5, 8))
- outputs = random_ops.random_normal((32, 5, 5, 16))
- layer_collection = lc.LayerCollection()
- block = fb.DepthwiseConvKFCBasicFB(
- layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME')
- block.register_additional_tower(inputs, outputs)
- grads = outputs**2
- block.instantiate_factors(([grads],), 0.5)
-
- def testMultiplyInverse(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- params = random_ops.random_normal((3, 3, 8, 2))
- inputs = random_ops.random_normal((32, 5, 5, 8))
- outputs = random_ops.random_normal((32, 5, 5, 16))
- layer_collection = lc.LayerCollection()
- block = fb.DepthwiseConvKFCBasicFB(
- layer_collection, params=params, strides=[1, 1, 1, 1], padding='SAME')
- block.register_additional_tower(inputs, outputs)
- grads = outputs**2
- block.instantiate_factors(([grads],), 0.5)
- block._input_factor.instantiate_cov_variables()
- block._output_factor.instantiate_cov_variables()
- block.register_inverse()
- block._input_factor.instantiate_inv_variables()
- block._output_factor.instantiate_inv_variables()
-
- # Ensure inverse update op doesn't crash.
- sess.run(tf_variables.global_variables_initializer())
- sess.run([
- factor.make_inverse_update_ops()
- for factor in layer_collection.get_factors()
- ])
-
- # Ensure inverse-vector multiply doesn't crash.
- output = block.multiply_inverse(params)
- sess.run(output)
-
- # Ensure same shape.
- self.assertAllEqual(output.shape, params.shape)
-
-
-class ConvKFCBasicFBTest(test.TestCase):
-
- def _testConvKFCBasicFBInitParams(self, params):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- if isinstance(params, (list, tuple)):
- params = [array_ops.constant(param) for param in params]
- else:
- params = array_ops.constant(params)
- inputs = random_ops.random_normal((2, 2, 2))
- outputs = random_ops.random_normal((2, 2, 2))
- block = fb.ConvKFCBasicFB(
- lc.LayerCollection(), params=params, padding='SAME')
- block.register_additional_tower(inputs, outputs)
-
- self.assertAllEqual([outputs], block.tensors_to_compute_grads())
-
- def testConvKFCBasicFBInitParamsParamsTuple(self):
- self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2]), np.ones([2])])
-
- def testConvKFCBasicFBInitParamsParamsSingle(self):
- self._testConvKFCBasicFBInitParams([np.ones([1, 2, 2])])
-
- def testMultiplyInverseTuple(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- params = random_ops.random_normal((2, 2, 2, 2))
- inputs = random_ops.random_normal((2, 2, 2, 2))
- outputs = random_ops.random_normal((2, 2, 2, 2))
- block = fb.ConvKFCBasicFB(
- lc.LayerCollection(), params=params, padding='SAME')
- block.register_additional_tower(inputs, outputs)
- grads = outputs**2
- block.instantiate_factors(((grads,),), 0.5)
- block._input_factor.instantiate_cov_variables()
- block._output_factor.instantiate_cov_variables()
- block.register_inverse()
- block._input_factor.instantiate_inv_variables()
- block._output_factor.instantiate_inv_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._input_factor.make_inverse_update_ops())
- sess.run(block._output_factor.make_inverse_update_ops())
-
- vector = (np.arange(1, 15).reshape(7, 2).astype(np.float32),
- np.arange(2, 4).reshape(2, 1).astype(np.float32))
- output = block.multiply_inverse((array_ops.constant(vector[0]),
- array_ops.constant(vector[1])))
-
- output = sess.run(output)
- self.assertAllClose([0.136455, 0.27291], output[0][0])
- self.assertAllClose([0.27291, 0.409365], output[1])
-
- def testMultiplyInverseNotTuple(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- params = random_ops.random_normal((2, 2, 2, 2))
- inputs = random_ops.random_normal((2, 2, 2, 2))
- outputs = random_ops.random_normal((2, 2, 2, 2))
- block = fb.ConvKFCBasicFB(
- lc.LayerCollection(), params=params, padding='SAME')
- block.register_additional_tower(inputs, outputs)
- self.assertFalse(block._has_bias)
- grads = outputs**2
- block.instantiate_factors(((grads,),), 0.5)
- block._input_factor.instantiate_cov_variables()
- block._output_factor.instantiate_cov_variables()
- block.register_inverse()
- block._input_factor.instantiate_inv_variables()
- block._output_factor.instantiate_inv_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._input_factor.make_inverse_update_ops())
- sess.run(block._output_factor.make_inverse_update_ops())
-
- vector = np.arange(1, 17).reshape(8, 2).astype(np.float32)
- output = block.multiply_inverse(array_ops.constant(vector))
-
- self.assertAllClose([0.136455, 0.27291], sess.run(output)[0])
-
- def testMultiplyInverseNotTupleWithBias(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- params = [random_ops.random_normal((2, 2, 2, 2))]
- inputs = random_ops.random_normal((2, 2, 2, 2))
- outputs = random_ops.random_normal((2, 2, 2, 2))
- block = fb.ConvKFCBasicFB(
- lc.LayerCollection(), params=params, padding='SAME')
- block.register_additional_tower(inputs, outputs)
- self.assertTrue(block._has_bias)
- grads = outputs**2
- block.instantiate_factors(((grads,),), 0.5)
- block._input_factor.instantiate_cov_variables()
- block._output_factor.instantiate_cov_variables()
- block.register_inverse()
- block._input_factor.instantiate_inv_variables()
- block._output_factor.instantiate_inv_variables()
-
- # Make sure our inverse is something other than the identity.
- sess.run(tf_variables.global_variables_initializer())
- sess.run(block._input_factor.make_inverse_update_ops())
- sess.run(block._output_factor.make_inverse_update_ops())
-
- vector = np.arange(1, 19).reshape(9, 2).astype(np.float32)
- output = block.multiply_inverse(array_ops.constant(vector))
-
- self.assertAllClose([0.136455, 0.27291], sess.run(output)[0])
-
- def testMultiplyInverseAgainstExplicit(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- params = array_ops.zeros((2, 2, 2, 2))
- inputs = array_ops.zeros((2, 2, 2, 2))
- outputs = array_ops.zeros((2, 2, 2, 2))
- block = fb.ConvKFCBasicFB(
- lc.LayerCollection(), params=params, padding='SAME')
- block.register_additional_tower(inputs, outputs)
- grads = outputs**2
- damping = 0. # This test is only valid without damping.
- block.instantiate_factors(((grads,),), damping)
- block._input_factor.instantiate_cov_variables()
- block._output_factor.instantiate_cov_variables()
- block.register_inverse()
- block._input_factor.instantiate_inv_variables()
- block._output_factor.instantiate_inv_variables()
-
- sess.run(state_ops.assign(block._input_factor._cov, _make_psd(8)))
- sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2)))
- sess.run(block._input_factor.make_inverse_update_ops())
- sess.run(block._output_factor.make_inverse_update_ops())
-
- v_flat = np.arange(16, dtype=np.float32)
- vector = utils.column_to_tensors(params, array_ops.constant(v_flat))
- output = block.multiply_inverse(vector)
- output_flat = sess.run(utils.tensors_to_column(output)).ravel()
-
- full = sess.run(block.full_fisher_block())
- explicit = np.dot(np.linalg.inv(full + damping * np.eye(16)), v_flat)
-
- self.assertAllClose(output_flat, explicit)
-
-
-class FullyConnectedSeriesFBTest(test.TestCase):
-
- def testFullyConnectedSeriesFBInit(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- inputs = array_ops.constant([1., 2.])
- outputs = array_ops.constant([3., 4.])
- block = fb.FullyConnectedSeriesFB(lc.LayerCollection())
- block.register_additional_tower([inputs], [outputs])
- self.assertAllEqual([[outputs]], block.tensors_to_compute_grads())
-
- def testInstantiateFactorsHasBias(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- inputs = array_ops.constant([[1., 2.], [3., 4.]])
- outputs = array_ops.constant([[3., 4.], [5., 6.]])
- block = fb.FullyConnectedSeriesFB(
- lc.LayerCollection(),
- has_bias=True)
- block.register_additional_tower([inputs], [outputs])
- grads = outputs**2
- block.instantiate_factors((((grads,),),), 0.5)
-
- def testInstantiateFactorsNoBias(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- inputs = array_ops.constant([[1., 2.], [3., 4.]])
- outputs = array_ops.constant([[3., 4.], [5., 6.]])
- block = fb.FullyConnectedSeriesFB(
- lc.LayerCollection(),
- has_bias=False)
- block.register_additional_tower([inputs], [outputs])
- grads = outputs**2
- block.instantiate_factors((((grads,),),), 0.5)
-
-
-def as_tensors(tensor_or_tuple):
- """Converts a potentially nested tuple of np.array to Tensors."""
- if isinstance(tensor_or_tuple, (tuple, list)):
- return tuple(as_tensors(t) for t in tensor_or_tuple)
- return ops.convert_to_tensor(tensor_or_tuple)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
deleted file mode 100644
index a396ca3f85..0000000000
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
+++ /dev/null
@@ -1,955 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tf.contrib.kfac.fisher_factors."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-import numpy.random as npr
-
-from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
-from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops as tf_ops
-from tensorflow.python.framework import random_seed
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gradients_impl
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import variables as tf_variables
-from tensorflow.python.platform import test
-
-
-# We need to set these constants since the numerical values used in the tests
-# were chosen when these used to be the defaults.
-ff.set_global_constants(init_covariances_at_zero=False,
- zero_debias=False,
- init_inverses_at_zero=False)
-
-
-def make_damping_func(damping):
- return fb._package_func(lambda: damping, damping)
-
-
-class FisherFactorTestingDummy(ff.FisherFactor):
- """Dummy class to test the non-abstract methods on ff.FisherFactor."""
-
- @property
- def _var_scope(self):
- return 'dummy/a_b_c'
-
- @property
- def _cov_shape(self):
- raise NotImplementedError
-
- @property
- def _num_sources(self):
- return 1
-
- @property
- def _dtype(self):
- return dtypes.float32
-
- def _compute_new_cov(self):
- raise NotImplementedError
-
- def instantiate_covariance(self):
- pass
-
- def make_inverse_update_ops(self):
- return []
-
- def get_cov(self):
- return NotImplementedError
-
- def instantiate_inv_variables(self):
- return NotImplementedError
-
- def _num_towers(self):
- raise NotImplementedError
-
- def _get_data_device(self):
- raise NotImplementedError
-
- def register_matpower(self, exp, damping_func):
- raise NotImplementedError
-
- def register_cholesky(self, damping_func):
- raise NotImplementedError
-
- def register_cholesky_inverse(self, damping_func):
- raise NotImplementedError
-
- def get_matpower(self, exp, damping_func):
- raise NotImplementedError
-
- def get_cholesky(self, damping_func):
- raise NotImplementedError
-
- def get_cholesky_inverse(self, damping_func):
- raise NotImplementedError
-
- def get_cov_as_linear_operator(self):
- raise NotImplementedError
-
-
-class DenseSquareMatrixFactorTestingDummy(ff.DenseSquareMatrixFactor):
- """Dummy class to test the non-abstract methods on ff.DenseSquareMatrixFactor.
- """
-
- def __init__(self, shape):
- self._shape = shape
- super(DenseSquareMatrixFactorTestingDummy, self).__init__()
-
- @property
- def _var_scope(self):
- return 'dummy/a_b_c'
-
- @property
- def _cov_shape(self):
- return self._shape
-
- @property
- def _num_sources(self):
- return 1
-
- @property
- def _dtype(self):
- return dtypes.float32
-
- def _compute_new_cov(self):
- raise NotImplementedError
-
- def instantiate_covariance(self):
- pass
-
- def _num_towers(self):
- raise NotImplementedError
-
- def _get_data_device(self):
- raise NotImplementedError
-
-
-class NumericalUtilsTest(test.TestCase):
-
- def testComputeCovAgainstNumpy(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- npr.seed(0)
- random_seed.set_random_seed(200)
-
- x = npr.randn(100, 3)
- cov = ff.compute_cov(array_ops.constant(x))
- np_cov = np.dot(x.T, x) / x.shape[0]
-
- self.assertAllClose(sess.run(cov), np_cov)
-
- def testComputeCovAgainstNumpyWithAlternativeNormalizer(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- npr.seed(0)
- random_seed.set_random_seed(200)
-
- normalizer = 10.
- x = npr.randn(100, 3)
- cov = ff.compute_cov(array_ops.constant(x), normalizer=normalizer)
- np_cov = np.dot(x.T, x) / normalizer
-
- self.assertAllClose(sess.run(cov), np_cov)
-
- def testAppendHomog(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- npr.seed(0)
-
- m, n = 3, 4
- a = npr.randn(m, n)
- a_homog = ff.append_homog(array_ops.constant(a))
- np_result = np.hstack([a, np.ones((m, 1))])
-
- self.assertAllClose(sess.run(a_homog), np_result)
-
-
-class NameStringUtilFunctionTest(test.TestCase):
-
- def _make_tensor(self):
- x = array_ops.placeholder(dtypes.float64, (3, 1))
- w = array_ops.constant(npr.RandomState(0).randn(3, 3))
- y = math_ops.matmul(w, x)
- g = gradients_impl.gradients(y, x)[0]
- return g
-
- def testScopeStringFromParamsSingleTensor(self):
- with tf_ops.Graph().as_default():
- g = self._make_tensor()
- scope_string = ff.scope_string_from_params(g)
- self.assertEqual('gradients_MatMul_grad_MatMul_1', scope_string)
-
- def testScopeStringFromParamsMultipleTensors(self):
- with tf_ops.Graph().as_default():
- x = array_ops.constant(1,)
- y = array_ops.constant(2,)
- scope_string = ff.scope_string_from_params((x, y))
- self.assertEqual('Const_Const_1', scope_string)
-
- def testScopeStringFromParamsMultipleTypes(self):
- with tf_ops.Graph().as_default():
- x = array_ops.constant(1,)
- y = array_ops.constant(2,)
- scope_string = ff.scope_string_from_params([[1, 2, 3], 'foo', True, 4,
- (x, y)])
- self.assertEqual('1-2-3_foo_True_4_Const__Const_1', scope_string)
-
- def testScopeStringFromParamsUnsupportedType(self):
- with tf_ops.Graph().as_default():
- x = array_ops.constant(1,)
- y = array_ops.constant(2,)
- unsupported = 1.2 # Floats are not supported.
- with self.assertRaises(ValueError):
- ff.scope_string_from_params([[1, 2, 3], 'foo', True, 4, (x, y),
- unsupported])
-
- def testScopeStringFromName(self):
- with tf_ops.Graph().as_default():
- g = self._make_tensor()
- scope_string = ff.scope_string_from_name(g)
- self.assertEqual('gradients_MatMul_grad_MatMul_1', scope_string)
-
- def testScalarOrTensorToString(self):
- with tf_ops.Graph().as_default():
- self.assertEqual(ff.scalar_or_tensor_to_string(5.), repr(5.))
-
- g = self._make_tensor()
- scope_string = ff.scope_string_from_name(g)
- self.assertEqual(ff.scalar_or_tensor_to_string(g), scope_string)
-
-
-class FisherFactorTest(test.TestCase):
-
- def testMakeInverseUpdateOps(self):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- factor = FisherFactorTestingDummy()
-
- self.assertEqual(0, len(factor.make_inverse_update_ops()))
-
-
-class DenseSquareMatrixFactorTest(test.TestCase):
-
- def testRegisterDampedInverse(self):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- shape = [2, 2]
- factor = DenseSquareMatrixFactorTestingDummy(shape)
- factor_var_scope = 'dummy/a_b_c'
-
- damping_funcs = [make_damping_func(0.1),
- make_damping_func(0.1),
- make_damping_func(1e-5),
- make_damping_func(1e-5)]
- for damping_func in damping_funcs:
- factor.register_inverse(damping_func)
-
- factor.instantiate_inv_variables()
-
- inv = factor.get_inverse(damping_funcs[0]).to_dense()
- self.assertEqual(inv, factor.get_inverse(damping_funcs[1]).to_dense())
- self.assertNotEqual(inv, factor.get_inverse(damping_funcs[2]).to_dense())
- self.assertEqual(factor.get_inverse(damping_funcs[2]).to_dense(),
- factor.get_inverse(damping_funcs[3]).to_dense())
- factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
- factor_var_scope)
- factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars)
-
- self.assertEqual(set([inv,
- factor.get_inverse(damping_funcs[2]).to_dense()]),
- set(factor_tensors))
- self.assertEqual(shape, inv.get_shape())
-
- def testRegisterMatpower(self):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- shape = [3, 3]
- factor = DenseSquareMatrixFactorTestingDummy(shape)
- factor_var_scope = 'dummy/a_b_c'
-
- # TODO(b/74201126): Change to using the same func for both once
- # Topohash is in place.
- damping_func_1 = make_damping_func(0.5)
- damping_func_2 = make_damping_func(0.5)
-
- factor.register_matpower(-0.5, damping_func_1)
- factor.register_matpower(2, damping_func_2)
-
- factor.instantiate_inv_variables()
-
- factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES,
- factor_var_scope)
-
- factor_tensors = (tf_ops.convert_to_tensor(var) for var in factor_vars)
-
- matpower1 = factor.get_matpower(-0.5, damping_func_1).to_dense()
- matpower2 = factor.get_matpower(2, damping_func_2).to_dense()
-
- self.assertEqual(set([matpower1, matpower2]), set(factor_tensors))
-
- self.assertEqual(shape, matpower1.get_shape())
- self.assertEqual(shape, matpower2.get_shape())
-
- def testMakeInverseUpdateOps(self):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- factor = FisherFactorTestingDummy()
-
- self.assertEqual(0, len(factor.make_inverse_update_ops()))
-
- def testMakeInverseUpdateOpsManyInversesEigenDecomp(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- cov = np.array([[1., 2.], [3., 4.]])
- factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
- factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
-
- damping_funcs = []
- for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1):
- damping_funcs.append(make_damping_func(1./i))
-
- for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD):
- factor.register_inverse(damping_funcs[i])
-
- factor.instantiate_inv_variables()
- ops = factor.make_inverse_update_ops()
- self.assertEqual(1, len(ops))
-
- sess.run(tf_variables.global_variables_initializer())
- new_invs = []
- sess.run(ops)
- for i in range(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD):
- # The inverse op will assign the damped inverse of cov to the inv var.
- new_invs.append(
- sess.run(factor.get_inverse(damping_funcs[i]).to_dense()))
-
- # We want to see that the new invs are all different from each other.
- for i in range(len(new_invs)):
- for j in range(i + 1, len(new_invs)):
- # Just check the first element.
- self.assertNotEqual(new_invs[i][0][0], new_invs[j][0][0])
-
- def testMakeInverseUpdateOpsMatPowerEigenDecomp(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- cov = np.array([[6., 2.], [2., 4.]])
- factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
- factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
- exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power
- damping = 0.5
- damping_func = make_damping_func(damping)
-
- factor.register_matpower(exp, damping_func)
- factor.instantiate_inv_variables()
- ops = factor.make_inverse_update_ops()
- self.assertEqual(1, len(ops))
-
- sess.run(tf_variables.global_variables_initializer())
- sess.run(ops[0])
- matpower = sess.run(factor.get_matpower(exp, damping_func).to_dense())
- matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp)
- self.assertAllClose(matpower, matpower_np)
-
- def testMakeInverseUpdateOpsNoEigenDecomp(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- cov = np.array([[5., 2.], [2., 4.]]) # NOTE(mattjj): must be symmetric
- factor = DenseSquareMatrixFactorTestingDummy(cov.shape)
- factor._cov = array_ops.constant(cov, dtype=dtypes.float32)
-
- damping_func = make_damping_func(0)
-
- factor.register_inverse(damping_func)
- factor.instantiate_inv_variables()
- ops = factor.make_inverse_update_ops()
- self.assertEqual(1, len(ops))
-
- sess.run(tf_variables.global_variables_initializer())
- # The inverse op will assign the damped inverse of cov to the inv var.
- old_inv = sess.run(factor.get_inverse(damping_func).to_dense())
- self.assertAllClose(
- sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv)
-
- sess.run(ops)
- new_inv = sess.run(factor.get_inverse(damping_func).to_dense())
- self.assertAllClose(new_inv, np.linalg.inv(cov))
-
-
-class FullFactorTest(test.TestCase):
-
- def testFullFactorInit(self):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), name='a/b/c')
- factor = ff.FullFactor((tensor,), 32)
- factor.instantiate_cov_variables()
- self.assertEqual([6, 6], factor.get_cov().get_shape().as_list())
-
- def testFullFactorInitFloat64(self):
- with tf_ops.Graph().as_default():
- dtype = dtypes.float64_ref
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
- factor = ff.FullFactor((tensor,), 32)
- factor.instantiate_cov_variables()
- cov = factor.get_cov()
- self.assertEqual(cov.dtype, dtype)
- self.assertEqual([6, 6], cov.get_shape().as_list())
-
- def testMakeCovarianceUpdateOp(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- tensor = array_ops.constant([1., 2.], name='a/b/c')
- factor = ff.FullFactor((tensor,), 2)
- factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(.5))
- self.assertAllClose([[0.75, 0.5], [0.5, 1.5]], new_cov)
-
-
-class NaiveDiagonalFactorTest(test.TestCase):
-
- def testNaiveDiagonalFactorInit(self):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), name='a/b/c')
- factor = ff.NaiveDiagonalFactor((tensor,), 32)
- factor.instantiate_cov_variables()
- self.assertEqual([6, 1], factor.get_cov().get_shape().as_list())
-
- def testNaiveDiagonalFactorInitFloat64(self):
- with tf_ops.Graph().as_default():
- dtype = dtypes.float64_ref
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
- factor = ff.NaiveDiagonalFactor((tensor,), 32)
- factor.instantiate_cov_variables()
- cov = factor.get_cov()
- self.assertEqual(cov.dtype, dtype)
- self.assertEqual([6, 1], cov.get_shape().as_list())
-
- def testMakeCovarianceUpdateOp(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- tensor = array_ops.constant([1., 2.], name='a/b/c')
- factor = ff.NaiveDiagonalFactor((tensor,), 2)
- factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(.5))
- self.assertAllClose([[0.75], [1.5]], new_cov)
-
-
-class EmbeddingInputKroneckerFactorTest(test.TestCase):
-
- def testInitialization(self):
- with tf_ops.Graph().as_default():
- input_ids = array_ops.constant([[0], [1], [4]])
- vocab_size = 5
- factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
- factor.instantiate_cov_variables()
- cov = factor.get_cov()
- self.assertEqual(cov.shape.as_list(), [vocab_size])
-
- def testCovarianceUpdateOp(self):
- with tf_ops.Graph().as_default():
- input_ids = array_ops.constant([[0], [1], [4]])
- vocab_size = 5
- factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size)
- factor.instantiate_cov_variables()
- cov_update_op = factor.make_covariance_update_op(0.0)
-
- with self.cached_session() as sess:
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(cov_update_op)
- self.assertAllClose(np.array([1., 1., 0., 0., 1.]) / 3., new_cov)
-
-
-class ConvDiagonalFactorTest(test.TestCase):
-
- def setUp(self):
- self.batch_size = 10
- self.height = self.width = 32
- self.in_channels = 3
- self.out_channels = 1
- self.kernel_height = self.kernel_width = 3
- self.strides = [1, 2, 2, 1]
- self.data_format = 'NHWC'
- self.padding = 'SAME'
- self.kernel_shape = [
- self.kernel_height, self.kernel_width, self.in_channels,
- self.out_channels
- ]
-
- def testInit(self):
- with tf_ops.Graph().as_default():
- inputs = random_ops.random_uniform(
- [self.batch_size, self.height, self.width, self.in_channels])
- outputs_grads = [
- random_ops.random_uniform([
- self.batch_size, self.height // self.strides[1],
- self.width // self.strides[2], self.out_channels
- ]) for _ in range(3)
- ]
-
- factor = ff.ConvDiagonalFactor(
- (inputs,),
- (outputs_grads,),
- self.kernel_shape,
- self.strides,
- self.padding,
- data_format=self.data_format)
- factor.instantiate_cov_variables()
-
- # Ensure covariance matrix's shape makes sense.
- self.assertEqual([
- self.kernel_height * self.kernel_width * self.in_channels,
- self.out_channels
- ],
- factor.get_cov().shape.as_list())
-
- def testMakeCovarianceUpdateOp(self):
- with tf_ops.Graph().as_default():
- # Construct all arguments such that convolution kernel is applied in
- # exactly one spatial location.
- inputs = np.random.randn(
- 1, # batch_size
- self.kernel_height,
- self.kernel_width,
- self.in_channels) # in_channels
- outputs_grad = np.random.randn(
- 1, # batch_size
- 1, # output_height
- 1, # output_width
- self.out_channels)
-
- factor = ff.ConvDiagonalFactor(
- (constant_op.constant(inputs),),
- ((constant_op.constant(outputs_grad),),),
- self.kernel_shape,
- strides=[1, 1, 1, 1],
- padding='VALID')
- factor.instantiate_cov_variables()
-
- # Completely forget initial value on first update.
- cov_update_op = factor.make_covariance_update_op(0.0)
-
- # Ensure new covariance value is same as outer-product of inputs/outputs
- # vectorized, squared.
- with self.cached_session() as sess:
- sess.run(tf_variables.global_variables_initializer())
- cov = sess.run(cov_update_op)
- expected_cov = np.outer(inputs.flatten(), outputs_grad.flatten())**2
- self.assertAllClose(expected_cov, cov)
-
- def testHasBias(self):
- with tf_ops.Graph().as_default():
- inputs = random_ops.random_uniform(
- [self.batch_size, self.height, self.width, self.in_channels])
- outputs_grads = [
- random_ops.random_uniform([
- self.batch_size, self.height // self.strides[1],
- self.width // self.strides[2], self.out_channels
- ]) for _ in range(3)
- ]
-
- factor = ff.ConvDiagonalFactor(
- (inputs,),
- (outputs_grads,),
- self.kernel_shape,
- self.strides,
- self.padding,
- data_format=self.data_format,
- has_bias=True)
- factor.instantiate_cov_variables()
-
- # Ensure shape accounts for bias.
- self.assertEqual([
- self.kernel_height * self.kernel_width * self.in_channels + 1,
- self.out_channels
- ],
- factor.get_cov().shape.as_list())
-
- # Ensure update op doesn't crash.
- cov_update_op = factor.make_covariance_update_op(0.0)
- with self.cached_session() as sess:
- sess.run(tf_variables.global_variables_initializer())
- sess.run(cov_update_op)
-
-
-class FullyConnectedKroneckerFactorTest(test.TestCase):
-
- def _testFullyConnectedKroneckerFactorInit(self,
- has_bias,
- final_shape,
- dtype=dtypes.float32_ref):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
- factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=has_bias)
- factor.instantiate_cov_variables()
- cov = factor.get_cov()
- self.assertEqual(cov.dtype, dtype)
- self.assertEqual(final_shape, cov.get_shape().as_list())
-
- def testFullyConnectedKroneckerFactorInitNoBias(self):
- for dtype in (dtypes.float32_ref, dtypes.float64_ref):
- self._testFullyConnectedKroneckerFactorInit(False, [3, 3], dtype=dtype)
-
- def testFullyConnectedKroneckerFactorInitWithBias(self):
- for dtype in (dtypes.float32_ref, dtypes.float64_ref):
- self._testFullyConnectedKroneckerFactorInit(True, [4, 4], dtype=dtype)
-
- def testMakeCovarianceUpdateOpWithBias(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
- factor = ff.FullyConnectedKroneckerFactor(((tensor,),), has_bias=True)
- factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(.5))
- self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov)
-
- def testMakeCovarianceUpdateOpNoBias(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
- factor = ff.FullyConnectedKroneckerFactor(((tensor,),))
- factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(.5))
- self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov)
-
-
-class ConvFactorTestCase(test.TestCase):
-
- def assertMatrixRank(self, rank, matrix, atol=1e-5):
- assert rank <= matrix.shape[0], 'Rank cannot be larger than matrix size.'
- eigvals = np.linalg.eigvals(matrix)
- nnz_eigvals = np.sum(eigvals > atol)
- self.assertEqual(
- rank,
- nnz_eigvals,
- msg=('Found %d of %d expected non-zero eigenvalues: %s.' %
- (nnz_eigvals, rank, eigvals)))
-
-
-class ConvInputKroneckerFactorTest(ConvFactorTestCase):
-
- def test3DConvolution(self):
- with tf_ops.Graph().as_default():
- batch_size = 1
- width = 3
- in_channels = 3**3
- out_channels = 4
-
- factor = ff.ConvInputKroneckerFactor(
- inputs=(random_ops.random_uniform(
- (batch_size, width, width, width, in_channels), seed=0),),
- filter_shape=(width, width, width, in_channels, out_channels),
- padding='SAME',
- strides=(2, 2, 2),
- extract_patches_fn='extract_convolution_patches',
- has_bias=False)
- factor.instantiate_cov_variables()
-
- # Ensure shape of covariance matches input size of filter.
- input_size = in_channels * (width**3)
- self.assertEqual([input_size, input_size],
- factor.get_cov().shape.as_list())
-
- # Ensure cov_update_op doesn't crash.
- with self.cached_session() as sess:
- sess.run(tf_variables.global_variables_initializer())
- sess.run(factor.make_covariance_update_op(0.0))
- cov = sess.run(factor.get_cov())
-
- # Cov should be rank-8, as the filter will be applied at each corner of
- # the 4-D cube.
- self.assertMatrixRank(8, cov)
-
- def testPointwiseConv2d(self):
- with tf_ops.Graph().as_default():
- batch_size = 1
- width = 3
- in_channels = 3**2
- out_channels = 4
-
- factor = ff.ConvInputKroneckerFactor(
- inputs=(random_ops.random_uniform(
- (batch_size, width, width, in_channels), seed=0),),
- filter_shape=(1, 1, in_channels, out_channels),
- padding='SAME',
- strides=(1, 1, 1, 1),
- extract_patches_fn='extract_pointwise_conv2d_patches',
- has_bias=False)
- factor.instantiate_cov_variables()
-
- # Ensure shape of covariance matches input size of filter.
- self.assertEqual([in_channels, in_channels],
- factor.get_cov().shape.as_list())
-
- # Ensure cov_update_op doesn't crash.
- with self.cached_session() as sess:
- sess.run(tf_variables.global_variables_initializer())
- sess.run(factor.make_covariance_update_op(0.0))
- cov = sess.run(factor.get_cov())
-
- # Cov should be rank-9, as the filter will be applied at each location.
- self.assertMatrixRank(9, cov)
-
- def testStrides(self):
- with tf_ops.Graph().as_default():
- batch_size = 1
- width = 3
- in_channels = 3**2
- out_channels = 4
-
- factor = ff.ConvInputKroneckerFactor(
- inputs=(random_ops.random_uniform(
- (batch_size, width, width, in_channels), seed=0),),
- filter_shape=(1, 1, in_channels, out_channels),
- padding='SAME',
- strides=(1, 2, 1, 1),
- extract_patches_fn='extract_image_patches',
- has_bias=False)
- factor.instantiate_cov_variables()
-
- with self.cached_session() as sess:
- sess.run(tf_variables.global_variables_initializer())
- sess.run(factor.make_covariance_update_op(0.0))
- cov = sess.run(factor.get_cov())
-
- # Cov should be the sum of 3 * 2 = 6 outer products.
- self.assertMatrixRank(6, cov)
-
- def testDilationRate(self):
- with tf_ops.Graph().as_default():
- batch_size = 1
- width = 3
- in_channels = 2
- out_channels = 4
-
- factor = ff.ConvInputKroneckerFactor(
- inputs=(random_ops.random_uniform(
- (batch_size, width, width, in_channels), seed=0),),
- filter_shape=(3, 3, in_channels, out_channels),
- padding='SAME',
- extract_patches_fn='extract_image_patches',
- strides=(1, 1, 1, 1),
- dilation_rate=(1, width, width, 1),
- has_bias=False)
- factor.instantiate_cov_variables()
-
- with self.cached_session() as sess:
- sess.run(tf_variables.global_variables_initializer())
- sess.run(factor.make_covariance_update_op(0.0))
- cov = sess.run(factor.get_cov())
-
- # Cov should be rank = in_channels, as only the center of the filter
- # receives non-zero input for each input channel.
- self.assertMatrixRank(in_channels, cov)
-
- def testConvInputKroneckerFactorInitNoBias(self):
- with tf_ops.Graph().as_default():
- tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c')
- factor = ff.ConvInputKroneckerFactor(
- inputs=(tensor,),
- filter_shape=(1, 2, 3, 4),
- padding='SAME',
- has_bias=False)
- factor.instantiate_cov_variables()
- self.assertEqual([1 * 2 * 3, 1 * 2 * 3],
- factor.get_cov().get_shape().as_list())
-
- def testConvInputKroneckerFactorInit(self):
- with tf_ops.Graph().as_default():
- tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c')
- factor = ff.ConvInputKroneckerFactor(
- (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True)
- factor.instantiate_cov_variables()
- self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1],
- factor.get_cov().get_shape().as_list())
-
- def testConvInputKroneckerFactorInitFloat64(self):
- with tf_ops.Graph().as_default():
- dtype = dtypes.float64_ref
- tensor = array_ops.ones((64, 1, 2, 3), name='a/b/c', dtype=dtypes.float64)
- factor = ff.ConvInputKroneckerFactor(
- (tensor,), filter_shape=(1, 2, 3, 4), padding='SAME', has_bias=True)
- factor.instantiate_cov_variables()
- cov = factor.get_cov()
- self.assertEqual(cov.dtype, dtype)
- self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1],
- cov.get_shape().as_list())
-
- def testMakeCovarianceUpdateOpWithBias(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- input_shape = (2, 1, 1, 1)
- tensor = array_ops.constant(
- np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype(
- np.float32))
- factor = ff.ConvInputKroneckerFactor(
- (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME', has_bias=True)
- factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(0.))
- self.assertAllClose(
- [
- [(1. + 4.) / 2., (1. + 2.) / 2.], #
- [(1. + 2.) / 2., (1. + 1.) / 2.]
- ], #
- new_cov)
-
- def testMakeCovarianceUpdateOpNoBias(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- input_shape = (2, 1, 1, 1)
- tensor = array_ops.constant(
- np.arange(1, 1 + np.prod(input_shape)).reshape(input_shape).astype(
- np.float32))
- factor = ff.ConvInputKroneckerFactor(
- (tensor,), filter_shape=(1, 1, 1, 1), padding='SAME')
- factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(0.))
- self.assertAllClose([[(1. + 4.) / 2.]], new_cov)
-
- def testSubSample(self):
- with tf_ops.Graph().as_default():
- patches_1 = array_ops.constant(1, shape=(10, 2))
- patches_2 = array_ops.constant(1, shape=(10, 8))
- patches_3 = array_ops.constant(1, shape=(3, 3))
- patches_1_sub = ff._subsample_for_cov_computation(patches_1)
- patches_2_sub = ff._subsample_for_cov_computation(patches_2)
- patches_3_sub = ff._subsample_for_cov_computation(patches_3)
- patches_1_sub_batch_size = patches_1_sub.shape.as_list()[0]
- patches_2_sub_batch_size = patches_2_sub.shape.as_list()[0]
- patches_3_sub_batch_size = patches_3_sub.shape.as_list()[0]
- self.assertEqual(2, patches_1_sub_batch_size)
- self.assertEqual(8, patches_2_sub_batch_size)
- self.assertEqual(3, patches_3_sub_batch_size)
-
-
-class ConvOutputKroneckerFactorTest(ConvFactorTestCase):
-
- def test3DConvolution(self):
- with tf_ops.Graph().as_default():
- batch_size = 1
- width = 3
- out_channels = width**3
-
- factor = ff.ConvOutputKroneckerFactor(outputs_grads=([
- random_ops.random_uniform(
- (batch_size, width, width, width, out_channels), seed=0)
- ],))
- factor.instantiate_cov_variables()
-
- with self.cached_session() as sess:
- sess.run(tf_variables.global_variables_initializer())
- sess.run(factor.make_covariance_update_op(0.0))
- cov = sess.run(factor.get_cov())
-
- # Cov should be rank 3^3, as each spatial position donates a rank-1
- # update.
- self.assertMatrixRank(width**3, cov)
-
- def testConvOutputKroneckerFactorInit(self):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3, 4, 5), name='a/b/c')
- factor = ff.ConvOutputKroneckerFactor(((tensor,),))
- factor.instantiate_cov_variables()
- self.assertEqual([5, 5], factor.get_cov().get_shape().as_list())
-
- def testConvOutputKroneckerFactorInitFloat64(self):
- with tf_ops.Graph().as_default():
- dtype = dtypes.float64_ref
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3, 4, 5), dtype=dtype, name='a/b/c')
- factor = ff.ConvOutputKroneckerFactor(((tensor,),))
- factor.instantiate_cov_variables()
- cov = factor.get_cov()
- self.assertEqual(cov.dtype, dtype)
- self.assertEqual([5, 5], cov.get_shape().as_list())
-
- def testMakeCovarianceUpdateOp(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- tensor = np.arange(1, 17).reshape(2, 2, 2, 2).astype(np.float32)
- factor = ff.ConvOutputKroneckerFactor(((array_ops.constant(tensor),),))
- factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(.5))
- self.assertAllClose([[43, 46.5], [46.5, 51.5]], new_cov)
-
-
-class FullyConnectedMultiKFTest(test.TestCase):
-
- def testFullyConnectedMultiKFInit(self):
- with tf_ops.Graph().as_default():
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), name='a/b/c')
- factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False)
- factor.instantiate_cov_variables()
- self.assertEqual([3, 3], factor.get_cov().get_shape().as_list())
-
- def testFullyConnectedMultiKFInitFloat64(self):
- with tf_ops.Graph().as_default():
- dtype = dtypes.float64_ref
- random_seed.set_random_seed(200)
- tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c')
- factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=False)
- factor.instantiate_cov_variables()
- cov = factor.get_cov()
- self.assertEqual(cov.dtype, dtype)
- self.assertEqual([3, 3], cov.get_shape().as_list())
-
- def testMakeCovarianceUpdateOpWithBias(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
- factor = ff.FullyConnectedMultiKF(((tensor,),), has_bias=True)
- factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(.5))
- self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov)
-
- def testMakeCovarianceUpdateOpNoBias(self):
- with tf_ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c')
- factor = ff.FullyConnectedMultiKF(((tensor,),))
- factor.instantiate_cov_variables()
-
- sess.run(tf_variables.global_variables_initializer())
- new_cov = sess.run(factor.make_covariance_update_op(.5))
- self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
deleted file mode 100644
index 586fcd4c3c..0000000000
--- a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py
+++ /dev/null
@@ -1,597 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tf.contrib.kfac.layer_collection."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.kfac.python.ops import fisher_blocks
-from tensorflow.contrib.kfac.python.ops import fisher_factors
-from tensorflow.contrib.kfac.python.ops import layer_collection
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import random_seed
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.platform import test
-
-
-class MockFisherBlock(object):
- """A fake FisherBlock."""
-
- num_registered_towers = 2
-
- def __init__(self, name='MockFisherBlock'):
- self.name = name
-
- def __eq__(self, other):
- return isinstance(other, MockFisherBlock) and other.name == self.name
-
- def __hash__(self):
- return hash(self.name)
-
-
-class LayerParametersDictTest(test.TestCase):
-
- def testSetItem(self):
- """Ensure insertion, contains, retrieval works for supported key types."""
- with ops.Graph().as_default():
- lp_dict = layer_collection.LayerParametersDict()
-
- x = array_ops.constant(0)
- y0 = array_ops.constant(0)
- y1 = array_ops.constant(0)
- z0 = array_ops.constant(0)
- z1 = array_ops.constant(0)
- keys = [x, (y0, y1), [z0, z1]]
- for key in keys:
- lp_dict[key] = key
-
- for key in keys:
- self.assertTrue(key in lp_dict)
- self.assertEqual(lp_dict[key], key)
-
- def testSetItemOverlap(self):
- """Ensure insertion fails if key overlaps with existing key."""
- with ops.Graph().as_default():
- lp_dict = layer_collection.LayerParametersDict()
-
- x = array_ops.constant(0)
- y = array_ops.constant(0)
- lp_dict[x] = 'value'
-
- with self.assertRaises(ValueError):
- lp_dict[(x, y)] = 'value'
-
- # Ensure 'y' wasn't inserted.
- self.assertTrue(x in lp_dict)
- self.assertFalse(y in lp_dict)
-
-
-class LayerCollectionTest(test.TestCase):
-
- def testLayerCollectionInit(self):
- lc = layer_collection.LayerCollection()
- self.assertEqual(0, len(lc.get_blocks()))
- self.assertEqual(0, len(lc.get_factors()))
- self.assertFalse(lc.losses)
-
- def testRegisterBlocks(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- lc = layer_collection.LayerCollection()
- lc.register_fully_connected(
- array_ops.constant(1), array_ops.constant(2), array_ops.constant(3))
- lc.register_fully_connected(
- array_ops.constant(1),
- array_ops.constant(2),
- array_ops.constant(3),
- approx=layer_collection.APPROX_DIAGONAL_NAME)
- lc.register_conv2d(
- params=array_ops.ones((2, 3, 4, 5)),
- strides=[1, 1, 1, 1],
- padding='SAME',
- inputs=array_ops.ones((1, 2, 3, 4)),
- outputs=array_ops.ones((1, 1, 1, 5)))
- lc.register_conv2d(
- params=array_ops.ones((2, 3, 4, 5)),
- strides=[1, 1, 1, 1],
- padding='SAME',
- inputs=array_ops.ones((1, 2, 3, 4)),
- outputs=array_ops.ones((1, 1, 1, 5)),
- approx=layer_collection.APPROX_DIAGONAL_NAME)
- lc.register_separable_conv2d(
- depthwise_params=array_ops.ones((3, 3, 1, 2)),
- pointwise_params=array_ops.ones((1, 1, 2, 4)),
- inputs=array_ops.ones((32, 5, 5, 1)),
- depthwise_outputs=array_ops.ones((32, 5, 5, 2)),
- pointwise_outputs=array_ops.ones((32, 5, 5, 4)),
- strides=[1, 1, 1, 1],
- padding='SAME')
- lc.register_convolution(
- params=array_ops.ones((3, 3, 1, 8)),
- inputs=array_ops.ones((32, 5, 5, 1)),
- outputs=array_ops.ones((32, 5, 5, 8)),
- padding='SAME')
- lc.register_generic(
- array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME)
- lc.register_generic(
- array_ops.constant(6),
- 16,
- approx=layer_collection.APPROX_DIAGONAL_NAME)
- lc.register_fully_connected_multi(
- array_ops.constant(1),
- (array_ops.constant(2), array_ops.constant(3)),
- (array_ops.constant(4), array_ops.constant(5)))
- lc.register_conv2d_multi(
- params=array_ops.ones((2, 3, 4, 5)),
- strides=[1, 1, 1, 1],
- padding='SAME',
- inputs=(array_ops.ones((1, 2, 3, 4)), array_ops.ones((5, 6, 7, 8))),
- outputs=(array_ops.ones((1, 1, 1, 5)), array_ops.ones((2, 2, 2, 10))))
- lc.register_embedding_multi(
- array_ops.constant((1,)),
- (array_ops.constant(2), array_ops.constant(3)),
- (array_ops.constant(4), array_ops.constant(5)))
-
- self.assertEqual(12, len(lc.get_blocks()))
-
- def testRegisterBlocksMultipleRegistrations(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- lc = layer_collection.LayerCollection()
- key = array_ops.constant(1)
- lc.register_fully_connected(key, array_ops.constant(2),
- array_ops.constant(3))
- with self.assertRaises(ValueError) as cm:
- lc.register_generic(key, 16)
- self.assertIn('already in LayerCollection', str(cm.exception))
-
- def testRegisterSingleParamNotRegistered(self):
- x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
- lc = layer_collection.LayerCollection()
- lc.fisher_blocks = {
- variable_scope.get_variable('y', initializer=array_ops.constant(1,)):
- '1'
- }
- lc.register_block(x, 'foo')
-
- def testShouldRegisterSingleParamRegistered(self):
- x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
- lc = layer_collection.LayerCollection()
- lc.fisher_blocks = {x: '1'}
- with self.assertRaises(ValueError) as cm:
- lc.register_block(x, 'foo')
- self.assertIn('already in LayerCollection', str(cm.exception))
-
- def testRegisterSingleParamRegisteredInTuple(self):
- x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
- y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
- lc = layer_collection.LayerCollection()
- lc.fisher_blocks = {(x, y): '1'}
- with self.assertRaises(ValueError) as cm:
- lc.register_block(x, 'foo')
- self.assertIn('was already registered', str(cm.exception))
-
- def testRegisterTupleParamNotRegistered(self):
- x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
- y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
- lc = layer_collection.LayerCollection()
- lc.fisher_blocks = {
- variable_scope.get_variable('z', initializer=array_ops.constant(1,)):
- '1'
- }
-
- lc.register_block((x, y), 'foo')
- self.assertEqual(set(['1', 'foo']), set(lc.get_blocks()))
-
- def testRegisterTupleParamRegistered(self):
- x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
- y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
- lc = layer_collection.LayerCollection()
- lc.fisher_blocks = {(x, y): '1'}
-
- with self.assertRaises(ValueError) as cm:
- lc.register_block((x, y), 'foo')
- self.assertIn('already in LayerCollection', str(cm.exception))
-
- def testRegisterTupleParamRegisteredInSuperset(self):
- x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
- y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
- z = variable_scope.get_variable('z', initializer=array_ops.constant(1,))
- lc = layer_collection.LayerCollection()
- lc.fisher_blocks = {(x, y, z): '1'}
-
- with self.assertRaises(ValueError) as cm:
- lc.register_block((x, y), 'foo')
- self.assertIn('was already registered', str(cm.exception))
-
- def testRegisterTupleParamSomeRegistered(self):
- x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
- y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
- z = variable_scope.get_variable('z', initializer=array_ops.constant(1,))
- lc = layer_collection.LayerCollection()
- lc.fisher_blocks = {x: MockFisherBlock('1'), z: MockFisherBlock('2')}
-
- with self.assertRaises(ValueError) as cm:
- lc.register_block((x, y), MockFisherBlock('foo'))
- self.assertIn('was already registered', str(cm.exception))
-
- def testRegisterTupleVarSomeRegisteredInOtherTuples(self):
- x = variable_scope.get_variable('x', initializer=array_ops.constant(1,))
- y = variable_scope.get_variable('y', initializer=array_ops.constant(1,))
- z = variable_scope.get_variable('z', initializer=array_ops.constant(1,))
- w = variable_scope.get_variable('w', initializer=array_ops.constant(1,))
- lc = layer_collection.LayerCollection()
- lc.fisher_blocks = {(x, z): '1', (z, w): '2'}
-
- with self.assertRaises(ValueError) as cm:
- lc.register_block((x, y), 'foo')
- self.assertIn('was already registered', str(cm.exception))
-
- def testRegisterCategoricalPredictiveDistribution(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- logits = linalg_ops.eye(2)
-
- lc = layer_collection.LayerCollection()
- lc.register_categorical_predictive_distribution(logits, seed=200)
- single_loss = sess.run(lc.total_sampled_loss())
-
- lc2 = layer_collection.LayerCollection()
- lc2.register_categorical_predictive_distribution(logits, seed=200)
- lc2.register_categorical_predictive_distribution(logits, seed=200)
- double_loss = sess.run(lc2.total_sampled_loss())
- self.assertAlmostEqual(2 * single_loss, double_loss)
-
- def testLossFunctionByName(self):
- """Ensure loss functions can be identified by name."""
- with ops.Graph().as_default():
- logits = linalg_ops.eye(2)
- lc = layer_collection.LayerCollection()
-
- # Create a new loss function by name.
- lc.register_categorical_predictive_distribution(logits, name='loss1')
- self.assertEqual(1, len(lc.towers_by_loss))
-
- # Add logits to same loss function.
- lc.register_categorical_predictive_distribution(
- logits, name='loss1', reuse=True)
- self.assertEqual(1, len(lc.towers_by_loss))
-
- # Add another new loss function.
- lc.register_categorical_predictive_distribution(logits, name='loss2')
- self.assertEqual(2, len(lc.towers_by_loss))
-
- def testLossFunctionWithoutName(self):
- """Ensure loss functions get unique names if 'name' not specified."""
- with ops.Graph().as_default():
- logits = linalg_ops.eye(2)
- lc = layer_collection.LayerCollection()
-
- # Create a new loss function with default names.
- lc.register_categorical_predictive_distribution(logits)
- lc.register_categorical_predictive_distribution(logits)
- self.assertEqual(2, len(lc.losses))
-
- def testCategoricalPredictiveDistributionMultipleMinibatches(self):
- """Ensure multiple minibatches are registered."""
- with ops.Graph().as_default():
- batch_size = 3
- output_size = 2
- logits = array_ops.zeros([batch_size, output_size])
- targets = array_ops.ones([batch_size], dtype=dtypes.int32)
- lc = layer_collection.LayerCollection()
-
- # Create a new loss function.
- lc.register_categorical_predictive_distribution(
- logits, targets=targets, name='loss1')
-
- # Can add when reuse=True
- lc.register_categorical_predictive_distribution(
- logits, targets=targets, name='loss1', reuse=True)
-
- # Can add when reuse=VARIABLE_SCOPE and reuse=True there.
- with variable_scope.variable_scope(
- variable_scope.get_variable_scope(), reuse=True):
- lc.register_categorical_predictive_distribution(
- logits,
- targets=targets,
- name='loss1',
- reuse=layer_collection.VARIABLE_SCOPE)
-
- # Can't add when reuse=False
- with self.assertRaises(KeyError):
- lc.register_categorical_predictive_distribution(
- logits, targets=targets, name='loss1', reuse=False)
-
- # Can't add when reuse=VARIABLE_SCOPE and reuse=False there.
- with self.assertRaises(KeyError):
- lc.register_categorical_predictive_distribution(
- logits,
- targets=targets,
- name='loss1',
- reuse=layer_collection.VARIABLE_SCOPE)
-
- self.assertEqual(len(lc.towers_by_loss), 1)
- # Three successful registrations.
- self.assertEqual(len(lc.towers_by_loss[0]), 3)
-
- def testRegisterCategoricalPredictiveDistributionBatchSize1(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- logits = random_ops.random_normal((1, 2))
- lc = layer_collection.LayerCollection()
-
- lc.register_categorical_predictive_distribution(logits, seed=200)
-
- def testRegisterCategoricalPredictiveDistributionSpecifiedTargets(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- logits = array_ops.constant([[1., 2.], [3., 4.]], dtype=dtypes.float32)
- lc = layer_collection.LayerCollection()
- targets = array_ops.constant([0, 1], dtype=dtypes.int32)
-
- lc.register_categorical_predictive_distribution(logits, targets=targets)
- single_loss = sess.run(lc.total_loss())
- self.assertAlmostEqual(1.6265233, single_loss)
-
- def testRegisterNormalPredictiveDistribution(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- predictions = array_ops.constant(
- [[1., 2.], [3., 4]], dtype=dtypes.float32)
-
- lc = layer_collection.LayerCollection()
- lc.register_normal_predictive_distribution(predictions, 1., seed=200)
- single_loss = sess.run(lc.total_sampled_loss())
-
- lc2 = layer_collection.LayerCollection()
- lc2.register_normal_predictive_distribution(predictions, 1., seed=200)
- lc2.register_normal_predictive_distribution(predictions, 1., seed=200)
- double_loss = sess.run(lc2.total_sampled_loss())
-
- self.assertAlmostEqual(2 * single_loss, double_loss)
-
- def testRegisterNormalPredictiveDistributionSpecifiedTargets(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- predictions = array_ops.constant(
- [[1., 2.], [3., 4.]], dtype=dtypes.float32)
- lc = layer_collection.LayerCollection()
- targets = array_ops.constant([[3., 1.], [4., 2.]], dtype=dtypes.float32)
-
- lc.register_normal_predictive_distribution(
- predictions, 2.**2, targets=targets)
- single_loss = sess.run(lc.total_loss())
- self.assertAlmostEqual(7.6983433, single_loss)
-
- def ensureLayerReuseWorks(self, register_fn):
- """Ensure the 'reuse' keyword argument function as intended.
-
- Args:
- register_fn: function for registering a layer. Arguments are
- layer_collection, reuse, and approx.
- """
- # Fails on second if reuse=False.
- lc = layer_collection.LayerCollection()
- register_fn(lc)
- with self.assertRaises(ValueError):
- register_fn(lc, reuse=False)
-
- # Succeeds on second if reuse=True.
- lc = layer_collection.LayerCollection()
- register_fn(lc)
- register_fn(lc, reuse=True)
-
- # Fails on second if reuse=VARIABLE_SCOPE and no variable reuse.
- lc = layer_collection.LayerCollection()
- register_fn(lc)
- with self.assertRaises(ValueError):
- register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE)
-
- # Succeeds on second if reuse=VARIABLE_SCOPE and variable reuse.
- lc = layer_collection.LayerCollection()
- register_fn(lc)
- with variable_scope.variable_scope(
- variable_scope.get_variable_scope(), reuse=True):
- register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE)
-
- # Fails if block type changes.
- lc = layer_collection.LayerCollection()
- register_fn(lc, approx=layer_collection.APPROX_KRONECKER_NAME)
- with self.assertRaises(ValueError):
- register_fn(lc, approx=layer_collection.APPROX_DIAGONAL_NAME, reuse=True)
-
- # Fails if reuse requested but no FisherBlock exists.
- lc = layer_collection.LayerCollection()
- with self.assertRaises(KeyError):
- register_fn(lc, reuse=True)
-
- def testRegisterFullyConnectedReuse(self):
- """Ensure the 'reuse' works with register_fully_connected."""
- with ops.Graph().as_default():
- inputs = array_ops.ones([2, 10])
- outputs = array_ops.zeros([2, 5])
- params = (
- variable_scope.get_variable('w', [10, 5]), #
- variable_scope.get_variable('b', [5]))
-
- def register_fn(lc, **kwargs):
- lc.register_fully_connected(
- params=params, inputs=inputs, outputs=outputs, **kwargs)
-
- self.ensureLayerReuseWorks(register_fn)
-
- def testRegisterConv2dReuse(self):
- """Ensure the 'reuse' works with register_conv2d."""
- with ops.Graph().as_default():
- inputs = array_ops.ones([2, 5, 5, 10])
- outputs = array_ops.zeros([2, 5, 5, 3])
- params = (
- variable_scope.get_variable('w', [1, 1, 10, 3]), #
- variable_scope.get_variable('b', [3]))
-
- def register_fn(lc, **kwargs):
- lc.register_conv2d(
- params=params,
- strides=[1, 1, 1, 1],
- padding='SAME',
- inputs=inputs,
- outputs=outputs,
- **kwargs)
-
- self.ensureLayerReuseWorks(register_fn)
-
- def testReuseWithInvalidRegistration(self):
- """Invalid registrations shouldn't overwrite existing blocks."""
- with ops.Graph().as_default():
- inputs = array_ops.ones([2, 5, 5, 10])
- outputs = array_ops.zeros([2, 5, 5, 3])
- w = variable_scope.get_variable('w', [1, 1, 10, 3])
- b = variable_scope.get_variable('b', [3])
- lc = layer_collection.LayerCollection()
- lc.register_fully_connected(w, inputs, outputs)
- self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1)
- with self.assertRaises(KeyError):
- lc.register_fully_connected((w, b), inputs, outputs, reuse=True)
- self.assertNotIn((w, b), lc.fisher_blocks)
- self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 1)
- lc.register_fully_connected(w, inputs, outputs, reuse=True)
- self.assertEqual(lc.fisher_blocks[w].num_registered_towers, 2)
-
- def testMakeOrGetFactor(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- lc = layer_collection.LayerCollection()
- key = array_ops.constant(1)
- lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
- lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
- lc.make_or_get_factor(fisher_factors.FullFactor,
- ((array_ops.constant(2),), 16))
-
- self.assertEqual(2, len(lc.get_factors()))
- variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertTrue(
- all([var.name.startswith('LayerCollection') for var in variables]))
-
- def testMakeOrGetFactorCustomScope(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- scope = 'Foo'
- lc = layer_collection.LayerCollection(name=scope)
- key = array_ops.constant(1)
- lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
- lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16))
- lc.make_or_get_factor(fisher_factors.FullFactor,
- ((array_ops.constant(2),), 16))
-
- self.assertEqual(2, len(lc.get_factors()))
- variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- self.assertTrue(all([var.name.startswith(scope) for var in variables]))
-
- def testIdentifyLinkedParametersSomeRegisteredInOtherTuples(self):
- x = variable_scope.get_variable('x', shape=())
- y = variable_scope.get_variable('y', shape=())
- z = variable_scope.get_variable('z', shape=())
- lc = layer_collection.LayerCollection()
- lc.define_linked_parameters((x, y))
-
- with self.assertRaises(ValueError):
- lc.define_linked_parameters((x, z))
-
- def testIdentifySubsetPreviouslyRegisteredTensor(self):
- x = variable_scope.get_variable('x', shape=())
- y = variable_scope.get_variable('y', shape=())
- lc = layer_collection.LayerCollection()
- lc.define_linked_parameters((x, y))
-
- with self.assertRaises(ValueError):
- lc.define_linked_parameters(x)
-
- def testSpecifyApproximation(self):
- w_0 = variable_scope.get_variable('w_0', [10, 10])
- w_1 = variable_scope.get_variable('w_1', [10, 10])
-
- b_0 = variable_scope.get_variable('b_0', [10])
- b_1 = variable_scope.get_variable('b_1', [10])
-
- x_0 = array_ops.placeholder(dtypes.float32, shape=(32, 10))
- x_1 = array_ops.placeholder(dtypes.float32, shape=(32, 10))
-
- pre_bias_0 = math_ops.matmul(x_0, w_0)
- pre_bias_1 = math_ops.matmul(x_1, w_1)
-
- # Build the fully connected layers in the graph.
- pre_bias_0 + b_0 # pylint: disable=pointless-statement
- pre_bias_1 + b_1 # pylint: disable=pointless-statement
-
- lc = layer_collection.LayerCollection()
- lc.define_linked_parameters(
- w_0, approximation=layer_collection.APPROX_DIAGONAL_NAME)
- lc.define_linked_parameters(
- w_1, approximation=layer_collection.APPROX_DIAGONAL_NAME)
- lc.define_linked_parameters(
- b_0, approximation=layer_collection.APPROX_FULL_NAME)
- lc.define_linked_parameters(
- b_1, approximation=layer_collection.APPROX_FULL_NAME)
-
- lc.register_fully_connected(w_0, x_0, pre_bias_0)
- lc.register_fully_connected(
- w_1, x_1, pre_bias_1, approx=layer_collection.APPROX_KRONECKER_NAME)
- self.assertIsInstance(lc.fisher_blocks[w_0],
- fisher_blocks.FullyConnectedDiagonalFB)
- self.assertIsInstance(lc.fisher_blocks[w_1],
- fisher_blocks.FullyConnectedKFACBasicFB)
-
- lc.register_generic(b_0, batch_size=1)
- lc.register_generic(
- b_1, batch_size=1, approx=layer_collection.APPROX_DIAGONAL_NAME)
- self.assertIsInstance(lc.fisher_blocks[b_0], fisher_blocks.FullFB)
- self.assertIsInstance(lc.fisher_blocks[b_1], fisher_blocks.NaiveDiagonalFB)
-
- def testDefaultLayerCollection(self):
- with ops.Graph().as_default():
- # Can't get default if there isn't one set.
- with self.assertRaises(ValueError):
- layer_collection.get_default_layer_collection()
-
- # Can't set default twice.
- lc = layer_collection.LayerCollection()
- layer_collection.set_default_layer_collection(lc)
- with self.assertRaises(ValueError):
- layer_collection.set_default_layer_collection(lc)
-
- # Same as one set.
- self.assertTrue(lc is layer_collection.get_default_layer_collection())
-
- # Can set to None.
- layer_collection.set_default_layer_collection(None)
- with self.assertRaises(ValueError):
- layer_collection.get_default_layer_collection()
-
- # as_default() is the same as setting/clearing.
- with lc.as_default():
- self.assertTrue(lc is layer_collection.get_default_layer_collection())
- with self.assertRaises(ValueError):
- layer_collection.get_default_layer_collection()
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py
deleted file mode 100644
index f424e02360..0000000000
--- a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py
+++ /dev/null
@@ -1,190 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tf.contrib.kfac.loss_functions."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.kfac.python.ops import loss_functions
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-
-
-class InsertSliceInZerosTest(test.TestCase):
-
- def testBadShape(self):
- bad_shaped_ones = array_ops.ones(shape=[1, 3]) # n.b. shape[1] != 1
- with self.assertRaises(ValueError):
- loss_functions.insert_slice_in_zeros(bad_shaped_ones, 1, 42, 17)
-
- def test3d(self):
- input_tensor = constant_op.constant([[[1, 2]], [[3, 4]]])
- expected_output_array = [[[1, 2], [0, 0]], [[3, 4], [0, 0]]]
- op = loss_functions.insert_slice_in_zeros(input_tensor, 1, 2, 0)
- with self.cached_session() as sess:
- actual_output_array = sess.run(op)
- self.assertAllEqual(expected_output_array, actual_output_array)
-
-
-class CategoricalLogitsNegativeLogProbLossTest(test.TestCase):
-
- def testSample(self):
- """Ensure samples can be drawn."""
- with ops.Graph().as_default(), self.cached_session() as sess:
- logits = np.asarray([
- [0., 0., 0.], #
- [1., -1., 0.]
- ]).astype(np.float32)
- loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
- array_ops.constant(logits))
- sample = loss.sample(42)
- sample = sess.run(sample)
- self.assertEqual(sample.shape, (2,))
-
- def testEvaluateOnTargets(self):
- """Ensure log probability can be evaluated correctly."""
- with ops.Graph().as_default(), self.cached_session() as sess:
- logits = np.asarray([
- [0., 0., 0.], #
- [1., -1., 0.]
- ]).astype(np.float32)
- targets = np.asarray([2, 1]).astype(np.int32)
- loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
- array_ops.constant(logits), targets=array_ops.constant(targets))
- neg_log_prob = loss.evaluate()
- neg_log_prob = sess.run(neg_log_prob)
-
- # Calculate explicit log probability of targets.
- probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
- log_probs = np.log([
- probs[0, targets[0]], #
- probs[1, targets[1]]
- ])
- expected_log_prob = np.sum(log_probs)
-
- self.assertAllClose(neg_log_prob, -expected_log_prob)
-
- def testEvaluateOnSample(self):
- """Ensure log probability of a sample can be drawn."""
- with ops.Graph().as_default(), self.cached_session() as sess:
- logits = np.asarray([
- [0., 0., 0.], #
- [1., -1., 0.]
- ]).astype(np.float32)
- loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
- array_ops.constant(logits))
- neg_log_prob = loss.evaluate_on_sample(42)
-
- # Simply ensure this doesn't crash. As the output is random, it's
- # difficult to say if the output is correct or not...
- neg_log_prob = sess.run(neg_log_prob)
-
- def testMultiplyFisherSingleVector(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- logits = np.array([1., 2., 3.])
- loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)
-
- # the LossFunction.multiply_fisher docstring only says it supports the
- # case where the vector is the same shape as the input natural parameters
- # (i.e. the logits here), but here we also test leading dimensions
- vector = np.array([1., 2., 3.])
- vectors = [vector, vector.reshape(1, -1), np.stack([vector] * 4)]
-
- probs = np.exp(logits - np.logaddexp.reduce(logits))
- fisher = np.diag(probs) - np.outer(probs, probs)
-
- for vector in vectors:
- result = loss.multiply_fisher(vector)
- expected_result = np.dot(vector, fisher)
- self.assertAllClose(expected_result, sess.run(result))
-
- def testMultiplyFisherBatch(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- logits = np.array([[1., 2., 3.], [4., 6., 8.]])
- loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)
-
- vector = np.array([[1., 2., 3.], [5., 3., 1.]])
-
- na = np.newaxis
- probs = np.exp(logits - np.logaddexp.reduce(logits, axis=-1,
- keepdims=True))
- fishers = probs[..., na] * np.eye(3) - probs[..., na] * probs[..., na, :]
-
- result = loss.multiply_fisher(vector)
- expected_result = np.matmul(vector[..., na, :], fishers)[..., 0, :]
- self.assertEqual(sess.run(result).shape, logits.shape)
- self.assertAllClose(expected_result, sess.run(result))
-
-
-class OnehotCategoricalLogitsNegativeLogProbLossTest(test.TestCase):
-
- def testSample(self):
- """Ensure samples can be drawn."""
- with ops.Graph().as_default(), self.cached_session() as sess:
- logits = np.asarray([
- [0., 0., 0.], #
- [1., -1., 0.]
- ]).astype(np.float32)
- loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
- array_ops.constant(logits))
- sample = loss.sample(42)
- sample = sess.run(sample)
- self.assertEqual(sample.shape, (2, 3))
-
- def testEvaluateOnTargets(self):
- """Ensure log probability can be evaluated correctly."""
- with ops.Graph().as_default(), self.cached_session() as sess:
- logits = np.asarray([
- [0., 0., 0.], #
- [1., -1., 0.]
- ]).astype(np.float32)
- targets = np.asarray([2, 1]).astype(np.int32)
- loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
- array_ops.constant(logits), targets=array_ops.one_hot(targets, 3))
- neg_log_prob = loss.evaluate()
- neg_log_prob = sess.run(neg_log_prob)
-
- # Calculate explicit log probability of targets.
- probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
- log_probs = np.log([
- probs[0, targets[0]], #
- probs[1, targets[1]]
- ])
- expected_log_prob = np.sum(log_probs)
-
- self.assertAllClose(neg_log_prob, -expected_log_prob)
-
- def testEvaluateOnSample(self):
- """Ensure log probability of a sample can be drawn."""
- with ops.Graph().as_default(), self.cached_session() as sess:
- logits = np.asarray([
- [0., 0., 0.], #
- [1., -1., 0.]
- ]).astype(np.float32)
- loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss(
- array_ops.constant(logits))
- neg_log_prob = loss.evaluate_on_sample(42)
-
- # Simply ensure this doesn't crash. As the output is random, it's
- # difficult to say if the output is correct or not...
- neg_log_prob = sess.run(neg_log_prob)
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py b/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py
deleted file mode 100644
index 4fae4374e1..0000000000
--- a/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tf.contrib.kfac.op_queue."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.kfac.python.ops import op_queue
-from tensorflow.python.framework import ops as tf_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class OpQueueTest(test.TestCase):
-
- def testNextOp(self):
- """Ensures all ops get selected eventually."""
- with tf_ops.Graph().as_default():
- ops = [
- math_ops.add(1, 2),
- math_ops.subtract(1, 2),
- math_ops.reduce_mean([1, 2]),
- ]
- queue = op_queue.OpQueue(ops, seed=0)
-
- with self.cached_session() as sess:
- # Ensure every inv update op gets selected.
- selected_ops = set([queue.next_op(sess) for _ in ops])
- self.assertEqual(set(ops), set(selected_ops))
-
- # Ensure additional calls don't create any new ops.
- selected_ops.add(queue.next_op(sess))
- self.assertEqual(set(ops), set(selected_ops))
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py b/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py
deleted file mode 100644
index 0b0de12ce6..0000000000
--- a/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py
+++ /dev/null
@@ -1,219 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tf.contrib.kfac.optimizer."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.kfac.python.ops import fisher_factors as ff
-from tensorflow.contrib.kfac.python.ops import layer_collection as lc
-from tensorflow.contrib.kfac.python.ops import optimizer
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables as tf_variables
-from tensorflow.python.platform import test
-
-
-# We need to set these constants since the numerical values used in the tests
-# were chosen when these used to be the defaults.
-ff.set_global_constants(init_covariances_at_zero=False,
- zero_debias=False,
- init_inverses_at_zero=False)
-
-
-def dummy_layer_collection():
- lcoll = lc.LayerCollection()
- dummy = array_ops.constant([1., 2.])
- lcoll.register_categorical_predictive_distribution(logits=dummy)
- return lcoll
-
-
-class OptimizerTest(test.TestCase):
-
- def testOptimizerInitInvalidMomentumRegistration(self):
- with self.assertRaises(ValueError):
- optimizer.KfacOptimizer(
- 0.1, 0.2, 0.3, lc.LayerCollection(), momentum_type='foo')
-
- def testOptimizerInit(self):
- with ops.Graph().as_default():
- layer_collection = lc.LayerCollection()
-
- inputs = array_ops.ones((2, 1)) * 2
- weights_val = np.ones((1, 1), dtype=np.float32) * 3.
- weights = variable_scope.get_variable(
- 'w', initializer=array_ops.constant(weights_val))
- bias = variable_scope.get_variable(
- 'b', initializer=init_ops.zeros_initializer(), shape=(1, 1))
- output = math_ops.matmul(inputs, weights) + bias
-
- layer_collection.register_fully_connected((weights, bias), inputs, output)
-
- logits = math_ops.tanh(output)
- targets = array_ops.constant([[0.], [1.]])
- output = math_ops.reduce_mean(
- nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets))
-
- layer_collection.register_categorical_predictive_distribution(logits)
-
- optimizer.KfacOptimizer(
- 0.1,
- 0.2,
- 0.3,
- layer_collection,
- momentum=0.5,
- momentum_type='regular')
-
- def testSquaredFisherNorm(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None),
- (array_ops.constant([[2., 3.], [4., 5.]]), None)]
- pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None),
- (array_ops.constant([[7., 8.], [9., 10.]]), None)]
- opt = optimizer.KfacOptimizer(0.1, 0.2, 0.3, dummy_layer_collection())
- sq_norm = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars)
- self.assertAlmostEqual(174., sess.run(sq_norm), places=5)
-
- def testUpdateClipCoeff(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None),
- (array_ops.constant([[2., 3.], [4., 5.]]), None)]
- pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None),
- (array_ops.constant([[7., 8.], [9., 10.]]), None)]
- lrate = 0.1
-
- # Note: without rescaling, the squared Fisher norm of the update
- # is 1.74
-
- # If the update already satisfies the norm constraint, there should
- # be no rescaling.
- opt = optimizer.KfacOptimizer(
- lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=10.)
- coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars)
- self.assertAlmostEqual(1., sess.run(coeff), places=5)
-
- # If the update violates the constraint, it should be rescaled to
- # be on the constraint boundary.
- opt = optimizer.KfacOptimizer(
- lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=0.5)
- coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars)
- sq_norm_pgrad = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars)
- sq_norm_update = lrate**2 * coeff**2 * sq_norm_pgrad
- self.assertAlmostEqual(0.5, sess.run(sq_norm_update), places=5)
-
- def testComputeUpdateStepsRegular(self):
- # TODO(olganw): implement this.
- pass
-
- def testComputeUpdateStepsAdam(self):
- # TODO(olganw): implement this.
- pass
-
- def testUpdateVelocities(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- layers = lc.LayerCollection()
- layers.register_categorical_predictive_distribution(
- array_ops.constant([1.0]))
- opt = optimizer.KfacOptimizer(
- 0.1, 0.2, 0.3, layers, momentum=0.5, momentum_type='regular')
- x = variable_scope.get_variable('x', initializer=array_ops.ones((2, 2)))
- y = variable_scope.get_variable(
- 'y', initializer=array_ops.ones((2, 2)) * 2)
- vec1 = array_ops.ones((2, 2)) * 3
- vec2 = array_ops.ones((2, 2)) * 4
-
- model_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- update_op = opt._update_velocities([(vec1, x), (vec2, y)], 0.5)
- opt_vars = [
- v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
- if v not in model_vars
- ]
-
- sess.run(tf_variables.global_variables_initializer())
- old_opt_vars = sess.run(opt_vars)
-
- # Optimizer vars start out at 0.
- for opt_var in old_opt_vars:
- self.assertAllEqual(sess.run(array_ops.zeros_like(opt_var)), opt_var)
-
- sess.run(update_op)
- new_opt_vars = sess.run(opt_vars)
- # After one update, the velocities are equal to the vectors.
- for vec, opt_var in zip([vec1, vec2], new_opt_vars):
- self.assertAllEqual(sess.run(vec), opt_var)
-
- sess.run(update_op)
- final_opt_vars = sess.run(opt_vars)
- for first, second in zip(new_opt_vars, final_opt_vars):
- self.assertFalse(np.equal(first, second).all())
-
- def testApplyGradients(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- layer_collection = lc.LayerCollection()
-
- inputs = array_ops.ones((2, 1)) * 2
- weights_val = np.ones((1, 1), dtype=np.float32) * 3.
- weights = variable_scope.get_variable(
- 'w', initializer=array_ops.constant(weights_val))
- bias = variable_scope.get_variable(
- 'b', initializer=init_ops.zeros_initializer(), shape=(1, 1))
- output = math_ops.matmul(inputs, weights) + bias
-
- layer_collection.register_fully_connected((weights, bias), inputs, output)
-
- logits = math_ops.tanh(output)
- targets = array_ops.constant([[0.], [1.]])
- output = math_ops.reduce_mean(
- nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets))
-
- layer_collection.register_categorical_predictive_distribution(logits)
-
- opt = optimizer.KfacOptimizer(
- 0.1,
- 0.2,
- 0.3,
- layer_collection,
- momentum=0.5,
- momentum_type='regular')
- (cov_update_thunks,
- inv_update_thunks) = opt.make_vars_and_create_op_thunks()
- cov_update_ops = tuple(thunk() for thunk in cov_update_thunks)
- inv_update_ops = tuple(thunk() for thunk in inv_update_thunks)
-
- grads_and_vars = opt.compute_gradients(output, [weights, bias])
- all_vars = [grad_and_var[1] for grad_and_var in grads_and_vars]
-
- op = opt.apply_gradients(grads_and_vars)
-
- sess.run(tf_variables.global_variables_initializer())
- old_vars = sess.run(all_vars)
- sess.run(cov_update_ops)
- sess.run(inv_update_ops)
- sess.run(op)
- new_vars = sess.run(all_vars)
-
- for old_var, new_var in zip(old_vars, new_vars):
- self.assertNotEqual(old_var, new_var)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py
deleted file mode 100644
index 7df79a3c7f..0000000000
--- a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py
+++ /dev/null
@@ -1,410 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tf.contrib.kfac.utils."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-import numpy.random as npr
-
-from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.contrib.tpu.python.tpu import tpu_function
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import random_seed
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import test
-
-
-class SequenceDictTest(test.TestCase):
-
- def testSequenceDictInit(self):
- seq_dict = utils.SequenceDict()
- self.assertFalse(seq_dict._dict)
-
- def testSequenceDictInitWithIterable(self):
- reg_dict = {'a': 'foo', 'b': 'bar'}
- itr = zip(reg_dict.keys(), reg_dict.values())
- seq_dict = utils.SequenceDict(itr)
- self.assertEqual(reg_dict, seq_dict._dict)
-
- def testGetItemSingleKey(self):
- seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'})
- self.assertEqual('foo', seq_dict['a'])
-
- def testGetItemMultipleKeys(self):
- seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'})
- self.assertEqual(['foo', 'bar'], seq_dict[('a', 'b')])
-
- def testSetItemSingleKey(self):
- seq_dict = utils.SequenceDict()
- seq_dict['a'] = 'foo'
- self.assertEqual([('a', 'foo')], seq_dict.items())
-
- def testSetItemMultipleKeys(self):
- seq_dict = utils.SequenceDict()
- keys = ('a', 'b', 'c')
- values = ('foo', 'bar', 'baz')
- seq_dict[keys] = values
- self.assertItemsEqual(list(zip(keys, values)), seq_dict.items())
-
-
-class SubGraphTest(test.TestCase):
-
- def testBasicGraph(self):
- a = array_ops.constant([[1., 2.], [3., 4.]])
- b = array_ops.constant([[5., 6.], [7., 8.]])
- c = a + b
- d = a * b
- sub_graph = utils.SubGraph((c,))
- self.assertTrue(sub_graph.is_member(a))
- self.assertTrue(sub_graph.is_member(b))
- self.assertTrue(sub_graph.is_member(c))
- self.assertFalse(sub_graph.is_member(d))
-
- def testRepeatedAdds(self):
- a = array_ops.constant([[1., 2.], [3., 4.]])
- b = array_ops.constant([[5., 6.], [7., 8.]])
- c = a + b + a # note that a appears twice in this graph
- sub_graph = utils.SubGraph((c,))
- self.assertTrue(sub_graph.is_member(a))
- self.assertTrue(sub_graph.is_member(b))
- self.assertTrue(sub_graph.is_member(c))
-
- def testFilterList(self):
- a = array_ops.constant([[1., 2.], [3., 4.]])
- b = array_ops.constant([[5., 6.], [7., 8.]])
- c = a + b
- d = a * b
- sub_graph = utils.SubGraph((c,))
- input_list = [b, d]
- filtered_list = sub_graph.filter_list(input_list)
- self.assertEqual(filtered_list, [b])
-
- def testVariableUses(self):
- with ops.Graph().as_default():
- var = variable_scope.get_variable('var', shape=[10, 10])
- resource_var = variable_scope.get_variable(
- 'resource_var', shape=[10, 10], use_resource=True)
- x = array_ops.zeros([3, 10])
- z0 = math_ops.matmul(x, var) + math_ops.matmul(x, var)
- z1 = math_ops.matmul(x, resource_var)
- sub_graph = utils.SubGraph((z0, z1))
- self.assertEqual(2, sub_graph.variable_uses(var))
- self.assertEqual(1, sub_graph.variable_uses(resource_var))
-
-
-class UtilsTest(test.TestCase):
-
- def _fully_connected_layer_params(self):
- weights_part = array_ops.constant([[1., 2.], [4., 3.]])
- bias_part = array_ops.constant([1., 2.])
- return (weights_part, bias_part)
-
- def _conv_layer_params(self):
- weights_shape = 2, 2, 3, 4
- biases_shape = weights_shape[-1:]
- weights = array_ops.constant(npr.RandomState(0).randn(*weights_shape))
- biases = array_ops.constant(npr.RandomState(1).randn(*biases_shape))
- return (weights, biases)
-
- def testFullyConnectedLayerParamsTupleToMat2d(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- layer_params = self._fully_connected_layer_params()
- output = utils.layer_params_to_mat2d(layer_params)
- self.assertListEqual([3, 2], output.get_shape().as_list())
- self.assertAllClose(
- sess.run(output), np.array([[1., 2.], [4., 3.], [1., 2.]]))
-
- def testFullyConnectedLayerParamsTensorToMat2d(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- layer_params = self._fully_connected_layer_params()
- output = utils.layer_params_to_mat2d(layer_params[0])
- self.assertListEqual([2, 2], output.get_shape().as_list())
- self.assertAllClose(sess.run(output), np.array([[1., 2.], [4., 3.]]))
-
- def testConvLayerParamsTupleToMat2d(self):
- with ops.Graph().as_default():
- random_seed.set_random_seed(200)
- layer_params = self._conv_layer_params()
- output = utils.layer_params_to_mat2d(layer_params)
- self.assertListEqual([2 * 2 * 3 + 1, 4], output.get_shape().as_list())
-
- def testKron(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- mat1 = np.array([[1., 2.], [3., 4.]])
- mat2 = np.array([[5., 6.], [7., 8.]])
- mat1_tf = array_ops.constant(mat1)
- mat2_tf = array_ops.constant(mat2)
- ans_tf = sess.run(utils.kronecker_product(mat1_tf, mat2_tf))
- ans_np = np.kron(mat1, mat2)
- self.assertAllClose(ans_tf, ans_np)
-
- def testMat2dToFullyConnectedLayerParamsTuple(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- vector_template = self._fully_connected_layer_params()
- mat2d = array_ops.constant([[5., 4.], [3., 2.], [1., 0.]])
-
- output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d))
-
- self.assertIsInstance(output, tuple)
- self.assertEqual(len(output), 2)
- a, b = output
- self.assertAllClose(a, np.array([[5., 4.], [3., 2.]]))
- self.assertAllClose(b, np.array([1., 0.]))
-
- def testMat2dToFullyConnectedLayerParamsTensor(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- vector_template = self._fully_connected_layer_params()[0]
- mat2d = array_ops.constant([[5., 4.], [3., 2.]])
-
- output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d))
-
- self.assertAllClose(output, np.array([[5., 4.], [3., 2.]]))
-
- def testTensorsToColumn(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
-
- vector = array_ops.constant(np.array([[0., 1.], [2., 3.]]))
- output = utils.tensors_to_column(vector)
- self.assertListEqual([4, 1], output.get_shape().as_list())
- self.assertAllClose(sess.run(output), np.array([0., 1., 2., 3.])[:, None])
-
- vector = self._fully_connected_layer_params()
- output = utils.tensors_to_column(vector)
- self.assertListEqual([6, 1], output.get_shape().as_list())
- self.assertAllClose(
- sess.run(output), np.array([1., 2., 4., 3., 1., 2.])[:, None])
-
- vector = list(vector)
- vector.append(array_ops.constant([[6.], [7.], [8.], [9.]]))
-
- output = utils.tensors_to_column(vector)
- self.assertListEqual([10, 1], output.get_shape().as_list())
- self.assertAllClose(
- sess.run(output),
- np.array([1., 2., 4., 3., 1., 2., 6., 7., 8., 9.])[:, None])
-
- def testColumnToTensors(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
-
- vector_template = array_ops.constant(np.array([[0., 1.], [2., 3.]]))
- colvec = array_ops.constant(np.arange(4.)[:, None])
- output = sess.run(utils.column_to_tensors(vector_template, colvec))
- self.assertAllClose(output, np.array([[0., 1.], [2., 3.]]))
-
- vector_template = self._fully_connected_layer_params()
- colvec = array_ops.constant(np.arange(6.)[:, None])
- output = sess.run(utils.column_to_tensors(vector_template, colvec))
-
- self.assertIsInstance(output, tuple)
- self.assertEqual(len(output), 2)
- a, b = output
- self.assertAllClose(a, np.array([[0., 1.], [2., 3.]]))
- self.assertAllClose(b, np.array([4., 5.]))
-
- vector_template = list(vector_template)
- vector_template.append(array_ops.constant([[6.], [7.], [8.], [9.]]))
- colvec = array_ops.constant(np.arange(10.)[:, None])
- output = sess.run(utils.column_to_tensors(vector_template, colvec))
- self.assertIsInstance(output, tuple)
- self.assertEqual(len(output), 3)
- a, b, c = output
- self.assertAllClose(a, np.array([[0., 1.], [2., 3.]]))
- self.assertAllClose(b, np.array([4., 5.]))
- self.assertAllClose(c, np.array([[6.], [7.], [8.], [9.]]))
-
- def testPosDefInvCholesky(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- npr.seed(0)
- square = lambda x: np.dot(x, x.T)
-
- size = 3
- x = square(npr.randn(size, size))
- damp = 0.1
- identity = linalg_ops.eye(size, dtype=dtypes.float64)
-
- tf_inv = utils.posdef_inv_cholesky(array_ops.constant(x), identity, damp)
- np_inv = np.linalg.inv(x + damp * np.eye(size))
- self.assertAllClose(sess.run(tf_inv), np_inv)
-
- def testPosDefInvMatrixInverse(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- random_seed.set_random_seed(200)
- npr.seed(0)
- square = lambda x: np.dot(x, x.T)
-
- size = 3
- x = square(npr.randn(size, size))
- damp = 0.1
- identity = linalg_ops.eye(size, dtype=dtypes.float64)
-
- tf_inv = utils.posdef_inv_matrix_inverse(
- array_ops.constant(x), identity, damp)
- np_inv = np.linalg.inv(x + damp * np.eye(size))
- self.assertAllClose(sess.run(tf_inv), np_inv)
-
- def testCrossReplicaMean(self):
- """Ensures that cross_replica_mean() executes only when num_shards > 1."""
- with ops.Graph().as_default():
- with tpu_function.tpu_shard_context(4):
- tensor = array_ops.zeros([], dtype=dtypes.float32)
- mean = utils.cross_replica_mean(tensor)
- self.assertNotEqual(mean, tensor)
-
- with ops.Graph().as_default():
- with tpu_function.tpu_shard_context(1):
- tensor = array_ops.zeros([], dtype=dtypes.float32)
- mean = utils.cross_replica_mean(tensor)
- self.assertEqual(mean, tensor)
-
- with ops.Graph().as_default():
- with self.assertRaises(ValueError): # Outside of TPU context.
- tensor = array_ops.zeros([], dtype=dtypes.float32)
- mean = utils.cross_replica_mean(tensor)
-
- def testBatchExecute(self):
- """Ensure batch_execute runs in a round-robin fashion."""
-
- def increment_var(var):
- return lambda: var.assign_add(1)
-
- with ops.Graph().as_default(), self.cached_session() as sess:
- i = variable_scope.get_variable('i', initializer=0)
- accumulators = [
- variable_scope.get_variable('var%d' % j, initializer=0)
- for j in range(3)
- ]
- thunks = [increment_var(var) for var in accumulators]
- increment_accumulators = utils.batch_execute(i, thunks, 2)
- increment_i = i.assign_add(1)
-
- sess.run(variables.global_variables_initializer())
-
- # Ensure one op per thunk.
- self.assertEqual(3, len(increment_accumulators))
-
- # Ensure round-robin execution.
- values = []
- for _ in range(5):
- sess.run(increment_accumulators)
- sess.run(increment_i)
- values.append(sess.run(accumulators))
- self.assertAllClose(
- [
- [1, 1, 0], #
- [2, 1, 1], #
- [2, 2, 2], #
- [3, 3, 2], #
- [4, 3, 3]
- ],
- values)
-
- def testExtractConvolutionPatches(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- batch_size = 10
- image_spatial_shape = [9, 10, 11]
- in_channels = out_channels = 32
- kernel_spatial_shape = [5, 3, 3]
- spatial_strides = [1, 2, 1]
- spatial_dilation = [1, 1, 1]
- padding = 'SAME'
-
- images = random_ops.random_uniform(
- [batch_size] + image_spatial_shape + [in_channels], seed=0)
- kernel_shape = kernel_spatial_shape + [in_channels, out_channels]
- kernel = random_ops.random_uniform(kernel_shape, seed=1)
-
- # Ensure shape matches expectation.
- patches = utils.extract_convolution_patches(
- images,
- kernel_shape,
- padding,
- strides=spatial_strides,
- dilation_rate=spatial_dilation)
- result_spatial_shape = (
- patches.shape.as_list()[1:1 + len(image_spatial_shape)])
- self.assertEqual(patches.shape.as_list(),
- [batch_size] + result_spatial_shape +
- kernel_spatial_shape + [in_channels])
-
- # Ensure extract...patches() + matmul() and convolution() implementation
- # give the same answer.
- outputs = nn_ops.convolution(
- images,
- kernel,
- padding,
- strides=spatial_strides,
- dilation_rate=spatial_dilation)
-
- patches_flat = array_ops.reshape(
- patches, [-1, np.prod(kernel_spatial_shape) * in_channels])
- kernel_flat = array_ops.reshape(kernel, [-1, out_channels])
- outputs_flat = math_ops.matmul(patches_flat, kernel_flat)
-
- outputs_, outputs_flat_ = sess.run([outputs, outputs_flat])
- self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten())
-
- def testExtractPointwiseConv2dPatches(self):
- with ops.Graph().as_default(), self.cached_session() as sess:
- batch_size = 10
- image_height = image_width = 8
- in_channels = out_channels = 3
- kernel_height = kernel_width = 1
- strides = [1, 1, 1, 1]
- padding = 'VALID'
-
- images = random_ops.random_uniform(
- [batch_size, image_height, image_width, in_channels], seed=0)
- kernel_shape = [kernel_height, kernel_width, in_channels, out_channels]
- kernel = random_ops.random_uniform(kernel_shape, seed=1)
-
- # Ensure shape matches expectation.
- patches = utils.extract_pointwise_conv2d_patches(images, kernel_shape)
- self.assertEqual(patches.shape.as_list(), [
- batch_size, image_height, image_width, kernel_height, kernel_width,
- in_channels
- ])
-
- # Ensure extract...patches() + matmul() and conv2d() implementation
- # give the same answer.
- outputs = nn_ops.conv2d(images, kernel, strides, padding)
-
- patches_flat = array_ops.reshape(
- patches, [-1, kernel_height * kernel_width * in_channels])
- kernel_flat = array_ops.reshape(kernel, [-1, out_channels])
- outputs_flat = math_ops.matmul(patches_flat, kernel_flat)
-
- outputs_, outputs_flat_ = sess.run([outputs, outputs_flat])
- self.assertAllClose(outputs_.flatten(), outputs_flat_.flatten())
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD
deleted file mode 100644
index 3c01eb65e7..0000000000
--- a/tensorflow/contrib/kfac/python/ops/BUILD
+++ /dev/null
@@ -1,263 +0,0 @@
-package(default_visibility = [
- "//tensorflow/contrib/kfac:__pkg__",
- "//tensorflow/contrib/kfac/python/kernel_tests:__pkg__",
-])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-py_library(
- name = "fisher_blocks",
- srcs = ["fisher_blocks.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":fisher_factors",
- ":utils",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:math_ops",
- "@six_archive//:six",
- ],
-)
-
-py_library(
- name = "fisher_blocks_lib",
- srcs = ["fisher_blocks_lib.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":fisher_blocks",
- "//tensorflow/python:util",
- ],
-)
-
-py_library(
- name = "fisher_factors",
- srcs = ["fisher_factors.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":linear_operator",
- ":utils",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:init_ops",
- "//tensorflow/python:linalg_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:special_math_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- "//third_party/py/numpy",
- "@six_archive//:six",
- ],
-)
-
-py_library(
- name = "fisher_factors_lib",
- srcs = ["fisher_factors_lib.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":fisher_factors",
- "//tensorflow/python:util",
- ],
-)
-
-py_library(
- name = "linear_operator",
- srcs = ["linear_operator.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":utils",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python/ops/linalg",
- "@six_archive//:six",
- ],
-)
-
-py_library(
- name = "loss_functions",
- srcs = ["loss_functions.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/distributions:distributions_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/ops/distributions",
- "@six_archive//:six",
- ],
-)
-
-py_library(
- name = "loss_functions_lib",
- srcs = ["loss_functions_lib.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":loss_functions",
- "//tensorflow/python:util",
- ],
-)
-
-py_library(
- name = "curvature_matrix_vector_products",
- srcs = ["curvature_matrix_vector_products.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":utils",
- "//tensorflow/python:gradients",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:util",
- ],
-)
-
-py_library(
- name = "curvature_matrix_vector_products_lib",
- srcs = ["curvature_matrix_vector_products_lib.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":curvature_matrix_vector_products",
- "//tensorflow/python:util",
- ],
-)
-
-py_library(
- name = "layer_collection",
- srcs = ["layer_collection.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":fisher_blocks",
- ":loss_functions",
- ":utils",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:util",
- "//tensorflow/python:variable_scope",
- "@six_archive//:six",
- ],
-)
-
-py_library(
- name = "layer_collection_lib",
- srcs = ["layer_collection_lib.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":layer_collection",
- "//tensorflow/python:util",
- ],
-)
-
-py_library(
- name = "kfac_optimizer",
- srcs = [
- "optimizer.py",
- ],
- srcs_version = "PY2AND3",
- deps = [
- ":curvature_matrix_vector_products",
- ":fisher_estimator",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:linalg_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
- ],
-)
-
-py_library(
- name = "kfac_optimizer_lib",
- srcs = [
- "optimizer_lib.py",
- ],
- srcs_version = "PY2AND3",
- deps = [
- ":kfac_optimizer",
- "//tensorflow/python:util",
- ],
-)
-
-py_library(
- name = "fisher_estimator",
- srcs = [
- "estimator.py",
- "placement.py",
- ],
- srcs_version = "PY2AND3",
- deps = [
- ":utils",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:gradients",
- "//tensorflow/python:util",
- "//third_party/py/numpy",
- "@six_archive//:six",
- ],
-)
-
-py_library(
- name = "fisher_estimator_lib",
- srcs = [
- "estimator_lib.py",
- ],
- srcs_version = "PY2AND3",
- deps = [
- ":fisher_estimator",
- "//tensorflow/python:util",
- ],
-)
-
-py_library(
- name = "utils",
- srcs = ["utils.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/tpu",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:gradients",
- "//tensorflow/python:linalg_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_library(
- name = "utils_lib",
- srcs = ["utils_lib.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":utils",
- "//tensorflow/python:util",
- ],
-)
-
-py_library(
- name = "op_queue",
- srcs = ["op_queue.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:dataset_ops",
- "//tensorflow/python:framework_ops",
- ],
-)
-
-py_library(
- name = "op_queue_lib",
- srcs = ["op_queue_lib.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":op_queue",
- "//tensorflow/python:util",
- ],
-)
diff --git a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py
deleted file mode 100644
index 21b5cde9b9..0000000000
--- a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py
+++ /dev/null
@@ -1,183 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Curvature matrix-vector multiplication."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.python.ops import gradients_impl
-from tensorflow.python.ops import math_ops
-from tensorflow.python.util import nest
-
-
-class CurvatureMatrixVectorProductComputer(object):
- """Class for computing matrix-vector products for Fishers, GGNs and Hessians.
-
- In other words we compute M*v where M is the matrix, v is the vector, and
- * refers to standard matrix/vector multiplication (not element-wise
- multiplication).
-
- The matrices are defined in terms of some differential quantity of the total
- loss function with respect to a provided list of tensors ("wrt_tensors").
- For example, the Fisher associated with a log-prob loss w.r.t. the
- parameters.
-
- The 'vecs' argument to each method are lists of tensors that must be the
- size as the corresponding ones from "wrt_tensors". They represent
- the vector being multiplied.
-
- "factors" of the matrix M are defined as matrices B such that B*B^T = M.
- Methods that multiply by the factor B take a 'loss_inner_vecs' argument
- instead of 'vecs', which must be a list of tensors with shapes given by the
- corresponding XXX_inner_shapes property.
-
- Note that matrix-vector products are not normalized by the batch size, nor
- are any damping terms added to the results. These things can be easily
- applied externally, if desired.
-
- See for example: www.cs.utoronto.ca/~jmartens/docs/HF_book_chapter.pdf
- and https://arxiv.org/abs/1412.1193 for more information about the
- generalized Gauss-Newton, Fisher, etc., and how to compute matrix-vector
- products.
- """
-
- def __init__(self, losses, wrt_tensors):
- """Create a CurvatureMatrixVectorProductComputer object.
-
- Args:
- losses: A list of LossFunction instances whose sum defines the total loss.
- wrt_tensors: A list of Tensors to compute the differential quantities
- (defining the matrices) with respect to. See class description for more
- info.
- """
- self._losses = losses
- self._inputs_to_losses = list(loss.inputs for loss in losses)
- self._inputs_to_losses_flat = nest.flatten(self._inputs_to_losses)
- self._wrt_tensors = wrt_tensors
-
- @property
- def _total_loss(self):
- return math_ops.add_n(tuple(loss.evaluate() for loss in self._losses))
-
- # Jacobian multiplication functions:
- def _multiply_jacobian(self, vecs):
- """Multiply vecs by the Jacobian of losses."""
- # We stop gradients at wrt_tensors to produce partial derivatives (which is
- # what we want for Jacobians).
- jacobian_vecs_flat = utils.fwd_gradients(
- self._inputs_to_losses_flat, self._wrt_tensors, grad_xs=vecs,
- stop_gradients=self._wrt_tensors)
- return nest.pack_sequence_as(self._inputs_to_losses, jacobian_vecs_flat)
-
- def _multiply_jacobian_transpose(self, loss_vecs):
- """Multiply vecs by the transpose Jacobian of losses."""
- loss_vecs_flat = nest.flatten(loss_vecs)
- # We stop gradients at wrt_tensors to produce partial derivatives (which is
- # what we want for Jacobians).
- return gradients_impl.gradients(
- self._inputs_to_losses_flat, self._wrt_tensors, grad_ys=loss_vecs_flat,
- stop_gradients=self._wrt_tensors)
-
- # Losses Fisher/Hessian multiplication functions:
- def _multiply_loss_fisher(self, loss_vecs):
- """Multiply loss_vecs by Fisher of total loss."""
- return tuple(
- loss.multiply_fisher(loss_vec)
- for loss, loss_vec in zip(self._losses, loss_vecs))
-
- def _multiply_loss_fisher_factor(self, loss_inner_vecs):
- """Multiply loss_inner_vecs by factor of Fisher of total loss."""
- return tuple(
- loss.multiply_fisher_factor(loss_vec)
- for loss, loss_vec in zip(self._losses, loss_inner_vecs))
-
- def _multiply_loss_fisher_factor_transpose(self, loss_vecs):
- """Multiply loss_vecs by transpose factor of Fisher of total loss."""
- return tuple(
- loss.multiply_fisher_factor_transpose(loss_vec)
- for loss, loss_vec in zip(self._losses, loss_vecs))
-
- def _multiply_loss_hessian(self, loss_vecs):
- """Multiply loss_vecs by Hessian of total loss."""
- return tuple(
- loss.multiply_hessian(loss_vec)
- for loss, loss_vec in zip(self._losses, loss_vecs))
-
- def _multiply_loss_hessian_factor(self, loss_inner_vecs):
- """Multiply loss_inner_vecs by factor of Hessian of total loss."""
- return tuple(
- loss.multiply_hessian_factor(loss_vec)
- for loss, loss_vec in zip(self._losses, loss_inner_vecs))
-
- def _multiply_loss_hessian_factor_transpose(self, loss_vecs):
- """Multiply loss_vecs by transpose factor of Hessian of total loss."""
- return tuple(
- loss.multiply_hessian_factor_transpose(loss_vec)
- for loss, loss_vec in zip(self._losses, loss_vecs))
-
- # Matrix-vector product functions:
- def multiply_fisher(self, vecs):
- """Multiply vecs by Fisher of total loss."""
- jacobian_vecs = self._multiply_jacobian(vecs)
- loss_fisher_jacobian_vecs = self._multiply_loss_fisher(jacobian_vecs)
- return self._multiply_jacobian_transpose(loss_fisher_jacobian_vecs)
-
- def multiply_fisher_factor_transpose(self, vecs):
- """Multiply vecs by transpose of factor of Fisher of total loss."""
- jacobian_vecs = self._multiply_jacobian(vecs)
- return self._multiply_loss_fisher_factor_transpose(jacobian_vecs)
-
- def multiply_fisher_factor(self, loss_inner_vecs):
- """Multiply loss_inner_vecs by factor of Fisher of total loss."""
- fisher_factor_transpose_vecs = self._multiply_loss_fisher_factor_transpose(
- loss_inner_vecs)
- return self._multiply_jacobian_transpose(fisher_factor_transpose_vecs)
-
- def multiply_hessian(self, vecs):
- """Multiply vecs by Hessian of total loss."""
- return gradients_impl.gradients(
- gradients_impl.gradients(self._total_loss, self._wrt_tensors),
- self._wrt_tensors,
- grad_ys=vecs)
-
- def multiply_generalized_gauss_newton(self, vecs):
- """Multiply vecs by generalized Gauss-Newton of total loss."""
- jacobian_vecs = self._multiply_jacobian(vecs)
- loss_hessian_jacobian_vecs = self._multiply_loss_hessian(jacobian_vecs)
- return self._multiply_jacobian_transpose(loss_hessian_jacobian_vecs)
-
- def multiply_generalized_gauss_newton_factor_transpose(self, vecs):
- """Multiply vecs by transpose of factor of GGN of total loss."""
- jacobian_vecs = self._multiply_jacobian(vecs)
- return self._multiply_loss_hessian_factor_transpose(jacobian_vecs)
-
- def multiply_generalized_gauss_newton_factor(self, loss_inner_vecs):
- """Multiply loss_inner_vecs by factor of GGN of total loss."""
- hessian_factor_transpose_vecs = (
- self._multiply_loss_hessian_factor_transpose(loss_inner_vecs))
- return self._multiply_jacobian_transpose(hessian_factor_transpose_vecs)
-
- # Shape properties for multiply_XXX_factor methods:
- @property
- def fisher_factor_inner_shapes(self):
- """Shapes required by multiply_fisher_factor."""
- return tuple(loss.fisher_factor_inner_shape for loss in self._losses)
-
- @property
- def generalized_gauss_newton_factor_inner_shapes(self):
- """Shapes required by multiply_generalized_gauss_newton_factor."""
- return tuple(loss.hessian_factor_inner_shape for loss in self._losses)
diff --git a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py
deleted file mode 100644
index 6e8c6404dc..0000000000
--- a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py
+++ /dev/null
@@ -1,30 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Curvature matrix-vector multiplication."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,line-too-long,wildcard-import
-from tensorflow.contrib.kfac.python.ops.curvature_matrix_vector_products import *
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long,wildcard-import
-
-_allowed_symbols = [
- 'CurvatureMatrixVectorProductComputer',
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py
deleted file mode 100644
index 323234c403..0000000000
--- a/tensorflow/contrib/kfac/python/ops/estimator.py
+++ /dev/null
@@ -1,516 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Defines the high-level Fisher estimator class."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import abc
-import numpy as np
-import six
-
-from tensorflow.contrib.kfac.python.ops import placement
-from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.python.framework import ops as tf_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import gradients_impl
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.util import nest
-
-
-# The linter is confused.
-# pylint: disable=abstract-class-instantiated
-def make_fisher_estimator(placement_strategy=None, **kwargs):
- """Creates Fisher estimator instances based on the placement strategy.
-
- For example if the `placement_strategy` is 'round_robin' then
- `FisherEstimatorRoundRobin` instance is returned.
-
- Args:
- placement_strategy: `string`, Strategy to be used for placing covariance
- variables, covariance ops and inverse ops. Check
- `placement.FisherEstimatorRoundRobin` for a concrete example.
- **kwargs: Arguments to be passed into `FisherEstimator` class initializer.
-
- Returns:
- An instance of class which inherits from `FisherEstimator` and the mixin
- which implements specific placement strategy. See,
- `FisherEstimatorRoundRobin` which inherits from `FisherEstimator` and
- `RoundRobinPlacementMixin`.
-
- Raises:
- ValueError: If the `placement_strategy` is not equal to 'round_robin'.
- """
- if placement_strategy in [None, "round_robin"]:
- return FisherEstimatorRoundRobin(**kwargs)
- else:
- raise ValueError("Unimplemented vars and ops "
- "placement strategy : {}".format(placement_strategy))
-# pylint: enable=abstract-class-instantiated
-
-
-@six.add_metaclass(abc.ABCMeta)
-class FisherEstimator(object):
- """Fisher estimator class supporting various approximations of the Fisher.
-
- This is an abstract base class which does not implement a strategy for
- placing covariance variables, covariance update ops and inverse update ops.
- The placement strategies are implemented in `placement.py`. See
- `FisherEstimatorRoundRobin` for example of a concrete subclass with
- a round-robin placement strategy.
- """
-
- def __init__(self,
- variables,
- cov_ema_decay,
- damping,
- layer_collection,
- exps=(-1,),
- estimation_mode="gradients",
- colocate_gradients_with_ops=True,
- name="FisherEstimator",
- compute_cholesky=False,
- compute_cholesky_inverse=False):
- """Create a FisherEstimator object.
-
- Args:
- variables: A `list` of variables or `callable` which returns the variables
- for which to estimate the Fisher. This must match the variables
- registered in layer_collection (if it is not None).
- cov_ema_decay: The decay factor used when calculating the covariance
- estimate moving averages.
- damping: float. The damping factor used to stabilize training due to
- errors in the local approximation with the Fisher information matrix,
- and to regularize the update direction by making it closer to the
- gradient. (Higher damping means the update looks more like a standard
- gradient update - see Tikhonov regularization.)
- layer_collection: The layer collection object, which holds the Fisher
- blocks, Kronecker factors, and losses associated with the
- graph.
- exps: List of floats or ints. These represent the different matrix
- powers of the approximate Fisher that the FisherEstimator will be able
- to multiply vectors by. If the user asks for a matrix power other
- one of these (or 1, which is always supported), there will be a
- failure. (Default: (-1,))
- estimation_mode: The type of estimator to use for the Fishers. Can be
- 'gradients', 'empirical', 'curvature_prop', or 'exact'.
- (Default: 'gradients'). 'gradients' is the basic estimation approach
- from the original K-FAC paper. 'empirical' computes the 'empirical'
- Fisher information matrix (which uses the data's distribution for the
- targets, as opposed to the true Fisher which uses the model's
- distribution) and requires that each registered loss have specified
- targets. 'curvature_propagation' is a method which estimates the
- Fisher using self-products of random 1/-1 vectors times "half-factors"
- of the Fisher, as described here: https://arxiv.org/abs/1206.6464 .
- Finally, 'exact' is the obvious generalization of Curvature
- Propagation to compute the exact Fisher (modulo any additional
- diagonal or Kronecker approximations) by looping over one-hot vectors
- for each coordinate of the output instead of using 1/-1 vectors. It
- is more expensive to compute than the other three options by a factor
- equal to the output dimension, roughly speaking.
- colocate_gradients_with_ops: Whether we should request gradients be
- colocated with their respective ops. (Default: True)
- name: A string. A name given to this estimator, which is added to the
- variable scope when constructing variables and ops.
- (Default: "FisherEstimator")
- compute_cholesky: Bool. Whether or not the FisherEstimator will be
- able to multiply vectors by the Cholesky factor.
- (Default: False)
- compute_cholesky_inverse: Bool. Whether or not the FisherEstimator
- will be able to multiply vectors by the Cholesky factor inverse.
- (Default: False)
- Raises:
- ValueError: If no losses have been registered with layer_collection.
- """
- self._variables = variables
- self._cov_ema_decay = cov_ema_decay
- self._damping = damping
- self._estimation_mode = estimation_mode
- self._layers = layer_collection
- self._gradient_fns = {
- "gradients": self._get_grads_lists_gradients,
- "empirical": self._get_grads_lists_empirical,
- "curvature_prop": self._get_grads_lists_curvature_prop,
- "exact": self._get_grads_lists_exact
- }
- self._colocate_gradients_with_ops = colocate_gradients_with_ops
-
- self._made_vars = False
- self._exps = exps
- self._compute_cholesky = compute_cholesky
- self._compute_cholesky_inverse = compute_cholesky_inverse
-
- self._name = name
-
- @property
- def variables(self):
- if callable(self._variables):
- return self._variables()
- else:
- return self._variables
-
- @property
- def damping(self):
- return self._damping
-
- @property
- def blocks(self):
- """All registered FisherBlocks."""
- return self._layers.get_blocks()
-
- @property
- def factors(self):
- """All registered FisherFactors."""
- return self._layers.get_factors()
-
- @property
- def name(self):
- return self._name
-
- @abc.abstractmethod
- def make_vars_and_create_op_thunks(self, scope=None):
- """Make vars and create op thunks with a specific placement strategy.
-
- For each factor, all of that factor's cov variables and their associated
- update ops will be placed on a particular device. A new device is chosen
- for each factor by cycling through list of devices in the cov_devices
- argument. If cov_devices is None then no explicit device placement occurs.
-
- An analogous strategy is followed for inverse update ops, with the list of
- devices being given by the inv_devices argument.
-
- Inverse variables on the other hand are not placed on any specific device
- (they will just use the current the device placement context, whatever
- that happens to be). The idea is that the inverse variable belong where
- they will be accessed most often, which is the device that actually applies
- the preconditioner to the gradient. The user will be responsible for setting
- the device context for this.
-
- Args:
- scope: A string or None. If None it will be set to the name of this
- estimator (given by the name property). All variables will be created,
- and all thunks will execute, inside of a variable scope of the given
- name. (Default: None)
-
- Returns:
- cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
- the list of factors given by the "factors" property.
- inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
- the list of factors given by the "factors" property.
- """
- pass
-
- def _apply_transformation(self, vecs_and_vars, transform):
- """Applies an block-wise transformation to the corresponding vectors.
-
- Args:
- vecs_and_vars: List of (vector, variable) pairs.
- transform: A function of the form f(fb, vec), where vec is the vector
- to transform and fb is its corresponding block in the matrix, that
- returns the transformed vector.
-
- Returns:
- A list of (transformed vector, var) pairs in the same order as
- vecs_and_vars.
- """
-
- vecs = utils.SequenceDict((var, vec) for vec, var in vecs_and_vars)
-
- trans_vecs = utils.SequenceDict()
-
- for params, fb in self._layers.fisher_blocks.items():
- trans_vecs[params] = transform(fb, vecs[params])
-
- return [(trans_vecs[var], var) for _, var in vecs_and_vars]
-
- def multiply_inverse(self, vecs_and_vars):
- """Multiplies the vecs by the corresponding (damped) inverses of the blocks.
-
- Args:
- vecs_and_vars: List of (vector, variable) pairs.
-
- Returns:
- A list of (transformed vector, var) pairs in the same order as
- vecs_and_vars.
- """
- return self.multiply_matpower(-1, vecs_and_vars)
-
- def multiply(self, vecs_and_vars):
- """Multiplies the vectors by the corresponding (damped) blocks.
-
- Args:
- vecs_and_vars: List of (vector, variable) pairs.
-
- Returns:
- A list of (transformed vector, var) pairs in the same order as
- vecs_and_vars.
- """
- return self.multiply_matpower(1, vecs_and_vars)
-
- def multiply_matpower(self, exp, vecs_and_vars):
- """Multiplies the vecs by the corresponding matrix powers of the blocks.
-
- Args:
- exp: A float representing the power to raise the blocks by before
- multiplying it by the vector.
- vecs_and_vars: List of (vector, variable) pairs.
-
- Returns:
- A list of (transformed vector, var) pairs in the same order as
- vecs_and_vars.
- """
- assert exp in self._exps
-
- fcn = lambda fb, vec: fb.multiply_matpower(vec, exp)
- return self._apply_transformation(vecs_and_vars, fcn)
-
- def multiply_cholesky(self, vecs_and_vars, transpose=False):
- """Multiplies the vecs by the corresponding Cholesky factors.
-
- Args:
- vecs_and_vars: List of (vector, variable) pairs.
- transpose: Bool. If true the Cholesky factors are transposed before
- multiplying the vecs. (Default: False)
-
- Returns:
- A list of (transformed vector, var) pairs in the same order as
- vecs_and_vars.
- """
- assert self._compute_cholesky
-
- fcn = lambda fb, vec: fb.multiply_cholesky(vec, transpose=transpose)
- return self._apply_transformation(vecs_and_vars, fcn)
-
- def multiply_cholesky_inverse(self, vecs_and_vars, transpose=False):
- """Mults the vecs by the inverses of the corresponding Cholesky factors.
-
- Note: if you are using Cholesky inverse multiplication to sample from
- a matrix-variate Gaussian you will want to multiply by the transpose.
- Let L be the Cholesky factor of F and observe that
-
- L^-T * L^-1 = (L * L^T)^-1 = F^-1 .
-
- Thus we want to multiply by L^-T in order to sample from Gaussian with
- covariance F^-1.
-
- Args:
- vecs_and_vars: List of (vector, variable) pairs.
- transpose: Bool. If true the Cholesky factor inverses are transposed
- before multiplying the vecs. (Default: False)
-
- Returns:
- A list of (transformed vector, var) pairs in the same order as
- vecs_and_vars.
- """
- assert self._compute_cholesky_inverse
-
- fcn = lambda fb, vec: fb.multiply_cholesky_inverse(vec, transpose=transpose)
- return self._apply_transformation(vecs_and_vars, fcn)
-
- def _instantiate_factors(self):
- """Instantiates FisherFactors' variables.
-
- Raises:
- ValueError: If estimation_mode was improperly specified at construction.
- """
- blocks = self.blocks
- tensors_to_compute_grads = [
- block.tensors_to_compute_grads() for block in blocks
- ]
-
- try:
- grads_lists = self._gradient_fns[self._estimation_mode](
- tensors_to_compute_grads)
- except KeyError:
- raise ValueError("Unrecognized value {} for estimation_mode.".format(
- self._estimation_mode))
-
- for grads_list, block in zip(grads_lists, blocks):
- block.instantiate_factors(grads_list, self.damping)
-
- def _check_vars_unmade_and_set_made_flag(self):
- if self._made_vars:
- raise Exception("Already made variables.")
- self._made_vars = True
-
- def made_vars(self):
- return self._made_vars
-
- def _register_matrix_functions(self):
- for block in self.blocks:
- for exp in self._exps:
- block.register_matpower(exp)
- if self._compute_cholesky:
- block.register_cholesky()
- if self._compute_cholesky_inverse:
- block.register_cholesky_inverse()
-
- def _finalize_layer_collection(self):
- self._layers.create_subgraph()
- self._layers.check_registration(self.variables)
- self._instantiate_factors()
- self._register_matrix_functions()
-
- def create_ops_and_vars_thunks(self, scope=None):
- """Create thunks that make the ops and vars on demand.
-
- This function returns 4 lists of thunks: cov_variable_thunks,
- cov_update_thunks, inv_variable_thunks, and inv_update_thunks.
-
- The length of each list is the number of factors and the i-th element of
- each list corresponds to the i-th factor (given by the "factors" property).
-
- Note that the execution of these thunks must happen in a certain
- partial order. The i-th element of cov_variable_thunks must execute
- before the i-th element of cov_update_thunks (and also the i-th element
- of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks
- must execute before the i-th element of inv_update_thunks.
-
- TL;DR (oversimplified): Execute the thunks according to the order that
- they are returned.
-
- Args:
- scope: A string or None. If None it will be set to the name of this
- estimator (given by the name property). All thunks will execute inside
- of a variable scope of the given name. (Default: None)
- Returns:
- cov_variable_thunks: A list of thunks that make the cov variables.
- cov_update_thunks: A list of thunks that make the cov update ops.
- inv_variable_thunks: A list of thunks that make the inv variables.
- inv_update_thunks: A list of thunks that make the inv update ops.
- """
- self._check_vars_unmade_and_set_made_flag()
-
- self._finalize_layer_collection()
-
- scope = self.name if scope is None else scope
-
- cov_variable_thunks = [
- self._create_cov_variable_thunk(factor, scope)
- for factor in self.factors
- ]
- cov_update_thunks = [
- self._create_cov_update_thunk(factor, scope) for factor in self.factors
- ]
- inv_variable_thunks = [
- self._create_inv_variable_thunk(factor, scope)
- for factor in self.factors
- ]
- inv_update_thunks = [
- self._create_inv_update_thunk(factor, scope) for factor in self.factors
- ]
-
- return (cov_variable_thunks, cov_update_thunks,
- inv_variable_thunks, inv_update_thunks)
-
- def _create_cov_variable_thunk(self, factor, scope):
- """Constructs a covariance variable thunk for a single FisherFactor."""
-
- def thunk():
- with variable_scope.variable_scope(scope):
- return factor.instantiate_cov_variables()
-
- return thunk
-
- def _create_cov_update_thunk(self, factor, scope):
- """Constructs a covariance update thunk for a single FisherFactor."""
-
- def thunk():
- with variable_scope.variable_scope(scope):
- return factor.make_covariance_update_op(self._cov_ema_decay)
-
- return thunk
-
- def _create_inv_variable_thunk(self, factor, scope):
- """Constructs a inverse variable thunk for a single FisherFactor."""
-
- def thunk():
- with variable_scope.variable_scope(scope):
- return factor.instantiate_inv_variables()
-
- return thunk
-
- def _create_inv_update_thunk(self, factor, scope):
- """Constructs an inverse update thunk for a single FisherFactor."""
-
- def thunk():
- with variable_scope.variable_scope(scope):
- return control_flow_ops.group(factor.make_inverse_update_ops())
-
- return thunk
-
- def _get_grads_lists_gradients(self, tensors):
- # Passing in a list of loss values is better than passing in the sum as
- # the latter creates unnessesary ops on the default device
- grads_flat = gradients_impl.gradients(
- self._layers.eval_losses_on_samples(),
- nest.flatten(tensors),
- colocate_gradients_with_ops=self._colocate_gradients_with_ops)
- grads_all = nest.pack_sequence_as(tensors, grads_flat)
- return tuple((grad,) for grad in grads_all)
-
- def _get_grads_lists_empirical(self, tensors):
- # Passing in a list of loss values is better than passing in the sum as
- # the latter creates unnecessary ops on the default device
- grads_flat = gradients_impl.gradients(
- self._layers.eval_losses(),
- nest.flatten(tensors),
- colocate_gradients_with_ops=self._colocate_gradients_with_ops)
- grads_all = nest.pack_sequence_as(tensors, grads_flat)
- return tuple((grad,) for grad in grads_all)
-
- def _get_transformed_random_signs(self):
- transformed_random_signs = []
- for loss in self._layers.losses:
- with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]):
- transformed_random_signs.append(
- loss.multiply_fisher_factor(
- utils.generate_random_signs(loss.fisher_factor_inner_shape)))
- return transformed_random_signs
-
- def _get_grads_lists_curvature_prop(self, tensors):
- loss_inputs = list(loss.inputs for loss in self._layers.losses)
- transformed_random_signs = self._get_transformed_random_signs()
- grads_flat = gradients_impl.gradients(
- nest.flatten(loss_inputs),
- nest.flatten(tensors),
- grad_ys=nest.flatten(transformed_random_signs),
- colocate_gradients_with_ops=self._colocate_gradients_with_ops)
- grads_all = nest.pack_sequence_as(tensors, grads_flat)
- return tuple((grad,) for grad in grads_all)
-
- def _get_grads_lists_exact(self, tensors):
- """No docstring required."""
- # Loop over all coordinates of all losses.
- grads_all = []
- for loss in self._layers.losses:
- with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]):
- for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]):
- transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot(
- index)
- grads_flat = gradients_impl.gradients(
- loss.inputs,
- nest.flatten(tensors),
- grad_ys=transformed_one_hot,
- colocate_gradients_with_ops=self._colocate_gradients_with_ops)
- grads_all.append(nest.pack_sequence_as(tensors, grads_flat))
- return zip(*grads_all)
-
-
-class FisherEstimatorRoundRobin(placement.RoundRobinPlacementMixin,
- FisherEstimator):
- """Fisher estimator which provides round robin device placement strategy."""
- pass
diff --git a/tensorflow/contrib/kfac/python/ops/estimator_lib.py b/tensorflow/contrib/kfac/python/ops/estimator_lib.py
deleted file mode 100644
index 9c9fef471f..0000000000
--- a/tensorflow/contrib/kfac/python/ops/estimator_lib.py
+++ /dev/null
@@ -1,31 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Defines the high-level Fisher estimator class."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,line-too-long,wildcard-import
-from tensorflow.contrib.kfac.python.ops.estimator import *
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long,wildcard-import
-
-_allowed_symbols = [
- 'FisherEstimator',
- 'make_fisher_estimator',
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
deleted file mode 100644
index 9fa6eb7dcd..0000000000
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py
+++ /dev/null
@@ -1,1752 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""FisherBlock definitions.
-
-This library contains classes for estimating blocks in a model's Fisher
-Information matrix. Suppose one has a model that parameterizes a posterior
-distribution over 'y' given 'x' with parameters 'params', p(y | x, params). Its
-Fisher Information matrix is given by,
-
- $$F(params) = E[ v(x, y, params) v(x, y, params)^T ]$$
-
-where,
-
- $$v(x, y, params) = (d / d params) log p(y | x, params)$$
-
-and the expectation is taken with respect to the data's distribution for 'x' and
-the model's posterior distribution for 'y',
-
- x ~ p(x)
- y ~ p(y | x, params)
-
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import abc
-import enum # pylint: disable=g-bad-import-order
-
-import numpy as np
-import six
-
-from tensorflow.contrib.kfac.python.ops import fisher_factors
-from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.util import nest
-
-# For blocks corresponding to convolutional layers, or any type of block where
-# the parameters can be thought of as being replicated in time or space,
-# we want to adjust the scale of the damping by
-# damping /= num_replications ** NORMALIZE_DAMPING_POWER
-NORMALIZE_DAMPING_POWER = 1.0
-
-# Methods for adjusting damping for FisherBlocks. See
-# compute_pi_adjusted_damping() for details.
-PI_OFF_NAME = "off"
-PI_TRACENORM_NAME = "tracenorm"
-PI_TYPE = PI_TRACENORM_NAME
-
-
-def set_global_constants(normalize_damping_power=None, pi_type=None):
- """Sets various global constants used by the classes in this module."""
- global NORMALIZE_DAMPING_POWER
- global PI_TYPE
-
- if normalize_damping_power is not None:
- NORMALIZE_DAMPING_POWER = normalize_damping_power
-
- if pi_type is not None:
- PI_TYPE = pi_type
-
-
-def normalize_damping(damping, num_replications):
- """Normalize damping after adjusting scale by NORMALIZE_DAMPING_POWER."""
- if NORMALIZE_DAMPING_POWER:
- return damping / (num_replications ** NORMALIZE_DAMPING_POWER)
- return damping
-
-
-def compute_pi_tracenorm(left_cov, right_cov):
- r"""Computes the scalar constant pi for Tikhonov regularization/damping.
-
- $$\pi = \sqrt{ (trace(A) / dim(A)) / (trace(B) / dim(B)) }$$
- See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details.
-
- Args:
- left_cov: A LinearOperator object. The left Kronecker factor "covariance".
- right_cov: A LinearOperator object. The right Kronecker factor "covariance".
-
- Returns:
- The computed scalar constant pi for these Kronecker Factors (as a Tensor).
- """
- # Instead of dividing by the dim of the norm, we multiply by the dim of the
- # other norm. This works out the same in the ratio.
- left_norm = left_cov.trace() * int(right_cov.domain_dimension)
- right_norm = right_cov.trace() * int(left_cov.domain_dimension)
- return math_ops.sqrt(left_norm / right_norm)
-
-
-def compute_pi_adjusted_damping(left_cov, right_cov, damping):
-
- if PI_TYPE == PI_TRACENORM_NAME:
- pi = compute_pi_tracenorm(left_cov, right_cov)
- return (damping * pi, damping / pi)
-
- elif PI_TYPE == PI_OFF_NAME:
- return (damping, damping)
-
-
-class PackagedFunc(object):
- """A Python thunk with a stable ID.
-
- Enables stable names for lambdas.
- """
-
- def __init__(self, func, func_id):
- """Initializes PackagedFunc.
-
- Args:
- func: a zero-arg Python function.
- func_id: a hashable, function that produces a hashable, or a list/tuple
- thereof.
- """
- self._func = func
- func_id = func_id if isinstance(func_id, (tuple, list)) else (func_id,)
- self._func_id = func_id
-
- def __call__(self):
- return self._func()
-
- @property
- def func_id(self):
- """A hashable identifier for this function."""
- return tuple(elt() if callable(elt) else elt for elt in self._func_id)
-
-
-def _package_func(func, func_id):
- return PackagedFunc(func, func_id)
-
-
-@six.add_metaclass(abc.ABCMeta)
-class FisherBlock(object):
- """Abstract base class for objects modeling approximate Fisher matrix blocks.
-
- Subclasses must implement register_matpower, multiply_matpower,
- instantiate_factors, tensors_to_compute_grads, and num_registered_towers
- methods.
- """
-
- def __init__(self, layer_collection):
- self._layer_collection = layer_collection
-
- @abc.abstractmethod
- def instantiate_factors(self, grads_list, damping):
- """Creates and registers the component factors of this Fisher block.
-
- Args:
- grads_list: A list gradients (each a Tensor or tuple of Tensors) with
- respect to the tensors returned by tensors_to_compute_grads() that
- are to be used to estimate the block.
- damping: The damping factor (float or Tensor).
- """
- pass
-
- @abc.abstractmethod
- def register_matpower(self, exp):
- """Registers a matrix power to be computed by the block.
-
- Args:
- exp: A float representing the power to raise the block by.
- """
- pass
-
- @abc.abstractmethod
- def register_cholesky(self):
- """Registers a Cholesky factor to be computed by the block."""
- pass
-
- @abc.abstractmethod
- def register_cholesky_inverse(self):
- """Registers an inverse Cholesky factor to be computed by the block."""
- pass
-
- def register_inverse(self):
- """Registers a matrix inverse to be computed by the block."""
- self.register_matpower(-1)
-
- @abc.abstractmethod
- def multiply_matpower(self, vector, exp):
- """Multiplies the vector by the (damped) matrix-power of the block.
-
- Args:
- vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
- exp: A float representing the power to raise the block by before
- multiplying it by the vector.
-
- Returns:
- The vector left-multiplied by the (damped) matrix-power of the block.
- """
- pass
-
- def multiply_inverse(self, vector):
- """Multiplies the vector by the (damped) inverse of the block.
-
- Args:
- vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
-
- Returns:
- The vector left-multiplied by the (damped) inverse of the block.
- """
- return self.multiply_matpower(vector, -1)
-
- def multiply(self, vector):
- """Multiplies the vector by the (damped) block.
-
- Args:
- vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
-
- Returns:
- The vector left-multiplied by the (damped) block.
- """
- return self.multiply_matpower(vector, 1)
-
- @abc.abstractmethod
- def multiply_cholesky(self, vector, transpose=False):
- """Multiplies the vector by the (damped) Cholesky-factor of the block.
-
- Args:
- vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
- transpose: Bool. If true the Cholesky factor is transposed before
- multiplying the vector. (Default: False)
-
- Returns:
- The vector left-multiplied by the (damped) Cholesky-factor of the block.
- """
- pass
-
- @abc.abstractmethod
- def multiply_cholesky_inverse(self, vector, transpose=False):
- """Multiplies vector by the (damped) inverse Cholesky-factor of the block.
-
- Args:
- vector: The vector (a Tensor or tuple of Tensors) to be multiplied.
- transpose: Bool. If true the Cholesky factor inverse is transposed
- before multiplying the vector. (Default: False)
- Returns:
- Vector left-multiplied by (damped) inverse Cholesky-factor of the block.
- """
- pass
-
- @abc.abstractmethod
- def tensors_to_compute_grads(self):
- """Returns the Tensor(s) with respect to which this FisherBlock needs grads.
- """
- pass
-
- @abc.abstractproperty
- def num_registered_towers(self):
- """Number of towers registered for this FisherBlock.
-
- Typically equal to the number of towers in a multi-tower setup.
- """
- pass
-
-
-class FullFB(FisherBlock):
- """FisherBlock using a full matrix estimate (no approximations).
-
- FullFB uses a full matrix estimate (no approximations), and should only ever
- be used for very low dimensional parameters.
-
- Note that this uses the naive "square the sum estimator", and so is applicable
- to any type of parameter in principle, but has very high variance.
- """
-
- def __init__(self, layer_collection, params):
- """Creates a FullFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- params: The parameters of this layer (Tensor or tuple of Tensors).
- """
- self._batch_sizes = []
- self._params = params
-
- super(FullFB, self).__init__(layer_collection)
-
- def instantiate_factors(self, grads_list, damping):
- self._damping_func = _package_func(lambda: damping, (damping,))
-
- self._factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullFactor, (grads_list, self._batch_size))
-
- def register_matpower(self, exp):
- self._factor.register_matpower(exp, self._damping_func)
-
- def register_cholesky(self):
- self._factor.register_cholesky(self._damping_func)
-
- def register_cholesky_inverse(self):
- self._factor.register_cholesky_inverse(self._damping_func)
-
- def _multiply_matrix(self, matrix, vector, transpose=False):
- vector_flat = utils.tensors_to_column(vector)
- out_flat = matrix.matmul(vector_flat, adjoint=transpose)
- return utils.column_to_tensors(vector, out_flat)
-
- def multiply_matpower(self, vector, exp):
- matrix = self._factor.get_matpower(exp, self._damping_func)
- return self._multiply_matrix(matrix, vector)
-
- def multiply_cholesky(self, vector, transpose=False):
- matrix = self._factor.get_cholesky(self._damping_func)
- return self._multiply_matrix(matrix, vector, transpose=transpose)
-
- def multiply_cholesky_inverse(self, vector, transpose=False):
- matrix = self._factor.get_cholesky_inverse(self._damping_func)
- return self._multiply_matrix(matrix, vector, transpose=transpose)
-
- def full_fisher_block(self):
- """Explicitly constructs the full Fisher block."""
- return self._factor.get_cov_as_linear_operator().to_dense()
-
- def tensors_to_compute_grads(self):
- return self._params
-
- def register_additional_tower(self, batch_size):
- """Register an additional tower.
-
- Args:
- batch_size: The batch size, used in the covariance estimator.
- """
- self._batch_sizes.append(batch_size)
-
- @property
- def num_registered_towers(self):
- return len(self._batch_sizes)
-
- @property
- def _batch_size(self):
- return math_ops.reduce_sum(self._batch_sizes)
-
-
-@six.add_metaclass(abc.ABCMeta)
-class DiagonalFB(FisherBlock):
- """A base class for FisherBlocks that use diagonal approximations."""
-
- def register_matpower(self, exp):
- # Not needed for this. Matrix powers are computed on demand in the
- # diagonal case
- pass
-
- def register_cholesky(self):
- # Not needed for this. Cholesky's are computed on demand in the
- # diagonal case
- pass
-
- def register_cholesky_inverse(self):
- # Not needed for this. Cholesky inverses's are computed on demand in the
- # diagonal case
- pass
-
- def _multiply_matrix(self, matrix, vector):
- vector_flat = utils.tensors_to_column(vector)
- out_flat = matrix.matmul(vector_flat)
- return utils.column_to_tensors(vector, out_flat)
-
- def multiply_matpower(self, vector, exp):
- matrix = self._factor.get_matpower(exp, self._damping_func)
- return self._multiply_matrix(matrix, vector)
-
- def multiply_cholesky(self, vector, transpose=False):
- matrix = self._factor.get_cholesky(self._damping_func)
- return self._multiply_matrix(matrix, vector)
-
- def multiply_cholesky_inverse(self, vector, transpose=False):
- matrix = self._factor.get_cholesky_inverse(self._damping_func)
- return self._multiply_matrix(matrix, vector)
-
- def full_fisher_block(self):
- return self._factor.get_cov_as_linear_operator().to_dense()
-
-
-class NaiveDiagonalFB(DiagonalFB):
- """FisherBlock using a diagonal matrix approximation.
-
- This type of approximation is generically applicable but quite primitive.
-
- Note that this uses the naive "square the sum estimator", and so is applicable
- to any type of parameter in principle, but has very high variance.
- """
-
- def __init__(self, layer_collection, params):
- """Creates a NaiveDiagonalFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- params: The parameters of this layer (Tensor or tuple of Tensors).
- """
- self._params = params
- self._batch_sizes = []
-
- super(NaiveDiagonalFB, self).__init__(layer_collection)
-
- def instantiate_factors(self, grads_list, damping):
- self._damping_func = _package_func(lambda: damping, (damping,))
-
- self._factor = self._layer_collection.make_or_get_factor(
- fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size))
-
- def tensors_to_compute_grads(self):
- return self._params
-
- def register_additional_tower(self, batch_size):
- """Register an additional tower.
-
- Args:
- batch_size: The batch size, used in the covariance estimator.
- """
- self._batch_sizes.append(batch_size)
-
- @property
- def num_registered_towers(self):
- return len(self._batch_sizes)
-
- @property
- def _batch_size(self):
- return math_ops.reduce_sum(self._batch_sizes)
-
-
-class InputOutputMultiTower(object):
- """Mix-in class for blocks with inputs & outputs and multiple mini-batches."""
-
- def __init__(self, *args, **kwargs):
- self.__inputs = []
- self.__outputs = []
- super(InputOutputMultiTower, self).__init__(*args, **kwargs)
-
- def _process_data(self, grads_list):
- """Process data into the format used by the factors.
-
- This function takes inputs and grads_lists data and processes it into
- one of the formats expected by the FisherFactor classes (depending on
- the value of the global configuration variable TOWER_STRATEGY).
-
- The initial format of self._inputs is expected to be a list of Tensors
- over towers. Similarly grads_lists is expected to be a list over sources
- of such lists.
-
- If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing a single
- tensor (represented as a PartitionedTensor object) equal to the
- concatenation (across towers) of all of the elements of self._inputs. And
- similarly grads_list is formatted into a tuple (over sources) of such
- tensors (also represented as PartitionedTensors).
-
- If TOWER_STRATEGY is "separate", formatting of inputs and grads_list
- remains unchanged from the initial format (although possibly converting
- from lists into tuples).
-
- Args:
- grads_list: grads_list in its initial format (see above).
-
- Returns:
- inputs: self._inputs transformed into the appropriate format (see
- above).
- grads_list: grads_list transformed into the appropriate format (see
- above).
-
- Raises:
- ValueError: if TOWER_STRATEGY is not one of "separate" or "concat".
- """
- inputs = self._inputs
- # inputs is a list over towers of Tensors
- # grads_list is a list of list with the first index being sources and the
- # second being towers.
- if fisher_factors.TOWER_STRATEGY == "concat":
- # Merge towers together into a PartitionedTensor. We package it in
- # a singleton tuple since the factors will expect a list over towers
- inputs = (utils.PartitionedTensor(inputs),)
- # Do the same for grads_list but preserve leading sources dimension
- grads_list = tuple((utils.PartitionedTensor(grads),)
- for grads in grads_list)
- elif fisher_factors.TOWER_STRATEGY == "separate":
- inputs = tuple(inputs)
- grads_list = tuple(grads_list)
-
- else:
- raise ValueError("Global config variable TOWER_STRATEGY must be one of "
- "'concat' or 'separate'.")
-
- return inputs, grads_list
-
- def tensors_to_compute_grads(self):
- """Tensors to compute derivative of loss with respect to."""
- return tuple(self._outputs)
-
- def register_additional_tower(self, inputs, outputs):
- self._inputs.append(inputs)
- self._outputs.append(outputs)
-
- @property
- def num_registered_towers(self):
- result = len(self._inputs)
- assert result == len(self._outputs)
- return result
-
- @property
- def _inputs(self):
- return self.__inputs
-
- @property
- def _outputs(self):
- return self.__outputs
-
-
-class FullyConnectedDiagonalFB(InputOutputMultiTower, DiagonalFB):
- """FisherBlock for fully-connected (dense) layers using a diagonal approx.
-
- Estimates the Fisher Information matrix's diagonal entries for a fully
- connected layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of
- squares" estimator.
-
- Let 'params' be a vector parameterizing a model and 'i' an arbitrary index
- into it. We are interested in Fisher(params)[i, i]. This is,
-
- $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i]
- = E[ v(x, y, params)[i] ^ 2 ]$$
-
- Consider fully connected layer in this model with (unshared) weight matrix
- 'w'. For an example 'x' that produces layer inputs 'a' and output
- preactivations 's',
-
- $$v(x, y, w) = vec( a (d loss / d s)^T )$$
-
- This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding
- to the layer's parameters 'w'.
- """
-
- def __init__(self, layer_collection, has_bias=False):
- """Creates a FullyConnectedDiagonalFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- has_bias: Whether the component Kronecker factors have an additive bias.
- (Default: False)
- """
- self._has_bias = has_bias
-
- super(FullyConnectedDiagonalFB, self).__init__(layer_collection)
-
- def instantiate_factors(self, grads_list, damping):
- inputs, grads_list = self._process_data(grads_list)
-
- self._factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedDiagonalFactor,
- (inputs, grads_list, self._has_bias))
-
- self._damping_func = _package_func(lambda: damping, (damping,))
-
-
-class ConvDiagonalFB(InputOutputMultiTower, DiagonalFB):
- """FisherBlock for 2-D convolutional layers using a diagonal approx.
-
- Estimates the Fisher Information matrix's diagonal entries for a convolutional
- layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares"
- estimator.
-
- Let 'params' be a vector parameterizing a model and 'i' an arbitrary index
- into it. We are interested in Fisher(params)[i, i]. This is,
-
- $$Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i]
- = E[ v(x, y, params)[i] ^ 2 ]$$
-
- Consider a convoluational layer in this model with (unshared) filter matrix
- 'w'. For an example image 'x' that produces layer inputs 'a' and output
- preactivations 's',
-
- $$v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T )$$
-
- where 'loc' is a single (x, y) location in an image.
-
- This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding
- to the layer's parameters 'w'.
- """
-
- def __init__(self,
- layer_collection,
- params,
- strides,
- padding,
- data_format=None,
- dilations=None):
- """Creates a ConvDiagonalFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- params: The parameters (Tensor or tuple of Tensors) of this layer. If
- kernel alone, a Tensor of shape [kernel_height, kernel_width,
- in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
- containing the previous and a Tensor of shape [out_channels].
- strides: The stride size in this layer (1-D Tensor of length 4).
- padding: The padding in this layer (e.g. "SAME").
- data_format: str or None. Format of input data.
- dilations: List of 4 ints or None. Rate for dilation along all dimensions.
-
- Raises:
- ValueError: if strides is not length-4.
- ValueError: if dilations is not length-4.
- ValueError: if channel is not last dimension.
- """
- if len(strides) != 4:
- raise ValueError("strides must contain 4 numbers.")
-
- if dilations is None:
- dilations = [1, 1, 1, 1]
-
- if len(dilations) != 4:
- raise ValueError("dilations must contain 4 numbers.")
-
- if not utils.is_data_format_channel_last(data_format):
- raise ValueError("data_format must be channels-last.")
-
- self._strides = maybe_tuple(strides)
- self._padding = padding
- self._data_format = data_format
- self._dilations = maybe_tuple(dilations)
- self._has_bias = isinstance(params, (tuple, list))
-
- fltr = params[0] if self._has_bias else params
- self._filter_shape = tuple(fltr.shape.as_list())
-
- if len(self._filter_shape) != 4:
- raise ValueError(
- "Convolution filter must be of shape"
- " [filter_height, filter_width, in_channels, out_channels].")
-
- super(ConvDiagonalFB, self).__init__(layer_collection)
-
- def instantiate_factors(self, grads_list, damping):
- inputs, grads_list = self._process_data(grads_list)
-
- # Infer number of locations upon which convolution is applied.
- self._num_locations = num_conv_locations(inputs[0].shape.as_list(),
- self._strides)
-
- self._factor = self._layer_collection.make_or_get_factor(
- fisher_factors.ConvDiagonalFactor,
- (inputs, grads_list, self._filter_shape, self._strides, self._padding,
- self._data_format, self._dilations, self._has_bias))
-
- def damping_func():
- return self._num_locations * normalize_damping(damping,
- self._num_locations)
-
- damping_id = (self._num_locations, "mult", "normalize_damping", damping,
- self._num_locations)
- self._damping_func = _package_func(damping_func, damping_id)
-
-
-class KroneckerProductFB(FisherBlock):
- """A base class for blocks with separate input and output Kronecker factors.
-
- The Fisher block is approximated as a Kronecker product of the input and
- output factors.
- """
-
- def _setup_damping(self, damping, normalization=None):
- """Makes functions that compute the damping values for both factors."""
- def compute_damping():
- if normalization is not None:
- maybe_normalized_damping = normalize_damping(damping, normalization)
- else:
- maybe_normalized_damping = damping
-
- return compute_pi_adjusted_damping(
- self._input_factor.get_cov_as_linear_operator(),
- self._output_factor.get_cov_as_linear_operator(),
- maybe_normalized_damping**0.5)
-
- if normalization is not None:
- damping_id = ("compute_pi_adjusted_damping",
- "cov", self._input_factor.name,
- "cov", self._output_factor.name,
- "normalize_damping", damping, normalization, "power", 0.5)
- else:
- damping_id = ("compute_pi_adjusted_damping",
- "cov", self._input_factor.name,
- "cov", self._output_factor.name,
- damping, "power", 0.5)
-
- self._input_damping_func = _package_func(lambda: compute_damping()[0],
- damping_id + ("ref", 0))
- self._output_damping_func = _package_func(lambda: compute_damping()[1],
- damping_id + ("ref", 1))
-
- def register_matpower(self, exp):
- self._input_factor.register_matpower(exp, self._input_damping_func)
- self._output_factor.register_matpower(exp, self._output_damping_func)
-
- def register_cholesky(self):
- self._input_factor.register_cholesky(self._input_damping_func)
- self._output_factor.register_cholesky(self._output_damping_func)
-
- def register_cholesky_inverse(self):
- self._input_factor.register_cholesky_inverse(self._input_damping_func)
- self._output_factor.register_cholesky_inverse(self._output_damping_func)
-
- @property
- def _renorm_coeff(self):
- """Kronecker factor multiplier coefficient.
-
- If this FisherBlock is represented as 'FB = c * kron(left, right)', then
- this is 'c'.
-
- Returns:
- 0-D Tensor.
- """
- return 1.0
-
- def _multiply_factored_matrix(self, left_factor, right_factor, vector,
- extra_scale=1.0, transpose_left=False,
- transpose_right=False):
- reshaped_vector = utils.layer_params_to_mat2d(vector)
- reshaped_out = right_factor.matmul_right(reshaped_vector,
- adjoint=transpose_right)
- reshaped_out = left_factor.matmul(reshaped_out,
- adjoint=transpose_left)
- if extra_scale != 1.0:
- reshaped_out *= math_ops.cast(extra_scale, dtype=reshaped_out.dtype)
- return utils.mat2d_to_layer_params(vector, reshaped_out)
-
- def multiply_matpower(self, vector, exp):
- left_factor = self._input_factor.get_matpower(
- exp, self._input_damping_func)
- right_factor = self._output_factor.get_matpower(
- exp, self._output_damping_func)
- extra_scale = float(self._renorm_coeff)**exp
- return self._multiply_factored_matrix(left_factor, right_factor, vector,
- extra_scale=extra_scale)
-
- def multiply_cholesky(self, vector, transpose=False):
- left_factor = self._input_factor.get_cholesky(self._input_damping_func)
- right_factor = self._output_factor.get_cholesky(self._output_damping_func)
- extra_scale = float(self._renorm_coeff)**0.5
- return self._multiply_factored_matrix(left_factor, right_factor, vector,
- extra_scale=extra_scale,
- transpose_left=transpose,
- transpose_right=not transpose)
-
- def multiply_cholesky_inverse(self, vector, transpose=False):
- left_factor = self._input_factor.get_cholesky_inverse(
- self._input_damping_func)
- right_factor = self._output_factor.get_cholesky_inverse(
- self._output_damping_func)
- extra_scale = float(self._renorm_coeff)**-0.5
- return self._multiply_factored_matrix(left_factor, right_factor, vector,
- extra_scale=extra_scale,
- transpose_left=transpose,
- transpose_right=not transpose)
-
- def full_fisher_block(self):
- """Explicitly constructs the full Fisher block.
-
- Used for testing purposes. (In general, the result may be very large.)
-
- Returns:
- The full Fisher block.
- """
- left_factor = self._input_factor.get_cov_as_linear_operator().to_dense()
- right_factor = self._output_factor.get_cov_as_linear_operator().to_dense()
- return self._renorm_coeff * utils.kronecker_product(left_factor,
- right_factor)
-
-
-class EmbeddingKFACFB(InputOutputMultiTower, KroneckerProductFB):
- """K-FAC FisherBlock for embedding layers.
-
- This FisherBlock is similar to FullyConnectedKFACBasicFB, except that its
- input factor is approximated by a diagonal matrix. In the case that each
- example references exactly one embedding, this approximation is exact.
-
- Does not support bias parameters.
- """
-
- def __init__(self, layer_collection, vocab_size):
- """Creates a EmbeddingKFACFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- vocab_size: int. Size of vocabulary for this embedding layer.
- """
- self._vocab_size = vocab_size
-
- super(EmbeddingKFACFB, self).__init__(layer_collection)
-
- def instantiate_factors(self, grads_list, damping):
- """Instantiate Kronecker Factors for this FisherBlock.
-
- Args:
- grads_list: List of list of Tensors. grads_list[i][j] is the
- gradient of the loss with respect to 'outputs' from source 'i' and
- tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size].
- damping: 0-D Tensor or float. 'damping' * identity is approximately added
- to this FisherBlock's Fisher approximation.
- """
- inputs, grads_list = self._process_data(grads_list)
-
- self._input_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.EmbeddingInputKroneckerFactor,
- (inputs, self._vocab_size))
- self._output_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedKroneckerFactor, (grads_list,))
- self._setup_damping(damping)
-
-
-class FullyConnectedKFACBasicFB(InputOutputMultiTower, KroneckerProductFB):
- """K-FAC FisherBlock for fully-connected (dense) layers.
-
- This uses the Kronecker-factorized approximation from the original
- K-FAC paper (https://arxiv.org/abs/1503.05671)
- """
-
- def __init__(self, layer_collection, has_bias=False):
- """Creates a FullyConnectedKFACBasicFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- has_bias: Whether the component Kronecker factors have an additive bias.
- (Default: False)
- """
- self._has_bias = has_bias
-
- super(FullyConnectedKFACBasicFB, self).__init__(layer_collection)
-
- def instantiate_factors(self, grads_list, damping):
- """Instantiate Kronecker Factors for this FisherBlock.
-
- Args:
- grads_list: List of list of Tensors. grads_list[i][j] is the
- gradient of the loss with respect to 'outputs' from source 'i' and
- tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size].
- damping: 0-D Tensor or float. 'damping' * identity is approximately added
- to this FisherBlock's Fisher approximation.
- """
- inputs, grads_list = self._process_data(grads_list)
-
- self._input_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedKroneckerFactor,
- ((inputs,), self._has_bias))
- self._output_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedKroneckerFactor,
- (grads_list,))
- self._setup_damping(damping)
-
-
-class ConvKFCBasicFB(InputOutputMultiTower, KroneckerProductFB):
- r"""FisherBlock for convolutional layers using the basic KFC approx.
-
- Estimates the Fisher Information matrix's blog for a convolutional
- layer.
-
- Consider a convolutional layer in this model with (unshared) filter matrix
- 'w'. For a minibatch that produces inputs 'a' and output preactivations 's',
- this FisherBlock estimates,
-
- $$F(w) = \#locations * kronecker(E[flat(a) flat(a)^T],
- E[flat(ds) flat(ds)^T])$$
-
- where
-
- $$ds = (d / ds) log p(y | x, w)$$
- #locations = number of (x, y) locations where 'w' is applied.
-
- where the expectation is taken over all examples and locations and flat()
- concatenates an array's leading dimensions.
-
- See equation 23 in https://arxiv.org/abs/1602.01407 for details.
- """
-
- def __init__(self,
- layer_collection,
- params,
- padding,
- strides=None,
- dilation_rate=None,
- data_format=None,
- extract_patches_fn=None):
- """Creates a ConvKFCBasicFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- params: The parameters (Tensor or tuple of Tensors) of this layer. If
- kernel alone, a Tensor of shape [..spatial_filter_shape..,
- in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
- containing the previous and a Tensor of shape [out_channels].
- padding: str. Padding method.
- strides: List of ints or None. Contains [..spatial_filter_strides..] if
- 'extract_patches_fn' is compatible with tf.nn.convolution(), else
- [1, ..spatial_filter_strides, 1].
- dilation_rate: List of ints or None. Rate for dilation along each spatial
- dimension if 'extract_patches_fn' is compatible with
- tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].
- data_format: str or None. Format of input data.
- extract_patches_fn: str or None. Name of function that extracts image
- patches. One of "extract_convolution_patches", "extract_image_patches",
- "extract_pointwise_conv2d_patches".
- """
- self._padding = padding
- self._strides = maybe_tuple(strides)
- self._dilation_rate = maybe_tuple(dilation_rate)
- self._data_format = data_format
- self._extract_patches_fn = extract_patches_fn
- self._has_bias = isinstance(params, (tuple, list))
-
- fltr = params[0] if self._has_bias else params
- self._filter_shape = tuple(fltr.shape.as_list())
-
- super(ConvKFCBasicFB, self).__init__(layer_collection)
-
- def instantiate_factors(self, grads_list, damping):
- inputs, grads_list = self._process_data(grads_list)
-
- # Infer number of locations upon which convolution is applied.
- self._num_locations = num_conv_locations(inputs[0].shape.as_list(),
- self._strides)
-
- self._input_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.ConvInputKroneckerFactor,
- (inputs, self._filter_shape, self._padding, self._strides,
- self._dilation_rate, self._data_format, self._extract_patches_fn,
- self._has_bias))
- self._output_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.ConvOutputKroneckerFactor, (grads_list,))
-
- self._setup_damping(damping, normalization=self._num_locations)
-
- @property
- def _renorm_coeff(self):
- return self._num_locations
-
-
-class DepthwiseConvDiagonalFB(ConvDiagonalFB):
- """FisherBlock for depthwise_conv2d().
-
- Equivalent to ConvDiagonalFB applied to each input channel in isolation.
- """
-
- def __init__(self,
- layer_collection,
- params,
- strides,
- padding,
- rate=None,
- data_format=None):
- """Creates a DepthwiseConvKFCBasicFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- params: Tensor of shape [filter_height, filter_width, in_channels,
- channel_multiplier].
- strides: List of 4 ints. Strides along all dimensions.
- padding: str. Padding method.
- rate: List of 4 ints or None. Rate for dilation along all dimensions.
- data_format: str or None. Format of input data.
-
- Raises:
- NotImplementedError: If parameters contains bias.
- ValueError: If filter is not 4-D.
- ValueError: If strides is not length-4.
- ValueError: If rates is not length-2.
- ValueError: If channels are not last dimension.
- """
- if isinstance(params, (tuple, list)):
- raise NotImplementedError("Bias not yet supported.")
-
- if params.shape.ndims != 4:
- raise ValueError("Filter must be 4-D.")
-
- if len(strides) != 4:
- raise ValueError("strides must account for 4 dimensions.")
-
- if rate is not None:
- if len(rate) != 2:
- raise ValueError("rate must only account for spatial dimensions.")
- rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate.
-
- if not utils.is_data_format_channel_last(data_format):
- raise ValueError("data_format must be channels-last.")
-
- super(DepthwiseConvDiagonalFB, self).__init__(
- layer_collection=layer_collection,
- params=params,
- strides=strides,
- padding=padding,
- dilations=rate,
- data_format=data_format)
-
- # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__().
- filter_height, filter_width, in_channels, channel_multiplier = (
- params.shape.as_list())
- self._filter_shape = (filter_height, filter_width, in_channels,
- in_channels * channel_multiplier)
-
- def _multiply_matrix(self, matrix, vector):
- conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
- conv2d_result = super(
- DepthwiseConvDiagonalFB, self)._multiply_matrix(matrix, conv2d_vector)
- return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)
-
-
-class DepthwiseConvKFCBasicFB(ConvKFCBasicFB):
- """FisherBlock for depthwise_conv2d().
-
- Equivalent to ConvKFCBasicFB applied to each input channel in isolation.
- """
-
- def __init__(self,
- layer_collection,
- params,
- strides,
- padding,
- rate=None,
- data_format=None):
- """Creates a DepthwiseConvKFCBasicFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- params: Tensor of shape [filter_height, filter_width, in_channels,
- channel_multiplier].
- strides: List of 4 ints. Strides along all dimensions.
- padding: str. Padding method.
- rate: List of 4 ints or None. Rate for dilation along all dimensions.
- data_format: str or None. Format of input data.
-
- Raises:
- NotImplementedError: If parameters contains bias.
- ValueError: If filter is not 4-D.
- ValueError: If strides is not length-4.
- ValueError: If rates is not length-2.
- ValueError: If channels are not last dimension.
- """
- if isinstance(params, (tuple, list)):
- raise NotImplementedError("Bias not yet supported.")
-
- if params.shape.ndims != 4:
- raise ValueError("Filter must be 4-D.")
-
- if len(strides) != 4:
- raise ValueError("strides must account for 4 dimensions.")
-
- if rate is not None:
- if len(rate) != 2:
- raise ValueError("rate must only account for spatial dimensions.")
- rate = [1, rate[0], rate[1], 1] # conv2d expects 4-element rate.
-
- if not utils.is_data_format_channel_last(data_format):
- raise ValueError("data_format must be channels-last.")
-
- super(DepthwiseConvKFCBasicFB, self).__init__(
- layer_collection=layer_collection,
- params=params,
- padding=padding,
- strides=strides,
- dilation_rate=rate,
- data_format=data_format,
- extract_patches_fn="extract_image_patches")
-
- # This is a hack to overwrite the same setting in ConvKFCBasicFB.__init__().
- filter_height, filter_width, in_channels, channel_multiplier = (
- params.shape.as_list())
- self._filter_shape = (filter_height, filter_width, in_channels,
- in_channels * channel_multiplier)
-
- def _multiply_factored_matrix(self, left_factor, right_factor, vector,
- extra_scale=1.0, transpose_left=False,
- transpose_right=False):
- conv2d_vector = depthwise_conv2d_filter_to_conv2d_filter(vector)
- conv2d_result = super(
- DepthwiseConvKFCBasicFB, self)._multiply_factored_matrix(
- left_factor, right_factor, conv2d_vector, extra_scale=extra_scale,
- transpose_left=transpose_left, transpose_right=transpose_right)
- return conv2d_filter_to_depthwise_conv2d_filter(conv2d_result)
-
-
-def depthwise_conv2d_filter_to_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin
- """Converts a convolution filter for use with conv2d.
-
- Transforms a filter for use with tf.nn.depthwise_conv2d() to one that's
- compatible with tf.nn.conv2d().
-
- Args:
- filter: Tensor of shape [height, width, in_channels, channel_multiplier].
- name: None or str. Name of Op.
-
- Returns:
- Tensor of shape [height, width, in_channels, out_channels].
-
- """
- with ops.name_scope(name, "depthwise_conv2d_filter_to_conv2d_filter",
- [filter]):
- filter = ops.convert_to_tensor(filter)
- filter_height, filter_width, in_channels, channel_multiplier = (
- filter.shape.as_list())
-
- results = []
- for i in range(in_channels):
- # Slice out one in_channel's filter. Insert zeros around it to force it
- # to affect that channel and that channel alone.
- elements = []
- if i > 0:
- elements.append(
- array_ops.zeros(
- [filter_height, filter_width, i, channel_multiplier]))
- elements.append(filter[:, :, i:(i + 1), :])
- if i + 1 < in_channels:
- elements.append(
- array_ops.zeros([
- filter_height, filter_width, in_channels - (i + 1),
- channel_multiplier
- ]))
-
- # Concat along in_channel.
- results.append(
- array_ops.concat(elements, axis=-2, name="in_channel_%d" % i))
-
- # Concat along out_channel.
- return array_ops.concat(results, axis=-1, name="out_channel")
-
-
-def conv2d_filter_to_depthwise_conv2d_filter(filter, name=None): # pylint: disable=redefined-builtin
- """Converts a convolution filter for use with depthwise_conv2d.
-
- Transforms a filter for use with tf.nn.conv2d() to one that's
- compatible with tf.nn.depthwise_conv2d(). Ignores all filters but those along
- the diagonal.
-
- Args:
- filter: Tensor of shape [height, width, in_channels, out_channels].
- name: None or str. Name of Op.
-
- Returns:
- Tensor of shape,
- [height, width, in_channels, channel_multiplier]
-
- Raises:
- ValueError: if out_channels is not evenly divisible by in_channels.
- """
- with ops.name_scope(name, "conv2d_filter_to_depthwise_conv2d_filter",
- [filter]):
- filter = ops.convert_to_tensor(filter)
- filter_height, filter_width, in_channels, out_channels = (
- filter.shape.as_list())
-
- if out_channels % in_channels != 0:
- raise ValueError("out_channels must be evenly divisible by in_channels.")
- channel_multiplier = out_channels // in_channels
-
- results = []
- filter = array_ops.reshape(filter, [
- filter_height, filter_width, in_channels, in_channels,
- channel_multiplier
- ])
- for i in range(in_channels):
- # Slice out output corresponding to the correct filter.
- filter_slice = array_ops.reshape(
- filter[:, :, i, i, :],
- [filter_height, filter_width, 1, channel_multiplier])
- results.append(filter_slice)
-
- # Concat along out_channel.
- return array_ops.concat(results, axis=-2, name="in_channels")
-
-
-def maybe_tuple(obj):
- if not isinstance(obj, list):
- return obj
- return tuple(obj)
-
-
-def num_conv_locations(input_shape, strides):
- """Returns the number of spatial locations a 2D Conv kernel is applied to.
-
- Args:
- input_shape: List of ints representing shape of inputs to
- tf.nn.convolution().
- strides: List of ints representing strides along spatial dimensions as
- passed in to tf.nn.convolution().
-
- Returns:
- A scalar |T| denoting the number of spatial locations for the Conv layer.
- """
- spatial_input_locations = np.prod(input_shape[1:-1])
-
- if strides is None:
- spatial_strides_divisor = 1
- else:
- spatial_strides_divisor = np.prod(strides)
-
- return spatial_input_locations // spatial_strides_divisor
-
-
-class InputOutputMultiTowerMultiUse(InputOutputMultiTower):
- """Adds methods for multi-use/time-step case to InputOutputMultiTower."""
-
- def __init__(self, num_uses=None, *args, **kwargs):
- self._num_uses = num_uses
- super(InputOutputMultiTowerMultiUse, self).__init__(*args, **kwargs)
-
- def _process_data(self, grads_list):
- """Process temporal/multi-use data into the format used by the factors.
-
- This function takes inputs and grads_lists data and processes it into
- one of the formats expected by the FisherFactor classes (depending on
- the value of the global configuration variable TOWER_STRATEGY).
-
- It accepts the data in one of two initial formats. The first possible
- format is where self._inputs is a list of list of Tensors. The first index
- is tower, the second is use/time-step. grads_list, meanwhile, is a list
- over sources of such lists of lists.
-
- The second possible data format is where self._inputs is a Tensor with
- uses/times-steps folded into the batch dimension. i.e. it is a Tensor
- of shape [num_uses * size_batch, ...] which represents a reshape of a
- Tensor of shape [num_uses, size_batch, ...]. And similarly grads_list is
- a list over sources of such Tensors.
-
- There are two possible formats which inputs and grads_list are transformed
- into.
-
- If TOWER_STRATEGY is "concat", 'inputs' becomes a tuple containing
- a single tensor (represented as a PartitionedTensor object) with all of
- the data from the towers, as well as the uses/time-steps, concatenated
- together. In this tensor the leading dimension is the batch and
- use/time-step dimensions folded together (with 'use' being the major of
- these two, so that the tensors can be thought of as reshapes of ones of
- shape [num_uses, batch_size, ...]). grads_list is similarly formatted as a
- tuple over sources of such tensors.
-
- If TOWER_STRATEGY is "separate" the inputs are formatted into lists of
- tensors over towers. Each of these tensors has a similar format to
- the tensor produced by the "concat" option, except that each contains
- only the data from a single tower. grads_list is similarly formatted
- into a tuple over sources of such tuples.
-
- Args:
- grads_list: grads_list in its initial format (see above).
-
- Returns:
- inputs: self._inputs transformed into the appropriate format (see
- above).
- grads_list: grads_list transformed into the appropriate format (see
- above).
-
- Raises:
- ValueError: If TOWER_STRATEGY is not one of "separate" or "concat".
- ValueError: If the given/initial format of self._inputs and grads_list
- isn't recognized, or doesn't agree with self._num_uses.
- """
-
- inputs = self._inputs
-
- if isinstance(inputs[0], (list, tuple)):
- num_uses = len(inputs[0])
- if self._num_uses is not None and self._num_uses != num_uses:
- raise ValueError("num_uses argument doesn't match length of inputs.")
- else:
- self._num_uses = num_uses
-
- # Check that all mini-batches/towers have the same number of uses
- if not all(len(input_) == num_uses for input_ in inputs):
- raise ValueError("Length of inputs argument is inconsistent across "
- "towers.")
-
- if fisher_factors.TOWER_STRATEGY == "concat":
- # Reverse the tower and use/time-step indices, so that use is now first,
- # and towers is second
- inputs = tuple(zip(*inputs))
-
- # Flatten the two dimensions
- inputs = nest.flatten(inputs)
-
- # Merge everything together into a PartitionedTensor. We package it in
- # a singleton tuple since the factors will expect a list over towers
- inputs = (utils.PartitionedTensor(inputs),)
-
- elif fisher_factors.TOWER_STRATEGY == "separate":
- # Merge together the uses/time-step dimension into PartitionedTensors,
- # but keep the leading dimension (towers) intact for the factors to
- # process individually.
- inputs = tuple(utils.PartitionedTensor(input_) for input_ in inputs)
-
- else:
- raise ValueError("Global config variable TOWER_STRATEGY must be one of "
- "'concat' or 'separate'.")
- else:
- inputs = tuple(inputs)
-
- # Now we perform the analogous processing for grads_list
- if isinstance(grads_list[0][0], (list, tuple)):
- num_uses = len(grads_list[0][0])
- if self._num_uses is not None and self._num_uses != num_uses:
- raise ValueError("num_uses argument doesn't match length of outputs, "
- "or length of outputs is inconsistent with length of "
- "inputs.")
- else:
- self._num_uses = num_uses
-
- if not all(len(grad) == num_uses for grads in grads_list
- for grad in grads):
- raise ValueError("Length of outputs argument is inconsistent across "
- "towers.")
-
- if fisher_factors.TOWER_STRATEGY == "concat":
- # Reverse the tower and use/time-step indices, so that use is now first,
- # and towers is second
- grads_list = tuple(tuple(zip(*grads)) for grads in grads_list)
-
- # Flatten the two dimensions, leaving the leading dimension (source)
- # intact
- grads_list = tuple(nest.flatten(grads) for grads in grads_list)
-
- # Merge inner dimensions together into PartitionedTensors. We package
- # them in a singleton tuple since the factors will expect a list over
- # towers
- grads_list = tuple((utils.PartitionedTensor(grads),)
- for grads in grads_list)
-
- elif fisher_factors.TOWER_STRATEGY == "separate":
- # Merge together the uses/time-step dimension into PartitionedTensors,
- # but keep the leading dimension (towers) intact for the factors to
- # process individually.
- grads_list = tuple(tuple(utils.PartitionedTensor(grad)
- for grad in grads)
- for grads in grads_list)
-
- else:
- raise ValueError("Global config variable TOWER_STRATEGY must be one of "
- "'concat' or 'separate'.")
- else:
- grads_list = tuple(tuple(grads) for grads in grads_list)
-
- if self._num_uses is None:
- raise ValueError("You must supply a value for the num_uses argument if "
- "the number of uses cannot be inferred from inputs or "
- "outputs arguments (e.g. if they are both given in the "
- "single Tensor format, instead of as lists of Tensors.")
-
- return inputs, grads_list
-
-
-class FullyConnectedMultiIndepFB(InputOutputMultiTowerMultiUse,
- KroneckerProductFB):
- """FisherBlock for fully-connected layers that share parameters.
-
- This class implements the "independence across time" approximation from the
- following paper:
- https://openreview.net/pdf?id=HyMTkQZAb
- """
-
- def __init__(self, layer_collection, has_bias=False, num_uses=None):
- """Creates a FullyConnectedMultiIndepFB block.
-
- Args:
- layer_collection: LayerCollection instance.
- has_bias: bool. If True, estimates Fisher with respect to a bias
- parameter as well as the layer's parameters.
- num_uses: int or None. Number of uses of the layer in the model's graph.
- Only required if the data is formatted with uses/time folded into the
- batch dimension (instead of uses/time being a list dimension).
- (Default: None)
- """
- self._has_bias = has_bias
-
- super(FullyConnectedMultiIndepFB, self).__init__(
- layer_collection=layer_collection,
- num_uses=num_uses)
-
- def instantiate_factors(self, grads_list, damping):
- inputs, grads_list = self._process_data(grads_list)
-
- self._input_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedMultiKF,
- ((inputs,), self._num_uses, self._has_bias))
-
- self._output_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses))
-
- self._setup_damping(damping, normalization=self._num_uses)
-
- @property
- def _renorm_coeff(self):
- return float(self._num_uses)
-
-
-class ConvKFCBasicMultiIndepFB(InputOutputMultiTowerMultiUse,
- KroneckerProductFB):
- """FisherBlock for 2D convolutional layers using the basic KFC approx.
-
- Similar to ConvKFCBasicFB except that this version supports multiple
- uses/time-steps via a standard independence approximation. Similar to the
- "independence across time" used in FullyConnectedMultiIndepFB but generalized
- in the obvious way to conv layers.
- """
-
- def __init__(self,
- layer_collection,
- params,
- padding,
- strides=None,
- dilation_rate=None,
- data_format=None,
- extract_patches_fn=None,
- num_uses=None):
- """Creates a ConvKFCBasicMultiIndepFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- params: The parameters (Tensor or tuple of Tensors) of this layer. If
- kernel alone, a Tensor of shape [..spatial_filter_shape..,
- in_channels, out_channels]. If kernel and bias, a tuple of 2 elements
- containing the previous and a Tensor of shape [out_channels].
- padding: str. Padding method.
- strides: List of ints or None. Contains [..spatial_filter_strides..] if
- 'extract_patches_fn' is compatible with tf.nn.convolution(), else
- [1, ..spatial_filter_strides, 1].
- dilation_rate: List of ints or None. Rate for dilation along each spatial
- dimension if 'extract_patches_fn' is compatible with
- tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].
- data_format: str or None. Format of input data.
- extract_patches_fn: str or None. Name of function that extracts image
- patches. One of "extract_convolution_patches", "extract_image_patches",
- "extract_pointwise_conv2d_patches".
- num_uses: int or None. Number of uses of the layer in the model's graph.
- Only required if the data is formatted with uses/time folded into the
- batch dimension (instead of uses/time being a list dimension).
- (Default: None)
- """
- self._padding = padding
- self._strides = maybe_tuple(strides)
- self._dilation_rate = maybe_tuple(dilation_rate)
- self._data_format = data_format
- self._extract_patches_fn = extract_patches_fn
- self._has_bias = isinstance(params, (tuple, list))
-
- fltr = params[0] if self._has_bias else params
- self._filter_shape = tuple(fltr.shape.as_list())
-
- super(ConvKFCBasicMultiIndepFB, self).__init__(
- layer_collection=layer_collection,
- num_uses=num_uses)
-
- def instantiate_factors(self, grads_list, damping):
- inputs, grads_list = self._process_data(grads_list)
-
- # Infer number of locations upon which convolution is applied.
- self._num_locations = num_conv_locations(inputs[0].shape.as_list(),
- self._strides)
-
- self._input_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.ConvInputKroneckerFactor,
- (inputs, self._filter_shape, self._padding, self._strides,
- self._dilation_rate, self._data_format, self._extract_patches_fn,
- self._has_bias))
- self._output_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.ConvOutputKroneckerFactor, (grads_list,))
-
- self._setup_damping(damping, normalization=
- (self._num_locations * self._num_uses))
-
- @property
- def _renorm_coeff(self):
- return self._num_locations * self._num_uses
-
-
-class EmbeddingKFACMultiIndepFB(InputOutputMultiTowerMultiUse,
- KroneckerProductFB):
- """K-FAC FisherBlock for embedding layers used multiple times in the graph.
-
- Similar to EmbeddingKFACFB except that this version supports multiple uses
- of the parameter within a single model. These uses could correspond to time
- steps in an RNN architecture, but they don't have to.
-
- Does not support bias parameters.
- """
-
- def __init__(self, layer_collection, vocab_size, num_uses=None):
- """Creates a EmbeddingKFACMultiIndepFB block.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- vocab_size: int. Size of vocabulary for this embedding layer.
- num_uses: int or None. Number of uses of the layer in the model's graph.
- Only required if the data is formatted with time folded into the batch
- dimension (instead of time being a list dimension). (Default: None)
- """
- self._vocab_size = vocab_size
-
- super(EmbeddingKFACMultiIndepFB, self).__init__(
- layer_collection=layer_collection,
- num_uses=num_uses)
-
- def instantiate_factors(self, grads_list, damping):
- """Instantiate Kronecker Factors for this FisherBlock.
-
- Args:
- grads_list: List of list of list of Tensors. grads_list[i][j][k] is the
- gradient of the loss with respect to 'outputs' from source 'i',
- tower/mini-batch 'j', and use/time-step 'k'. Each Tensor has shape
- [tower_minibatch_size, output_size].
- damping: 0-D Tensor or float. 'damping' * identity is approximately added
- to this FisherBlock's Fisher approximation.
- """
- inputs, grads_list = self._process_data(grads_list)
-
- self._input_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.EmbeddingInputKroneckerFactor,
- (inputs, self._vocab_size))
- self._output_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses))
- self._setup_damping(damping, normalization=self._num_uses)
-
- @property
- def _renorm_coeff(self):
- return float(self._num_uses)
-
-
-class SeriesFBApproximation(enum.IntEnum):
- """See FullyConnectedSeriesFB.__init__ for description and usage."""
- option1 = 1
- option2 = 2
-
-
-class FullyConnectedSeriesFB(InputOutputMultiTowerMultiUse,
- KroneckerProductFB):
- """FisherBlock for fully-connected layers that share parameters across time.
-
- This class implements the "Option 1" and "Option 2" approximation from the
- following paper:
- https://openreview.net/pdf?id=HyMTkQZAb
-
- See the end of the appendix of the paper for a pseudo-code of the
- algorithm being implemented by multiply_matpower here. Note that we are
- using pre-computed versions of certain matrix-matrix products to speed
- things up. This is explicitly explained wherever it is done.
- """
-
- def __init__(self,
- layer_collection,
- has_bias=False,
- num_uses=None,
- option=SeriesFBApproximation.option2):
- """Constructs a new `FullyConnectedSeriesFB`.
-
- Args:
- layer_collection: The collection of all layers in the K-FAC approximate
- Fisher information matrix to which this FisherBlock belongs.
- has_bias: Whether the layer includes a bias parameter.
- num_uses: int or None. Number of time-steps over which the layer
- is used. Only required if the data is formatted with time folded into
- the batch dimension (instead of time being a list dimension).
- (Default: None)
- option: A `SeriesFBApproximation` specifying the simplifying assumption
- to be used in this block. `option1` approximates the cross-covariance
- over time as a symmetric matrix, while `option2` makes
- the assumption that training sequences are infinitely long. See section
- 3.5 of the paper for more details.
- """
-
- self._has_bias = has_bias
- self._option = option
-
- super(FullyConnectedSeriesFB, self).__init__(
- layer_collection=layer_collection,
- num_uses=num_uses)
-
- @property
- def _num_timesteps(self):
- return self._num_uses
-
- @property
- def _renorm_coeff(self):
- # This should no longer be used since the multiply_X functions from the base
- # class have been overridden
- assert False
-
- def instantiate_factors(self, grads_list, damping):
- inputs, grads_list = self._process_data(grads_list)
-
- self._input_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedMultiKF,
- ((inputs,), self._num_uses, self._has_bias))
- self._input_factor.register_cov_dt1()
-
- self._output_factor = self._layer_collection.make_or_get_factor(
- fisher_factors.FullyConnectedMultiKF, (grads_list, self._num_uses))
- self._output_factor.register_cov_dt1()
-
- self._setup_damping(damping, normalization=self._num_uses)
-
- def register_matpower(self, exp):
- if exp != -1:
- raise NotImplementedError("FullyConnectedSeriesFB only supports inverse"
- "multiplications.")
-
- if self._option == SeriesFBApproximation.option1:
- self._input_factor.register_option1quants(self._input_damping_func)
- self._output_factor.register_option1quants(self._output_damping_func)
- elif self._option == SeriesFBApproximation.option2:
- self._input_factor.register_option2quants(self._input_damping_func)
- self._output_factor.register_option2quants(self._output_damping_func)
- else:
- raise ValueError(
- "Unrecognized FullyConnectedSeriesFB approximation: {}".format(
- self._option))
-
- def multiply_matpower(self, vector, exp):
- if exp != -1:
- raise NotImplementedError("FullyConnectedSeriesFB only supports inverse"
- "multiplications.")
-
- # pylint: disable=invalid-name
-
- Z = utils.layer_params_to_mat2d(vector)
-
- # Derivations were done for "batch_dim==1" case so we need to convert to
- # that orientation:
- Z = array_ops.transpose(Z)
-
- if self._option == SeriesFBApproximation.option1:
-
- # Note that \\(L_A = A0^{-1/2} * U_A and L_G = G0^{-1/2} * U_G.\\)
- L_A, psi_A = self._input_factor.get_option1quants(
- self._input_damping_func)
- L_G, psi_G = self._output_factor.get_option1quants(
- self._output_damping_func)
-
- def gamma(x):
- # We are assuming that each case has the same number of time-steps.
- # If this stops being the case one shouldn't simply replace this T
- # with its average value. Instead, one needs to go back to the
- # definition of the gamma function from the paper.
- T = self._num_timesteps
- return (1 - x)**2 / (T * (1 - x**2) - 2 * x * (1 - x**T))
-
- # \\(Y = \gamma( psi_G*psi_A^T )\\) (computed element-wise)
- # Even though Y is Z-independent we are recomputing it from the psi's
- # each since Y depends on both A and G quantities, and it is relatively
- # cheap to compute.
- Y = gamma(array_ops.reshape(psi_G, [int(psi_G.shape[0]), -1]) * psi_A)
-
- # \\(Z = L_G^T * Z * L_A\\)
- # This is equivalent to the following computation from the original
- # pseudo-code:
- # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\)
- # \\(Z = U_G^T * Z * U_A\\)
- Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A), transpose_a=True)
-
- # \\(Z = Z .* Y\\)
- Z *= Y
-
- # \\(Z = L_G * Z * L_A^T\\)
- # This is equivalent to the following computation from the original
- # pseudo-code:
- # \\(Z = U_G * Z * U_A^T\\)
- # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\)
- Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A, transpose_b=True))
-
- elif self._option == SeriesFBApproximation.option2:
-
- # Note that \\(P_A = A_1^T * A_0^{-1} and P_G = G_1^T * G_0^{-1}\\),
- # and \\(K_A = A_0^{-1/2} * E_A\ and\ K_G = G_0^{-1/2} * E_G.\\)
- P_A, K_A, mu_A = self._input_factor.get_option2quants(
- self._input_damping_func)
- P_G, K_G, mu_G = self._output_factor.get_option2quants(
- self._output_damping_func)
-
- # Our approach differs superficially from the pseudo-code in the paper
- # in order to reduce the total number of matrix-matrix multiplies.
- # In particular, the first three computations in the pseudo code are
- # \\(Z = G0^{-1/2} * Z * A0^{-1/2}\\)
- # \\(Z = Z - hPsi_G^T * Z * hPsi_A\\)
- # \\(Z = E_G^T * Z * E_A\\)
- # Noting that hPsi = C0^{-1/2} * C1 * C0^{-1/2}\\), so that
- # \\(C0^{-1/2} * hPsi = C0^{-1} * C1 * C0^{-1/2} = P^T * C0^{-1/2}\\)
- # the entire computation can be written as
- # \\(Z = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\)
- # \\( - hPsi_G^T * G0^{-1/2} * Z * A0^{-1/2} * hPsi_A) * E_A\\)
- # \\( = E_G^T * (G0^{-1/2} * Z * A0^{-1/2}\\)
- # \\( - G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2}) * E_A\\)
- # \\( = E_G^T * G0^{-1/2} * Z * A0^{-1/2} * E_A\\)
- # \\( - E_G^T* G0^{-1/2} * P_G * Z * P_A^T * A0^{-1/2} * E_A\\)
- # \\( = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A\\)
- # This final expression is computed by the following two lines:
- # \\(Z = Z - P_G * Z * P_A^T\\)
- Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A, transpose_b=True))
- # \\(Z = K_G^T * Z * K_A\\)
- Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A), transpose_a=True)
-
- # \\(Z = Z ./ (1*1^T - mu_G*mu_A^T)\\)
- # Be careful with the outer product. We don't want to accidentally
- # make it an inner-product instead.
- tmp = 1.0 - array_ops.reshape(mu_G, [int(mu_G.shape[0]), -1]) * mu_A
- # Prevent some numerical issues by setting any 0.0 eigs to 1.0
- tmp += 1.0 * math_ops.cast(math_ops.equal(tmp, 0.0), dtype=tmp.dtype)
- Z /= tmp
-
- # We now perform the transpose/reverse version of the operations
- # derived above, whose derivation from the original pseudo-code is
- # analgous.
- # \\(Z = K_G * Z * K_A^T\\)
- Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A, transpose_b=True))
-
- # \\(Z = Z - P_G^T * Z * P_A\\)
- Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A), transpose_a=True)
-
- # \\(Z = normalize (1/E[T]) * Z\\)
- # Note that this normalization is done because we compute the statistics
- # by averaging, not summing, over time. (And the gradient is presumably
- # summed over time, not averaged, and thus their scales are different.)
- Z /= math_ops.cast(self._num_timesteps, Z.dtype)
-
- # Convert back to the "batch_dim==0" orientation.
- Z = array_ops.transpose(Z)
-
- return utils.mat2d_to_layer_params(vector, Z)
-
- # pylint: enable=invalid-name
-
- def multiply_cholesky(self, vector):
- raise NotImplementedError("FullyConnectedSeriesFB does not support "
- "Cholesky computations.")
-
- def multiply_cholesky_inverse(self, vector):
- raise NotImplementedError("FullyConnectedSeriesFB does not support "
- "Cholesky computations.")
-
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py
deleted file mode 100644
index c04cf727fa..0000000000
--- a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py
+++ /dev/null
@@ -1,45 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""FisherBlock definitions."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,line-too-long,wildcard-import
-from tensorflow.contrib.kfac.python.ops.fisher_blocks import *
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long,wildcard-import
-
-_allowed_symbols = [
- 'FisherBlock',
- 'FullFB',
- 'NaiveDiagonalFB',
- 'FullyConnectedDiagonalFB',
- 'KroneckerProductFB',
- 'EmbeddingKFACFB',
- 'FullyConnectedKFACBasicFB',
- 'ConvKFCBasicFB',
- 'ConvDiagonalFB',
- 'set_global_constants',
- 'compute_pi_tracenorm',
- 'compute_pi_adjusted_damping',
- 'num_conv_locations',
- 'normalize_damping',
- 'LEFT_MULTIPLY',
- 'RIGHT_MULTIPLY',
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
deleted file mode 100644
index afa2fd1ca7..0000000000
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py
+++ /dev/null
@@ -1,1830 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""FisherFactor definitions."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import abc
-import contextlib
-
-import numpy as np
-import six
-
-from tensorflow.contrib.kfac.python.ops import linear_operator as lo
-from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops as tf_ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import special_math_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables
-from tensorflow.python.training import moving_averages
-from tensorflow.python.util import nest
-
-
-# Whether to initialize covariance estimators at a zero matrix (or the identity
-# matrix).
-INIT_COVARIANCES_AT_ZERO = True
-
-# Whether to zero-debias the moving averages.
-ZERO_DEBIAS = True
-
-# Whether to initialize inverse (and other such matrices computed from the cov
-# matrices) to the zero matrix (or the identity matrix).
-INIT_INVERSES_AT_ZERO = True
-
-# When the number of inverses requested from a FisherFactor exceeds this value,
-# the inverses are computed using an eigenvalue decomposition.
-EIGENVALUE_DECOMPOSITION_THRESHOLD = 2
-
-# Numerical eigenvalues computed from covariance matrix estimates are clipped to
-# be at least as large as this value before they are used to compute inverses or
-# matrix powers. Must be nonnegative.
-EIGENVALUE_CLIPPING_THRESHOLD = 0.0
-
-# Used to subsample the flattened extracted image patches. The number of
-# outer products per row of the covariance matrix should not exceed this
-# value. This parameter is used only if `_SUB_SAMPLE_OUTER_PRODUCTS` is True.
-_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW = 1
-
-# Used to subsample the inputs passed to the extract image patches. The batch
-# size of number of inputs to extract image patches is multiplied by this
-# factor. This parameter is used only if `_SUB_SAMPLE_INPUTS` is True.
-_INPUTS_TO_EXTRACT_PATCHES_FACTOR = 0.5
-
-# If True, then subsamples the tensor passed to compute the covariance matrix.
-_SUB_SAMPLE_OUTER_PRODUCTS = False
-
-# If True, then subsamples the tensor passed to compute the covariance matrix.
-_SUB_SAMPLE_INPUTS = False
-
-# TOWER_STRATEGY can be one of "concat" or "separate". If "concat", the data
-# passed to the factors from the blocks will be concatenated across towers
-# (lazily via PartitionedTensor objects). Otherwise a tuple of tensors over
-# towers will be passed in, and the factors will iterate over this and do the
-# cov computations separately for each one, averaging the results together.
-TOWER_STRATEGY = "concat"
-
-
-def set_global_constants(init_covariances_at_zero=None,
- zero_debias=None,
- init_inverses_at_zero=None,
- eigenvalue_decomposition_threshold=None,
- eigenvalue_clipping_threshold=None,
- max_num_outer_products_per_cov_row=None,
- sub_sample_outer_products=None,
- inputs_to_extract_patches_factor=None,
- sub_sample_inputs=None,
- tower_strategy=None):
- """Sets various global constants used by the classes in this module."""
- global INIT_COVARIANCES_AT_ZERO
- global ZERO_DEBIAS
- global INIT_INVERSES_AT_ZERO
- global EIGENVALUE_DECOMPOSITION_THRESHOLD
- global EIGENVALUE_CLIPPING_THRESHOLD
- global _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW
- global _SUB_SAMPLE_OUTER_PRODUCTS
- global _INPUTS_TO_EXTRACT_PATCHES_FACTOR
- global _SUB_SAMPLE_INPUTS
- global TOWER_STRATEGY
-
- if init_covariances_at_zero is not None:
- INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero
- if zero_debias is not None:
- ZERO_DEBIAS = zero_debias
- if init_inverses_at_zero is not None:
- INIT_INVERSES_AT_ZERO = init_inverses_at_zero
- if eigenvalue_decomposition_threshold is not None:
- EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold
- if eigenvalue_clipping_threshold is not None:
- EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold
- if max_num_outer_products_per_cov_row is not None:
- _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW = max_num_outer_products_per_cov_row
- if sub_sample_outer_products is not None:
- _SUB_SAMPLE_OUTER_PRODUCTS = sub_sample_outer_products
- if inputs_to_extract_patches_factor is not None:
- _INPUTS_TO_EXTRACT_PATCHES_FACTOR = inputs_to_extract_patches_factor
- if sub_sample_inputs is not None:
- _SUB_SAMPLE_INPUTS = sub_sample_inputs
- if tower_strategy is not None:
- TOWER_STRATEGY = tower_strategy
-
-
-def inverse_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument
- if INIT_INVERSES_AT_ZERO:
- return array_ops.zeros(shape, dtype=dtype)
- return linalg_ops.eye(num_rows=shape[0], dtype=dtype)
-
-
-def covariance_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument
- if INIT_COVARIANCES_AT_ZERO:
- return array_ops.zeros(shape, dtype=dtype)
- return linalg_ops.eye(num_rows=shape[0], dtype=dtype)
-
-
-def diagonal_covariance_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument
- if INIT_COVARIANCES_AT_ZERO:
- return array_ops.zeros(shape, dtype=dtype)
- return array_ops.ones(shape, dtype=dtype)
-
-
-@contextlib.contextmanager
-def place_on_device(device):
- if device is not None and len(device):
- with tf_ops.device(device):
- yield
- else:
- yield
-
-
-def compute_cov(tensor, tensor_right=None, normalizer=None):
- """Compute the empirical second moment of the rows of a 2D Tensor.
-
- This function is meant to be applied to random matrices for which the true row
- mean is zero, so that the true second moment equals the true covariance.
-
- Args:
- tensor: A 2D Tensor.
- tensor_right: An optional 2D Tensor. If provided, this function computes
- the matrix product tensor^T * tensor_right instead of tensor^T * tensor.
- normalizer: optional scalar for the estimator (by default, the normalizer is
- the number of rows of tensor).
-
- Returns:
- A square 2D Tensor with as many rows/cols as the number of input columns.
- """
- if normalizer is None:
- normalizer = array_ops.shape(tensor)[0]
- if tensor_right is None:
- cov = (
- math_ops.matmul(tensor, tensor, transpose_a=True) / math_ops.cast(
- normalizer, tensor.dtype))
- return (cov + array_ops.transpose(cov)) / math_ops.cast(2.0, cov.dtype)
- else:
- return (math_ops.matmul(tensor, tensor_right, transpose_a=True) /
- math_ops.cast(normalizer, tensor.dtype))
-
-
-def append_homog(tensor):
- """Appends a homogeneous coordinate to the last dimension of a Tensor.
-
- Args:
- tensor: A Tensor.
-
- Returns:
- A Tensor identical to the input but one larger in the last dimension. The
- new entries are filled with ones.
- """
- rank = len(tensor.shape.as_list())
- shape = array_ops.concat([array_ops.shape(tensor)[:-1], [1]], axis=0)
- ones = array_ops.ones(shape, dtype=tensor.dtype)
- return array_ops.concat([tensor, ones], axis=rank - 1)
-
-
-def scope_string_from_params(params):
- """Builds a variable scope string name from the given parameters.
-
- Supported parameters are:
- * tensors
- * booleans
- * ints
- * strings
- * depth-1 tuples/lists of ints
- * any depth tuples/lists of tensors
- Other parameter types will throw an error.
-
- Args:
- params: A parameter or list of parameters.
-
- Returns:
- A string to use for the variable scope.
-
- Raises:
- ValueError: if params includes an unsupported type.
- """
- params = params if isinstance(params, (tuple, list)) else (params,)
-
- name_parts = []
- for param in params:
- if param is None:
- name_parts.append("None")
- elif isinstance(param, (tuple, list)):
- if all([isinstance(p, int) for p in param]):
- name_parts.append("-".join([str(p) for p in param]))
- else:
- name_parts.append(scope_string_from_name(param))
- elif isinstance(param, (str, int, bool)):
- name_parts.append(str(param))
- elif isinstance(param, (tf_ops.Tensor, variables.Variable)):
- name_parts.append(scope_string_from_name(param))
- elif isinstance(param, utils.PartitionedTensor):
- name_parts.append(scope_string_from_name(param.tensors))
- else:
- raise ValueError("Encountered an unsupported param type {}".format(
- type(param)))
- return "_".join(name_parts)
-
-
-def scope_string_from_name(tensor):
- if isinstance(tensor, (tuple, list)):
- return "__".join([scope_string_from_name(t) for t in tensor])
- # "gradients/add_4_grad/Reshape:0" -> "gradients_add_4_grad_Reshape"
- return tensor.name.split(":")[0].replace("/", "_")
-
-
-def scalar_or_tensor_to_string(val):
- return repr(val) if np.isscalar(val) else scope_string_from_name(val)
-
-
-def list_to_string(lst):
- return "_".join(val if isinstance(val, six.string_types)
- else scalar_or_tensor_to_string(val) for val in lst)
-
-
-def graph_func_to_id(func):
- """Returns a hashable object that represents func's computation."""
- # TODO(b/74201126): replace with Topohash of func's output
- return func.func_id
-
-
-def graph_func_to_string(func):
- # TODO(b/74201126): replace with Topohash of func's output
- return list_to_string(func.func_id)
-
-
-def _subsample_for_cov_computation(array, name=None):
- """Subsamples the first dimension of the array.
-
- `array`(A) is a tensor of shape `[batch_size, dim_2]`. Then the covariance
- matrix(A^TA) is of shape `dim_2 ** 2`. Subsample only if the number of outer
- products per row of the covariance matrix is greater than
- `_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW`.
-
- Args:
- array: Tensor, of shape `[batch_size, dim_2]`.
- name: `string`, Default(None)
-
- Returns:
- A tensor of shape `[max_samples, dim_2]`.
-
- Raises:
- ValueError: If array's is not matrix-shaped.
- ValueError: If array's batch_size cannot be inferred.
-
- """
- with tf_ops.name_scope(name, "subsample", [array]):
- array = tf_ops.convert_to_tensor(array)
- if len(array.shape) != 2:
- raise ValueError("Input param array must be a matrix.")
-
- batch_size = array.shape.as_list()[0]
- if batch_size is None:
- raise ValueError("Unable to get batch_size from input param array.")
-
- num_cov_rows = array.shape.as_list()[-1]
- max_batch_size = int(_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW * num_cov_rows)
- if batch_size <= max_batch_size:
- return array
-
- return _random_tensor_gather(array, max_batch_size)
-
-
-def _random_tensor_gather(array, max_size):
- """Generates a random set of indices and gathers the value at the indices.
-
- Args:
- array: Tensor, of shape `[batch_size, dim_2]`.
- max_size: int, Number of indices to sample.
-
- Returns:
- A tensor of shape `[max_size, ...]`.
- """
- batch_size = array.shape.as_list()[0]
- indices = random_ops.random_shuffle(math_ops.range(0, batch_size))[:max_size]
- return array_ops.gather(array, indices)
-
-
-@six.add_metaclass(abc.ABCMeta)
-class FisherFactor(object):
- """Base class for objects modeling factors of approximate Fisher blocks.
-
- A FisherFactor represents part of an approximate Fisher Information matrix.
- For example, one approximation to the Fisher uses the Kronecker product of two
- FisherFactors A and B, F = kron(A, B). FisherFactors are composed with
- FisherBlocks to construct a block-diagonal approximation to the full Fisher.
-
- FisherFactors are backed by a single, non-trainable variable that is updated
- by running FisherFactor.make_covariance_update_op(). The shape and type of
- this variable is implementation specific.
-
- Note that for blocks that aren't based on approximations, a 'factor' can
- be the entire block itself, as is the case for the diagonal and full
- representations.
- """
-
- def __init__(self):
- self._cov = None
-
- @abc.abstractproperty
- def _var_scope(self):
- """Variable scope for this FisherFactor instance.
-
- Returns:
- string that unique identifies this FisherFactor instance.
- """
- pass
-
- @property
- def name(self):
- return self._var_scope
-
- @abc.abstractproperty
- def _cov_shape(self):
- """The shape of the variable backing this FisherFactor."""
- pass
-
- @abc.abstractproperty
- def _num_sources(self):
- """The number of things to sum over when updating covariance variable.
-
- The default make_covariance_update_op function will call _compute_new_cov
- with indices ranging from 0 to _num_sources-1. The typical situation is
- where the factor wants to sum the statistics it computes over multiple
- backpropped "gradients" (typically passed in via "tensors" or
- "outputs_grads" arguments).
- """
- pass
-
- @abc.abstractproperty
- def _num_towers(self):
- pass
-
- @abc.abstractproperty
- def _dtype(self):
- """dtype for variable backing this factor."""
- pass
-
- @property
- def _cov_initializer(self):
- """Function for initializing covariance variable."""
- return covariance_initializer
-
- def instantiate_cov_variables(self):
- """Makes the internal cov variable(s)."""
- assert self._cov is None
- with variable_scope.variable_scope(self._var_scope):
- self._cov = variable_scope.get_variable(
- "cov",
- initializer=self._cov_initializer,
- shape=self._cov_shape,
- trainable=False,
- dtype=self._dtype)
-
- @abc.abstractmethod
- def _compute_new_cov(self, source, tower):
- """Computes minibatch-estimated covariance for a single source.
-
- Args:
- source: int in [0, self._num_sources). Which source to use when computing
- the cov update.
- tower: int in [0, self._num_towers). Which tower to use when computing
- the cov update.
-
- Returns:
- Tensor of same shape as self.get_cov().
- """
- pass
-
- def make_covariance_update_op(self, ema_decay):
- """Constructs and returns the covariance update Op.
-
- Args:
- ema_decay: The exponential moving average decay (float or Tensor).
- Returns:
- An Op for updating the covariance Variable referenced by _cov.
- """
- new_cov_contribs = []
- for source in range(self._num_sources):
- for tower in range(self._num_towers):
- device = (self._get_data_device(tower)
- if TOWER_STRATEGY == "separate" else None)
- with place_on_device(device):
- new_cov_contribs.append(self._compute_new_cov(source, tower))
-
- new_cov = math_ops.add_n(new_cov_contribs) / float(self._num_towers)
-
- # Compute average of 'new_cov' across all TPU cores. On a TPU, each
- # instance of 'new_cov' will be based on a different minibatch. This ensures
- # that by the end of assign_moving_average(), all TPU cores see the same
- # value for self._cov.
- #
- # Other implementations of make_covariance_update_op() that accumulate
- # statistics in other variables should mimic this behavior.
- if utils.on_tpu():
- new_cov = utils.cross_replica_mean(new_cov)
-
- return moving_averages.assign_moving_average(
- self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS)
-
- @abc.abstractmethod
- def _get_data_device(self, tower):
- pass
-
- @abc.abstractmethod
- def instantiate_inv_variables(self):
- """Makes the internal "inverse" variable(s)."""
- pass
-
- @abc.abstractmethod
- def make_inverse_update_ops(self):
- """Create and return update ops corresponding to registered computations."""
- pass
-
- def get_cov(self):
- return self._cov
-
- @abc.abstractmethod
- def get_cov_as_linear_operator(self):
- pass
-
- @abc.abstractmethod
- def register_matpower(self, exp, damping_func):
- pass
-
- @abc.abstractmethod
- def register_cholesky(self, damping_func):
- pass
-
- @abc.abstractmethod
- def register_cholesky_inverse(self, damping_func):
- pass
-
- @abc.abstractmethod
- def get_matpower(self, exp, damping_func):
- pass
-
- @abc.abstractmethod
- def get_cholesky(self, damping_func):
- pass
-
- @abc.abstractmethod
- def get_cholesky_inverse(self, damping_func):
- pass
-
-
-class DenseSquareMatrixFactor(FisherFactor):
- """Base class for FisherFactors that are stored as dense square matrices.
-
- This class explicitly calculates and stores inverses of their `cov` matrices,
- which must be square dense matrices.
-
- Subclasses must implement the _compute_new_cov method, and the _var_scope and
- _cov_shape properties.
- """
-
- # TODO(b/69108481): This class (and its subclasses) should be refactored to
- # serve the matrix quantities it computes as both (potentially stale)
- # variables, updated by the inverse update ops, and fresh values stored in
- # tensors that recomputed once every session.run() call. Currently matpower
- # and damp_inverse have the former behavior, while eigendecomposition has
- # the latter.
-
- def __init__(self):
- self._matpower_by_exp_and_damping = {} # { (float, hashable): variable }
- self._matpower_registrations = set() # { (float, hashable) }
- self._eigendecomp = None
- self._damping_funcs_by_id = {} # {hashable: lambda}
-
- self._cholesky_registrations = set() # { hashable }
- self._cholesky_inverse_registrations = set() # { hashable }
-
- self._cholesky_by_damping = {} # { hashable: variable }
- self._cholesky_inverse_by_damping = {} # { hashable: variable }
-
- super(DenseSquareMatrixFactor, self).__init__()
-
- def get_cov_as_linear_operator(self):
- assert self.get_cov().shape.ndims == 2
- return lo.LinearOperatorFullMatrix(self.get_cov(),
- is_self_adjoint=True,
- is_square=True)
-
- def _register_damping(self, damping_func):
- damping_id = graph_func_to_id(damping_func)
- if damping_id not in self._damping_funcs_by_id:
- self._damping_funcs_by_id[damping_id] = damping_func
- return damping_id
-
- def register_inverse(self, damping_func):
- # Just for backwards compatibility of some old code and tests
- self.register_matpower(-1, damping_func)
-
- def register_matpower(self, exp, damping_func):
- """Registers a matrix power to be maintained and served on demand.
-
- This creates a variable and signals make_inverse_update_ops to make the
- corresponding update op. The variable can be read via the method
- get_matpower.
-
- Args:
- exp: float. The exponent to use in the matrix power.
- damping_func: A function that computes a 0-D Tensor or a float which will
- be the damping value used. i.e. damping = damping_func().
- """
- if exp == 1.0:
- return
-
- damping_id = self._register_damping(damping_func)
-
- if (exp, damping_id) not in self._matpower_registrations:
- self._matpower_registrations.add((exp, damping_id))
-
- def register_cholesky(self, damping_func):
- """Registers a Cholesky factor to be maintained and served on demand.
-
- This creates a variable and signals make_inverse_update_ops to make the
- corresponding update op. The variable can be read via the method
- get_cholesky.
-
- Args:
- damping_func: A function that computes a 0-D Tensor or a float which will
- be the damping value used. i.e. damping = damping_func().
- """
- damping_id = self._register_damping(damping_func)
-
- if damping_id not in self._cholesky_registrations:
- self._cholesky_registrations.add(damping_id)
-
- def register_cholesky_inverse(self, damping_func):
- """Registers an inverse Cholesky factor to be maintained/served on demand.
-
- This creates a variable and signals make_inverse_update_ops to make the
- corresponding update op. The variable can be read via the method
- get_cholesky_inverse.
-
- Args:
- damping_func: A function that computes a 0-D Tensor or a float which will
- be the damping value used. i.e. damping = damping_func().
- """
- damping_id = self._register_damping(damping_func)
-
- if damping_id not in self._cholesky_inverse_registrations:
- self._cholesky_inverse_registrations.add(damping_id)
-
- def instantiate_inv_variables(self):
- """Makes the internal "inverse" variable(s)."""
-
- for (exp, damping_id) in self._matpower_registrations:
- exp_string = scalar_or_tensor_to_string(exp)
- damping_func = self._damping_funcs_by_id[damping_id]
- damping_string = graph_func_to_string(damping_func)
- with variable_scope.variable_scope(self._var_scope):
- matpower = variable_scope.get_variable(
- "matpower_exp{}_damp{}".format(exp_string, damping_string),
- initializer=inverse_initializer,
- shape=self._cov_shape,
- trainable=False,
- dtype=self._dtype)
- assert (exp, damping_id) not in self._matpower_by_exp_and_damping
- self._matpower_by_exp_and_damping[(exp, damping_id)] = matpower
-
- for damping_id in self._cholesky_registrations:
- damping_func = self._damping_funcs_by_id[damping_id]
- damping_string = graph_func_to_string(damping_func)
- with variable_scope.variable_scope(self._var_scope):
- chol = variable_scope.get_variable(
- "cholesky_damp{}".format(damping_string),
- initializer=inverse_initializer,
- shape=self._cov_shape,
- trainable=False,
- dtype=self._dtype)
- assert damping_id not in self._cholesky_by_damping
- self._cholesky_by_damping[damping_id] = chol
-
- for damping_id in self._cholesky_inverse_registrations:
- damping_func = self._damping_funcs_by_id[damping_id]
- damping_string = graph_func_to_string(damping_func)
- with variable_scope.variable_scope(self._var_scope):
- cholinv = variable_scope.get_variable(
- "cholesky_inverse_damp{}".format(damping_string),
- initializer=inverse_initializer,
- shape=self._cov_shape,
- trainable=False,
- dtype=self._dtype)
- assert damping_id not in self._cholesky_inverse_by_damping
- self._cholesky_inverse_by_damping[damping_id] = cholinv
-
- def make_inverse_update_ops(self):
- """Create and return update ops corresponding to registered computations."""
- ops = []
-
- num_inverses = sum(1 for (exp, _) in self._matpower_by_exp_and_damping
- if exp == -1)
-
- num_other_matpower = len(self._matpower_by_exp_and_damping) - num_inverses
-
- other_matrix_power_registered = num_other_matpower >= 1
-
- use_eig = (
- self._eigendecomp or other_matrix_power_registered or
- num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD)
-
- # We precompute these so we don't need to evaluate them multiple times (for
- # each matrix power that uses them)
- damping_value_by_id = {damping_id: math_ops.cast(
- self._damping_funcs_by_id[damping_id](), self._dtype)
- for damping_id in self._damping_funcs_by_id}
-
- if use_eig:
- eigenvalues, eigenvectors = self.get_eigendecomp() # pylint: disable=unpacking-non-sequence
-
- for (exp, damping_id), matpower in (
- self._matpower_by_exp_and_damping.items()):
- damping = damping_value_by_id[damping_id]
- ops.append(
- matpower.assign(
- math_ops.matmul(eigenvectors *
- (eigenvalues + damping)**exp,
- array_ops.transpose(eigenvectors))))
- # These ops share computation and should be run on a single device.
- ops = [control_flow_ops.group(*ops)]
- else:
- for (exp, damping_id), matpower in (
- self._matpower_by_exp_and_damping.items()):
- assert exp == -1
- damping = damping_value_by_id[damping_id]
- ops.append(matpower.assign(utils.posdef_inv(self.get_cov(), damping)))
-
- # TODO(b/77902055): If inverses are being computed with Cholesky's
- # we can share the work. Instead this code currently just computes the
- # Cholesky a second time. It does at least share work between requests for
- # Cholesky's and Cholesky inverses with the same damping id.
- for damping_id, cholesky_inv in self._cholesky_inverse_by_damping.items():
- cholesky_ops = []
-
- damping = damping_value_by_id[damping_id]
- cholesky_value = utils.cholesky(self.get_cov(), damping)
-
- if damping_id in self._cholesky_by_damping:
- cholesky = self._cholesky_by_damping[damping_id]
- cholesky_ops.append(cholesky.assign(cholesky_value))
-
- identity = linalg_ops.eye(cholesky_value.shape.as_list()[0],
- dtype=cholesky_value.dtype)
- cholesky_inv_value = linalg_ops.matrix_triangular_solve(cholesky_value,
- identity)
- cholesky_ops.append(cholesky_inv.assign(cholesky_inv_value))
-
- ops.append(control_flow_ops.group(*cholesky_ops))
-
- for damping_id, cholesky in self._cholesky_by_damping.items():
- if damping_id not in self._cholesky_inverse_by_damping:
- damping = damping_value_by_id[damping_id]
- cholesky_value = utils.cholesky(self.get_cov(), damping)
- ops.append(cholesky.assign(cholesky_value))
-
- self._eigendecomp = False
- return ops
-
- def get_inverse(self, damping_func):
- # Just for backwards compatibility of some old code and tests
- return self.get_matpower(-1, damping_func)
-
- def get_matpower(self, exp, damping_func):
- # Note that this function returns a variable which gets updated by the
- # inverse ops. It may be stale / inconsistent with the latest value of
- # get_cov().
- if exp != 1:
- damping_id = graph_func_to_id(damping_func)
- matpower = self._matpower_by_exp_and_damping[(exp, damping_id)]
- else:
- matpower = self.get_cov()
- identity = linalg_ops.eye(matpower.shape.as_list()[0],
- dtype=matpower.dtype)
- matpower += math_ops.cast(damping_func(), dtype=matpower.dtype)*identity
-
- assert matpower.shape.ndims == 2
- return lo.LinearOperatorFullMatrix(matpower,
- is_non_singular=True,
- is_self_adjoint=True,
- is_positive_definite=True,
- is_square=True)
-
- def get_cholesky(self, damping_func):
- # Note that this function returns a variable which gets updated by the
- # inverse ops. It may be stale / inconsistent with the latest value of
- # get_cov().
- damping_id = graph_func_to_id(damping_func)
- cholesky = self._cholesky_by_damping[damping_id]
- assert cholesky.shape.ndims == 2
- return lo.LinearOperatorFullMatrix(cholesky,
- is_non_singular=True,
- is_square=True)
-
- def get_cholesky_inverse(self, damping_func):
- # Note that this function returns a variable which gets updated by the
- # inverse ops. It may be stale / inconsistent with the latest value of
- # get_cov().
- damping_id = graph_func_to_id(damping_func)
- cholesky_inv = self._cholesky_inverse_by_damping[damping_id]
- assert cholesky_inv.shape.ndims == 2
- return lo.LinearOperatorFullMatrix(cholesky_inv,
- is_non_singular=True,
- is_square=True)
-
- def get_eigendecomp(self):
- """Creates or retrieves eigendecomposition of self._cov."""
- # Unlike get_matpower this doesn't retrieve a stored variable, but instead
- # always computes a fresh version from the current value of get_cov().
- if not self._eigendecomp:
- eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self.get_cov())
-
- # The matrix self._cov is positive semidefinite by construction, but the
- # numerical eigenvalues could be negative due to numerical errors, so here
- # we clip them to be at least FLAGS.eigenvalue_clipping_threshold
- clipped_eigenvalues = math_ops.maximum(eigenvalues,
- EIGENVALUE_CLIPPING_THRESHOLD)
- self._eigendecomp = (clipped_eigenvalues, eigenvectors)
-
- return self._eigendecomp
-
-
-class FullFactor(DenseSquareMatrixFactor):
- """FisherFactor for a full matrix representation of the Fisher of a parameter.
-
- Note that this uses the naive "square the sum estimator", and so is applicable
- to any type of parameter in principle, but has very high variance.
- """
-
- def __init__(self,
- params_grads,
- batch_size):
- self._batch_size = batch_size
- self._params_grads = tuple(utils.ensure_sequence(params_grad)
- for params_grad in params_grads)
- super(FullFactor, self).__init__()
-
- @property
- def _var_scope(self):
- return "ff_full_" + scope_string_from_params(
- [self._params_grads, self._batch_size])
-
- @property
- def _cov_shape(self):
- size = sum(param_grad.shape.num_elements()
- for param_grad in self._params_grads[0])
- return (size, size)
-
- @property
- def _num_sources(self):
- return len(self._params_grads)
-
- @property
- def _num_towers(self):
- return 1
-
- @property
- def _dtype(self):
- return self._params_grads[0][0].dtype
-
- def _compute_new_cov(self, source, tower):
- assert tower == 0
-
- # This will be a very basic rank 1 estimate
- params_grads_flat = utils.tensors_to_column(self._params_grads[source])
- return ((params_grads_flat * array_ops.transpose(
- params_grads_flat)) / math_ops.cast(self._batch_size,
- params_grads_flat.dtype))
-
- def _get_data_device(self, tower):
- return None
-
-
-class DiagonalFactor(FisherFactor):
- """A base class for FisherFactors that use diagonal approximations.
-
- A DiagonalFactor's covariance variable can be of any shape, but must contain
- exactly one entry per parameter.
- """
-
- def __init__(self):
- super(DiagonalFactor, self).__init__()
-
- def get_cov_as_linear_operator(self):
- assert self._matrix_diagonal.shape.ndims == 1
- return lo.LinearOperatorDiag(self._matrix_diagonal,
- is_self_adjoint=True,
- is_square=True)
-
- @property
- def _cov_initializer(self):
- return diagonal_covariance_initializer
-
- @property
- def _matrix_diagonal(self):
- return array_ops.reshape(self.get_cov(), [-1])
-
- def make_inverse_update_ops(self):
- return []
-
- def instantiate_inv_variables(self):
- pass
-
- def register_matpower(self, exp, damping_func):
- pass
-
- def register_cholesky(self, damping_func):
- pass
-
- def register_cholesky_inverse(self, damping_func):
- pass
-
- def get_matpower(self, exp, damping_func):
- matpower_diagonal = (self._matrix_diagonal
- + math_ops.cast(damping_func(), self._dtype))**exp
- return lo.LinearOperatorDiag(matpower_diagonal,
- is_non_singular=True,
- is_self_adjoint=True,
- is_positive_definite=True,
- is_square=True)
-
- def get_cholesky(self, damping_func):
- return self.get_matpower(0.5, damping_func)
-
- def get_cholesky_inverse(self, damping_func):
- return self.get_matpower(-0.5, damping_func)
-
-
-class NaiveDiagonalFactor(DiagonalFactor):
- """FisherFactor for a diagonal approximation of any type of param's Fisher.
-
- Note that this uses the naive "square the sum estimator", and so is applicable
- to any type of parameter in principle, but has very high variance.
- """
-
- def __init__(self,
- params_grads,
- batch_size):
- """Initializes NaiveDiagonalFactor instance.
-
- Args:
- params_grads: Sequence of Tensors, each with same shape as parameters this
- FisherFactor corresponds to. For example, the gradient of the loss with
- respect to parameters.
- batch_size: int or 0-D Tensor. Size
- """
- self._params_grads = tuple(utils.ensure_sequence(params_grad)
- for params_grad in params_grads)
- self._batch_size = batch_size
- super(NaiveDiagonalFactor, self).__init__()
-
- @property
- def _var_scope(self):
- return "ff_naivediag_" + scope_string_from_params(
- [self._params_grads, self._batch_size])
-
- @property
- def _cov_shape(self):
- size = sum(param_grad.shape.num_elements()
- for param_grad in self._params_grads[0])
- return [size, 1]
-
- @property
- def _num_sources(self):
- return len(self._params_grads)
-
- @property
- def _num_towers(self):
- return 1
-
- @property
- def _dtype(self):
- return self._params_grads[0][0].dtype
-
- def _compute_new_cov(self, source, tower):
- assert tower == 0
-
- params_grads_flat = utils.tensors_to_column(self._params_grads[source])
- return (math_ops.square(params_grads_flat) / math_ops.cast(
- self._batch_size, params_grads_flat.dtype))
-
- def _get_data_device(self, tower):
- return None
-
-
-class EmbeddingInputKroneckerFactor(DiagonalFactor):
- r"""FisherFactor for input to an embedding layer.
-
- Given input_ids = [batch_size, input_size] representing indices into an
- [vocab_size, embedding_size] embedding matrix, approximate input covariance by
- a diagonal matrix,
-
- Cov(input_ids, input_ids) =
- (1/batch_size) sum_{i} diag(n_hot(input[i]) ** 2).
-
- where n_hot() constructs an n-hot binary vector and diag() constructs a
- diagonal matrix of size [vocab_size, vocab_size].
- """
-
- def __init__(self, input_ids, vocab_size, dtype=None):
- """Instantiate EmbeddingInputKroneckerFactor.
-
- Args:
- input_ids: List of Tensors of shape [batch_size, input_size] and dtype
- int32. Indices into embedding matrix. List index is tower.
- vocab_size: int or 0-D Tensor. Maximum value for entries in 'input_ids'.
- dtype: dtype for covariance statistics. Must be a floating point type.
- Defaults to float32.
- """
- self._input_ids = input_ids
- self._vocab_size = vocab_size
- self._cov_dtype = dtype or dtypes.float32
-
- super(EmbeddingInputKroneckerFactor, self).__init__()
-
- @property
- def _var_scope(self):
- return "ff_diag_embedding_" + scope_string_from_params(self._input_ids)
-
- @property
- def _cov_shape(self):
- return [self._vocab_size]
-
- @property
- def _num_sources(self):
- return 1
-
- @property
- def _num_towers(self):
- return len(self._input_ids)
-
- @property
- def _dtype(self):
- return self._cov_dtype
-
- def _compute_new_cov(self, source, tower):
- assert source == 0
-
- input_ids = self._input_ids[tower]
-
- if len(input_ids.shape) > 2:
- raise ValueError(
- "Input to embeddings must have rank <= 2. Found rank %d." % len(
- input_ids.shape))
-
- batch_size = array_ops.shape(input_ids)[0]
-
- # Transform indices into one-hot vectors.
- #
- # TODO(b/72714822): There must be a faster way to construct the diagonal
- # covariance matrix! This operation is O(batch_size * vocab_size), where
- # it should be O(batch_size * input_size).
- flat_input_ids = array_ops.reshape(input_ids, [-1])
- one_hots = array_ops.one_hot(flat_input_ids,
- self._vocab_size) # [?, vocab_size]
-
- # Take average across examples. Note that, because all entries have
- # magnitude zero or one, there's no need to square the entries.
- #
- # TODO(b/72714822): Support for SparseTensor, other kinds of aggregation
- # within an example such as average.
- #
- # TODO(b/72714822): Support for partitioned embeddings.
- new_cov = math_ops.reduce_sum(one_hots, axis=0) # [vocab_size]
- new_cov /= math_ops.cast(batch_size, new_cov.dtype)
-
- return new_cov
-
- def _get_data_device(self, tower):
- return self._input_ids[tower].device
-
-
-class FullyConnectedDiagonalFactor(DiagonalFactor):
- r"""FisherFactor for a diagonal approx of a fully-connected layer's Fisher.
-
- Given in = [batch_size, input_size] and out_grad = [batch_size, output_size],
- approximates the covariance as,
-
- Cov(in, out) = (1/batch_size) sum_{i} outer(in[i], out_grad[i]) ** 2.0
-
- where the square is taken element-wise.
- """
-
- def __init__(self,
- inputs,
- outputs_grads,
- has_bias=False):
- """Instantiate FullyConnectedDiagonalFactor.
-
- Args:
- inputs: List of Tensors of shape [batch_size, input_size]. Inputs to this
- layer. List index is towers.
- outputs_grads: List of Tensors, each of shape [batch_size, output_size],
- which are the gradients of the loss with respect to the layer's
- outputs. First index is source, second is tower.
-
- has_bias: bool. If True, append '1' to each input.
- """
- self._inputs = inputs
- self._has_bias = has_bias
- self._outputs_grads = outputs_grads
- self._squared_inputs = None
-
- super(FullyConnectedDiagonalFactor, self).__init__()
-
- @property
- def _var_scope(self):
- return "ff_diagfc_" + scope_string_from_params(
- tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads)))
-
- @property
- def _cov_shape(self):
- input_size = self._inputs[0].shape[1] + self._has_bias
- output_size = self._outputs_grads[0][0].shape[1]
- return [input_size, output_size]
-
- @property
- def _num_sources(self):
- return len(self._outputs_grads)
-
- @property
- def _num_towers(self):
- return len(self._inputs)
-
- @property
- def _dtype(self):
- return self._outputs_grads[0][0].dtype
-
- def make_covariance_update_op(self, ema_decay):
-
- self._squared_inputs = []
- for tower in range(self._num_towers):
- inputs = self._inputs[tower]
-
- with place_on_device(self._get_data_device(tower)):
- if self._has_bias:
- inputs = append_homog(inputs)
- self._squared_inputs.append(math_ops.square(inputs))
-
- return super(FullyConnectedDiagonalFactor, self).make_covariance_update_op(
- ema_decay)
-
- def _compute_new_cov(self, source, tower):
- batch_size = array_ops.shape(self._squared_inputs[tower])[0]
- outputs_grad = self._outputs_grads[source][tower]
-
- # The well-known special formula that uses the fact that the entry-wise
- # square of an outer product is the outer-product of the entry-wise squares.
- # The gradient is the outer product of the input and the output gradients,
- # so we just square both and then take their outer-product.
- new_cov = math_ops.matmul(
- self._squared_inputs[tower],
- math_ops.square(outputs_grad),
- transpose_a=True)
- new_cov /= math_ops.cast(batch_size, new_cov.dtype)
- return new_cov
-
- def _get_data_device(self, tower):
- return self._inputs[tower].device
-
-
-class ConvDiagonalFactor(DiagonalFactor):
- """FisherFactor for a diagonal approx of a convolutional layer's Fisher."""
-
- def __init__(self,
- inputs,
- outputs_grads,
- filter_shape,
- strides,
- padding,
- data_format=None,
- dilations=None,
- has_bias=False):
- """Creates a ConvDiagonalFactor object.
-
- Args:
- inputs: List of Tensors of shape [batch_size, height, width, in_channels].
- Input activations to this layer. List index is towers.
- outputs_grads: List of Tensors, each of shape [batch_size,
- height, width, out_channels], which are the gradients of the loss
- with respect to the layer's outputs. First index is source, second
- index is tower.
- filter_shape: Tuple of 4 ints: (kernel_height, kernel_width, in_channels,
- out_channels). Represents shape of kernel used in this layer.
- strides: The stride size in this layer (1-D Tensor of length 4).
- padding: The padding in this layer (1-D of Tensor length 4).
- data_format: None or str. Format of conv2d inputs.
- dilations: None or tuple of 4 ints.
- has_bias: Python bool. If True, the layer is assumed to have a bias
- parameter in addition to its filter parameter.
-
- Raises:
- ValueError: If inputs, output_grads, and filter_shape do not agree on
- in_channels or out_channels.
- ValueError: If strides, dilations are not length-4 lists of ints.
- ValueError: If data_format does not put channel last.
- """
- if not utils.is_data_format_channel_last(data_format):
- raise ValueError("Channel must be last.")
- if any(input_.shape.ndims != 4 for input_ in inputs):
- raise ValueError("inputs must be a list of 4-D Tensors.")
- if any(input_.shape.as_list()[-1] != filter_shape[-2] for input_ in inputs):
- raise ValueError("inputs and filter_shape must agree on in_channels.")
- for i, outputs_grad in enumerate(outputs_grads):
- if any(output_grad.shape.ndims != 4 for output_grad in outputs_grad):
- raise ValueError("outputs[%d] must be 4-D Tensor." % i)
- if any(output_grad.shape.as_list()[-1] != filter_shape[-1]
- for output_grad in outputs_grad):
- raise ValueError(
- "outputs[%d] and filter_shape must agree on out_channels." % i)
- if len(strides) != 4:
- raise ValueError("strides must be length-4 list of ints.")
- if dilations is not None and len(dilations) != 4:
- raise ValueError("dilations must be length-4 list of ints.")
-
- self._inputs = inputs
- self._outputs_grads = outputs_grads
- self._filter_shape = filter_shape
- self._strides = strides
- self._padding = padding
- self._data_format = data_format
- self._dilations = dilations
- self._has_bias = has_bias
- self._patches = None
-
- super(ConvDiagonalFactor, self).__init__()
-
- @property
- def _var_scope(self):
- return "ff_convdiag_" + scope_string_from_params(
- tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads)))
-
- @property
- def _cov_shape(self):
- filter_height, filter_width, in_channels, out_channels = self._filter_shape
- return [
- filter_height * filter_width * in_channels + self._has_bias,
- out_channels
- ]
-
- @property
- def _num_sources(self):
- return len(self._outputs_grads)
-
- @property
- def _num_towers(self):
- return len(self._inputs)
-
- @property
- def _dtype(self):
- return self._inputs[0].dtype
-
- def make_covariance_update_op(self, ema_decay):
- filter_height, filter_width, _, _ = self._filter_shape
-
- # TODO(b/64144716): there is potential here for a big savings in terms
- # of memory use.
- if self._dilations is None:
- rates = (1, 1, 1, 1)
- else:
- rates = tuple(self._dilations)
-
- self._patches = []
- for tower in range(self._num_towers):
- with place_on_device(self._get_data_device(tower)):
- patches = array_ops.extract_image_patches(
- self._inputs[tower],
- ksizes=[1, filter_height, filter_width, 1],
- strides=self._strides,
- rates=rates,
- padding=self._padding)
-
- if self._has_bias:
- patches = append_homog(patches)
-
- self._patches.append(patches)
-
- return super(ConvDiagonalFactor, self).make_covariance_update_op(ema_decay)
-
- def _compute_new_cov(self, source, tower):
- patches = self._patches[tower]
- batch_size = array_ops.shape(patches)[0]
- outputs_grad = self._outputs_grads[source][tower]
-
- new_cov = self._convdiag_sum_of_squares(patches, outputs_grad)
- new_cov /= math_ops.cast(batch_size, new_cov.dtype)
-
- return new_cov
-
- def _convdiag_sum_of_squares(self, patches, outputs_grad):
- # This computes the sum of the squares of the per-training-case "gradients".
- # It does this simply by computing a giant tensor containing all of these,
- # doing an entry-wise square, and them summing along the batch dimension.
- case_wise_gradients = special_math_ops.einsum("bijk,bijl->bkl", patches,
- outputs_grad)
- return math_ops.reduce_sum(math_ops.square(case_wise_gradients), axis=0)
-
- def _get_data_device(self, tower):
- return self._inputs[tower].device
-
-
-class FullyConnectedKroneckerFactor(DenseSquareMatrixFactor):
- """Kronecker factor for the input or output side of a fully-connected layer.
- """
-
- def __init__(self,
- tensors,
- has_bias=False):
- """Instantiate FullyConnectedKroneckerFactor.
-
- Args:
- tensors: List of list of Tensors, each of shape [batch_size, n]. The
- Tensors are typically either a layer's inputs or its output's gradients.
- The first list index is source, the second is tower.
- has_bias: bool. If True, append '1' to each row.
- """
- # The tensor argument is either a tensor of input activations or a tensor of
- # output pre-activation gradients.
- self._has_bias = has_bias
- self._tensors = tensors
- super(FullyConnectedKroneckerFactor, self).__init__()
-
- @property
- def _var_scope(self):
- return "ff_fckron_" + scope_string_from_params(
- tuple(nest.flatten(self._tensors)) + (self._has_bias,))
-
- @property
- def _cov_shape(self):
- size = self._tensors[0][0].shape[1] + self._has_bias
- return [size, size]
-
- @property
- def _num_sources(self):
- return len(self._tensors)
-
- @property
- def _num_towers(self):
- return len(self._tensors[0])
-
- @property
- def _dtype(self):
- return self._tensors[0][0].dtype
-
- def _compute_new_cov(self, source, tower):
- tensor = self._tensors[source][tower]
- if self._has_bias:
- tensor = append_homog(tensor)
- return compute_cov(tensor)
-
- def _get_data_device(self, tower):
- return self._tensors[0][tower].device
-
-
-class ConvInputKroneckerFactor(DenseSquareMatrixFactor):
- r"""Kronecker factor for the input side of a convolutional layer.
-
- Estimates E[ a a^T ] where a is the inputs to a convolutional layer given
- example x. Expectation is taken over all examples and locations.
-
- Equivalent to Omega in https://arxiv.org/abs/1602.01407 for details. See
- Section 3.1 Estimating the factors.
- """
-
- def __init__(self,
- inputs,
- filter_shape,
- padding,
- strides=None,
- dilation_rate=None,
- data_format=None,
- extract_patches_fn=None,
- has_bias=False,
- sub_sample_inputs=None,
- sub_sample_patches=None):
- """Initializes ConvInputKroneckerFactor.
-
- Args:
- inputs: List of Tensors of shape [batch_size, ..spatial_input_size..,
- in_channels]. Inputs to layer. List index is tower.
- filter_shape: List of ints. Contains [..spatial_filter_size..,
- in_channels, out_channels]. Shape of convolution kernel.
- padding: str. Padding method for layer. "SAME" or "VALID".
- strides: List of ints or None. Contains [..spatial_filter_strides..] if
- 'extract_patches_fn' is compatible with tf.nn.convolution(), else
- [1, ..spatial_filter_strides, 1].
- dilation_rate: List of ints or None. Rate for dilation along each spatial
- dimension if 'extract_patches_fn' is compatible with
- tf.nn.convolution(), else [1, ..spatial_dilation_rates.., 1].
- data_format: str or None. Format of input data.
- extract_patches_fn: str or None. Name of function that extracts image
- patches. One of "extract_convolution_patches", "extract_image_patches",
- "extract_pointwise_conv2d_patches".
- has_bias: bool. If True, append 1 to in_channel.
- sub_sample_inputs: `bool`. If True, then subsample the inputs from which
- the image patches are extracted. (Default: None)
- sub_sample_patches: `bool`, If `True` then subsample the extracted
- patches.(Default: None)
- """
- self._inputs = inputs
- self._filter_shape = filter_shape
- self._strides = strides
- self._padding = padding
- self._dilation_rate = dilation_rate
- self._data_format = data_format
- self._extract_patches_fn = extract_patches_fn
- self._has_bias = has_bias
- if sub_sample_inputs is None:
- self._sub_sample_inputs = _SUB_SAMPLE_INPUTS
- else:
- self._sub_sample_inputs = sub_sample_inputs
-
- if sub_sample_patches is None:
- self._sub_sample_patches = _SUB_SAMPLE_OUTER_PRODUCTS
- else:
- self._sub_sample_patches = sub_sample_patches
- super(ConvInputKroneckerFactor, self).__init__()
-
- @property
- def _var_scope(self):
- return "ff_convinkron_" + scope_string_from_params(
- tuple(self._inputs) +
- tuple((self._filter_shape, self._strides, self._padding,
- self._dilation_rate, self._data_format, self._has_bias)))
-
- @property
- def _cov_shape(self):
- spatial_filter_shape = self._filter_shape[0:-2]
- in_channels = self._filter_shape[-2]
- size = np.prod(spatial_filter_shape) * in_channels + self._has_bias
- return [size, size]
-
- @property
- def _num_sources(self):
- return 1
-
- @property
- def _num_towers(self):
- return len(self._inputs)
-
- @property
- def _dtype(self):
- return self._inputs[0].dtype
-
- def _compute_new_cov(self, source, tower):
- assert source == 0
-
- inputs = self._inputs[tower]
- if self._sub_sample_inputs:
- batch_size = inputs.shape.as_list()[0]
- max_size = int(batch_size * _INPUTS_TO_EXTRACT_PATCHES_FACTOR)
- inputs = _random_tensor_gather(inputs, max_size)
-
- # TODO(b/64144716): there is potential here for a big savings in terms of
- # memory use.
- if self._extract_patches_fn in [None, "extract_convolution_patches"]:
- patches = utils.extract_convolution_patches(
- inputs,
- self._filter_shape,
- padding=self._padding,
- strides=self._strides,
- dilation_rate=self._dilation_rate,
- data_format=self._data_format)
-
- elif self._extract_patches_fn == "extract_image_patches":
- assert inputs.shape.ndims == 4
- assert len(self._filter_shape) == 4
- assert len(self._strides) == 4, self._strides
- if self._dilation_rate is None:
- rates = [1, 1, 1, 1]
- else:
- rates = self._dilation_rate
- assert len(rates) == 4
- assert rates[0] == rates[-1] == 1
- patches = array_ops.extract_image_patches(
- inputs,
- ksizes=[1] + list(self._filter_shape[0:-2]) + [1],
- strides=self._strides,
- rates=rates,
- padding=self._padding)
-
- elif self._extract_patches_fn == "extract_pointwise_conv2d_patches":
- assert self._strides in [None, [1, 1, 1, 1], (1, 1, 1, 1)]
- assert self._filter_shape[0] == self._filter_shape[1] == 1
- patches = utils.extract_pointwise_conv2d_patches(
- inputs, self._filter_shape, data_format=None)
-
- else:
- raise NotImplementedError(self._extract_patches_fn)
-
- flatten_size = np.prod(self._filter_shape[0:-1])
- # patches_flat below is the matrix [[A_l]] from the KFC paper (tilde
- # omitted over A for clarity). It has shape M|T| x J|Delta| (eq. 14),
- # where M = minibatch size, |T| = number of spatial locations,
- # |Delta| = number of spatial offsets, and J = number of input maps
- # for convolutional layer l.
- patches_flat = array_ops.reshape(patches, [-1, flatten_size])
-
- # We append a homogenous coordinate to patches_flat if the layer has
- # bias parameters. This gives us [[A_l]]_H from the paper.
- if self._sub_sample_patches:
- patches_flat = _subsample_for_cov_computation(patches_flat)
-
- if self._has_bias:
- patches_flat = append_homog(patches_flat)
- # We call compute_cov without passing in a normalizer. compute_cov uses
- # the first dimension of patches_flat i.e. M|T| as the normalizer by
- # default. Hence we end up computing 1/M|T| * [[A_l]]^T [[A_l]], with
- # shape J|Delta| x J|Delta|. This is related to hat{Omega}_l from
- # the paper but has a different scale here for consistency with
- # ConvOutputKroneckerFactor.
- # (Tilde omitted over A for clarity.)
- return compute_cov(patches_flat)
-
- def _get_data_device(self, tower):
- return self._inputs[tower].device
-
-
-class ConvOutputKroneckerFactor(DenseSquareMatrixFactor):
- r"""Kronecker factor for the output side of a convolutional layer.
-
- Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer
- given example x and ds = (d / d s) log(p(y|x, w)). Expectation is taken over
- all examples and locations.
-
- Equivalent to Gamma in https://arxiv.org/abs/1602.01407 for details. See
- Section 3.1 Estimating the factors.
- """
-
- def __init__(self, outputs_grads, data_format=None):
- """Initializes ConvOutputKroneckerFactor.
-
- Args:
- outputs_grads: List of list of Tensors. Each Tensor is of shape
- [batch_size, ..spatial_input_size.., out_channels]. First list index
- is source, the second is tower.
- data_format: None or str. Format of outputs_grads.
-
- Raises:
- ValueError: If channels are not final dimension.
- """
- if not utils.is_data_format_channel_last(data_format):
- raise ValueError("Channel must be last.")
- self._out_channels = outputs_grads[0][0].shape.as_list()[-1]
- self._outputs_grads = outputs_grads
- super(ConvOutputKroneckerFactor, self).__init__()
-
- @property
- def _var_scope(self):
- return "ff_convoutkron_" + scope_string_from_params(
- nest.flatten(self._outputs_grads))
-
- @property
- def _cov_shape(self):
- size = self._out_channels
- return [size, size]
-
- @property
- def _num_sources(self):
- return len(self._outputs_grads)
-
- @property
- def _num_towers(self):
- return len(self._outputs_grads[0])
-
- @property
- def _dtype(self):
- return self._outputs_grads[0][0].dtype
-
- def _compute_new_cov(self, source, tower):
- outputs_grad = self._outputs_grads[source][tower]
-
- # reshaped_tensor below is the matrix DS_l defined in the KFC paper
- # (tilde omitted over S for clarity). It has shape M|T| x I, where
- # M = minibatch size, |T| = number of spatial locations, and
- # I = number of output maps for convolutional layer l.
- reshaped_tensor = array_ops.reshape(outputs_grad, [-1, self._out_channels])
- # Following the reasoning in ConvInputKroneckerFactor._compute_new_cov,
- # compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l
- # as defined in the paper, with shape I x I.
- # (Tilde omitted over S for clarity.)
- return compute_cov(reshaped_tensor)
-
- def _get_data_device(self, tower):
- return self._outputs_grads[0][tower].device
-
-
-class FullyConnectedMultiKF(FullyConnectedKroneckerFactor):
- """Kronecker factor for a fully connected layer used multiple times."""
-
- def __init__(self,
- tensors,
- num_uses=None,
- has_bias=False):
- """Constructs a new `FullyConnectedMultiKF`.
-
- Args:
- tensors: List of list of Tensors of shape, each of shape
- [num_uses * batch_size, n], and is a reshape version of a Tensor of
- shape [num_uses, batch_size, n]. Each of these tensors is usually a
- layer's inputs or its output's gradients. The first list index is
- sources, the second is towers.
- num_uses: int. The number of time-steps / uses.
- has_bias: bool. If True, '1' is appended to each row.
- """
-
- self._num_uses = num_uses
-
- self._cov_dt1 = None
- self._make_cov_dt1 = False
- self._option1quants_by_damping = {}
- self._option2quants_by_damping = {}
- self._option1quants_registrations = set()
- self._option2quants_registrations = set()
-
- super(FullyConnectedMultiKF, self).__init__(tensors=tensors,
- has_bias=has_bias)
-
- @property
- def _num_timesteps(self):
- return self._num_uses
-
- @property
- def _var_scope(self):
- return "ff_fc_multi_" + scope_string_from_params(
- tuple(nest.flatten(self._tensors))
- + (self._num_timesteps, self._has_bias,))
-
- def make_covariance_update_op(self, ema_decay):
-
- op = super(FullyConnectedMultiKF, self).make_covariance_update_op(ema_decay)
-
- if self._cov_dt1 is not None:
- new_cov_dt1_contribs = []
- for source in range(self._num_sources):
- for tower in range(self._num_towers):
- with place_on_device(self._get_data_device(tower)):
- new_cov_dt1_contribs.append(self._compute_new_cov_dt1(source,
- tower))
-
- new_cov_dt1 = (math_ops.add_n(new_cov_dt1_contribs)
- / float(self._num_towers))
-
- # See comments in FisherFactor.make_covariance_update_op() for details.
- if utils.on_tpu():
- new_cov_dt1 = utils.cross_replica_mean(new_cov_dt1)
-
- op2 = moving_averages.assign_moving_average(
- self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS)
-
- # TODO(b/69112164):
- # It's important that _cov and _cov_dt1 remain consistent with each
- # other while the inverse ops are happening. How can we ensure this?
- # We will need to add explicit synchronization for this to
- # work with asynchronous training.
- op = control_flow_ops.group(op, op2)
-
- return op
-
- def _compute_new_cov_dt1(self, source, tower): # pylint: disable=missing-docstring
- tensor = self._tensors[source][tower]
- if self._has_bias:
- # This appending is technically done twice (the other time is for
- # _compute_new_cov())
- tensor = append_homog(tensor)
-
- total_len = array_ops.shape(tensor)[0]
- batch_size = total_len // self._num_timesteps
-
- tensor_present = tensor[:-batch_size, :]
- tensor_future = tensor[batch_size:, :]
-
- # We specify a normalizer for this computation to ensure a PSD Fisher
- # block estimate. This is equivalent to padding with zeros, as was done
- # in Section B.2 of the appendix.
- return compute_cov(
- tensor_future, tensor_right=tensor_present, normalizer=total_len)
-
- def _get_data_device(self, tower):
- return self._tensors[0][tower].device
-
- @property
- def _vec_shape(self):
- size = self._tensors[0][0].shape[1] + self._has_bias
- return [size]
-
- def get_option1quants(self, damping_func):
- damping_id = graph_func_to_id(damping_func)
- return self._option1quants_by_damping[damping_id]
-
- def get_option2quants(self, damping_func):
- damping_id = graph_func_to_id(damping_func)
- return self._option2quants_by_damping[damping_id]
-
- def get_cov_dt1(self):
- assert self._cov_dt1 is not None
- return self._cov_dt1
-
- def register_cov_dt1(self):
- self._make_cov_dt1 = True
-
- def instantiate_cov_variables(self):
- super(FullyConnectedMultiKF, self).instantiate_cov_variables()
- assert self._cov_dt1 is None
- if self._make_cov_dt1:
- with variable_scope.variable_scope(self._var_scope):
- self._cov_dt1 = variable_scope.get_variable(
- "cov_dt1",
- initializer=init_ops.zeros_initializer,
- shape=self._cov_shape,
- trainable=False,
- dtype=self._dtype)
-
- def register_option1quants(self, damping_func):
- damping_id = self._register_damping(damping_func)
- if damping_id not in self._option1quants_registrations:
- self._option1quants_registrations.add(damping_id)
-
- def register_option2quants(self, damping_func):
- damping_id = self._register_damping(damping_func)
- if damping_id not in self._option2quants_registrations:
- self._option2quants_registrations.add(damping_id)
-
- def instantiate_inv_variables(self):
- super(FullyConnectedMultiKF, self).instantiate_inv_variables()
-
- for damping_id in self._option1quants_registrations:
- damping_func = self._damping_funcs_by_id[damping_id]
- damping_string = graph_func_to_string(damping_func)
- # It's questionable as to whether we should initialize with stuff like
- # this at all. Ideally these values should never be used until they are
- # updated at least once.
- with variable_scope.variable_scope(self._var_scope):
- Lmat = variable_scope.get_variable( # pylint: disable=invalid-name
- "Lmat_damp{}".format(damping_string),
- initializer=inverse_initializer,
- shape=self._cov_shape,
- trainable=False,
- dtype=self._dtype)
- psi = variable_scope.get_variable(
- "psi_damp{}".format(damping_string),
- initializer=init_ops.ones_initializer,
- shape=self._vec_shape,
- trainable=False,
- dtype=self._dtype)
-
- assert damping_id not in self._option1quants_by_damping
- self._option1quants_by_damping[damping_id] = (Lmat, psi)
-
- for damping_id in self._option2quants_registrations:
- damping_func = self._damping_funcs_by_id[damping_id]
- damping_string = graph_func_to_string(damping_func)
- # It's questionable as to whether we should initialize with stuff like
- # this at all. Ideally these values should never be used until they are
- # updated at least once.
- with variable_scope.variable_scope(self._var_scope):
- Pmat = variable_scope.get_variable( # pylint: disable=invalid-name
- "Lmat_damp{}".format(damping_string),
- initializer=inverse_initializer,
- shape=self._cov_shape,
- trainable=False,
- dtype=self._dtype)
- Kmat = variable_scope.get_variable( # pylint: disable=invalid-name
- "Kmat_damp{}".format(damping_string),
- initializer=inverse_initializer,
- shape=self._cov_shape,
- trainable=False,
- dtype=self._dtype)
- mu = variable_scope.get_variable(
- "mu_damp{}".format(damping_string),
- initializer=init_ops.ones_initializer,
- shape=self._vec_shape,
- trainable=False,
- dtype=self._dtype)
-
- assert damping_id not in self._option2quants_by_damping
- self._option2quants_by_damping[damping_id] = (Pmat, Kmat, mu)
-
- def make_inverse_update_ops(self):
- """Create and return update ops corresponding to registered computations."""
- # TODO(b/69918258): Add correctness tests for this method.
- # pylint: disable=invalid-name
-
- ops = []
-
- if (len(self._option1quants_by_damping) +
- len(self._option2quants_by_damping)):
-
- # Note that C0 and C1 are stand-ins for A0 and A1, or G0 and G1, from
- # the pseudo-code in the original paper. Because the computations for
- # the A and G case are essentially the same they can both be performed by
- # the same class (this one).
-
- C1 = self.get_cov_dt1()
-
- # Get the eigendecomposition of C0 (= self.get_cov())
- eigen_e, eigen_V = self.get_eigendecomp()
-
- # TODO(b/69678661): Note, there is an implicit assumption here that C1
- # and C0 (as represented here by its eigen-decomp) are consistent. This
- # could fail to be the case if self._cov and self._cov_dt1 are not updated
- # consistently, or are somehow read between or during the cov updates.
- # Can this possibly happen? Is there a way to prevent it?
-
- for damping_id, (Lmat_var,
- psi_var) in self._option1quants_by_damping.items():
-
- damping = self._damping_funcs_by_id[damping_id]()
- damping = math_ops.cast(damping, self._dtype)
-
- invsqrtC0 = math_ops.matmul(
- eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True)
-
- # Might need to enforce symmetry lost due to numerical issues.
- invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0
-
- # The following line imposes the symmetry assumed by "Option 1" on C1.
- # Strangely the code can work okay with this line commented out,
- # depending on how psd_eig is defined. I'm not sure why.
- C1 = (C1 + array_ops.transpose(C1)) / 2.0
-
- # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi})
- hPsi = math_ops.matmul(math_ops.matmul(invsqrtC0, C1), invsqrtC0)
-
- # Compute the decomposition U*diag(psi)*U^T = hPsi
- psi, U = utils.posdef_eig(hPsi)
-
- # L = C0^(-1/2) * U
- Lmat = math_ops.matmul(invsqrtC0, U)
-
- ops.append(Lmat_var.assign(Lmat))
- ops.append(psi_var.assign(psi))
-
- for damping_id, (Pmat_var, Kmat_var,
- mu_var) in self._option2quants_by_damping.items():
-
- damping = self._damping_funcs_by_id[damping_id]()
- damping = math_ops.cast(damping, self._dtype)
-
- # compute C0^(-1/2)
- invsqrtC0 = math_ops.matmul(
- eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True)
-
- # Might need to enforce symmetry lost due to numerical issues.
- invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0
-
- # Compute the product C0^(-1/2) * C1
- invsqrtC0C1 = math_ops.matmul(invsqrtC0, C1)
-
- # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi})
- hPsi = math_ops.matmul(invsqrtC0C1, invsqrtC0)
-
- # Compute the decomposition E*diag(mu)*E^T = hPsi^T * hPsi
- # Note that we using the notation mu instead of "m" for the eigenvalues.
- # Instead of computing the product hPsi^T * hPsi and then doing an
- # eigen-decomposition of this we just compute the SVD of hPsi and then
- # square the singular values to get the eigenvalues. For a justification
- # of this approach, see:
- # https://en.wikipedia.org/wiki/Singular-value_decomposition#Relation_to_eigenvalue_decomposition
- sqrtmu, _, E = linalg_ops.svd(hPsi)
- mu = math_ops.square(sqrtmu)
-
- # Mathematically, the eigenvalues should not should not exceed 1.0, but
- # due to numerical issues, or possible issues with inconsistent
- # values of C1 and (the eigen-decomposition of) C0 they might. So
- # we enforce this condition.
- mu = math_ops.minimum(mu, 1.0)
-
- # P = (C0^(-1/2) * C1)^T * C0^(-1/2) = C_1^T * C_0^(-1)
- Pmat = math_ops.matmul(invsqrtC0C1, invsqrtC0, transpose_a=True)
-
- # K = C_0^(-1/2) * E
- Kmat = math_ops.matmul(invsqrtC0, E)
-
- ops.append(Pmat_var.assign(Pmat))
- ops.append(Kmat_var.assign(Kmat))
- ops.append(mu_var.assign(mu))
-
- ops += super(FullyConnectedMultiKF, self).make_inverse_update_ops()
- return [control_flow_ops.group(*ops)]
-
- # pylint: enable=invalid-name
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py
deleted file mode 100644
index 2d8e378a93..0000000000
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py
+++ /dev/null
@@ -1,38 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""FisherFactor definitions."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,line-too-long,wildcard-import
-from tensorflow.contrib.kfac.python.ops.fisher_factors import *
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long,wildcard-import
-
-_allowed_symbols = [
- "inverse_initializer", "covariance_initializer",
- "diagonal_covariance_initializer", "scope_string_from_params",
- "scope_string_from_name", "scalar_or_tensor_to_string", "FisherFactor",
- "InverseProvidingFactor", "FullFactor", "DiagonalFactor",
- "NaiveDiagonalFactor", "EmbeddingInputKroneckerFactor",
- "FullyConnectedDiagonalFactor", "FullyConnectedKroneckerFactor",
- "ConvInputKroneckerFactor", "ConvOutputKroneckerFactor",
- "ConvDiagonalFactor", "set_global_constants", "maybe_colocate_with",
- "compute_cov", "append_homog"
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py
deleted file mode 100644
index 43aa713edc..0000000000
--- a/tensorflow/contrib/kfac/python/ops/layer_collection.py
+++ /dev/null
@@ -1,1269 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Registry for layers and their parameters/variables.
-
-This represents the collection of all layers in the approximate Fisher
-information matrix to which a particular FisherBlock may belong. That is, we
-might have several layer collections for one TF graph (if we have multiple K-FAC
-optimizers being used, for example.)
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from collections import defaultdict
-from collections import OrderedDict
-from contextlib import contextmanager
-from functools import partial
-import warnings
-
-import math
-import six
-
-from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb
-from tensorflow.contrib.kfac.python.ops import loss_functions as lf
-from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.util import nest
-
-# Names for various approximations that can be requested for Fisher blocks.
-APPROX_KRONECKER_NAME = "kron"
-APPROX_DIAGONAL_NAME = "diagonal"
-APPROX_FULL_NAME = "full"
-
-_GENERIC_APPROX_TO_BLOCK_TYPES = {
- APPROX_FULL_NAME: fb.FullFB,
- APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB,
-}
-
-_FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES = {
- APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB,
- APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB,
-}
-
-_CONV2D_APPROX_TO_BLOCK_TYPES = {
- APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB,
- APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB,
-}
-
-_EMBEDDING_APPROX_TO_BLOCK_TYPES = {
- APPROX_KRONECKER_NAME: fb.EmbeddingKFACFB
-}
-
-APPROX_KRONECKER_INDEP_NAME = "kron_indep"
-APPROX_KRONECKER_SERIES_1_NAME = "kron_series_1"
-APPROX_KRONECKER_SERIES_2_NAME = "kron_series_2"
-
-_FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES = {
- APPROX_KRONECKER_INDEP_NAME: fb.FullyConnectedMultiIndepFB,
- APPROX_KRONECKER_SERIES_1_NAME: partial(fb.FullyConnectedSeriesFB,
- option=1),
- APPROX_KRONECKER_SERIES_2_NAME: partial(fb.FullyConnectedSeriesFB,
- option=2)
-}
-
-_CONV2D_MULTI_APPROX_TO_BLOCK_TYPES = {
- APPROX_KRONECKER_INDEP_NAME: fb.ConvKFCBasicMultiIndepFB
-}
-
-_EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES = {
- APPROX_KRONECKER_INDEP_NAME: fb.EmbeddingKFACMultiIndepFB
-}
-
-# Possible value for `reuse` keyword argument. Sets `reuse` to
-# tf.get_variable_scope().reuse.
-VARIABLE_SCOPE = "VARIABLE_SCOPE"
-
-_DEFAULT_LAYER_COLLECTION = None
-
-
-def get_default_layer_collection():
- """Get default LayerCollection."""
- if _DEFAULT_LAYER_COLLECTION is None:
- raise ValueError(
- "Attempted to retrieve default LayerCollection when none is set. Use "
- "LayerCollection.as_default().")
-
- return _DEFAULT_LAYER_COLLECTION
-
-
-def set_default_layer_collection(layer_collection):
- global _DEFAULT_LAYER_COLLECTION
-
- if _DEFAULT_LAYER_COLLECTION is not None and layer_collection is not None:
- raise ValueError("Default LayerCollection is already set.")
-
- _DEFAULT_LAYER_COLLECTION = layer_collection
-
-
-class LayerParametersDict(OrderedDict):
- """An OrderedDict where keys are Tensors or tuples of Tensors.
-
- Ensures that no Tensor is associated with two different keys.
- """
-
- def __init__(self, *args, **kwargs):
- self._tensors = set()
- super(LayerParametersDict, self).__init__(*args, **kwargs)
-
- def __setitem__(self, key, value):
- key = self._canonicalize_key(key)
- tensors = key if isinstance(key, (tuple, list)) else (key,)
- key_collisions = self._tensors.intersection(tensors)
- if key_collisions:
- raise ValueError("Key(s) already present: {}".format(key_collisions))
- self._tensors.update(tensors)
- super(LayerParametersDict, self).__setitem__(key, value)
-
- def __delitem__(self, key):
- key = self._canonicalize_key(key)
- self._tensors.remove(key)
- super(LayerParametersDict, self).__delitem__(key)
-
- def __getitem__(self, key):
- key = self._canonicalize_key(key)
- return super(LayerParametersDict, self).__getitem__(key)
-
- def __contains__(self, key):
- key = self._canonicalize_key(key)
- return super(LayerParametersDict, self).__contains__(key)
-
- def _canonicalize_key(self, key):
- if isinstance(key, (list, tuple)):
- return tuple(key)
- return key
-
-
-# TODO(b/68034464): add capability for LayerCollection to be "finalized"
-# and do this when it gets used by FisherEstimator / KfacOptimizer.
-
-
-class LayerCollection(object):
- """Registry of information about layers and losses.
-
- Note that you need to create a new one of these for each MatrixEstimator or
- KfacOptimizer.
-
- Attributes:
- fisher_blocks: a LayersParamsDict (subclass of OrderedDict) mapping layer
- parameters (Tensors or tuples of Tensors) to FisherBlock instances.
- fisher_factors: an OrderedDict mapping tuples to FisherFactor instances.
- losses: a list of LossFunction objects. The loss to be optimized is their
- sum.
- loss_colocation_ops: ops to colocate loss function evaluations with. These
- will typically be the inputs to the losses.
- """
-
- def __init__(self,
- graph=None,
- name="LayerCollection"):
- warnings.warn(
- "tf.contrib.kfac is deprecated and will be removed by 2018-11-01. "
- "Use https://pypi.python.org/pypi/kfac instead.")
- self.fisher_blocks = LayerParametersDict()
- self.fisher_factors = OrderedDict()
- self._linked_parameters = dict(
- ) # dict mapping sets of variables to optionally specified approximations.
- self._graph = graph or ops.get_default_graph()
- self._loss_dict = {} # {str: LossFunction}
- self._subgraph = None
- self._default_generic_approximation = APPROX_DIAGONAL_NAME
- self._default_embedding_approximation = APPROX_KRONECKER_NAME
- self._default_fully_connected_approximation = APPROX_KRONECKER_NAME
- self._default_conv2d_approximation = APPROX_KRONECKER_NAME
- self._default_fully_connected_multi_approximation = (
- APPROX_KRONECKER_INDEP_NAME)
- self._default_conv2d_multi_approximation = (
- APPROX_KRONECKER_INDEP_NAME)
- self._default_embedding_multi_approximation = APPROX_KRONECKER_INDEP_NAME
- self.loss_colocation_ops = {}
- self._vars_to_uses = defaultdict(lambda: 0)
-
- with variable_scope.variable_scope(None, default_name=name) as scope:
- self._var_scope = scope.name
-
- @property
- def losses(self):
- """Tuple of LossFunction objects registered with this LayerCollection."""
- return nest.flatten(self.towers_by_loss)
-
- @property
- def towers_by_loss(self):
- """Tuple across losses of LossFunction objects registered to each tower."""
- return tuple(tuple(lst) for lst in self._loss_dict.values())
-
- @property
- def registered_variables(self):
- """A tuple of all of the variables currently registered."""
- tuple_of_tuples = (utils.ensure_sequence(key) for key, block
- in six.iteritems(self.fisher_blocks))
- flat_tuple = tuple(item for tuple_ in tuple_of_tuples for item in tuple_)
- return flat_tuple
-
- @property
- def linked_parameters(self):
- """Groups of parameters with an optionally specified approximation.
-
- Linked parameters can be added using `define_linked_parameters`.
- If an approximation is specified, then this approximation will be used
- when registering a layer with exactly these parameters, unless an
- approximation is specified when calling the registration function.
-
- Returns:
- A `dict` mapping tuples of parameters to an optional string.
- """
- return self._linked_parameters
-
- @property
- def default_embedding_approximation(self):
- return self._default_embedding_approximation
-
- def set_default_embedding_approximation(self, value):
- if value != APPROX_KRONECKER_NAME:
- raise ValueError(
- "{} is not a valid approximation for embedding variables.".format(
- value))
- self._default_embedding_approximation = value
-
- @property
- def default_generic_approximation(self):
- return self._default_generic_approximation
-
- def set_default_generic_approximation(self, value):
- if value not in _GENERIC_APPROX_TO_BLOCK_TYPES:
- raise ValueError(
- "{} is not a valid approximation for generic variables.".format(
- value))
- self._default_generic_approximation = value
-
- @property
- def default_fully_connected_approximation(self):
- return self._default_fully_connected_approximation
-
- def set_default_fully_connected_approximation(self, value):
- if value not in _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES:
- raise ValueError(
- "{} is not a valid approximation for fully connected layers.".format(
- value))
- self._default_fully_connected_approximation = value
-
- @property
- def default_conv2d_approximation(self):
- return self._default_conv2d_approximation
-
- def set_default_conv2d_approximation(self, value):
- if value not in _CONV2D_APPROX_TO_BLOCK_TYPES:
- raise ValueError(
- "{} is not a valid approximation for 2d convolutional layers.".format(
- value))
- self._default_conv2d_approximation = value
-
- @property
- def default_fully_connected_multi_approximation(self):
- return self._default_fully_connected_multi_approximation
-
- def set_default_fully_connected_multi_approximation(self, value):
- if value not in _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES:
- raise ValueError("{} is not a valid approximation for a fully-connected "
- "multi layer.".format(value))
- self._default_fully_connected_multi_approximation = value
-
- @property
- def default_conv2d_multi_approximation(self):
- return self._default_conv2d_multi_approximation
-
- @property
- def default_embedding_multi_approximation(self):
- return self._default_embedding_multi_approximation
-
- def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE):
- """Validates and registers the layer_key associated with the fisher_block.
-
- Args:
- layer_key: A variable or tuple of variables. The key to check for in
- existing registrations and to register if valid.
- fisher_block: The associated `FisherBlock`.
- reuse: Method to use for inserting new `FisherBlock's. One of True, False,
- or `VARIABLE_SCOPE`.
-
- Raises:
- ValueError: If `layer_key` was already registered and reuse is `False`,
- if `layer_key` was registered with a different block type, or if
- `layer_key` shares any variables with but is not equal to a previously
- registered key.
- KeyError: If `reuse` is `True` but `layer_key` was not previously
- registered.
-
- Returns:
- The `FisherBlock` registered under `layer_key`. If `layer_key` was already
- registered, this will be the previously registered `FisherBlock`.
- """
- if reuse is VARIABLE_SCOPE:
- reuse = variable_scope.get_variable_scope().reuse
-
- if reuse is True or (reuse is variable_scope.AUTO_REUSE and
- layer_key in self.fisher_blocks):
- result = self.fisher_blocks[layer_key]
- if type(result) != type(fisher_block): # pylint: disable=unidiomatic-typecheck
- raise ValueError(
- "Attempted to register FisherBlock of type %s when existing "
- "FisherBlock has type %s." % (type(fisher_block), type(result)))
- return result
- if reuse is False and layer_key in self.fisher_blocks:
- raise ValueError("FisherBlock for %s is already in LayerCollection." %
- (layer_key,))
-
- # Insert fisher_block into self.fisher_blocks.
- if layer_key in self.fisher_blocks:
- raise ValueError("Duplicate registration: {}".format(layer_key))
- # Raise an error if any variable in layer_key has been registered in any
- # other blocks.
- variable_to_block = {
- var: (params, block)
- for (params, block) in self.fisher_blocks.items()
- for var in utils.ensure_sequence(params)
- }
- for variable in utils.ensure_sequence(layer_key):
- if variable in variable_to_block:
- prev_key, prev_block = variable_to_block[variable]
- raise ValueError(
- "Attempted to register layer_key {} with block {}, but variable {}"
- " was already registered in key {} with block {}.".format(
- layer_key, fisher_block, variable, prev_key, prev_block))
- self.fisher_blocks[layer_key] = fisher_block
- return fisher_block
-
- def register_loss_function(self,
- loss,
- colocation_op,
- base_name,
- name=None,
- reuse=VARIABLE_SCOPE):
- """Registers a LossFunction object.
-
- Args:
- loss: The LossFunction object.
- colocation_op: The op to colocate the loss function's computations with.
- base_name: The name to derive a new unique name from is the name argument
- is None.
- name: (OPTIONAL) str or None. Unique name for this loss function. If None,
- a new name is generated. (Default: None)
- reuse: (OPTIONAL) bool or str. If True, adds `loss` as an additional
- tower for the existing loss function.
-
- Raises:
- ValueError: If reuse == True and name == None.
- ValueError: If reuse == True and seed != None.
- KeyError: If reuse == True and no existing LossFunction with `name` found.
- KeyError: If reuse == False and existing LossFunction with `name` found.
- """
-
- name = name or self._graph.unique_name(base_name)
-
- if reuse == VARIABLE_SCOPE:
- reuse = variable_scope.get_variable_scope().reuse
-
- if reuse:
- if name is None:
- raise ValueError(
- "If reuse is enabled, loss function's name must be set.")
-
- loss_list = self._loss_dict.get(name, None)
-
- if loss_list is None:
- raise KeyError(
- "Unable to find loss function named {}. Register a new loss "
- "function with reuse=False.".format(name))
- else:
- if name in self._loss_dict:
- raise KeyError(
- "Loss function named {} already exists. Set reuse=True to append "
- "another tower.".format(name))
-
- loss_list = []
- self._loss_dict[name] = loss_list
-
- loss_list.append(loss)
- self.loss_colocation_ops[loss] = colocation_op
-
- def _get_use_count_map(self):
- """Returns a dict mapping variables to their number of registrations."""
- return self._vars_to_uses
-
- def _add_uses(self, params, uses):
- """Register additional uses by params in the graph.
-
- Args:
- params: Variable or tuple of Variables. Parameters for a layer.
- uses: int or float. Number of additional uses for these parameters.
- """
- params = params if isinstance(params, (tuple, list)) else (params,)
- for var in params:
- self._vars_to_uses[var] += uses
-
- def check_registration(self, variables):
- """Checks that all variable uses have been registered properly.
-
- Args:
- variables: List of variables.
-
- Raises:
- ValueError: If any registered variables are not included in the list.
- ValueError: If any variable in the list is not registered.
- ValueError: If any variable in the list is registered with the wrong
- number of "uses" in the subgraph recorded (vs the number of times that
- variable is actually used in the subgraph).
- """
- # Note that overlapping parameters (i.e. those that share variables) will
- # be caught by layer_collection.LayerParametersDict during registration.
-
- reg_use_map = self._get_use_count_map()
-
- error_messages = []
-
- for var in variables:
- total_uses = self.subgraph.variable_uses(var)
- reg_uses = reg_use_map[var]
-
- if reg_uses == 0:
- error_messages.append("Variable {} not registered.".format(var))
- elif (not math.isinf(reg_uses)) and reg_uses != total_uses:
- error_messages.append(
- "Variable {} registered with wrong number of uses ({} "
- "registrations vs {} uses).".format(var, reg_uses, total_uses))
-
- num_get_vars = len(reg_use_map)
-
- if num_get_vars > len(variables):
- error_messages.append("{} registered variables were not included in list."
- .format(num_get_vars - len(variables)))
-
- if error_messages:
- error_messages = [
- "Found the following errors with variable registration:"
- ] + error_messages
- raise ValueError("\n\t".join(error_messages))
-
- def get_blocks(self):
- return self.fisher_blocks.values()
-
- def get_factors(self):
- return self.fisher_factors.values()
-
- @property
- def graph(self):
- return self._graph
-
- @property
- def subgraph(self):
- return self._subgraph
-
- def define_linked_parameters(self, params, approximation=None):
- """Identify a set of parameters that should be grouped together.
-
- During automatic graph scanning, any matches containing variables that have
- been identified as part of a linked group will be filtered out unless
- the match parameters are exactly equal to the ones specified in the linked
- group.
-
- Args:
- params: A variable, or a tuple or list of variables. The variables
- to be linked.
- approximation: Optional string specifying the type of approximation to use
- for these variables. If unspecified, this layer collection's default
- approximation for the layer type will be used.
-
- Raises:
- ValueError: If the parameters were already registered in a layer or
- identified as part of an incompatible group.
- """
- params = frozenset(utils.ensure_sequence(params))
-
- # Check if any of the variables in `params` is already in
- # 'self.fisher_blocks.keys()`.
- for registered_params, fisher_block in self.fisher_blocks.items():
- registered_params_set = set(utils.ensure_sequence(registered_params))
- for variable in params:
- if (variable in registered_params_set and
- params != registered_params_set):
- raise ValueError(
- "Can`t link parameters {}, variable {} was already registered in "
- "group {} with layer {}".format(params, variable,
- registered_params, fisher_block))
-
- # Check if any of the variables in `params` is already in
- # 'self.linked_parameters`.
- for variable in params:
- for other_linked_params in self.linked_parameters:
- if variable in other_linked_params:
- raise ValueError("Can`t link parameters {}, variable {} was already "
- "linked in group {}.".format(params, variable,
- other_linked_params))
- self._linked_parameters[params] = approximation
-
- def create_subgraph(self):
- if not self.losses:
- raise ValueError("Must have at least one registered loss.")
- inputs_to_losses = nest.flatten(tuple(loss.inputs for loss in self.losses))
- self._subgraph = utils.SubGraph(inputs_to_losses)
-
- def eval_losses(self):
- """Return evaluated losses (colocated with inputs to losses)."""
- evals = []
- for loss in self.losses:
- with ops.colocate_with(self.loss_colocation_ops[loss]):
- evals.append(loss.evaluate())
- return evals
-
- def eval_losses_on_samples(self):
- """Return losses evaluated on samples (colocated with inputs to losses)."""
- evals = []
- for loss in self.losses:
- with ops.colocate_with(self.loss_colocation_ops[loss]):
- evals.append(loss.evaluate_on_sample())
- return evals
-
- def total_loss(self):
- return math_ops.add_n(self.eval_losses())
-
- def total_sampled_loss(self):
- return math_ops.add_n(self.eval_losses_on_samples())
-
- def _get_linked_approx(self, params):
- """If params were linked, return their specified approximation."""
- params_set = frozenset(utils.ensure_sequence(params))
- if params_set in self.linked_parameters:
- return self.linked_parameters[params_set]
- else:
- return None
-
- def _get_block_type(self, params, approx, default, approx_to_type):
- if approx is None:
- approx = self._get_linked_approx(params)
- if approx is None:
- approx = default
-
- if approx not in approx_to_type:
- raise ValueError("Bad value {} for approx.".format(approx))
-
- return approx_to_type[approx], approx
-
- def register_embedding(self,
- params,
- inputs,
- outputs,
- approx=None,
- reuse=VARIABLE_SCOPE):
- """Registers an embedding layer.
-
- Args:
- params: Embedding matrix of shape [vocab_size, embedding_size].
- inputs: Tensor of shape [batch_size, input_size] and dtype int32. Indices
- into embedding matrix.
- outputs: Tensor of shape [batch_size, embedding_size]. Outputs
- produced by layer.
- approx: str or None. If not None must be "kron". The Fisher
- approximation to use. If None the default value is used. (Default: None)
- reuse: bool or str. If True, this adds `inputs` and `outputs` as an
- additional mini-batch/tower of data to use when estimating the Fisher
- block for this layer (which must have already been registered). If
- "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
- (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- KeyError: If reuse == True but no FisherBlock found for `params`.
- ValueError: If reuse == True and FisherBlock found but of the wrong type.
- """
- block_type, approx = self._get_block_type(
- params, approx, self.default_embedding_approximation,
- _EMBEDDING_APPROX_TO_BLOCK_TYPES)
-
- if isinstance(params, (tuple, list)):
- raise ValueError("Bias not supported.")
- vocab_size = int(params.shape[0])
- block = self.register_block(
- params, block_type(self, vocab_size), reuse=reuse)
- block.register_additional_tower(inputs, outputs)
-
- self._add_uses(params, 1)
-
- def register_fully_connected(self,
- params,
- inputs,
- outputs,
- approx=None,
- reuse=VARIABLE_SCOPE):
- """Registers a fully connected layer.
-
- Args:
- params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
- this layer. Weight matrix should have shape [input_size, output_size].
- Bias should have shape [output_size].
- inputs: Tensor of shape [batch_size, input_size]. Inputs to layer.
- outputs: Tensor of shape [batch_size, output_size]. Outputs
- produced by layer.
- approx: str or None. If not None must be one of "kron" or "diagonal".
- The Fisher approximation to use. If None the default value is used.
- (Default: None)
- reuse: bool or str. If True, this adds `inputs` and `outputs` as an
- additional mini-batch/tower of data to use when estimating the Fisher
- block for this layer (which must have already been registered). If
- "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
- (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- KeyError: If reuse == True but no FisherBlock found for `params`.
- ValueError: If reuse == True and FisherBlock found but of the wrong type.
- """
-
- block_type, approx = self._get_block_type(
- params, approx, self.default_fully_connected_approximation,
- _FULLY_CONNECTED_APPROX_TO_BLOCK_TYPES)
-
- has_bias = isinstance(params, (tuple, list))
- block = self.register_block(params, block_type(self, has_bias=has_bias),
- reuse=reuse)
- block.register_additional_tower(inputs, outputs)
-
- self._add_uses(params, 1)
-
- def register_conv2d(self,
- params,
- strides,
- padding,
- inputs,
- outputs,
- data_format=None,
- dilations=None,
- approx=None,
- reuse=VARIABLE_SCOPE):
- """Registers a call to tf.nn.conv2d().
-
- Args:
- params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
- this layer. Weight matrix should have shape [kernel_height,
- kernel_width, in_channels, out_channels]. Bias should have shape
- [out_channels].
- strides: List of 4 ints. Strides for convolution kernel.
- padding: string. see tf.nn.conv2d for valid values.
- inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs
- to layer.
- outputs: Tensor of shape [batch_size, height, width, out_channels].
- Output produced by layer.
- data_format: str or None. Format of data.
- dilations: List of 4 ints. Dilations along each dimension.
- approx: str or None. If not None must be one of "kron" or "diagonal".
- The Fisher approximation to use. If None the default value is used.
- (Default: None)
- reuse: bool or str. If True, this adds `inputs` and `outputs` as an
- additional mini-batch/tower of data to use when estimating the Fisher
- block for this layer (which must have already been registered). If
- "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
- (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- KeyError: If reuse == True but no FisherBlock found for `params`.
- ValueError: If reuse == True and FisherBlock found but of the wrong type.
- """
-
- block_type, approx = self._get_block_type(
- params, approx, self.default_conv2d_approximation,
- _CONV2D_APPROX_TO_BLOCK_TYPES)
-
- # It feels bad to pass in configuration that has to do with the internal
- # implementation. And then we can`t use the same constructor for both
- # anymore and are thus forced to use this ugly if-statement.
- # TODO(b/74793309): Clean this up?
- if approx == APPROX_KRONECKER_NAME:
- block = self.register_block(
- params,
- block_type(
- layer_collection=self,
- params=params,
- padding=padding,
- strides=strides,
- data_format=data_format,
- dilation_rate=dilations,
- extract_patches_fn="extract_image_patches"),
- reuse=reuse)
- elif approx == APPROX_DIAGONAL_NAME:
- assert strides[0] == strides[-1] == 1
- block = self.register_block(
- params,
- block_type(
- layer_collection=self,
- params=params,
- padding=padding,
- strides=strides,
- dilations=dilations,
- data_format=data_format),
- reuse=reuse)
- else:
- raise NotImplementedError(approx)
-
- block.register_additional_tower(inputs, outputs)
-
- self._add_uses(params, 1)
-
- def register_convolution(self,
- params,
- inputs,
- outputs,
- padding,
- strides=None,
- dilation_rate=None,
- data_format=None,
- approx=None,
- reuse=VARIABLE_SCOPE):
- """Register a call to tf.nn.convolution().
-
- Args:
- params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
- this layer. Weight matrix should have shape [..filter_spatial_size..,
- in_channels, out_channels]. Bias should have shape [out_channels].
- inputs: Tensor of shape [batch_size, ..input_spatial_size.., in_channels].
- Inputs to layer.
- outputs: Tensor of shape [batch_size, ..output_spatial_size..,
- out_channels]. Output produced by layer.
- padding: string. see tf.nn.conv2d for valid values.
- strides: List of ints of length len(..input_spatial_size..). Strides for
- convolution kernel in spatial dimensions.
- dilation_rate: List of ints of length len(..input_spatial_size..).
- Dilations along spatial dimension.
- data_format: str or None. Format of data.
- approx: str or None. If not None must be one of "kron" or "diagonal".
- The Fisher approximation to use. If None the default value is used.
- (Default: None)
- reuse: bool or str. If True, this adds `inputs` and `outputs` as an
- additional mini-batch/tower of data to use when estimating the Fisher
- block for this layer (which must have already been registered). If
- "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
- (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- KeyError: If reuse == True but no FisherBlock found for `params`.
- ValueError: If reuse == True and FisherBlock found but of the wrong type.
- """
- # TODO(b/74793309): Have this use _get_block_type like the other
- # registration functions?
- assert approx is None or approx == APPROX_KRONECKER_NAME
-
- block = self.register_block(
- params,
- fb.ConvKFCBasicFB(
- layer_collection=self,
- params=params,
- padding=padding,
- strides=strides,
- dilation_rate=dilation_rate,
- data_format=data_format),
- reuse=reuse)
- block.register_additional_tower(inputs, outputs)
-
- self._add_uses(params, 1)
-
- def register_depthwise_conv2d(self,
- params,
- inputs,
- outputs,
- strides,
- padding,
- rate=None,
- data_format=None,
- approx=None,
- reuse=VARIABLE_SCOPE):
- """Register a call to tf.nn.depthwise_conv2d().
-
- Args:
- params: 4-D Tensor of shape [filter_height, filter_width,
- in_channels, channel_multiplier]. Convolutional filter.
- inputs: Tensor of shape [batch_size, input_height, input_width,
- in_channels]. Inputs to layer.
- outputs: Tensor of shape [batch_size, output_height, output_width,
- in_channels * channel_multiplier]. Output produced by depthwise conv2d.
- strides: List of ints of length 4. Strides along all dimensions.
- padding: string. see tf.nn.conv2d for valid values.
- rate: None or List of ints of length 2. Dilation rates in spatial
- dimensions.
- data_format: str or None. Format of data.
- approx: str or None. If not None must "diagonal". The Fisher
- approximation to use. If None the default value is used. (Default: None)
- reuse: bool or str. If True, this adds `inputs` and `outputs` as an
- additional mini-batch/tower of data to use when estimating the Fisher
- block for this layer (which must have already been registered). If
- "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
- (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- KeyError: If reuse == True but no FisherBlock found for `params`.
- ValueError: If reuse == True and FisherBlock found but of the wrong type.
- """
- # TODO(b/74793309): Have this use _get_block_type like the other
- # registration functions?
- assert approx is None or approx == APPROX_DIAGONAL_NAME
- assert data_format in [None, "NHWC"]
-
- block = self.register_block(
- params,
- fb.DepthwiseConvDiagonalFB(
- layer_collection=self,
- params=params,
- strides=strides,
- padding=padding,
- rate=rate,
- data_format=data_format),
- reuse=reuse)
- block.register_additional_tower(inputs, outputs)
-
- self._add_uses(params, 1)
-
- def register_separable_conv2d(self,
- depthwise_params,
- pointwise_params,
- inputs,
- depthwise_outputs,
- pointwise_outputs,
- strides,
- padding,
- rate=None,
- data_format=None,
- approx=None,
- reuse=VARIABLE_SCOPE):
- """Register a call to tf.nn.separable_conv2d().
-
- Note: This requires access to intermediate outputs between depthwise and
- pointwise convolutions.
-
- Args:
- depthwise_params: 4-D Tensor of shape [filter_height, filter_width,
- in_channels, channel_multiplier]. Filter for depthwise conv2d.
- pointwise_params: 4-D Tensor of shape [1, 1, in_channels *
- channel_multiplier, out_channels]. Filter for pointwise conv2d.
- inputs: Tensor of shape [batch_size, input_height, input_width,
- in_channels]. Inputs to layer.
- depthwise_outputs: Tensor of shape [batch_size, output_height,
- output_width, in_channels * channel_multiplier]. Output produced by
- depthwise conv2d.
- pointwise_outputs: Tensor of shape [batch_size, output_height,
- output_width, out_channels]. Output produced by pointwise conv2d.
- strides: List of ints of length 4. Strides for depthwise conv2d kernel in
- all dimensions.
- padding: string. see tf.nn.conv2d for valid values.
- rate: None or List of ints of length 2. Dilation rate of depthwise conv2d
- kernel in spatial dimensions.
- data_format: str or None. Format of data.
- approx: str or None. If not None must be one of "kron" or "diagonal".
- The Fisher approximation to use. If None the default value is used.
- (Default: None)
- reuse: bool or str. If True, this adds `inputs` and `outputs` as an
- additional mini-batch/tower of data to use when estimating the Fisher
- block for this layer (which must have already been registered). If
- "VARIABLE_SCOPE", use tf.get_variable_scope().reuse.
- (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- KeyError: If reuse == True but no FisherBlock found for `params`.
- ValueError: If reuse == True and FisherBlock found but of the wrong type.
- """
- self.register_depthwise_conv2d(
- params=depthwise_params,
- inputs=inputs,
- outputs=depthwise_outputs,
- strides=strides,
- padding=padding,
- rate=rate,
- data_format=data_format,
- approx=APPROX_DIAGONAL_NAME,
- reuse=reuse)
-
- self.register_conv2d(
- params=pointwise_params,
- inputs=depthwise_outputs,
- outputs=pointwise_outputs,
- strides=[1, 1, 1, 1],
- padding="VALID",
- data_format=data_format,
- approx=approx,
- reuse=reuse)
-
- def register_generic(self,
- params,
- batch_size,
- approx=None,
- reuse=VARIABLE_SCOPE):
- """Registers a generic layer.
-
- Args:
- params: Tensor or tuple of Tensors corresponding to the parameters.
- batch_size: 0-D Tensor. Size of the minibatch (for this tower).
- approx: str or None. It not None, must be one of "full" or "diagonal".
- The Fisher approximation to use. If None the default value is used.
- (Default: None)
- reuse: bool or str. If True, this adds `batch_size` to the total
- mini-batch size use when estimating the Fisher block for this layer
- (which must have already been registered). If "VARIABLE_SCOPE", use
- tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- KeyError: If reuse == True but no FisherBlock found for `params`.
- ValueError: If reuse == True and FisherBlock found but of the wrong type.
- """
- block_type, approx = self._get_block_type(
- params, approx, self.default_generic_approximation,
- _GENERIC_APPROX_TO_BLOCK_TYPES)
-
- block = self.register_block(params, block_type(self, params), reuse=reuse)
- block.register_additional_tower(batch_size)
-
- self._add_uses(params, float("inf"))
-
- def register_fully_connected_multi(self, params, inputs, outputs,
- num_uses=None, approx=None,
- reuse=VARIABLE_SCOPE):
- """Register fully connected layers with shared parameters.
-
- This can handle general fully-connected layers with shared parameters, but
- has specialized approximations to deal with the case where there is a
- meaningful linear order to the share instances (such as in an RNN).
-
- Args:
- params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
- this layer. Weight matrix should have shape [input_size, output_size].
- Bias should have shape [output_size].
- inputs: A list of Tensors, each of shape [batch_size, input_size]. Inputs
- to layer. The list indexes each use in the graph (which might
- correspond to a "time-step" in an RNN). OR, can be single Tensor, of
- shape [num_uses * batch_size , input_size], which is a reshaped version
- of a Tensor of shape [num_uses, batch_size, input_size].
- outputs: A list of Tensors, the same length as `inputs`, each of shape
- [batch_size, output_size]. Outputs produced by layer. The list indexes
- each use in the graph (which might correspond to a "time-step" in an
- RNN). Needs to correspond with the order used in `inputs`. OR, can be
- a single Tensor of shape [num_uses * batch_size, output_size], which is
- a reshaped version of a Tensor of shape [num_uses, batch_size,
- output_size].
- num_uses: int or None. The number uses/time-steps in the graph where the
- layer appears. Only needed if both inputs and outputs are given in the
- single Tensor format. (Default: None)
- approx: str or None. If not None, must be of "kron_indep", "kron_series_1"
- or "kron_series_2". The Fisher approximation to use. If None the default
- value is used. (Default: None)
- reuse: bool or str. If True, this adds `inputs` and `outputs` as an
- additional mini-batch/tower of data to use when estimating the Fisher
- block for this layer (which must have already been registered). If
- "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the
- word `use` here has a completely different meaning to "use in the graph"
- as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.)
- (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- """
- block_type, approx = self._get_block_type(
- params, approx, self.default_fully_connected_multi_approximation,
- _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES)
-
- # TODO(b/70283649): something along the lines of find_canonical_output
- # should be added back in here (and for the other block types, arguably).
-
- has_bias = isinstance(params, (tuple, list))
- block = self.register_block(params, block_type(self, has_bias=has_bias,
- num_uses=num_uses),
- reuse=reuse)
- block.register_additional_tower(inputs, outputs)
- if isinstance(inputs, (tuple, list)):
- assert len(inputs) == len(outputs)
- self._add_uses(params, len(inputs))
- else:
- self._add_uses(params, 1)
-
- def register_conv2d_multi(self,
- params,
- strides,
- padding,
- inputs,
- outputs,
- num_uses=None,
- data_format=None,
- dilations=None,
- approx=None,
- reuse=VARIABLE_SCOPE):
- """Registers convolutional layers with shared parameters.
-
- Args:
- params: Tensor or 2-tuple of Tensors corresponding to weight and bias of
- this layer. Weight matrix should have shape [kernel_height,
- kernel_width, in_channels, out_channels]. Bias should have shape
- [out_channels].
- strides: 1-D Tensor of length 4. Strides for convolution kernel.
- padding: string. see tf.nn.conv2d for valid values.
- inputs: A list of Tensors, each of shape [batch_size, height, width,
- in_channels]. Inputs to layer. The list indexes each use in the graph
- (which might correspond to a "time-step" in an RNN). OR, can be single
- Tensor, of shape [num_uses * batch_size, height, width, in_channels],
- which is a reshaped version of a Tensor of shape [num_uses, batch_size,
- height, width, in_channels].
- outputs: A list of Tensors, each of shape [batch_size, height, width,
- out_channels]. Output produced by layer. The list indexes each use
- in the graph (which might correspond to a "time-step" in an RNN).
- Needs to correspond with the order used in `inputs`. OR, can be a
- single Tensor, of shape [num_uses * batch_size, height, width,
- out_channels], which is a reshaped version of a Tensor of shape
- [num_uses, batch_size, height, width, out_channels].
- num_uses: int or None. The number uses/time-steps in the graph where the
- layer appears. Only needed if both inputs and outputs are given in the
- single Tensor format. (Default: None)
- data_format: str or None. Format of data.
- dilations: List of 4 ints. Dilations along each dimension.
- approx: str or None. If not None must by "kron_indep". The Fisher
- approximation to use. If None the default value is used.
- (Default: None)
- reuse: bool or str. If True, this adds `inputs` and `outputs` as an
- additional mini-batch/tower of data to use when estimating the Fisher
- block for this layer (which must have already been registered). If
- "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the
- word `use` here has a completely different meaning to "use in the graph"
- as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.)
- (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- KeyError: If reuse == True but no FisherBlock found for `params`.
- ValueError: If reuse == True and FisherBlock found but of the wrong type.
- """
- block_type, approx = self._get_block_type(
- params, approx, self.default_conv2d_multi_approximation,
- _CONV2D_MULTI_APPROX_TO_BLOCK_TYPES)
-
- block = self.register_block(
- params,
- block_type(
- layer_collection=self,
- params=params,
- padding=padding,
- strides=strides,
- data_format=data_format,
- dilation_rate=dilations,
- extract_patches_fn="extract_image_patches",
- num_uses=num_uses),
- reuse=reuse)
-
- block.register_additional_tower(inputs, outputs)
- if isinstance(inputs, (tuple, list)):
- assert len(inputs) == len(outputs)
- self._add_uses(params, len(inputs))
- else:
- self._add_uses(params, 1)
-
- # TODO(b/74108452): change the loss registration functions names to refer
- # to "loss functions" instead of distributions. Following naming convention
- # of the loss function classes themselves.
-
- def register_embedding_multi(self,
- params,
- inputs,
- outputs,
- num_uses=None,
- approx=None,
- reuse=VARIABLE_SCOPE):
- """Registers embedding layers with shared parameters.
-
- Args:
- params: Embedding matrix of shape [vocab_size, embedding_size].
- inputs: A list of Tensors, each of shape [batch_size, input_size] and
- dtype int32. Indices into embedding matrix. The list indexes each use
- in the graph (which might correspond to a "time-step" in an RNN).
- OR, can be single Tensor, of shape [num_uses*batch_size, input_size],
- which is a reshaped version of a Tensor of shape [num_uses, batch_size,
- input_size].
- outputs: A list of Tensors, each of shape [batch_size, embedding_size].
- Outputs produced by layer. The list indexes each use in the graph
- (which might correspond to a "time-step" in an RNN). Needs to
- correspond with the order used in `inputs`. OR, can be a
- single Tensor, of shape [num_uses * batch_size, embedding_size], which
- is a reshaped version of a Tensor of shape [num_uses, batch_size,
- embedding_size].
- num_uses: int or None. The number uses/time-steps in the graph where the
- layer appears. Only needed if both inputs and outputs are given in the
- single Tensor format. (Default: None)
- approx: str or None. If not None must by "kron_indep". The Fisher
- approximation to use. If None the default value is used.
- (Default: None)
- reuse: bool or str. If True, this adds `inputs` and `outputs` as an
- additional mini-batch/tower of data to use when estimating the Fisher
- block for this layer (which must have already been registered). If
- "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. (Note that the
- word `use` here has a completely different meaning to "use in the graph"
- as it pertains to the `inputs`, `outputs`, and `num_uses` arguments.)
- (Default: "VARIABLE_SCOPE")
-
- Raises:
- ValueError: For improper value to `approx`.
- KeyError: If reuse == True but no FisherBlock found for `params`.
- ValueError: If reuse == True and FisherBlock found but of the wrong type.
- """
- block_type, approx = self._get_block_type(
- params, approx, self.default_embedding_multi_approximation,
- _EMBEDDING_MULTI_APPROX_TO_BLOCK_TYPES)
-
- if isinstance(params, (tuple, list)):
- raise ValueError("Bias not supported.")
- vocab_size = int(params.shape[0])
-
- block = self.register_block(
- params, block_type(self, vocab_size, num_uses=num_uses), reuse=reuse)
- block.register_additional_tower(inputs, outputs)
-
- if isinstance(inputs, (tuple, list)):
- self._add_uses(params, len(inputs))
- else:
- self._add_uses(params, 1)
-
- def register_categorical_predictive_distribution(self,
- logits,
- seed=None,
- targets=None,
- name=None,
- reuse=VARIABLE_SCOPE):
- """Registers a categorical predictive distribution.
-
- Args:
- logits: The logits of the distribution (i.e. its parameters).
- seed: The seed for the RNG (for debugging) (Default: None)
- targets: (OPTIONAL) The targets for the loss function. Only required if
- one wants to call total_loss() instead of total_sampled_loss().
- total_loss() is required, for example, to estimate the
- "empirical Fisher" (instead of the true Fisher).
- (Default: None)
- name: (OPTIONAL) str or None. Unique name for this loss function. If None,
- a new name is generated. (Default: None)
- reuse: bool or str. If True, this adds `logits` as an additional
- mini-batch/tower of inputs to the loss-function/predictive distribution
- (which must have already been registered). If "VARIABLE_SCOPE", use
- tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
- """
- loss = lf.CategoricalLogitsNegativeLogProbLoss(logits, targets=targets,
- seed=seed)
- self.register_loss_function(loss, logits,
- "categorical_predictive_distribution",
- name=name, reuse=reuse)
-
- def register_normal_predictive_distribution(self,
- mean,
- var=0.5,
- seed=None,
- targets=None,
- name=None,
- reuse=VARIABLE_SCOPE):
- """Registers a normal predictive distribution.
-
- Args:
- mean: The mean vector defining the distribution.
- var: The variance (must be a scalar). Note that the default value of
- 0.5 corresponds to a standard squared error loss (target -
- prediction)**2. If your squared error loss is of the form
- 0.5*(target - prediction)**2 you should use var=1.0. (Default: 0.5)
- seed: The seed for the RNG (for debugging) (Default: None)
- targets: (OPTIONAL) The targets for the loss function. Only required if
- one wants to call total_loss() instead of total_sampled_loss().
- total_loss() is required, for example, to estimate the
- "empirical Fisher" (instead of the true Fisher).
- (Default: None)
- name: (OPTIONAL) str or None. Unique name for this loss function. If None,
- a new name is generated. (Default: None)
- reuse: bool or str. If True, this adds `mean` and `var` as an additional
- mini-batch/tower of inputs to the loss-function/predictive distribution
- (which must have already been registered). If "VARIABLE_SCOPE", use
- tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
- """
- loss = lf.NormalMeanNegativeLogProbLoss(mean, var, targets=targets,
- seed=seed)
- self.register_loss_function(loss, mean,
- "normal_predictive_distribution",
- name=name, reuse=reuse)
-
- def register_multi_bernoulli_predictive_distribution(self,
- logits,
- seed=None,
- targets=None,
- name=None,
- reuse=VARIABLE_SCOPE):
- """Registers a multi-Bernoulli predictive distribution.
-
- Args:
- logits: The logits of the distribution (i.e. its parameters).
- seed: The seed for the RNG (for debugging) (Default: None)
- targets: (OPTIONAL) The targets for the loss function. Only required if
- one wants to call total_loss() instead of total_sampled_loss().
- total_loss() is required, for example, to estimate the
- "empirical Fisher" (instead of the true Fisher).
- (Default: None)
- name: (OPTIONAL) str or None. Unique name for this loss function. If None,
- a new name is generated. (Default: None)
- reuse: bool or str. If True, this adds `logits` as an additional
- mini-batch/tower of inputs to the loss-function/predictive distribution
- (which must have already been registered). If "VARIABLE_SCOPE", use
- tf.get_variable_scope().reuse. (Default: "VARIABLE_SCOPE")
- """
- loss = lf.MultiBernoulliNegativeLogProbLoss(logits, targets=targets,
- seed=seed)
- self.register_loss_function(loss, logits,
- "multi_bernoulli_predictive_distribution",
- name=name, reuse=reuse)
-
- def make_or_get_factor(self, cls, args):
- """Insert `cls(args)` into 'self.fisher_factors` if not already present.
-
- Wraps constructor in `tf.variable_scope()` to ensure variables constructed
- in `cls.__init__` are placed under this LayerCollection's scope.
-
- Args:
- cls: Class that implements FisherFactor.
- args: Tuple of arguments to pass into `cls's constructor. Must be
- hashable.
-
- Returns:
- Instance of `cls` found in self.fisher_factors.
- """
- try:
- hash(args)
- except TypeError:
- raise TypeError(
- ("Unable to use (cls, args) = ({}, {}) as a key in "
- "LayerCollection.fisher_factors. The pair cannot be hashed.").format(
- cls, args))
-
- key = cls, args
- if key not in self.fisher_factors:
- with variable_scope.variable_scope(self._var_scope):
- self.fisher_factors[key] = cls(*args)
- return self.fisher_factors[key]
-
- @contextmanager
- def as_default(self):
- """Sets this LayerCollection as the default."""
- set_default_layer_collection(self)
- yield
- set_default_layer_collection(None)
diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py
deleted file mode 100644
index 9f46853807..0000000000
--- a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py
+++ /dev/null
@@ -1,46 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Registry for layers and their parameters/variables.
-
-This represents the collection of all layers in the approximate Fisher
-information matrix to which a particular FisherBlock may belong. That is, we
-might have several layer collections for one TF graph (if we have multiple K-FAC
-optimizers being used, for example.)
-"""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,line-too-long,wildcard-import
-from tensorflow.contrib.kfac.python.ops.layer_collection import *
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long,wildcard-import
-
-_allowed_symbols = [
- "get_default_layer_collection",
- "set_default_layer_collection",
- "LayerParametersDict",
- "LayerCollection",
- "APPROX_KRONECKER_NAME",
- "APPROX_DIAGONAL_NAME",
- "APPROX_FULL_NAME",
- "VARIABLE_SCOPE",
- "APPROX_KRONECKER_INDEP_NAME",
- "APPROX_KRONECKER_SERIES_1_NAME",
- "APPROX_KRONECKER_SERIES_2_NAME"
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/linear_operator.py b/tensorflow/contrib/kfac/python/ops/linear_operator.py
deleted file mode 100644
index 61cb955ae8..0000000000
--- a/tensorflow/contrib/kfac/python/ops/linear_operator.py
+++ /dev/null
@@ -1,95 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""SmartMatrices definitions."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.kfac.python.ops import utils
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops.linalg import linalg
-from tensorflow.python.ops.linalg import linalg_impl
-from tensorflow.python.ops.linalg import linear_operator_util as lou
-
-
-class LinearOperatorExtras(object): # pylint: disable=missing-docstring
-
- def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
-
- with self._name_scope(name, values=[x]):
- if isinstance(x, ops.IndexedSlices):
- return self._matmul_sparse(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
-
- x = ops.convert_to_tensor(x, name="x")
- self._check_input_dtype(x)
-
- self_dim = -2 if adjoint else -1
- arg_dim = -1 if adjoint_arg else -2
- self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim])
-
- return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
-
- def matmul_right(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
-
- with self._name_scope(name, values=[x]):
-
- if isinstance(x, ops.IndexedSlices):
- return self._matmul_right_sparse(
- x, adjoint=adjoint, adjoint_arg=adjoint_arg)
-
- x = ops.convert_to_tensor(x, name="x")
- self._check_input_dtype(x)
-
- self_dim = -1 if adjoint else -2
- arg_dim = -2 if adjoint_arg else -1
- self.shape[self_dim].assert_is_compatible_with(x.get_shape()[arg_dim])
-
- return self._matmul_right(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
-
-
-class LinearOperatorFullMatrix(LinearOperatorExtras,
- linalg.LinearOperatorFullMatrix):
-
- # TODO(b/78117889) Remove this definition once core LinearOperator
- # has _matmul_right.
- def _matmul_right(self, x, adjoint=False, adjoint_arg=False):
- return lou.matmul_with_broadcast(
- x, self._matrix, adjoint_a=adjoint_arg, adjoint_b=adjoint)
-
- def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False):
- raise NotImplementedError
-
- def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False):
- assert not adjoint and not adjoint_arg
- return utils.matmul_sparse_dense(x, self._matrix)
-
-
-class LinearOperatorDiag(LinearOperatorExtras, # pylint: disable=missing-docstring
- linalg.LinearOperatorDiag):
-
- def _matmul_right(self, x, adjoint=False, adjoint_arg=False):
- diag_mat = math_ops.conj(self._diag) if adjoint else self._diag
- x = linalg_impl.adjoint(x) if adjoint_arg else x
- return diag_mat * x
-
- def _matmul_sparse(self, x, adjoint=False, adjoint_arg=False):
- diag_mat = math_ops.conj(self._diag) if adjoint else self._diag
- assert not adjoint_arg
- return utils.matmul_diag_sparse(diag_mat, x)
-
- def _matmul_right_sparse(self, x, adjoint=False, adjoint_arg=False):
- raise NotImplementedError
diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py
deleted file mode 100644
index c8cebc42cb..0000000000
--- a/tensorflow/contrib/kfac/python/ops/loss_functions.py
+++ /dev/null
@@ -1,754 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Loss functions to be used by LayerCollection."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import abc
-
-import six
-
-from tensorflow.contrib.distributions.python.ops import onehot_categorical
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops.distributions import bernoulli
-from tensorflow.python.ops.distributions import categorical
-from tensorflow.python.ops.distributions import normal
-
-
-@six.add_metaclass(abc.ABCMeta)
-class LossFunction(object):
- """Abstract base class for loss functions.
-
- Note that unlike typical loss functions used in neural networks these are
- summed and not averaged across cases in the batch, since this is what the
- users of this class (FisherEstimator and MatrixVectorProductComputer) will
- be expecting. The implication of this is that you will may want to
- normalize things like Fisher-vector products by the batch size when you
- use this class. It depends on the use case.
- """
-
- @abc.abstractproperty
- def targets(self):
- """The targets being predicted by the model.
-
- Returns:
- None or Tensor of appropriate shape for calling self._evaluate() on.
- """
- pass
-
- @abc.abstractproperty
- def inputs(self):
- """The inputs to the loss function (excluding the targets)."""
- pass
-
- def evaluate(self):
- """Evaluate the loss function on the targets."""
- if self.targets is not None:
- # We treat the targets as "constant". It's only the inputs that get
- # "back-propped" through.
- return self._evaluate(array_ops.stop_gradient(self.targets))
- else:
- raise Exception("Cannot evaluate losses with unspecified targets.")
-
- @abc.abstractmethod
- def _evaluate(self, targets):
- """Evaluates the negative log probability of the targets.
-
- Args:
- targets: Tensor that distribution can calculate log_prob() of.
-
- Returns:
- negative log probability of each target, summed across all targets.
- """
- pass
-
- @abc.abstractmethod
- def multiply_hessian(self, vector):
- """Right-multiply a vector by the Hessian.
-
- Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
- of the loss function with respect to its inputs.
-
- Args:
- vector: The vector to multiply. Must be the same shape(s) as the
- 'inputs' property.
-
- Returns:
- The vector right-multiplied by the Hessian. Will be of the same shape(s)
- as the 'inputs' property.
- """
- pass
-
- @abc.abstractmethod
- def multiply_hessian_factor(self, vector):
- """Right-multiply a vector by a factor B of the Hessian.
-
- Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
- of the loss function with respect to its inputs. Typically this will be
- block-diagonal across different cases in the batch, since the loss function
- is typically summed across cases.
-
- Note that B can be any matrix satisfying B * B^T = H where H is the Hessian,
- but will agree with the one used in the other methods of this class.
-
- Args:
- vector: The vector to multiply. Must be of the shape given by the
- 'hessian_factor_inner_shape' property.
-
- Returns:
- The vector right-multiplied by B. Will be of the same shape(s) as the
- 'inputs' property.
- """
- pass
-
- @abc.abstractmethod
- def multiply_hessian_factor_transpose(self, vector):
- """Right-multiply a vector by the transpose of a factor B of the Hessian.
-
- Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
- of the loss function with respect to its inputs. Typically this will be
- block-diagonal across different cases in the batch, since the loss function
- is typically summed across cases.
-
- Note that B can be any matrix satisfying B * B^T = H where H is the Hessian,
- but will agree with the one used in the other methods of this class.
-
- Args:
- vector: The vector to multiply. Must be the same shape(s) as the
- 'inputs' property.
-
- Returns:
- The vector right-multiplied by B^T. Will be of the shape given by the
- 'hessian_factor_inner_shape' property.
- """
- pass
-
- @abc.abstractmethod
- def multiply_hessian_factor_replicated_one_hot(self, index):
- """Right-multiply a replicated-one-hot vector by a factor B of the Hessian.
-
- Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives)
- of the loss function with respect to its inputs. Typically this will be
- block-diagonal across different cases in the batch, since the loss function
- is typically summed across cases.
-
- A 'replicated-one-hot' vector means a tensor which, for each slice along the
- batch dimension (assumed to be dimension 0), is 1.0 in the entry
- corresponding to the given index and 0 elsewhere.
-
- Note that B can be any matrix satisfying B * B^T = H where H is the Hessian,
- but will agree with the one used in the other methods of this class.
-
- Args:
- index: A tuple representing in the index of the entry in each slice that
- is 1.0. Note that len(index) must be equal to the number of elements
- of the 'hessian_factor_inner_shape' tensor minus one.
-
- Returns:
- The vector right-multiplied by B^T. Will be of the same shape(s) as the
- 'inputs' property.
- """
- pass
-
- @abc.abstractproperty
- def hessian_factor_inner_shape(self):
- """The shape of the tensor returned by multiply_hessian_factor."""
- pass
-
- @abc.abstractproperty
- def hessian_factor_inner_static_shape(self):
- """Static version of hessian_factor_inner_shape."""
- pass
-
-
-@six.add_metaclass(abc.ABCMeta)
-class NegativeLogProbLoss(LossFunction):
- """Abstract base class for loss functions that are negative log probs."""
-
- def __init__(self, seed=None):
- self._default_seed = seed
- super(NegativeLogProbLoss, self).__init__()
-
- @property
- def inputs(self):
- return self.params
-
- @abc.abstractproperty
- def params(self):
- """Parameters to the underlying distribution."""
- pass
-
- @abc.abstractmethod
- def multiply_fisher(self, vector):
- """Right-multiply a vector by the Fisher.
-
- Args:
- vector: The vector to multiply. Must be the same shape(s) as the
- 'inputs' property.
-
- Returns:
- The vector right-multiplied by the Fisher. Will be of the same shape(s)
- as the 'inputs' property.
- """
- pass
-
- @abc.abstractmethod
- def multiply_fisher_factor(self, vector):
- """Right-multiply a vector by a factor B of the Fisher.
-
- Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
- product of gradients) with respect to the parameters of the underlying
- probability distribution (whose log-prob defines the loss). Typically this
- will be block-diagonal across different cases in the batch, since the
- distribution is usually (but not always) conditionally iid across different
- cases.
-
- Note that B can be any matrix satisfying B * B^T = F where F is the Fisher,
- but will agree with the one used in the other methods of this class.
-
- Args:
- vector: The vector to multiply. Must be of the shape given by the
- 'fisher_factor_inner_shape' property.
-
- Returns:
- The vector right-multiplied by B. Will be of the same shape(s) as the
- 'inputs' property.
- """
- pass
-
- @abc.abstractmethod
- def multiply_fisher_factor_transpose(self, vector):
- """Right-multiply a vector by the transpose of a factor B of the Fisher.
-
- Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
- product of gradients) with respect to the parameters of the underlying
- probability distribution (whose log-prob defines the loss). Typically this
- will be block-diagonal across different cases in the batch, since the
- distribution is usually (but not always) conditionally iid across different
- cases.
-
- Note that B can be any matrix satisfying B * B^T = F where F is the Fisher,
- but will agree with the one used in the other methods of this class.
-
- Args:
- vector: The vector to multiply. Must be the same shape(s) as the
- 'inputs' property.
-
- Returns:
- The vector right-multiplied by B^T. Will be of the shape given by the
- 'fisher_factor_inner_shape' property.
- """
- pass
-
- @abc.abstractmethod
- def multiply_fisher_factor_replicated_one_hot(self, index):
- """Right-multiply a replicated-one-hot vector by a factor B of the Fisher.
-
- Here the 'Fisher' is the Fisher information matrix (i.e. expected outer-
- product of gradients) with respect to the parameters of the underlying
- probability distribution (whose log-prob defines the loss). Typically this
- will be block-diagonal across different cases in the batch, since the
- distribution is usually (but not always) conditionally iid across different
- cases.
-
- A 'replicated-one-hot' vector means a tensor which, for each slice along the
- batch dimension (assumed to be dimension 0), is 1.0 in the entry
- corresponding to the given index and 0 elsewhere.
-
- Note that B can be any matrix satisfying B * B^T = H where H is the Fisher,
- but will agree with the one used in the other methods of this class.
-
- Args:
- index: A tuple representing in the index of the entry in each slice that
- is 1.0. Note that len(index) must be equal to the number of elements
- of the 'fisher_factor_inner_shape' tensor minus one.
-
- Returns:
- The vector right-multiplied by B. Will be of the same shape(s) as the
- 'inputs' property.
- """
- pass
-
- @abc.abstractproperty
- def fisher_factor_inner_shape(self):
- """The shape of the tensor returned by multiply_fisher_factor."""
- pass
-
- @abc.abstractproperty
- def fisher_factor_inner_static_shape(self):
- """Static version of fisher_factor_inner_shape."""
- pass
-
- @abc.abstractmethod
- def sample(self, seed):
- """Sample 'targets' from the underlying distribution."""
- pass
-
- def evaluate_on_sample(self, seed=None):
- """Evaluates the log probability on a random sample.
-
- Args:
- seed: int or None. Random seed for this draw from the distribution.
-
- Returns:
- Log probability of sampled targets, summed across examples.
- """
- if seed is None:
- seed = self._default_seed
- # We treat the targets as "constant". It's only the inputs that get
- # "back-propped" through.
- return self._evaluate(array_ops.stop_gradient(self.sample(seed)))
-
-
-# TODO(jamesmartens): should this just inherit from object to avoid "diamond"
-# inheritance, or is there a better way?
-class NaturalParamsNegativeLogProbLoss(NegativeLogProbLoss):
- """Base class for neg log prob losses whose inputs are 'natural' parameters.
-
- Note that the Hessian and Fisher for natural parameters of exponential-
- family models are the same, hence the purpose of this class.
- See here: https://arxiv.org/abs/1412.1193
-
- 'Natural parameters' are defined for exponential-family models. See for
- example: https://en.wikipedia.org/wiki/Exponential_family
- """
-
- def multiply_hessian(self, vector):
- return self.multiply_fisher(vector)
-
- def multiply_hessian_factor(self, vector):
- return self.multiply_fisher_factor(vector)
-
- def multiply_hessian_factor_transpose(self, vector):
- return self.multiply_fisher_factor_transpose(vector)
-
- def multiply_hessian_factor_replicated_one_hot(self, index):
- return self.multiply_fisher_factor_replicated_one_hot(index)
-
- @property
- def hessian_factor_inner_shape(self):
- return self.fisher_factor_inner_shape
-
- @property
- def hessian_factor_inner_static_shape(self):
- return self.fisher_factor_inner_shape
-
-
-class DistributionNegativeLogProbLoss(NegativeLogProbLoss):
- """Base class for neg log prob losses that use the TF Distribution classes."""
-
- def __init__(self, seed=None):
- super(DistributionNegativeLogProbLoss, self).__init__(seed=seed)
-
- @abc.abstractproperty
- def dist(self):
- """The underlying tf.distributions.Distribution."""
- pass
-
- def _evaluate(self, targets):
- return -math_ops.reduce_sum(self.dist.log_prob(targets))
-
- def sample(self, seed):
- return self.dist.sample(seed=seed)
-
-
-class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss,
- NaturalParamsNegativeLogProbLoss):
- """Neg log prob loss for a normal distribution parameterized by a mean vector.
-
-
- Note that the covariance is treated as a constant 'var' times the identity.
- Also note that the Fisher for such a normal distribution with respect the mean
- parameter is given by:
-
- F = (1/var) * I
-
- See for example https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf.
- """
-
- def __init__(self, mean, var=0.5, targets=None, seed=None):
- self._mean = mean
- self._var = var
- self._targets = targets
- super(NormalMeanNegativeLogProbLoss, self).__init__(seed=seed)
-
- @property
- def targets(self):
- return self._targets
-
- @property
- def dist(self):
- return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._var))
-
- @property
- def params(self):
- return self._mean
-
- def multiply_fisher(self, vector):
- return (1. / self._var) * vector
-
- def multiply_fisher_factor(self, vector):
- return self._var**-0.5 * vector
-
- def multiply_fisher_factor_transpose(self, vector):
- return self.multiply_fisher_factor(vector) # it's symmetric in this case
-
- def multiply_fisher_factor_replicated_one_hot(self, index):
- assert len(index) == 1, "Length of index was {}".format(len(index))
- ones_slice = array_ops.expand_dims(
- array_ops.ones(array_ops.shape(self._mean)[:1], dtype=self._mean.dtype),
- axis=-1)
- output_slice = self._var**-0.5 * ones_slice
- return insert_slice_in_zeros(output_slice, 1, int(self._mean.shape[1]),
- index[0])
-
- @property
- def fisher_factor_inner_shape(self):
- return array_ops.shape(self._mean)
-
- @property
- def fisher_factor_inner_static_shape(self):
- return self._mean.shape
-
-
-class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss):
- """Negative log prob loss for a normal distribution with mean and variance.
-
- This class parameterizes a multivariate normal distribution with n independent
- dimensions. Unlike `NormalMeanNegativeLogProbLoss`, this class does not
- assume the variance is held constant. The Fisher Information for n = 1
- is given by,
-
- F = [[1 / variance, 0],
- [ 0, 0.5 / variance^2]]
-
- where the parameters of the distribution are concatenated into a single
- vector as [mean, variance]. For n > 1, the mean parameter vector is
- concatenated with the variance parameter vector.
-
- See https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf for derivation.
- """
-
- def __init__(self, mean, variance, targets=None, seed=None):
- assert len(mean.shape) == 2, "Expect 2D mean tensor."
- assert len(variance.shape) == 2, "Expect 2D variance tensor."
- self._mean = mean
- self._variance = variance
- self._targets = targets
- super(NormalMeanVarianceNegativeLogProbLoss, self).__init__(seed=seed)
-
- @property
- def targets(self):
- return self._targets
-
- @property
- def dist(self):
- return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._variance))
-
- @property
- def params(self):
- return self._mean, self._variance
-
- def _concat(self, mean, variance):
- return array_ops.concat([mean, variance], axis=-1)
-
- def _split(self, params):
- return array_ops.split(params, 2, axis=-1)
-
- @property
- def _fisher_mean(self):
- return 1. / self._variance
-
- @property
- def _fisher_mean_factor(self):
- return 1. / math_ops.sqrt(self._variance)
-
- @property
- def _fisher_var(self):
- return 1. / (2 * math_ops.square(self._variance))
-
- @property
- def _fisher_var_factor(self):
- return 1. / (math_ops.sqrt(2.) * self._variance)
-
- def multiply_fisher(self, vecs):
- mean_vec, var_vec = vecs
- return (self._fisher_mean * mean_vec, self._fisher_var * var_vec)
-
- def multiply_fisher_factor(self, vecs):
- mean_vec, var_vec = self._split(vecs)
- return (self._fisher_mean_factor * mean_vec,
- self._fisher_var_factor * var_vec)
-
- def multiply_fisher_factor_transpose(self, vecs):
- mean_vec, var_vec = vecs
- return self._concat(self._fisher_mean_factor * mean_vec,
- self._fisher_var_factor * var_vec)
-
- def multiply_fisher_factor_replicated_one_hot(self, index):
- assert len(index) == 1, "Length of index was {}".format(len(index))
- index = index[0]
-
- if index < int(self._mean.shape[-1]):
- # Index corresponds to mean parameter.
- mean_slice = self._fisher_mean_factor[:, index]
- mean_slice = array_ops.expand_dims(mean_slice, axis=-1)
- mean_output = insert_slice_in_zeros(mean_slice, 1, int(
- self._mean.shape[1]), index)
- var_output = array_ops.zeros_like(mean_output)
- else:
- index -= int(self._mean.shape[-1])
- # Index corresponds to variance parameter.
- var_slice = self._fisher_var_factor[:, index]
- var_slice = array_ops.expand_dims(var_slice, axis=-1)
- var_output = insert_slice_in_zeros(var_slice, 1,
- int(self._variance.shape[1]), index)
- mean_output = array_ops.zeros_like(var_output)
-
- return mean_output, var_output
-
- @property
- def fisher_factor_inner_shape(self):
- return array_ops.concat(
- [
- array_ops.shape(self._mean)[:-1],
- 2 * array_ops.shape(self._mean)[-1:]
- ],
- axis=0)
-
- @property
- def fisher_factor_inner_static_shape(self):
- shape = self._mean.shape.as_list()
- return tensor_shape.TensorShape(shape[-1:] + [2 * shape[-1]])
-
- def multiply_hessian(self, vector):
- raise NotImplementedError()
-
- def multiply_hessian_factor(self, vector):
- raise NotImplementedError()
-
- def multiply_hessian_factor_transpose(self, vector):
- raise NotImplementedError()
-
- def multiply_hessian_factor_replicated_one_hot(self, index):
- raise NotImplementedError()
-
- @property
- def hessian_factor_inner_shape(self):
- raise NotImplementedError()
-
- @property
- def hessian_factor_inner_static_shape(self):
- raise NotImplementedError()
-
-
-class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss,
- NaturalParamsNegativeLogProbLoss):
- """Neg log prob loss for a categorical distribution parameterized by logits.
-
-
- Note that the Fisher (for a single case) of a categorical distribution, with
- respect to the natural parameters (i.e. the logits), is given by:
-
- F = diag(p) - p*p^T
-
- where p = softmax(logits). F can be factorized as F = B * B^T where
-
- B = diag(q) - p*q^T
-
- where q is the entry-wise square root of p. This is easy to verify using the
- fact that q^T*q = 1.
- """
-
- def __init__(self, logits, targets=None, seed=None):
- """Instantiates a CategoricalLogitsNegativeLogProbLoss.
-
- Args:
- logits: Tensor of shape [batch_size, output_size]. Parameters for
- underlying distribution.
- targets: None or Tensor of shape [output_size]. Each elements contains an
- index in [0, output_size).
- seed: int or None. Default random seed when sampling.
- """
- self._logits = logits
- self._targets = targets
- super(CategoricalLogitsNegativeLogProbLoss, self).__init__(seed=seed)
-
- @property
- def targets(self):
- return self._targets
-
- @property
- def dist(self):
- return categorical.Categorical(logits=self._logits)
-
- @property
- def _probs(self):
- return self.dist.probs
-
- @property
- def _sqrt_probs(self):
- return math_ops.sqrt(self._probs)
-
- @property
- def params(self):
- return self._logits
-
- def multiply_fisher(self, vector):
- probs = self._probs
- return vector * probs - probs * math_ops.reduce_sum(
- vector * probs, axis=-1, keepdims=True)
-
- def multiply_fisher_factor(self, vector):
- probs = self._probs
- sqrt_probs = self._sqrt_probs
- return sqrt_probs * vector - probs * math_ops.reduce_sum(
- sqrt_probs * vector, axis=-1, keepdims=True)
-
- def multiply_fisher_factor_transpose(self, vector):
- probs = self._probs
- sqrt_probs = self._sqrt_probs
- return sqrt_probs * vector - sqrt_probs * math_ops.reduce_sum(
- probs * vector, axis=-1, keepdims=True)
-
- def multiply_fisher_factor_replicated_one_hot(self, index):
- assert len(index) == 1, "Length of index was {}".format(len(index))
- probs = self._probs
- sqrt_probs = self._sqrt_probs
- sqrt_probs_slice = array_ops.expand_dims(sqrt_probs[:, index[0]], -1)
- padded_slice = insert_slice_in_zeros(sqrt_probs_slice, 1,
- int(sqrt_probs.shape[1]), index[0])
- return padded_slice - probs * sqrt_probs_slice
-
- @property
- def fisher_factor_inner_shape(self):
- return array_ops.shape(self._logits)
-
- @property
- def fisher_factor_inner_static_shape(self):
- return self._logits.shape
-
-
-class MultiBernoulliNegativeLogProbLoss(DistributionNegativeLogProbLoss,
- NaturalParamsNegativeLogProbLoss):
- """Neg log prob loss for multiple Bernoulli distributions param'd by logits.
-
- Represents N independent Bernoulli distributions where N = len(logits). Its
- Fisher Information matrix is given by,
-
- F = diag(p * (1-p))
- p = sigmoid(logits)
-
- As F is diagonal with positive entries, its factor B is,
-
- B = diag(sqrt(p * (1-p)))
- """
-
- def __init__(self, logits, targets=None, seed=None):
- self._logits = logits
- self._targets = targets
- super(MultiBernoulliNegativeLogProbLoss, self).__init__(seed=seed)
-
- @property
- def targets(self):
- return self._targets
-
- @property
- def dist(self):
- return bernoulli.Bernoulli(logits=self._logits)
-
- @property
- def _probs(self):
- return self.dist.probs
-
- @property
- def params(self):
- return self._logits
-
- def multiply_fisher(self, vector):
- return self._probs * (1 - self._probs) * vector
-
- def multiply_fisher_factor(self, vector):
- return math_ops.sqrt(self._probs * (1 - self._probs)) * vector
-
- def multiply_fisher_factor_transpose(self, vector):
- return self.multiply_fisher_factor(vector) # it's symmetric in this case
-
- def multiply_fisher_factor_replicated_one_hot(self, index):
- assert len(index) == 1, "Length of index was {}".format(len(index))
- probs_slice = array_ops.expand_dims(self._probs[:, index[0]], -1)
- output_slice = math_ops.sqrt(probs_slice * (1 - probs_slice))
- return insert_slice_in_zeros(output_slice, 1, int(self._logits.shape[1]),
- index[0])
-
- @property
- def fisher_factor_inner_shape(self):
- return array_ops.shape(self._logits)
-
- @property
- def fisher_factor_inner_static_shape(self):
- return self._logits.shape
-
-
-def insert_slice_in_zeros(slice_to_insert, dim, dim_size, position):
- """Inserts slice into a larger tensor of zeros.
-
- Forms a new tensor which is the same shape as slice_to_insert, except that
- the dimension given by 'dim' is expanded to the size given by 'dim_size'.
- 'position' determines the position (index) at which to insert the slice within
- that dimension.
-
- Assumes slice_to_insert.shape[dim] = 1.
-
- Args:
- slice_to_insert: The slice to insert.
- dim: The dimension which to expand with zeros.
- dim_size: The new size of the 'dim' dimension.
- position: The position of 'slice_to_insert' in the new tensor.
-
- Returns:
- The new tensor.
-
- Raises:
- ValueError: If the slice's shape at the given dim is not 1.
- """
- slice_shape = slice_to_insert.shape
- if slice_shape[dim] != 1:
- raise ValueError("Expected slice_to_insert.shape to have {} dim of 1, but "
- "was {}".format(dim, slice_to_insert.shape[dim]))
-
- before = [0] * int(len(slice_shape))
- after = before[:]
- before[dim] = position
- after[dim] = dim_size - position - 1
-
- return array_ops.pad(slice_to_insert, list(zip(before, after)))
-
-
-class OnehotCategoricalLogitsNegativeLogProbLoss(
- CategoricalLogitsNegativeLogProbLoss):
- """Neg log prob loss for a categorical distribution with onehot targets.
-
- Identical to CategoricalLogitsNegativeLogProbLoss except that the underlying
- distribution is OneHotCategorical as opposed to Categorical.
- """
-
- @property
- def dist(self):
- return onehot_categorical.OneHotCategorical(logits=self._logits)
diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py b/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py
deleted file mode 100644
index 4279cb2792..0000000000
--- a/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Loss functions to be used by LayerCollection."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,line-too-long,wildcard-import
-from tensorflow.contrib.kfac.python.ops.loss_functions import *
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long,wildcard-import
-
-_allowed_symbols = [
- "LossFunction",
- "NegativeLogProbLoss",
- "NaturalParamsNegativeLogProbLoss",
- "DistributionNegativeLogProbLoss",
- "NormalMeanNegativeLogProbLoss",
- "NormalMeanVarianceNegativeLogProbLoss",
- "CategoricalLogitsNegativeLogProbLoss",
- "OnehotCategoricalLogitsNegativeLogProbLoss",
- "MultiBernoulliNegativeLogProbLoss",
- "insert_slice_in_zeros",
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/op_queue.py b/tensorflow/contrib/kfac/python/ops/op_queue.py
deleted file mode 100644
index b6d9d37a31..0000000000
--- a/tensorflow/contrib/kfac/python/ops/op_queue.py
+++ /dev/null
@@ -1,69 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Helper for choosing which op to run next in a distributed setting."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import ops as tf_ops
-
-
-class OpQueue(object):
- """Class for choosing which Op to run next.
-
- Constructs an infinitely repeating sequence of Ops in shuffled order.
-
- In K-FAC, this can be used to distribute inverse update operations among
- workers.
- """
-
- def __init__(self, ops, seed=None):
- """Initializes an OpQueue.
-
- Args:
- ops: list of TensorFlow Ops. Ops to be selected from. All workers must
- initialize with the same set of ops.
- seed: int or None. Random seed used when shuffling order of ops.
- """
- self._ops_by_name = {op.name: op for op in ops}
-
- # Construct a (shuffled) Dataset with Op names.
- op_names = tf_ops.convert_to_tensor(list(sorted(op.name for op in ops)))
- op_names_dataset = (dataset_ops.Dataset.from_tensor_slices(op_names)
- .shuffle(len(ops), seed=seed).repeat())
- self._next_op_name = op_names_dataset.make_one_shot_iterator().get_next()
-
- @property
- def ops(self):
- """Ops this OpQueue can return in next_op()."""
- return self._ops_by_name.values()
-
- def next_op(self, sess):
- """Chooses which op to run next.
-
- Note: This call will make a call to sess.run().
-
- Args:
- sess: tf.Session.
-
- Returns:
- Next Op chosen from 'ops'.
- """
- # In Python 3, type(next_op_name) == bytes. Calling bytes.decode('ascii')
- # returns a str.
- next_op_name = sess.run(self._next_op_name).decode('ascii')
- return self._ops_by_name[next_op_name]
diff --git a/tensorflow/contrib/kfac/python/ops/op_queue_lib.py b/tensorflow/contrib/kfac/python/ops/op_queue_lib.py
deleted file mode 100644
index 09c9a4ab33..0000000000
--- a/tensorflow/contrib/kfac/python/ops/op_queue_lib.py
+++ /dev/null
@@ -1,30 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Helper for choosing which op to run next in a distributed setting."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,line-too-long,wildcard-import
-from tensorflow.contrib.kfac.python.ops.op_queue import *
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long,wildcard-import
-
-_allowed_symbols = [
- 'OpQueue',
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py
deleted file mode 100644
index 38605259b5..0000000000
--- a/tensorflow/contrib/kfac/python/ops/optimizer.py
+++ /dev/null
@@ -1,727 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""The KFAC optimizer."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import warnings
-
-# pylint disable=long-line
-from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp
-from tensorflow.contrib.kfac.python.ops import estimator as est
-# pylint enable=long-line
-
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables as tf_variables
-from tensorflow.python.training import gradient_descent
-
-
-class KfacOptimizer(gradient_descent.GradientDescentOptimizer):
- """The KFAC Optimizer (https://arxiv.org/abs/1503.05671)."""
-
- def __init__(self,
- learning_rate,
- cov_ema_decay,
- damping,
- layer_collection,
- var_list=None,
- momentum=0.9,
- momentum_type="regular",
- norm_constraint=None,
- name="KFAC",
- estimation_mode="gradients",
- colocate_gradients_with_ops=True,
- batch_size=None,
- placement_strategy=None,
- **kwargs):
- """Initializes the KFAC optimizer with the given settings.
-
- Args:
- learning_rate: The base learning rate for the optimizer. Should probably
- be set to 1.0 when using momentum_type = 'qmodel', but can still be
- set lowered if desired (effectively lowering the trust in the
- quadratic model.)
- cov_ema_decay: The decay factor used when calculating the covariance
- estimate moving averages.
- damping: The damping factor used to stabilize training due to errors in
- the local approximation with the Fisher information matrix, and to
- regularize the update direction by making it closer to the gradient.
- If damping is adapted during training then this value is used for
- initializing damping variable.
- (Higher damping means the update looks more like a standard gradient
- update - see Tikhonov regularization.)
- layer_collection: The layer collection object, which holds the fisher
- blocks, Kronecker factors, and losses associated with the
- graph. The layer_collection cannot be modified after KfacOptimizer's
- initialization.
- var_list: Optional list or tuple of variables to train. Defaults to the
- list of variables collected in the graph under the key
- `GraphKeys.TRAINABLE_VARIABLES`.
- momentum: The momentum decay constant to use. Only applies when
- momentum_type is 'regular' or 'adam'. (Default: 0.9)
- momentum_type: The type of momentum to use in this optimizer, one of
- 'regular', 'adam', or 'qmodel'. (Default: 'regular')
- norm_constraint: float or Tensor. If specified, the update is scaled down
- so that its approximate squared Fisher norm v^T F v is at most the
- specified value. May only be used with momentum type 'regular'.
- (Default: None)
- name: The name for this optimizer. (Default: 'KFAC')
- estimation_mode: The type of estimator to use for the Fishers. Can be
- 'gradients', 'empirical', 'curvature_propagation', or 'exact'.
- (Default: 'gradients'). See the doc-string for FisherEstimator for
- more a more detailed description of these options.
- colocate_gradients_with_ops: Whether we should request gradients we
- compute in the estimator be colocated with their respective ops.
- (Default: True)
- batch_size: The size of the mini-batch. Only needed when momentum_type
- == 'qmodel' or when automatic adjustment is used. (Default: None)
- placement_strategy: string, Device placement strategy used when creating
- covariance variables, covariance ops, and inverse ops.
- (Default: `None`)
- **kwargs: Arguments to be passed to specific placement
- strategy mixin. Check `placement.RoundRobinPlacementMixin` for example.
-
- Raises:
- ValueError: If the momentum type is unsupported.
- ValueError: If clipping is used with momentum type other than 'regular'.
- ValueError: If no losses have been registered with layer_collection.
- ValueError: If momentum is non-zero and momentum_type is not 'regular'
- or 'adam'.
- """
- warnings.warn(
- "third_party.tensorflow.contrib.kfac is deprecated."
- "This will be removed on 15-07-2018. Check README for further details.",
- DeprecationWarning)
- # Parameters to be passed to the Fisher estimator:
- self._variables = var_list or tf_variables.trainable_variables
- self._cov_ema_decay = cov_ema_decay
- self._layers = layer_collection
- self._estimation_mode = estimation_mode
- self._colocate_gradients_with_ops = colocate_gradients_with_ops
-
- # The below parameters are required only if damping needs to be adapted.
- # These parameters can be set by calling
- # set_damping_adaptation_params() explicitly.
- self._damping_adaptation_decay = 0.95
- self._damping_adaptation_interval = 5
- # Check section 6.5 KFAC paper. omega(1) = pow(damping decay, interval)
- self._omega = (
- self._damping_adaptation_decay**self._damping_adaptation_interval)
- self._adapt_damping = False
- self._min_damping = 1e-5
- self._prev_train_batch = None
- self._is_chief = False
- self._loss_fn = None
- self._damping_constant = damping
- self._damping = None
- self._rho = None
- self._prev_loss = None
- self._q_model_change = None
- self._update_damping_op = None
-
- momentum_type = momentum_type.lower()
- legal_momentum_types = ["regular", "adam", "qmodel"]
-
- if momentum_type not in legal_momentum_types:
- raise ValueError("Unsupported momentum type {}. Must be one of {}."
- .format(momentum_type, legal_momentum_types))
- if momentum_type != "regular" and norm_constraint is not None:
- raise ValueError("Update clipping is only supported with momentum "
- "type 'regular'.")
- if momentum_type not in ["regular", "adam"] and momentum != 0:
- raise ValueError("Momentum must be unspecified if using a momentum_type "
- "other than 'regular' or 'adam'.")
-
- # Extra parameters of the optimizer
- self._momentum = momentum
- self._momentum_type = momentum_type
- self._norm_constraint = norm_constraint
- self._batch_size = batch_size
- self._placement_strategy = placement_strategy
-
- with variable_scope.variable_scope(name):
- self._fisher_est = est.make_fisher_estimator(
- placement_strategy=placement_strategy,
- variables=self._variables,
- cov_ema_decay=self._cov_ema_decay,
- damping=self.damping,
- layer_collection=self._layers,
- exps=(-1,),
- estimation_mode=self._estimation_mode,
- colocate_gradients_with_ops=self._colocate_gradients_with_ops,
- **kwargs)
-
- super(KfacOptimizer, self).__init__(learning_rate, name=name)
-
- def set_damping_adaptation_params(self,
- is_chief,
- prev_train_batch,
- loss_fn,
- min_damping=1e-5,
- damping_adaptation_decay=0.99,
- damping_adaptation_interval=5):
- """Sets parameters required to adapt damping during training.
-
- When called, enables damping adaptation according to the Levenberg-Marquardt
- style rule described in Section 6.5 of "Optimizing Neural Networks with
- Kronecker-factored Approximate Curvature".
-
- Note that this function creates Tensorflow variables which store a few
- scalars and are accessed by the ops which update the damping (as part
- of the training op returned by the minimize() method).
-
- Args:
- is_chief: `Boolean`, `True` if the worker is chief.
- prev_train_batch: Training data used to minimize loss in the previous
- step. This will be used to evaluate loss by calling
- `loss_fn(prev_train_batch)`.
- loss_fn: `function` that takes as input training data tensor and returns
- a scalar loss.
- min_damping: `float`(Optional), Minimum value the damping parameter
- can take. Default value 1e-5.
- damping_adaptation_decay: `float`(Optional), The `damping` parameter is
- multiplied by the `damping_adaptation_decay` every
- `damping_adaptation_interval` number of iterations. Default value 0.99.
- damping_adaptation_interval: `int`(Optional), Number of steps in between
- updating the `damping` parameter. Default value 5.
-
- Raises:
- ValueError: If `set_damping_adaptation_params` is already called and the
- the `adapt_damping` is `True`.
- """
- if self._adapt_damping:
- raise ValueError("Damping adaptation parameters already set.")
-
- with variable_scope.variable_scope(self.get_name()):
- self._adapt_damping = True
- self._is_chief = is_chief
- self._prev_train_batch = prev_train_batch
- self._loss_fn = loss_fn
- self._damping_adaptation_decay = damping_adaptation_decay
- self._damping_adaptation_interval = damping_adaptation_interval
- self._omega = (
- self._damping_adaptation_decay**self._damping_adaptation_interval)
- self._min_damping = min_damping
-
- self._rho = variable_scope.get_variable(
- "rho", shape=(), dtype=dtypes.float32, trainable=False) # LM ratio.
- self._prev_loss = variable_scope.get_variable(
- "prev_loss", shape=(), dtype=dtypes.float32, trainable=False)
- self._q_model_change = variable_scope.get_variable(
- "q_model_change", shape=(), dtype=dtypes.float32, trainable=False)
- self._damping = variable_scope.get_variable(
- "damping", initializer=self._damping_constant, trainable=False)
-
- @property
- def variables(self):
- return self._fisher_est.variables
-
- @property
- def damping(self):
- if self._damping:
- return self._damping
- else:
- return self._damping_constant
-
- @property
- def damping_adaptation_interval(self):
- return self._damping_adaptation_interval
-
- def make_vars_and_create_op_thunks(self):
- """Make vars and create op thunks.
-
- Returns:
- cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
- the list of factors given by the "factors" property.
- inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
- the list of factors given by the "factors" property.
- """
- scope = self.get_name() + "/" + self._fisher_est.name
- return self._fisher_est.make_vars_and_create_op_thunks(scope=scope)
-
- def create_ops_and_vars_thunks(self):
- """Create thunks that make the ops and vars on demand.
-
- This function returns 4 lists of thunks: cov_variable_thunks,
- cov_update_thunks, inv_variable_thunks, and inv_update_thunks.
-
- The length of each list is the number of factors and the i-th element of
- each list corresponds to the i-th factor (given by the "factors" property).
-
- Note that the execution of these thunks must happen in a certain
- partial order. The i-th element of cov_variable_thunks must execute
- before the i-th element of cov_update_thunks (and also the i-th element
- of inv_update_thunks). Similarly, the i-th element of inv_variable_thunks
- must execute before the i-th element of inv_update_thunks.
-
- TL;DR (oversimplified): Execute the thunks according to the order that
- they are returned.
-
- Returns:
- cov_variable_thunks: A list of thunks that make the cov variables.
- cov_update_thunks: A list of thunks that make the cov update ops.
- inv_variable_thunks: A list of thunks that make the inv variables.
- inv_update_thunks: A list of thunks that make the inv update ops.
- """
- scope = self.get_name() + "/" + self._fisher_est.name
- return self._fisher_est.create_ops_and_vars_thunks(scope=scope)
-
- def minimize(self, *args, **kwargs):
- # Should this variable scope encompass everything below? Or will the super-
- # class make another copy of the same name scope?
- with variable_scope.variable_scope(self.get_name()):
- kwargs["var_list"] = kwargs.get("var_list") or self.variables
- if set(kwargs["var_list"]) != set(self.variables):
- raise ValueError("var_list doesn't match with set of Fisher-estimating "
- "variables.")
- if self._adapt_damping and self._is_chief:
- global_step = kwargs.get("global_step", None)
- if not global_step:
- raise KeyError("global_step needs to be passed to optimizer.minimize "
- "if damping parameter is adapted.")
- update_damping_op = self._update_damping(self._prev_train_batch,
- global_step)
- with ops.control_dependencies([update_damping_op]):
- loss = args[0]
- loss_assign_op = state_ops.assign(self._prev_loss, loss)
- train_op = super(KfacOptimizer, self).minimize(*args, **kwargs)
- return control_flow_ops.group(loss_assign_op, train_op)
- else:
- return super(KfacOptimizer, self).minimize(*args, **kwargs)
-
- def compute_gradients(self, *args, **kwargs):
- # args[1] could be our var_list
- if len(args) > 1:
- var_list = args[1]
- else:
- kwargs["var_list"] = kwargs.get("var_list") or self.variables
- var_list = kwargs["var_list"]
-
- if set(var_list) != set(self.variables):
- raise ValueError("var_list doesn't match with set of Fisher-estimating "
- "variables.")
- return super(KfacOptimizer, self).compute_gradients(*args, **kwargs)
-
- def apply_gradients(self, grads_and_vars, *args, **kwargs):
- """Applies gradients to variables.
-
- Args:
- grads_and_vars: List of (gradient, variable) pairs.
- *args: Additional arguments for super.apply_gradients.
- **kwargs: Additional keyword arguments for super.apply_gradients.
-
- Returns:
- An `Operation` that applies the specified gradients.
- """
- # In Python 3, grads_and_vars can be a zip() object which can only be
- # iterated over once. By converting it to a list, we ensure that it can be
- # iterated over more than once.
- grads_and_vars = list(grads_and_vars)
-
- # Compute step.
- steps_and_vars = self._compute_update_steps(grads_and_vars)
-
- # Update trainable variables with this step.
- return super(KfacOptimizer, self).apply_gradients(steps_and_vars, *args,
- **kwargs)
-
- def _squared_fisher_norm(self, grads_and_vars, precon_grads_and_vars):
- """Computes the squared (approximate) Fisher norm of the updates.
-
- This is defined as v^T F v, where F is the approximate Fisher matrix
- as computed by the estimator, and v = F^{-1} g, where g is the gradient.
- This is computed efficiently as v^T g.
-
- Args:
- grads_and_vars: List of (gradient, variable) pairs.
- precon_grads_and_vars: List of (preconditioned gradient, variable) pairs.
- Must be the result of calling `self._fisher_est.multiply_inverse`
- on `grads_and_vars`.
-
- Returns:
- Scalar representing the squared norm.
-
- Raises:
- ValueError: if the two list arguments do not contain the same variables,
- in the same order.
- """
- for (_, gvar), (_, pgvar) in zip(grads_and_vars, precon_grads_and_vars):
- if gvar is not pgvar:
- raise ValueError("The variables referenced by the two arguments "
- "must match.")
- terms = [
- math_ops.reduce_sum(grad * pgrad)
- for (grad, _), (pgrad, _) in zip(grads_and_vars, precon_grads_and_vars)
- ]
- return math_ops.reduce_sum(terms)
-
- def _update_clip_coeff(self, grads_and_vars, precon_grads_and_vars):
- """Computes the scale factor for the update to satisfy the norm constraint.
-
- Defined as min(1, sqrt(c / r^T F r)), where c is the norm constraint,
- F is the approximate Fisher matrix, and r is the update vector, i.e.
- -alpha * v, where alpha is the learning rate, and v is the preconditioned
- gradient.
-
- This is based on Section 5 of Ba et al., Distributed Second-Order
- Optimization using Kronecker-Factored Approximations. Note that they
- absorb the learning rate alpha (which they denote eta_max) into the formula
- for the coefficient, while in our implementation, the rescaling is done
- before multiplying by alpha. Hence, our formula differs from theirs by a
- factor of alpha.
-
- Args:
- grads_and_vars: List of (gradient, variable) pairs.
- precon_grads_and_vars: List of (preconditioned gradient, variable) pairs.
- Must be the result of calling `self._fisher_est.multiply_inverse`
- on `grads_and_vars`.
-
- Returns:
- Scalar representing the coefficient which should be applied to the
- preconditioned gradients to satisfy the norm constraint.
- """
- sq_norm_grad = self._squared_fisher_norm(grads_and_vars,
- precon_grads_and_vars)
- sq_norm_up = sq_norm_grad * self._learning_rate**2
- return math_ops.minimum(1.,
- math_ops.sqrt(self._norm_constraint / sq_norm_up))
-
- def _clip_updates(self, grads_and_vars, precon_grads_and_vars):
- """Rescales the preconditioned gradients to satisfy the norm constraint.
-
- Rescales the preconditioned gradients such that the resulting update r
- (after multiplying by the learning rate) will satisfy the norm constraint.
- This constraint is that r^T F r <= C, where F is the approximate Fisher
- matrix, and C is the norm_constraint attribute. See Section 5 of
- Ba et al., Distributed Second-Order Optimization using Kronecker-Factored
- Approximations.
-
- Args:
- grads_and_vars: List of (gradient, variable) pairs.
- precon_grads_and_vars: List of (preconditioned gradient, variable) pairs.
- Must be the result of calling `self._fisher_est.multiply_inverse`
- on `grads_and_vars`.
-
- Returns:
- List of (rescaled preconditioned gradient, variable) pairs.
- """
- coeff = self._update_clip_coeff(grads_and_vars, precon_grads_and_vars)
- return [(pgrad * coeff, var) for pgrad, var in precon_grads_and_vars]
-
- def _compute_prev_updates(self, variables):
- """Computes previous updates as negative velocities scaled by learning rate.
-
- Args:
- variables: List of variables in the graph that the update will be
- applied to.
-
- Returns:
- List of previous updates applied to the `variables`.
- """
- return list(
- -1 * self._learning_rate * self._zeros_slot(var, "velocity", self._name)
- for var in variables)
-
- def _compute_qmodel_hyperparams(self, precon_grads, prev_updates, grads,
- variables):
- """Compute optimal update hyperparameters from the quadratic model.
-
- More specifically, if L is the loss we minimize a quadratic approximation
- of L(theta + d) which we denote by qmodel(d) with
- d = alpha*precon_grad + mu*prev_update with respect to alpha and mu, where
-
- qmodel(d) = (1/2) * d^T * B * d + grad^T*d + L(theta) .
-
- Unlike in the KL clipping approach we use the non-approximated quadratic
- model where the curvature matrix C is the true Fisher on the current
- mini-batch (computed without any approximations beyond mini-batch sampling),
- with the usual Tikhonov damping/regularization applied,
-
- C = F + damping * I
-
- See Section 7 of https://arxiv.org/abs/1503.05671 for a derivation of
- the formula. See Appendix C for a discussion of the trick of using
- a factorized Fisher matrix to more efficiently compute the required
- vector-matrix-vector products.
-
- Note that the elements of all 4 lists passed to this function must
- be in correspondence with each other.
-
- Args:
- precon_grads: List of preconditioned gradients.
- prev_updates: List of updates computed at the previous iteration.
- grads: List of gradients.
- variables: List of variables in the graph that the update will be
- applied to. (Note that this function doesn't actually apply the
- update.)
-
- Returns:
- (alpha, mu, qmodel_change), where alpha and mu are chosen to optimize the
- quadratic model, and
- qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0)
- = qmodel(alpha*precon_grad + mu*prev_update) - L(theta).
- """
-
- cmvpc = cmvp.CurvatureMatrixVectorProductComputer(self._layers.losses,
- variables)
-
- # compute the matrix-vector products with the transposed Fisher factor
- fft_precon_grads = cmvpc.multiply_fisher_factor_transpose(precon_grads)
- fft_prev_updates = cmvpc.multiply_fisher_factor_transpose(prev_updates)
- batch_size = math_ops.cast(
- self._batch_size, dtype=fft_precon_grads[0].dtype)
-
- # compute the entries of the 2x2 matrix
- m_11 = (
- _inner_product_list(fft_precon_grads, fft_precon_grads) / batch_size +
- self.damping * _inner_product_list(precon_grads, precon_grads))
-
- m_21 = (
- _inner_product_list(fft_prev_updates, fft_precon_grads) / batch_size +
- self.damping * _inner_product_list(prev_updates, precon_grads))
-
- m_22 = (
- _inner_product_list(fft_prev_updates, fft_prev_updates) / batch_size +
- self.damping * _inner_product_list(prev_updates, prev_updates))
-
- def non_zero_prevupd_case():
- r"""Computes optimal (alpha, mu) given non-zero previous update.
-
- We solve the full 2x2 linear system. See Martens & Grosse (2015),
- Section 7, definition of $\alpha^*$ and $\mu^*$.
-
- Returns:
- (alpha, mu, qmodel_change), where alpha and mu are chosen to optimize
- the quadratic model, and
- qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0).
- """
- m = ops.convert_to_tensor([[m_11, m_21], [m_21, m_22]])
-
- c = ops.convert_to_tensor([[_inner_product_list(grads, precon_grads)],
- [_inner_product_list(grads, prev_updates)]])
-
- sol = -1. * _two_by_two_solve(m, c)
- alpha = sol[0]
- mu = sol[1]
- qmodel_change = 0.5 * math_ops.reduce_sum(sol * c)
-
- return alpha, mu, qmodel_change
-
- def zero_prevupd_case():
- r"""Computes optimal (alpha, mu) given all-zero previous update.
-
- The linear system reduces to 1x1. See Martens & Grosse (2015),
- Section 6.4, definition of $\alpha^*$.
-
- Returns:
- (alpha, 0.0, qmodel_change), where alpha is chosen to optimize the
- quadratic model, and
- qmodel_change = qmodel(alpha*precon_grad) - qmodel(0)
- """
- m = m_11
- c = _inner_product_list(grads, precon_grads)
-
- alpha = -c / m
- mu = 0.0
- qmodel_change = 0.5 * alpha * c
-
- return alpha, mu, qmodel_change
-
- return control_flow_ops.cond(
- math_ops.equal(m_22, 0.0), zero_prevupd_case, non_zero_prevupd_case)
-
- def _assign_q_model_change(self, q_model_change):
- """Assigns `q_model_change` to `self._q_model_change` if damping is adapted.
-
- Note only the chief worker does the assignment.
-
- Args:
- q_model_change: Scalar tensor of type `float32`.
-
- Returns:
- If `adapt_damping` is `True` then returns an assign op, Otherwise returns
- a no_op().
- """
- if self._adapt_damping and self._is_chief:
- q_model_assign_op = state_ops.assign(self._q_model_change, q_model_change)
- else:
- q_model_assign_op = control_flow_ops.no_op()
- return q_model_assign_op
-
- def _compute_qmodel_hyperparams_wrapper(self, grads_and_vars,
- precon_grads_and_vars):
- """Wrapper function for `self._compute_qmodel_hyperparams`.
-
- Constructs a list of preconditioned gradients and variables. Also creates a
- op to assign the computed q model change to `self._q_model_change`.
-
- Args:
- grads_and_vars: List of (gradient, variable) pairs.
- precon_grads_and_vars: List of (preconditioned gradients, variable)
- pairs.
-
- Returns:
- (alpha, mu, q_model_assign_op), where alpha and mu are chosen to optimize
- the quadratic model, `q_model_assign_op` assigns the computed q model
- change to `self._q_model_change`.
- """
- precon_grads = list(
- precon_grad for (precon_grad, _) in precon_grads_and_vars)
- grads = list(grad for (grad, _) in grads_and_vars)
- variables = list(var for (_, var) in grads_and_vars)
- prev_updates = self._compute_prev_updates(variables)
- # Compute optimal velocity update parameters according to quadratic model
- alpha, mu, q_model_change = self._compute_qmodel_hyperparams(
- precon_grads, prev_updates, grads, variables)
-
- return alpha, mu, self._assign_q_model_change(q_model_change)
-
- def _compute_update_steps(self, grads_and_vars):
- """Computes the update steps for the variables given the gradients.
-
- Args:
- grads_and_vars: List of (gradient, variable) pairs.
-
- Returns:
- A list of tuple (assign_op ,var) where `assign_op` assigns the update
- steps to `var`.
- """
-
- if self._momentum_type == "regular":
- # Compute "preconditioned" gradient.
- precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars)
-
- # Apply "KL clipping" if asked for.
- if self._norm_constraint is not None:
- precon_grads_and_vars = self._clip_updates(grads_and_vars,
- precon_grads_and_vars)
-
- # Update the velocity with this and return it as the step.
- if self._adapt_damping and self._is_chief:
- _, _, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper(
- grads_and_vars, precon_grads_and_vars)
- with ops.control_dependencies([q_model_assign_op]):
- return self._update_velocities(precon_grads_and_vars, self._momentum)
- else:
- return self._update_velocities(precon_grads_and_vars, self._momentum)
- elif self._momentum_type == "adam":
- # Update velocity.
- velocities_and_vars = self._update_velocities(grads_and_vars,
- self._momentum)
- # Return "preconditioned" velocity vector as the step.
- return self._fisher_est.multiply_inverse(velocities_and_vars)
-
- elif self._momentum_type == "qmodel":
- # Compute "preconditioned" gradient.
- precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars)
-
- # Compute optimal velocity update parameters according to quadratic model
- alpha, mu, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper(
- grads_and_vars, precon_grads_and_vars)
-
- with ops.control_dependencies([q_model_assign_op]):
- return self._update_velocities(
- precon_grads_and_vars, mu, vec_coeff=-alpha)
-
- def _update_velocities(self, vecs_and_vars, decay, vec_coeff=1.0):
- """Updates the velocities of the variables with the given vectors.
-
- Args:
- vecs_and_vars: List of (vector, variable) pairs.
- decay: How much to decay the old velocity by. This is often referred to
- as the 'momentum constant'.
- vec_coeff: Coefficient to apply to the vectors before adding them to the
- velocity.
-
- Returns:
- A list of (velocity, var) indicating the new velocity for each var.
- """
-
- def _update_velocity(vec, var):
- velocity = self._zeros_slot(var, "velocity", self._name)
- with ops.colocate_with(velocity):
- # NOTE(mattjj): read/modify/write race condition not suitable for async.
-
- # Compute the new velocity for this variable.
- new_velocity = decay * velocity + vec_coeff * vec
-
- # Save the updated velocity.
- return (array_ops.identity(velocity.assign(new_velocity)), var)
-
- # Go through variable and update its associated part of the velocity vector.
- return [_update_velocity(vec, var) for vec, var in vecs_and_vars]
-
- def _update_damping(self, prev_batch, global_step):
- """Adapts damping parameter. Check KFAC (Section 6.5) for the details.
-
- The damping parameter is updated according to the Levenberg-Marquardt rule
- every `self._damping_adaptation_interval` iterations.
-
- Args:
- prev_batch: Tensor or tuple of tensors which can be passed to
- `self._loss_fn` to evaluate loss.
- global_step: `Variable` which keeps track of number of times the training
- variables have been updated.
- Returns:
- A `tf.cond` op which updates the damping parameter.
- """
- def compute_damping():
- """"Adapts damping parameter based on "reduction ratio".
-
- Reduction ratio captures how closely the quadratic approximation to the
- loss function approximates the actual loss within a trust region. The
- damping update tries to make the damping as small as possible while
- maintaining the property that the quadratic model remains a good local
- approximation to the loss function.
-
- Returns:
- An Op to assign newly computed damping value to `self._damping`.
- """
- prev_batch_loss = self._loss_fn(prev_batch)
- with ops.control_dependencies([prev_batch_loss]):
- rho_assign = self._rho.assign(
- (prev_batch_loss - self._prev_loss) / self._q_model_change)
- with ops.control_dependencies([rho_assign]):
- new_damping = control_flow_ops.case(
- [(self._rho < 0.25, lambda: self.damping / self._omega),
- (self._rho > 0.75, lambda: self.damping * self._omega)],
- lambda: self.damping)
- with ops.control_dependencies([new_damping]):
- new_damping_min = math_ops.maximum(new_damping, self._min_damping)
- return control_flow_ops.group(self._damping.assign(new_damping_min))
-
- return control_flow_ops.cond(
- math_ops.equal(
- math_ops.mod(global_step + 1, self._damping_adaptation_interval),
- 0), compute_damping, control_flow_ops.no_op)
-
-
-def _inner_product_list(list1, list2):
- return math_ops.add_n(
- [math_ops.reduce_sum(elt1 * elt2) for elt1, elt2 in zip(list1, list2)])
-
-
-def _two_by_two_solve(m, c):
- # it might be better just to crank out the exact formula for 2x2 inverses
- return math_ops.matmul(linalg_ops.matrix_inverse(m), c)
diff --git a/tensorflow/contrib/kfac/python/ops/placement.py b/tensorflow/contrib/kfac/python/ops/placement.py
deleted file mode 100644
index c4454325ae..0000000000
--- a/tensorflow/contrib/kfac/python/ops/placement.py
+++ /dev/null
@@ -1,114 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Implements placement strategies for cov and inv ops, cov variables."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import itertools
-
-from tensorflow.python.framework import ops as tf_ops
-
-
-def _make_thunk_on_device(func, device):
- def thunk():
- with tf_ops.device(device):
- return func()
- return thunk
-
-
-class RoundRobinPlacementMixin(object):
- """Implements round robin placement strategy for ops and variables."""
-
- def __init__(self, cov_devices=None, inv_devices=None, **kwargs):
- """Initializes the RoundRobinPlacementMixin class.
-
- Args:
- cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
- computations will be placed on these devices in a round-robin fashion.
- Can be None, which means that no devices are specified.
- inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
- computations will be placed on these devices in a round-robin fashion.
- Can be None, which means that no devices are specified.
- **kwargs: Need something here?
-
- """
- super(RoundRobinPlacementMixin, self).__init__(**kwargs)
- self._cov_devices = cov_devices
- self._inv_devices = inv_devices
-
- def make_vars_and_create_op_thunks(self, scope=None):
- """Make vars and create op thunks w/ a round-robin device placement start.
-
- For each factor, all of that factor's cov variables and their associated
- update ops will be placed on a particular device. A new device is chosen
- for each factor by cycling through list of devices in the
- `self._cov_devices` attribute. If `self._cov_devices` is `Non`e then no
- explicit device placement occurs.
-
- An analogous strategy is followed for inverse update ops, with the list of
- devices being given by the `self._inv_devices` attribute.
-
- Inverse variables on the other hand are not placed on any specific device
- (they will just use the current the device placement context, whatever
- that happens to be). The idea is that the inverse variable belong where
- they will be accessed most often, which is the device that actually applies
- the preconditioner to the gradient. The user will be responsible for setting
- the device context for this.
-
- Args:
- scope: A string or None. If None it will be set to the name of this
- estimator (given by the name property). All variables will be created,
- and all thunks will execute, inside of a variable scope of the given
- name. (Default: None)
-
- Returns:
- cov_update_thunks: List of cov update thunks. Corresponds one-to-one with
- the list of factors given by the "factors" property.
- inv_update_thunks: List of inv update thunks. Corresponds one-to-one with
- the list of factors given by the "factors" property.
- """
- # Note: `create_ops_and_vars_thunks` is implemented in `FisherEstimator`.
- (cov_variable_thunks_raw, cov_update_thunks_raw, inv_variable_thunks_raw,
- inv_update_thunks_raw) = self.create_ops_and_vars_thunks(scope=scope)
-
- if self._cov_devices:
- cov_update_thunks = []
- for cov_variable_thunk, cov_update_thunk, device in zip(
- cov_variable_thunks_raw, cov_update_thunks_raw,
- itertools.cycle(self._cov_devices)):
- with tf_ops.device(device):
- cov_variable_thunk()
- cov_update_thunks.append(_make_thunk_on_device(cov_update_thunk,
- device))
- else:
- for cov_variable_thunk in cov_variable_thunks_raw:
- cov_variable_thunk()
- cov_update_thunks = cov_update_thunks_raw
-
- for inv_variable_thunk in inv_variable_thunks_raw:
- inv_variable_thunk()
-
- if self._inv_devices:
- inv_update_thunks = []
- for inv_update_thunk, device in zip(inv_update_thunks_raw,
- itertools.cycle(self._inv_devices)):
- inv_update_thunks.append(_make_thunk_on_device(inv_update_thunk,
- device))
- else:
- inv_update_thunks = inv_update_thunks_raw
-
- return cov_update_thunks, inv_update_thunks
diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py
deleted file mode 100644
index 144295f4c7..0000000000
--- a/tensorflow/contrib/kfac/python/ops/utils.py
+++ /dev/null
@@ -1,709 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Utility functions."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.tpu.python.ops import tpu_ops
-from tensorflow.contrib.tpu.python.tpu import tpu_function
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import gradients_impl
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import variables
-
-# Method used for inverting matrices.
-POSDEF_INV_METHOD = "cholesky"
-POSDEF_EIG_METHOD = "self_adjoint"
-
-
-def set_global_constants(posdef_inv_method=None):
- """Sets various global constants used by the classes in this module."""
- global POSDEF_INV_METHOD
-
- if posdef_inv_method is not None:
- POSDEF_INV_METHOD = posdef_inv_method
-
-
-class SequenceDict(object):
- """A dict convenience wrapper that allows getting/setting with sequences."""
-
- def __init__(self, iterable=None):
- self._dict = dict(iterable or [])
-
- def __getitem__(self, key_or_keys):
- if isinstance(key_or_keys, (tuple, list)):
- return list(map(self.__getitem__, key_or_keys))
- else:
- return self._dict[key_or_keys]
-
- def __setitem__(self, key_or_keys, val_or_vals):
- if isinstance(key_or_keys, (tuple, list)):
- for key, value in zip(key_or_keys, val_or_vals):
- self[key] = value
- else:
- self._dict[key_or_keys] = val_or_vals
-
- def items(self):
- return list(self._dict.items())
-
-
-def tensors_to_column(tensors):
- """Converts a tensor or list of tensors to a column vector.
-
- Args:
- tensors: A tensor or list of tensors.
-
- Returns:
- The tensors reshaped into vectors and stacked on top of each other.
- """
- if isinstance(tensors, (tuple, list)):
- return array_ops.concat(
- tuple(array_ops.reshape(tensor, [-1, 1]) for tensor in tensors), axis=0)
- else:
- return array_ops.reshape(tensors, [-1, 1])
-
-
-def column_to_tensors(tensors_template, colvec):
- """Converts a column vector back to the shape of the given template.
-
- Args:
- tensors_template: A tensor or list of tensors.
- colvec: A 2d column vector with the same shape as the value of
- tensors_to_column(tensors_template).
-
- Returns:
- X, where X is tensor or list of tensors with the properties:
- 1) tensors_to_column(X) = colvec
- 2) X (or its elements) have the same shape as tensors_template (or its
- elements)
- """
- if isinstance(tensors_template, (tuple, list)):
- offset = 0
- tensors = []
- for tensor_template in tensors_template:
- sz = np.prod(tensor_template.shape.as_list(), dtype=np.int32)
- tensor = array_ops.reshape(colvec[offset:(offset + sz)],
- tensor_template.shape)
- tensors.append(tensor)
- offset += sz
-
- tensors = tuple(tensors)
- else:
- tensors = array_ops.reshape(colvec, tensors_template.shape)
-
- return tensors
-
-
-def kronecker_product(mat1, mat2):
- """Computes the Kronecker product two matrices."""
- m1, n1 = mat1.get_shape().as_list()
- mat1_rsh = array_ops.reshape(mat1, [m1, 1, n1, 1])
- m2, n2 = mat2.get_shape().as_list()
- mat2_rsh = array_ops.reshape(mat2, [1, m2, 1, n2])
- return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2])
-
-
-def layer_params_to_mat2d(vector):
- """Converts a vector shaped like layer parameters to a 2D matrix.
-
- In particular, we reshape the weights/filter component of the vector to be
- 2D, flattening all leading (input) dimensions. If there is a bias component,
- we concatenate it to the reshaped weights/filter component.
-
- Args:
- vector: A Tensor or pair of Tensors shaped like layer parameters.
-
- Returns:
- A 2D Tensor with the same coefficients and the same output dimension.
- """
- if isinstance(vector, (tuple, list)):
- w_part, b_part = vector
- w_part_reshaped = array_ops.reshape(w_part,
- [-1, w_part.shape.as_list()[-1]])
- return array_ops.concat(
- (w_part_reshaped, array_ops.reshape(b_part, [1, -1])), axis=0)
- elif isinstance(vector, ops.IndexedSlices):
- return vector
- else: # Tensor or Tensor-like.
- return array_ops.reshape(vector, [-1, vector.shape.as_list()[-1]])
-
-
-def mat2d_to_layer_params(vector_template, mat2d):
- """Converts a canonical 2D matrix representation back to a vector.
-
- Args:
- vector_template: A Tensor or pair of Tensors shaped like layer parameters.
- mat2d: A 2D Tensor with the same shape as the value of
- layer_params_to_mat2d(vector_template).
-
- Returns:
- A Tensor or pair of Tensors with the same coefficients as mat2d and the same
- shape as vector_template.
- """
- if isinstance(vector_template, (tuple, list)):
- w_part, b_part = mat2d[:-1], mat2d[-1]
- return array_ops.reshape(w_part, vector_template[0].shape), b_part
- elif isinstance(vector_template, ops.IndexedSlices):
- if not isinstance(mat2d, ops.IndexedSlices):
- raise TypeError(
- "If vector_template is an IndexedSlices, so should mat2d.")
- return mat2d
- else:
- return array_ops.reshape(mat2d, vector_template.shape)
-
-
-def posdef_inv(tensor, damping):
- """Computes the inverse of tensor + damping * identity."""
- identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype)
- damping = math_ops.cast(damping, dtype=tensor.dtype)
- return posdef_inv_functions[POSDEF_INV_METHOD](tensor, identity, damping)
-
-
-def posdef_inv_matrix_inverse(tensor, identity, damping):
- """Computes inverse(tensor + damping * identity) directly."""
- return linalg_ops.matrix_inverse(tensor + damping * identity)
-
-
-def posdef_inv_cholesky(tensor, identity, damping):
- """Computes inverse(tensor + damping * identity) with Cholesky."""
- chol = linalg_ops.cholesky(tensor + damping * identity)
- return linalg_ops.cholesky_solve(chol, identity)
-
-
-def posdef_inv_eig(tensor, identity, damping):
- """Computes inverse(tensor + damping * identity) with eigendecomposition."""
- eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(
- tensor + damping * identity)
- return math_ops.matmul(
- eigenvectors / eigenvalues, eigenvectors, transpose_b=True)
-
-
-posdef_inv_functions = {
- "matrix_inverse": posdef_inv_matrix_inverse,
- "cholesky": posdef_inv_cholesky,
- "eig": posdef_inv_eig,
-}
-
-
-def posdef_eig(mat):
- """Computes the eigendecomposition of a positive semidefinite matrix."""
- return posdef_eig_functions[POSDEF_EIG_METHOD](mat)
-
-
-def posdef_eig_svd(mat):
- """Computes the singular values and left singular vectors of a matrix."""
- evals, evecs, _ = linalg_ops.svd(mat)
-
- return evals, evecs
-
-
-def posdef_eig_self_adjoint(mat):
- """Computes eigendecomposition using self_adjoint_eig."""
- evals, evecs = linalg_ops.self_adjoint_eig(mat)
- evals = math_ops.abs(evals) # Should be equivalent to svd approach.
-
- return evals, evecs
-
-
-posdef_eig_functions = {
- "self_adjoint": posdef_eig_self_adjoint,
- "svd": posdef_eig_svd,
-}
-
-
-def cholesky(tensor, damping):
- """Computes the inverse of tensor + damping * identity."""
- identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype)
- damping = math_ops.cast(damping, dtype=tensor.dtype)
- return linalg_ops.cholesky(tensor + damping * identity)
-
-
-class SubGraph(object):
- """Defines a subgraph given by all the dependencies of a given set of outputs.
- """
-
- def __init__(self, outputs):
- # Set of all ancestor Tensors, Ops to 'outputs'.
- self._members = set()
-
- self._iter_add(outputs)
-
- def _iter_add(self, root):
- """Iteratively adds all of nodes' ancestors using depth first search."""
- stack = [root]
- while stack:
- nodes = stack.pop()
- for node in nodes:
- if node in self._members:
- continue
- self._members.add(node)
-
- if isinstance(node, ops.Tensor):
- stack.append((node.op,))
- elif isinstance(node, ops.Operation):
- stack.append(node.inputs)
-
- def is_member(self, node):
- """Check if 'node' is in this subgraph."""
- return node in self._members
-
- def variable_uses(self, var):
- """Computes number of times a variable is used.
-
- Args:
- var: Variable or ResourceVariable instance.
-
- Returns:
- Number of times a variable is used within this subgraph.
-
- Raises:
- ValueError: If 'var' is not a variable type.
- """
- if isinstance(var, resource_variable_ops.ResourceVariable):
- var = var.handle
- elif isinstance(var, variables.Variable):
- var = var.value()
- else:
- raise ValueError("%s does not appear to be a variable." % str(var))
-
- return len(self._members.intersection(set(var.consumers())))
-
- def filter_list(self, node_list):
- """Filters 'node_list' to nodes in this subgraph."""
- filtered_list = []
- for node in node_list:
- if self.is_member(node):
- filtered_list.append(node)
- return filtered_list
-
-
-def generate_random_signs(shape, dtype=dtypes.float32):
- """Generate a random tensor with {-1, +1} entries."""
- ints = random_ops.random_uniform(shape, maxval=2, dtype=dtypes.int32)
- return 2 * math_ops.cast(ints, dtype=dtype) - 1
-
-
-def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None):
- """Compute forward-mode gradients."""
- # See b/37888268.
-
- # This version of forward-mode autodiff is based on code by Tim Cooijmans
- # and handles list arguments and certain special cases such as when the
- # ys doesn't depend on one or more of the xs, and when ops.IndexedSlices are
- # generated by the first gradients_impl.gradients call.
-
- us = [array_ops.zeros_like(y) + float("nan") for y in ys]
- dydxs = gradients_impl.gradients(
- ys, xs, grad_ys=us, stop_gradients=stop_gradients)
-
- # Deal with strange types that gradients_impl.gradients returns but can't
- # deal with.
- dydxs = [
- ops.convert_to_tensor(dydx)
- if isinstance(dydx, ops.IndexedSlices) else dydx for dydx in dydxs
- ]
- dydxs = [
- array_ops.zeros_like(x) if dydx is None else dydx
- for x, dydx in zip(xs, dydxs)
- ]
-
- dysdx = gradients_impl.gradients(dydxs, us, grad_ys=grad_xs)
-
- return dysdx
-
-
-def on_tpu():
- """Returns True when building a TPU computation."""
- return tpu_function.get_tpu_context().number_of_shards is not None
-
-
-def cross_replica_mean(tensor, name=None):
- """Takes mean value of a Tensor across all TPU cores.
-
- Args:
- tensor: Tensor to be synchronized.
- name: None or string. Name of Op.
-
- Returns:
- Average of Tensor across all TPU cores.
-
- Raises:
- ValueError: If called outside of TPU context.
- """
- with ops.name_scope(name, "cross_replica_mean", [tensor]):
- num_shards = tpu_function.get_tpu_context().number_of_shards
- if num_shards is None:
- raise ValueError(
- "Cannot take cross_replica_mean() outside of TPU Context.")
- if num_shards == 1:
- return tensor
- return tpu_ops.cross_replica_sum(tensor / num_shards)
-
-
-def ensure_sequence(obj):
- """If `obj` isn't a tuple or list, return a tuple containing `obj`."""
- if isinstance(obj, (tuple, list)):
- return obj
- else:
- return (obj,)
-
-
-def batch_execute(global_step, thunks, batch_size, name=None):
- """Executes a subset of ops per global step.
-
- Given a list of thunks, each of which produces a single stateful op,
- ensures that exactly 'batch_size' ops are run per global step. Ops are
- scheduled in a round-robin fashion. For example, with 3 ops
-
- global_step | op0 | op1 | op2
- ------------+-----+-----+-----
- 0 | x | x |
- ------------+-----+-----+-----
- 1 | x | | x
- ------------+-----+-----+-----
- 2 | | x | x
- ------------+-----+-----+-----
- 3 | x | x |
- ------------+-----+-----+-----
- 4 | x | | x
-
- Does not guarantee order of op execution within a single global step.
-
- Args:
- global_step: Tensor indicating time. Determines which ops run.
- thunks: List of thunks. Each thunk encapsulates one op. Return values are
- ignored.
- batch_size: int. Number of ops to execute per global_step.
- name: string or None. Name scope for newly added ops.
-
- Returns:
- List of ops. Exactly 'batch_size' ops are guaranteed to have an effect
- every global step.
- """
-
- def true_fn(thunk):
- """Ensures thunk is executed and returns an Op (not a Tensor)."""
-
- def result():
- with ops.control_dependencies([thunk()]):
- return control_flow_ops.no_op()
-
- return result
-
- def false_fn(_):
- """Executes a no-op."""
-
- def result():
- return control_flow_ops.no_op()
-
- return result
-
- with ops.name_scope(name, "batch_execute"):
- true_fns = [true_fn(thunk) for thunk in thunks]
- false_fns = [false_fn(thunk) for thunk in thunks]
- num_thunks = len(thunks)
- conditions = [
- math_ops.less(
- math_ops.mod(batch_size - 1 + global_step * batch_size - j,
- num_thunks), batch_size) for j in range(num_thunks)
- ]
- result = [
- control_flow_ops.cond(condition, true_fn, false_fn)
- for (condition, true_fn,
- false_fn) in zip(conditions, true_fns, false_fns)
- ]
- return result
-
-
-def extract_convolution_patches(inputs,
- filter_shape,
- padding,
- strides=None,
- dilation_rate=None,
- name=None,
- data_format=None):
- """Extracts inputs to each output coordinate in tf.nn.convolution.
-
- This is a generalization of tf.extract_image_patches() to tf.nn.convolution(),
- where the number of spatial dimensions may be something other than 2.
-
- Assumes,
- - First dimension of inputs is batch_size
- - Convolution filter is applied to all input channels.
-
- Args:
- inputs: Tensor of shape [batch_size, ..spatial_image_shape..,
- ..spatial_filter_shape.., in_channels]. Inputs to tf.nn.convolution().
- filter_shape: List of ints. Shape of filter passed to tf.nn.convolution().
- padding: string. Padding method. One of "VALID", "SAME".
- strides: None or list of ints. Strides along spatial dimensions.
- dilation_rate: None or list of ints. Dilation along spatial dimensions.
- name: None or str. Name of Op.
- data_format: None or str. Format of data.
-
- Returns:
- Tensor of shape [batch_size, ..spatial_image_shape..,
- ..spatial_filter_shape.., in_channels]
-
- Raises:
- ValueError: If data_format does not put channel last.
- ValueError: If inputs and filter disagree on in_channels.
- """
- if not is_data_format_channel_last(data_format):
- raise ValueError("Channel must be last dimension.")
- with ops.name_scope(name, "extract_convolution_patches",
- [inputs, filter_shape, padding, strides, dilation_rate]):
- batch_size = inputs.shape.as_list()[0]
- in_channels = inputs.shape.as_list()[-1]
-
- # filter_shape = spatial_filter_shape + [in_channels, out_channels]
- spatial_filter_shape = filter_shape[:-2]
- if in_channels != filter_shape[-2]:
- raise ValueError("inputs and filter_shape must agree on in_channels.")
-
- # Map each input feature to a location in the output.
- out_channels = np.prod(spatial_filter_shape) * in_channels
- filters = linalg_ops.eye(out_channels)
- filters = array_ops.reshape(
- filters,
- list(spatial_filter_shape) + [in_channels, out_channels])
-
- result = nn_ops.convolution(
- inputs,
- filters,
- padding=padding,
- strides=strides,
- dilation_rate=dilation_rate)
- spatial_output_shape = result.shape.as_list()[1:-1]
- result = array_ops.reshape(result,
- [batch_size or -1] + spatial_output_shape +
- list(spatial_filter_shape) + [in_channels])
-
- return result
-
-
-def extract_pointwise_conv2d_patches(inputs,
- filter_shape,
- name=None,
- data_format=None):
- """Extract patches for a 1x1 conv2d.
-
- Args:
- inputs: 4-D Tensor of shape [batch_size, height, width, in_channels].
- filter_shape: List of 4 ints. Shape of filter to apply with conv2d()
- name: None or str. Name for Op.
- data_format: None or str. Format for data. See 'data_format' in
- tf.nn.conv2d() for details.
-
- Returns:
- Tensor of shape [batch_size, ..spatial_input_shape..,
- ..spatial_filter_shape.., in_channels]
-
- Raises:
- ValueError: if inputs is not 4-D.
- ValueError: if filter_shape is not [1, 1, ?, ?]
- ValueError: if data_format is not channels-last.
- """
- if inputs.shape.ndims != 4:
- raise ValueError("inputs must have 4 dims.")
- if len(filter_shape) != 4:
- raise ValueError("filter_shape must have 4 dims.")
- if filter_shape[0] != 1 or filter_shape[1] != 1:
- raise ValueError("filter_shape must have shape 1 along spatial dimensions.")
- if not is_data_format_channel_last(data_format):
- raise ValueError("data_format must be channels last.")
- with ops.name_scope(name, "extract_pointwise_conv2d_patches",
- [inputs, filter_shape]):
- ksizes = [1, 1, 1, 1] # Spatial shape is 1x1.
- strides = [1, 1, 1, 1] # Operate on all pixels.
- rates = [1, 1, 1, 1] # Dilation has no meaning with spatial shape = 1.
- padding = "VALID" # Doesn't matter.
- result = array_ops.extract_image_patches(inputs, ksizes, strides, rates,
- padding)
-
- batch_size, input_height, input_width, in_channels = inputs.shape.as_list()
- filter_height, filter_width, in_channels, _ = filter_shape
- return array_ops.reshape(result, [
- batch_size, input_height, input_width, filter_height, filter_width,
- in_channels
- ])
-
-
-def is_data_format_channel_last(data_format):
- """True if data_format puts channel last."""
- if data_format is None:
- return True
- return data_format.endswith("C")
-
-
-def matmul_sparse_dense(A, B, name=None, transpose_a=False, transpose_b=False): # pylint: disable=invalid-name
- """Computes matmul(A, B) where A is sparse, B is dense.
-
- Args:
- A: tf.IndexedSlices with dense shape [m, n].
- B: tf.Tensor with shape [n, k].
- name: str. Name of op.
- transpose_a: Bool. If true we transpose A before multiplying it by B.
- (Default: False)
- transpose_b: Bool. If true we transpose B before multiplying it by A.
- (Default: False)
-
- Returns:
- tf.IndexedSlices resulting from matmul(A, B).
-
- Raises:
- ValueError: If A doesn't represent a matrix.
- ValueError: If B is not rank-2.
- """
- with ops.name_scope(name, "matmul_sparse_dense", [A, B]):
- if A.indices.shape.ndims != 1 or A.values.shape.ndims != 2:
- raise ValueError("A must represent a matrix. Found: %s." % A)
- if B.shape.ndims != 2:
- raise ValueError("B must be a matrix.")
- new_values = math_ops.matmul(
- A.values, B, transpose_a=transpose_a, transpose_b=transpose_b)
- return ops.IndexedSlices(
- new_values,
- A.indices,
- dense_shape=array_ops.stack([A.dense_shape[0], new_values.shape[1]]))
-
-
-def matmul_diag_sparse(A_diag, B, name=None): # pylint: disable=invalid-name
- """Computes matmul(A, B) where A is a diagonal matrix, B is sparse.
-
- Args:
- A_diag: diagonal entries of matrix A of shape [m, m].
- B: tf.IndexedSlices. Represents matrix of shape [m, n].
- name: str. Name of op.
-
- Returns:
- tf.IndexedSlices resulting from matmul(A, B).
-
- Raises:
- ValueError: If A_diag is not rank-1.
- ValueError: If B doesn't represent a matrix.
- """
- with ops.name_scope(name, "matmul_diag_sparse", [A_diag, B]):
- A_diag = ops.convert_to_tensor(A_diag)
- if A_diag.shape.ndims != 1:
- raise ValueError("A_diag must be a rank-1 Tensor.")
- if B.indices.shape.ndims != 1 or B.values.shape.ndims != 2:
- raise ValueError("B must represent a matrix. Found: %s." % B)
- a = array_ops.gather(A_diag, B.indices)
- a = array_ops.reshape(a, list(a.shape) + [1] * (B.values.shape.ndims - 1))
- return ops.IndexedSlices(a * B.values, B.indices, dense_shape=B.dense_shape)
-
-
-class PartitionedTensor(object):
- """A Tensor partitioned across its 0-th dimension."""
-
- def __init__(self, tensors):
- """Initializes PartitionedTensor.
-
- Args:
- tensors: List of Tensors. All Tensors must agree on shape (excepting
- batch dimension) and dtype.
-
- Raises:
- ValueError: If 'tensors' has length zero.
- ValueError: if contents of 'tensors' don't agree on shape or dtype.
- """
- if not tensors:
- raise ValueError("tensors must be a list of 1+ Tensors.")
-
- dtype = tensors[0].dtype
- if not all(tensor.dtype == dtype for tensor in tensors):
- raise ValueError("all tensors must have dtype = %s." % dtype)
-
- shape = tensors[0].shape[1:]
- if not all(tensor.shape[1:] == shape for tensor in tensors):
- raise ValueError("All tensors must have shape = %s (excluding batch "
- "dimension)." % shape)
-
- self.tensors = tensors
- self._concats = {} # {device: Tensor}
-
- @property
- def shape(self):
- feature_shape = self.tensors[0].shape[1:]
- batch_size = sum([tensor.shape[0] for tensor in self.tensors],
- tensor_shape.Dimension(0))
- return tensor_shape.TensorShape([batch_size]).concatenate(feature_shape)
-
- def get_shape(self):
- return self.shape
-
- @property
- def dtype(self):
- return self.tensors[0].dtype
-
- def __str__(self):
- return "PartitionedTensor([%s, ...], dtype=%s, shape=%s)" % (
- self.tensors[0].name, self.dtype.name, tuple(self.shape.as_list()))
-
- def __hash__(self):
- return hash(tuple(self.tensors))
-
- def __eq__(self, other):
- if not isinstance(other, PartitionedTensor):
- return False
- return self.tensors == other.tensors
-
- def __ne__(self, other):
- return not self == other # pylint: disable=g-comparison-negation
-
- def __getitem__(self, key):
- return self.as_tensor()[key]
-
- def as_tensor(self, dtype=None, name=None, as_ref=False):
- with ops.name_scope(name, "PartitionedTensor.as_tensor", self.tensors):
- assert not as_ref
- assert dtype in [None, self.dtype]
- result = array_ops.concat(self.tensors, axis=0)
-
- # Cache 'result' if we haven't already cached a value for this device.
- if result.device not in self._concats:
- self._concats[result.device] = result
- return self._concats[result.device]
-
- @property
- def device(self):
- # PartitionedTensors in general do not live on a single device. If the
- # device cannot be determined unambiguously this property will return None.
- device = self.tensors[0].device
- if all(tensor.device == device for tensor in self.tensors):
- return device
- return None
-
-
-ops.register_tensor_conversion_function(
- PartitionedTensor,
- lambda val, dtype, name, as_ref: val.as_tensor(dtype, name, as_ref))
-
-
-# TODO(b/69623235): Add a function for finding tensors that share gradients
-# to eliminate redundant fisher factor computations.
diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py
deleted file mode 100644
index 330d222dbf..0000000000
--- a/tensorflow/contrib/kfac/python/ops/utils_lib.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Utility functions."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,line-too-long,wildcard-import
-from tensorflow.contrib.kfac.python.ops.utils import *
-from tensorflow.python.util.all_util import remove_undocumented
-# pylint: enable=unused-import,line-too-long,wildcard-import
-
-_allowed_symbols = [
- "set_global_constants",
- "SequenceDict",
- "tensors_to_column",
- "column_to_tensors",
- "kronecker_product",
- "layer_params_to_mat2d",
- "mat2d_to_layer_params",
- "posdef_inv",
- "posdef_inv_matrix_inverse",
- "posdef_inv_cholesky",
- "posdef_inv_funcs",
- "SubGraph",
- "generate_random_signs",
- "fwd_gradients",
- "ensure_sequence",
- "batch_execute",
- "extract_convolution_patches",
- "extract_pointwise_conv2d_patches",
- "is_data_format_channel_last",
- "matmul_sparse_dense",
- "matmul_diag_sparse",
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
index 7ede193029..124515e5a6 100644
--- a/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
+++ b/tensorflow/contrib/layers/python/layers/embedding_ops_test.py
@@ -109,7 +109,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
return sparse_ids, sparse_weights
def test_safe_embedding_lookup_sparse_return_zero_vector(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, sparse_weights = self._ids_and_weights_2d()
@@ -122,7 +122,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
3.0, [0] * 4, [0] * 4, embedding_weights[0][2], [0] * 4])
def test_safe_embedding_lookup_sparse_return_special_vector(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, sparse_weights = self._ids_and_weights_2d()
@@ -136,7 +136,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
embedding_weights[0][2], embedding_weights[0][3]])
def test_safe_embedding_lookup_sparse_no_weights(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, _ = self._ids_and_weights_2d()
@@ -150,7 +150,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
embedding_weights[0][0] + embedding_weights[0][1]) / 2.0])
def test_safe_embedding_lookup_sparse_partitioned(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_ids, _ = self._ids_and_weights_2d()
@@ -164,7 +164,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
(embedding_weights[0] + embedding_weights[1]) / 2.0])
def test_safe_embedding_lookup_sparse_partitioned_inconsistent_weights(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_ids, sparse_weights = self._ids_and_weights_2d()
@@ -179,7 +179,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
embedding_weights, sparse_ids, sparse_weights)
def test_safe_embedding_lookup_sparse_3d_return_zero_vector(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, sparse_weights = self._ids_and_weights_3d()
@@ -192,7 +192,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
], [embedding_weights[0][2], [0] * 4, [0] * 4]])
def test_safe_embedding_lookup_sparse_3d_return_special_vector(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, sparse_weights = self._ids_and_weights_3d()
@@ -208,7 +208,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
]])
def test_safe_embedding_lookup_sparse_3d_no_weights(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
sparse_ids, _ = self._ids_and_weights_3d()
@@ -224,7 +224,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
]])
def test_safe_embedding_lookup_sparse_3d_partitioned(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_ids, _ = self._ids_and_weights_3d()
@@ -241,7 +241,7 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
def test_safe_embedding_lookup_sparse_3d_partitioned_inconsistent_weights(
self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_ids, sparse_weights = self._ids_and_weights_3d()
@@ -276,7 +276,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
return embedding_weights
def test_scattered_embedding_consistency(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
values = constant_op.constant(["foo", "foo"])
@@ -288,7 +288,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
embedding_lookup_result[1])
def test_scattered_embedding_multiple_partition(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=7)
values = constant_op.constant([4, 4, 5])
@@ -304,7 +304,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
self.assertGreater(embedding_diff, 0)
def test_scattered_embedding_coverage(self):
- with self.test_session():
+ with self.cached_session():
size = 8
embedding_weights = self._random_weights(size=size, num_shards=3)
values = constant_op.constant(["foo"])
@@ -316,7 +316,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
self.assertEqual(len(np.unique(embedding_lookup_result[0])), size)
def test_scattered_embedding_multi_dimension(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
values = constant_op.constant([["foo", "bar", "bar"],
["bar", "bar", "foo"]])
@@ -329,7 +329,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
embedding_lookup_result[1][2])
def test_scattered_embedding_lookup_sparse(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights(num_shards=3)
sparse_tensor = sparse_tensor_lib.SparseTensor(
values=["foo", "bar", "foo", "bar"],
@@ -358,7 +358,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
embeds = np.random.randn(n_embed, d_embed)
idx = np.random.randint(0, n_embed, idx_shape)
- with self.test_session():
+ with self.cached_session():
embedded_np = embeds[idx]
embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval()
@@ -370,7 +370,7 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
idx = np.random.randint(0, 5, 10)
idx2d = np.random.randint(0, 5, (10, 2))
- with self.test_session():
+ with self.cached_session():
embedded_np = embeds[idx]
embedded_np2d = embeds[idx2d]
embedded_tf = embedding_ops.embedding_lookup_unique(embeds, idx).eval()
@@ -408,7 +408,7 @@ class SampledScatteredEmbeddingLookupTest(test.TestCase):
return embedding_weights
def test_hashed_embedding_consistency(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
values = constant_op.constant(["foo", "foo"])
# The first three sampled_candidates are equal, so the first three
@@ -429,7 +429,7 @@ class SampledScatteredEmbeddingLookupTest(test.TestCase):
embedding_lookup_result[1][3])
def test_hashed_embedding_multi_dimension(self):
- with self.test_session():
+ with self.cached_session():
embedding_weights = self._random_weights()
values = constant_op.constant([["foo", "bar", "bar"],
["bar", "bar", "foo"]])
@@ -467,7 +467,7 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):
def test_output_shape(self):
"""Verifies the shape of the output tensor."""
- with self.test_session():
+ with self.cached_session():
sp_values = sparse_tensor_lib.SparseTensor(
values=["a", "a", "b", "c", "d", "e", "f"],
indices=[[1, 0], [2, 0], [2, 1], [2, 2], [2, 3], [2, 4], [2, 5]],
@@ -481,7 +481,7 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):
def test_output_values(self):
"""Verifies the values in a trivial case."""
- with self.test_session():
+ with self.cached_session():
sp_values = sparse_tensor_lib.SparseTensor(
values=["a"], indices=[[1, 0]], dense_shape=[3, 1])
params = constant_op.constant([.1, .2, .3])
@@ -495,7 +495,7 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):
def test_output_values_with_sampled_candidates(self):
"""Verifies the values for given sampled_candidates."""
- with self.test_session():
+ with self.cached_session():
sp_values = sparse_tensor_lib.SparseTensor(
values=["a", "a", "b", "c", "d", "e", "f"],
indices=[[1, 0], [2, 0], [2, 1], [2, 2], [2, 3], [2, 4], [2, 5]],
@@ -520,7 +520,7 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):
def test_output_values_with_sign_hash(self):
"""Verifies the values in a trivial case with hash_signs=True."""
- with self.test_session():
+ with self.cached_session():
sp_values = sparse_tensor_lib.SparseTensor(
values=["a"], indices=[[1, 0]], dense_shape=[3, 1])
params = constant_op.constant([.1, .1, .1])
@@ -537,7 +537,7 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):
def test_distributive_property(self):
"""Verifies the distributive property of matrix multiplication."""
- with self.test_session():
+ with self.cached_session():
params = constant_op.constant([.1, .2, .3])
sp_values_a = sparse_tensor_lib.SparseTensor(
values=["a"], indices=[[0, 0]], dense_shape=[3, 1])
@@ -710,7 +710,7 @@ class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase):
[1, 5], ["sum", "mean", "sqrtn"], [dtypes.float32,
dtypes.float64], [True, False]):
- with self.test_session():
+ with self.cached_session():
p, params, feed_dict = _EmbeddingParams(
num_shards, vocab_size, shape=param_shape, dtype=dtype)
embedding_sum = \
@@ -749,7 +749,7 @@ class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase):
for num_shards, combiner, dtype, ignore_weights in itertools.product(
[1, 3], ["sum", "mean", "sqrtn"], [dtypes.float32,
dtypes.float64], [True, False]):
- with self.test_session():
+ with self.cached_session():
x, params, _ = _EmbeddingParams(
num_shards, vocab_size, shape=param_shape, dtype=dtype)
@@ -767,7 +767,7 @@ class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase):
self.assertLess(err, 1e-5 if dtype == dtypes.float64 else 2e-3)
def testIncompatibleShapes(self):
- with self.test_session():
+ with self.cached_session():
x, _, _ = _EmbeddingParams(1, 10, dtype=dtypes.float32)
sp_ids = sparse_tensor_lib.SparseTensor(
constant_op.constant([[0, 0], [0, 1], [1, 0]], dtypes.int64),
diff --git a/tensorflow/contrib/layers/python/layers/encoders_test.py b/tensorflow/contrib/layers/python/layers/encoders_test.py
index e8528e9890..1a2aa710d5 100644
--- a/tensorflow/contrib/layers/python/layers/encoders_test.py
+++ b/tensorflow/contrib/layers/python/layers/encoders_test.py
@@ -34,14 +34,14 @@ def _get_const_var(name, shape, value):
class EncodersTest(test.TestCase):
def testBowEncoderSparse(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
docs = [[0, 1], [2, 3]]
enc = encoders.bow_encoder(docs, 4, 3)
sess.run(variables.global_variables_initializer())
self.assertAllEqual([2, 3], enc.eval().shape)
def testBowEncoderSparseTensor(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
docs = [[0, 1], [2, 3]]
sparse_docs = sparse_ops.dense_to_sparse_tensor(docs)
enc = encoders.bow_encoder(sparse_docs, 4, 3)
@@ -49,28 +49,28 @@ class EncodersTest(test.TestCase):
self.assertAllEqual([2, 3], enc.eval().shape)
def testBowEncoderSparseEmptyRow(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
docs = [[0, 1], [2, 3], [0, 0]]
enc = encoders.bow_encoder(docs, 4, 5)
sess.run(variables.global_variables_initializer())
self.assertAllEqual([3, 5], enc.eval().shape)
def testBowEncoderDense(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
docs = [[0, 1], [2, 3], [0, 0], [0, 0]]
enc = encoders.bow_encoder(docs, 4, 3, sparse_lookup=False)
sess.run(variables.global_variables_initializer())
self.assertAllEqual([4, 3], enc.eval().shape)
def testBowEncoderSparseTensorDenseLookup(self):
- with self.test_session():
+ with self.cached_session():
docs = [[0, 1]]
sparse_docs = sparse_ops.dense_to_sparse_tensor(docs)
with self.assertRaises(TypeError):
encoders.bow_encoder(sparse_docs, 4, 3, sparse_lookup=False)
def testBowEncodersSharingEmbeddings(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
docs = [[0, 1], [2, 3]]
enc_1 = encoders.bow_encoder(docs, 4, 3, scope='test')
enc_2 = encoders.bow_encoder(docs, 4, 3, scope='test', reuse=True)
@@ -79,7 +79,7 @@ class EncodersTest(test.TestCase):
self.assertAllEqual(avg_1, avg_2)
def testBowEncodersSharingEmbeddingsInheritedScopes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
docs = [[0, 1], [2, 3]]
with variable_scope.variable_scope('test'):
enc_1 = encoders.bow_encoder(docs, 4, 3)
@@ -90,7 +90,7 @@ class EncodersTest(test.TestCase):
self.assertAllEqual(avg_1, avg_2)
def testBowEncodersSharingEmbeddingsSharedScope(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
docs = [[0, 1], [2, 3]]
enc_1 = encoders.bow_encoder(docs, 4, 3, scope='bow')
variable_scope.get_variable_scope().reuse_variables()
@@ -100,7 +100,7 @@ class EncodersTest(test.TestCase):
self.assertAllEqual(avg_1, avg_2)
def testBowEncoderReuseEmbeddingsVariable(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
docs = [[1, 1], [2, 3]]
with variable_scope.variable_scope('test'):
v = _get_const_var('embeddings', (4, 3),
@@ -111,7 +111,7 @@ class EncodersTest(test.TestCase):
self.assertAllClose([[3., 4., 5.], [7.5, 8.5, 9.5]], enc.eval())
def testEmbedSequence(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
docs = [[1, 1], [2, 3]]
with variable_scope.variable_scope('test'):
v = _get_const_var('embeddings', (4, 3),
diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py
index 28d19a0445..53c8ae5d08 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column.py
@@ -1100,9 +1100,9 @@ class _EmbeddingColumn(
raise ValueError("Must specify both `ckpt_to_load_from` and "
"`tensor_name_in_ckpt` or none of them.")
if initializer is None:
- logging.warn("The default stddev value of initializer will change from "
- "\"1/sqrt(vocab_size)\" to \"1/sqrt(dimension)\" after "
- "2017/02/25.")
+ logging.warn("The default stddev value of initializer was changed from "
+ "\"1/sqrt(vocab_size)\" to \"1/sqrt(dimension)\" in core "
+ "implementation (tf.feature_column.embedding_column).")
stddev = 1 / math.sqrt(sparse_id_column.length)
initializer = init_ops.truncated_normal_initializer(
mean=0.0, stddev=stddev)
@@ -1501,8 +1501,6 @@ class _ScatteredEmbeddingColumn(
raise ValueError("initializer must be callable if specified. "
"column_name: {}".format(column_name))
if initializer is None:
- logging.warn("The default stddev value of initializer will change from "
- "\"0.1\" to \"1/sqrt(dimension)\" after 2017/02/25.")
stddev = 0.1
initializer = init_ops.truncated_normal_initializer(
mean=0.0, stddev=stddev)
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
index e6bbd86ab7..6fb4b9ff35 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py
@@ -49,7 +49,7 @@ class TransformerTest(test.TestCase):
real_valued = feature_column.real_valued_column("price")
features = {"price": constant_op.constant([[20.], [110], [-3]])}
output = feature_column_ops._Transformer(features).transform(real_valued)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(output.eval(), [[20.], [110], [-3]])
def testSparseRealValuedColumnIdentityTransformation(self):
@@ -60,7 +60,7 @@ class TransformerTest(test.TestCase):
features = {"rating": rating_tensor}
output = feature_column_ops._Transformer(features).transform(
sparse_real_valued)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(output.values.eval(), rating_tensor.values.eval())
self.assertAllEqual(output.indices.eval(), rating_tensor.indices.eval())
self.assertAllEqual(output.dense_shape.eval(),
@@ -80,7 +80,7 @@ class TransformerTest(test.TestCase):
[sparse_real_valued])
self.assertTrue(sparse_real_valued in output_dict)
output = output_dict[sparse_real_valued]
- with self.test_session():
+ with self.cached_session():
self.assertArrayNear(output.values.eval(), [4.0, 25.0], 1e-5)
self.assertAllEqual(output.indices.eval(), rating_tensor.indices.eval())
self.assertAllEqual(output.dense_shape.eval(),
@@ -97,7 +97,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[bucket])
self.assertEqual(len(output), 1)
self.assertIn(bucket, output)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(output[bucket].eval(), [[2], [3], [0]])
def testBucketizedColumnWithMultiDimensions(self):
@@ -109,7 +109,7 @@ class TransformerTest(test.TestCase):
"price": constant_op.constant([[20., 110], [110., 20], [-3, -3]])
}
output = feature_column_ops._Transformer(features).transform(bucket)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(output.eval(), [[2, 3], [3, 2], [0, 0]])
def testCachedTransformation(self):
@@ -118,7 +118,7 @@ class TransformerTest(test.TestCase):
# buckets 2, 3, 0
features = {"price": constant_op.constant([[20.], [110], [-3]])}
transformer = feature_column_ops._Transformer(features)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
transformer.transform(bucket)
num_of_ops = len(sess.graph.get_operations())
# Verify that the second call to transform the same feature
@@ -138,7 +138,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[hashed_sparse])
self.assertEqual(len(output), 1)
self.assertIn(hashed_sparse, output)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(output[hashed_sparse].values.dtype, dtypes.int64)
self.assertTrue(
all(x < 10 and x >= 0 for x in output[hashed_sparse].values.eval()))
@@ -161,7 +161,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[hashed_sparse])
self.assertEqual(len(output), 1)
self.assertIn(hashed_sparse, output)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(output[hashed_sparse].values.dtype, dtypes.int64)
self.assertTrue(
all(x < 10 and x >= 0 for x in output[hashed_sparse].values.eval()))
@@ -177,7 +177,7 @@ class TransformerTest(test.TestCase):
features = {"wire": wire_tensor}
output = feature_column_ops._Transformer(features).transform(hashed_sparse)
- with self.test_session():
+ with self.cached_session():
# While the input is a dense Tensor, the output should be a SparseTensor.
self.assertIsInstance(output, sparse_tensor.SparseTensor)
self.assertEqual(output.values.dtype, dtypes.int64)
@@ -203,7 +203,7 @@ class TransformerTest(test.TestCase):
self.assertEqual(len(output), 2)
self.assertIn(hashed_sparse, output)
self.assertIn(wire_embedding, output)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(output[wire_embedding].indices.eval(),
wire_tensor.indices.eval())
self.assertAllEqual(output[wire_embedding].dense_shape.eval(), [2, 2])
@@ -223,7 +223,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[keys_sparse])
self.assertEqual(len(output), 1)
self.assertIn(keys_sparse, output)
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
self.assertEqual(output[keys_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[keys_sparse].values.eval(), [1, 2, 0])
@@ -241,7 +241,7 @@ class TransformerTest(test.TestCase):
features = {"wire": wire_tensor}
output = feature_column_ops._Transformer(features).transform(keys_sparse)
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
# While the input is a dense Tensor, the output should be a SparseTensor.
self.assertIsInstance(output, sparse_tensor.SparseTensor)
@@ -264,7 +264,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[hashed_sparse])
self.assertEqual(len(output), 1)
self.assertIn(hashed_sparse, output)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(output[hashed_sparse].values.dtype, dtypes.int32)
self.assertTrue(
all(x < 10 and x >= 0 for x in output[hashed_sparse].values.eval()))
@@ -282,7 +282,7 @@ class TransformerTest(test.TestCase):
wire_tensor = constant_op.constant([[100, 0], [1, 25]])
features = {"wire": wire_tensor}
output = feature_column_ops._Transformer(features).transform(hashed_sparse)
- with self.test_session():
+ with self.cached_session():
# While the input is a dense Tensor, the output should be a SparseTensor.
self.assertIsInstance(output, sparse_tensor.SparseTensor)
self.assertEqual(output.values.dtype, dtypes.int32)
@@ -310,7 +310,7 @@ class TransformerTest(test.TestCase):
self.assertEqual(len(output), 1)
self.assertIn(weighted_ids, output)
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
self.assertAllEqual(output[weighted_ids][0].dense_shape.eval(),
ids_tensor.dense_shape.eval())
@@ -340,7 +340,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[vocab_sparse])
self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output)
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0])
@@ -362,7 +362,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[vocab_sparse])
self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output)
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1])
@@ -386,7 +386,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[vocab_sparse])
self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output)
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0])
@@ -408,7 +408,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[vocab_sparse])
self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output)
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1])
@@ -440,7 +440,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[country_language])
self.assertEqual(len(output), 1)
self.assertIn(country_language, output)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(output[country_language].values.dtype, dtypes.int64)
self.assertTrue(
all(x < 15 and x >= 0 for x in output[country_language].values.eval(
@@ -467,7 +467,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[country_price])
self.assertEqual(len(output), 1)
self.assertIn(country_price, output)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(output[country_price].values.dtype, dtypes.int64)
self.assertTrue(
all(x < 15 and x >= 0 for x in output[country_price].values.eval()))
@@ -498,7 +498,7 @@ class TransformerTest(test.TestCase):
weights = column_to_variable[country_price][0]
grad = array_ops.squeeze(
gradients_impl.gradients(output, weights)[0].values)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertEqual(len(grad.eval()), 6)
@@ -537,7 +537,7 @@ class TransformerTest(test.TestCase):
features=features, feature_columns=[wire_country_price])
self.assertEqual(len(output), 1)
self.assertIn(wire_country_price, output)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(output[wire_country_price].values.dtype, dtypes.int64)
self.assertTrue(
all(x < 15 and x >= 0 for x in output[wire_country_price].values.eval(
@@ -600,7 +600,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
columns = [one_hot_column, embedding_column, real_valued_column]
output = feature_column_ops.input_from_feature_columns(features, columns)
output_core = fc_core.input_layer(features, columns)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(output.eval().shape, [3, 2 + 4 + 10])
@@ -626,7 +626,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
cols_to_outs = {}
feature_column_ops.input_from_feature_columns(
features, columns, cols_to_outs=cols_to_outs)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
for column in columns:
@@ -637,7 +637,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
features = {"price": constant_op.constant([[20.], [110], [-3]])}
output = feature_column_ops.input_from_feature_columns(features,
[real_valued])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(output.eval(), features["price"].eval())
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllClose(output.eval(),
@@ -650,7 +650,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
}
output = feature_column_ops.input_from_feature_columns(features,
[real_valued])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(output.eval(), features["price"].eval())
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllClose(output.eval(),
@@ -662,7 +662,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
rating = np.array([[0., 1., 2., -1.],
[3., 4., 5., 6.]])
features = {"rating": constant_op.constant(rating)}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output = sess.run(feature_column_ops.input_from_feature_columns(
features, [var_len_real_valued]))
self.assertAllClose(rating, output)
@@ -673,7 +673,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
rating = np.array([[0, 1, 2, -1],
[3, 4, 5, 6]])
features = {"rating": constant_op.constant(rating, dtype=dtypes.int64)}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output = sess.run(feature_column_ops.input_from_feature_columns(
features, [var_len_real_valued]))
self.assertAllClose(rating.astype(np.float32), output)
@@ -684,7 +684,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
features = {"price": constant_op.constant([[20.], [110], [-3]])}
output = feature_column_ops.input_from_feature_columns(features,
[real_valued])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(output.eval(), features["price"].eval() - 2)
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllClose(output.eval(),
@@ -698,7 +698,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
}
output = feature_column_ops.input_from_feature_columns(features,
[real_valued])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(output.eval(), features["price"].eval() - 2)
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllClose(output.eval(),
@@ -713,7 +713,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
features = {"price": constant_op.constant([[20.], [110], [-3]])}
output = feature_column_ops.input_from_feature_columns(features, [bucket])
expected = [[0, 1, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0]]
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(output.eval(), expected)
self.assertAllClose(output.eval(),
fc_core.input_layer(features, [bucket]).eval())
@@ -729,7 +729,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
output = feature_column_ops.input_from_feature_columns(features, [bucket])
expected = [[0, 0, 1, 0, 0, 0, 0, 1], [0, 0, 0, 1, 0, 0, 1, 0],
[1, 0, 0, 0, 1, 0, 0, 0]]
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(output.eval(), expected)
self.assertAllClose(output.eval(),
fc_core.input_layer(features, [bucket]).eval())
@@ -752,7 +752,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
output = feature_column_ops.input_from_feature_columns(features,
[one_hot_column])
output_core = fc_core.input_layer(features, [one_hot_column])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual([[0, 0, 10., 0], [0, 20., 0, 0], [30., 0, 40., 0]],
@@ -773,7 +773,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[one_hot_sparse])
output_core = fc_core.input_layer(features, [one_hot_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]],
@@ -794,7 +794,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[one_hot_sparse])
output_core = fc_core.input_layer(features, [one_hot_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]],
@@ -816,7 +816,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
output = feature_column_ops.input_from_feature_columns(features,
[one_hot_sparse])
output_core = fc_core.input_layer(features, [one_hot_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]],
output.eval())
@@ -834,7 +834,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
output = feature_column_ops.input_from_feature_columns(features,
[one_hot_sparse])
output_core = fc_core.input_layer(features, [one_hot_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual([3, 10], output.eval().shape)
@@ -852,7 +852,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
output = feature_column_ops.input_from_feature_columns(features,
[embeded_sparse])
output_core = fc_core.input_layer(features, [embeded_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(output.eval().shape, [4, 10])
# Verify cross compatibility: Core builder output should equal to contrib.
@@ -878,7 +878,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
features, [embedded_sparse], weight_collections=["my_collection_core"])
weights_core = ops.get_collection("my_collection_core")
grad_core = gradients_impl.gradients(output_core, weights_core)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
gradient_values = []
gradient_values_core = []
@@ -907,7 +907,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[embeded_sparse])
output_core = fc_core.input_layer(features, [embeded_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
output_eval = output.eval()
self.assertAllEqual(output_eval.shape, [2, 10])
@@ -935,7 +935,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
# Makes sure that trying to use different initializers with the same
# embedding column explicitly fails.
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
ValueError,
"Duplicate feature column key found for column: wire_embedding"):
@@ -961,7 +961,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[embeded_sparse])
output_core = fc_core.input_layer(features, [embeded_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(output.eval().shape, [2, 10])
@@ -986,7 +986,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
embeded_sparse = feature_column.embedding_column(weighted_ids, 10)
output = feature_column_ops.input_from_feature_columns(features,
[embeded_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(output.eval().shape, [2, 10])
@@ -1005,7 +1005,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
embeded_sparse = feature_column.embedding_column(crossed, 10)
output = feature_column_ops.input_from_feature_columns(features,
[embeded_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(output.eval().shape, [2, 10])
@@ -1016,7 +1016,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {"wire": wire_tensor}
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
ValueError, "Error creating input layer for column: wire"):
variables_lib.global_variables_initializer().run()
@@ -1035,7 +1035,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {"ids": ids_tensor, "weights": weights_tensor}
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
ValueError,
"Error creating input layer for column: ids_weighted_by_weights"):
@@ -1053,7 +1053,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {"aaa": wire_tensor, "bbb": wire_tensor}
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
ValueError, "Error creating input layer for column: aaa_X_bbb"):
variables_lib.global_variables_initializer().run()
@@ -1080,7 +1080,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
hashed_sparse, 10, initializer=init_ops.constant_initializer(133.7))
output = feature_column_ops.input_from_feature_columns(
features, [real_valued, bucket, embeded_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
# size of output = 3 (real_valued) + 2 * 4 (bucket) + 10 (embedding) = 21
self.assertAllEqual(output.eval().shape, [3, 21])
@@ -1099,7 +1099,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
initializer=init_ops.ones_initializer())
output = feature_column_ops.input_from_feature_columns(features,
[embeded_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
# score: (number of values)
self.assertAllEqual(output.eval(), [[1.], [2.], [0.]])
@@ -1119,7 +1119,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
max_norm=0.5)
output = feature_column_ops.input_from_feature_columns(features,
[embedded_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
# score: (number of values * 0.5)
self.assertAllClose(output.eval(), [[0.5], [1.], [0.]])
@@ -1144,7 +1144,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
initializer=init_ops.ones_initializer())
output = feature_column_ops.input_from_feature_columns(features,
[embeded_sparse])
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
# score: (sum of weights)
@@ -1236,7 +1236,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
# There should be one trainable variables for sparse_2
self.assertEqual(1, len(variables_lib.trainable_variables()))
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
output_1_eval = output_1.eval()
output_2_eval = output_2.eval()
@@ -1295,7 +1295,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
model_input_tensor = feature_column_ops.sequence_input_from_feature_columns(
columns_to_tensors, [measurement_column])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_inputs = sess.run(model_input_tensor)
self.assertAllClose(measurement_input, model_inputs)
@@ -1305,7 +1305,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
rating = np.array([[0., 1., 2., -1.],
[3., 4., 5., 6.]])
features = {"rating": constant_op.constant(rating)}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output = sess.run(
feature_column_ops.sequence_input_from_feature_columns(
features, [var_len_real_valued]))
@@ -1329,7 +1329,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
expected_shape = [batch_size, sequence_length, np.prod(dimensions)]
reshaped_measurements = np.reshape(measurement_input, expected_shape)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_inputs = sess.run(model_input_tensor)
self.assertAllClose(reshaped_measurements, model_inputs)
@@ -1350,7 +1350,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
model_input_tensor = feature_column_ops.sequence_input_from_feature_columns(
columns_to_tensors, [measurement_column])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_inputs = sess.run(model_input_tensor)
self.assertAllClose(normalizer(measurement_input), model_inputs)
@@ -1373,7 +1373,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
expected_shape = [batch_size, sequence_length, np.prod(dimensions)]
reshaped_measurements = np.reshape(measurement_input, expected_shape)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
model_inputs = sess.run(model_input_tensor)
self.assertAllClose(normalizer(reshaped_measurements), model_inputs)
@@ -1395,7 +1395,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
model_input_tensor = feature_column_ops.sequence_input_from_feature_columns(
columns_to_tensors, [one_hot_column])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
@@ -1429,7 +1429,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
model_input_tensor = feature_column_ops.sequence_input_from_feature_columns(
columns_to_tensors, [one_hot_column])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
@@ -1459,7 +1459,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
model_input_tensor = feature_column_ops.sequence_input_from_feature_columns(
columns_to_tensors, [embedded_column])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
@@ -1488,7 +1488,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
model_input_tensor = feature_column_ops.sequence_input_from_feature_columns(
columns_to_tensors, [embedded_column])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
@@ -1518,7 +1518,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
embedding_weights = ops.get_collection("my_collection")
gradient_tensor = gradients_impl.gradients(model_input_tensor,
embedding_weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
model_input, gradients = sess.run([model_input_tensor, gradient_tensor])
@@ -1585,7 +1585,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
columns_to_tensors, model_input_columns)
self.assertEqual(dtypes.float32, model_input_tensor.dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
@@ -1622,7 +1622,7 @@ class WeightedSumTest(test.TestCase):
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [hashed_sparse], num_outputs=5)
logits_core = fc_core.linear_model(features, [hashed_sparse], units=5)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
# Verify cross compatibility: Core builder output should equal to contrib.
@@ -1640,7 +1640,7 @@ class WeightedSumTest(test.TestCase):
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [hashed_sparse], num_outputs=5)
logits_core = fc_core.linear_model(features, [hashed_sparse], units=5)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
# Verify cross compatibility: Core builder output should equal to contrib.
@@ -1654,7 +1654,7 @@ class WeightedSumTest(test.TestCase):
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [hashed_sparse], num_outputs=5)
logits_core = fc_core.linear_model(features, [hashed_sparse], units=5)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
# Verify cross compatibility: Core builder output should equal to contrib.
@@ -1676,7 +1676,7 @@ class WeightedSumTest(test.TestCase):
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [weighted_ids], num_outputs=5)
logits_core = fc_core.linear_model(features, [weighted_ids], units=5)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
@@ -1695,7 +1695,7 @@ class WeightedSumTest(test.TestCase):
features, [weighted_ids], num_outputs=5)
logits_core = fc_core.linear_model(features, [weighted_ids], units=5)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
@@ -1716,7 +1716,7 @@ class WeightedSumTest(test.TestCase):
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [crossed], num_outputs=5)
logits_core = fc_core.linear_model(features, [crossed], units=5)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
# Verify cross compatibility: Core builder output should equal to contrib.
@@ -1730,7 +1730,7 @@ class WeightedSumTest(test.TestCase):
dense_shape=[2, 2])
features = {"wire": wire_tensor}
embeded_sparse = feature_column.embedding_column(hashed_sparse, 10)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(
ValueError, "Error creating weighted sum for column: wire_embedding"):
variables_lib.global_variables_initializer().run()
@@ -1756,7 +1756,7 @@ class WeightedSumTest(test.TestCase):
features, [movies], num_outputs=1))
logits_core = fc_core.linear_model(features, [movies])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.initialize_all_variables().run()
lookup_ops.tables_initializer().run()
@@ -1776,7 +1776,7 @@ class WeightedSumTest(test.TestCase):
}
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [real_valued], num_outputs=5)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(logits.eval().shape, [3, 5])
@@ -1789,7 +1789,7 @@ class WeightedSumTest(test.TestCase):
}
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
features, [bucket], num_outputs=5)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(logits.eval().shape, [3, 5])
@@ -1814,7 +1814,7 @@ class WeightedSumTest(test.TestCase):
features, [real_valued, bucket, hashed_sparse, crossed], num_outputs=5)
output_core = fc_core.linear_model(
features, [real_valued, bucket, hashed_sparse, crossed], units=5)
- with self.test_session():
+ with self.cached_session():
variables_lib.global_variables_initializer().run()
self.assertAllEqual(output.eval().shape, [3, 5])
# Verify cross compatibility: Core builder output should equal to contrib.
@@ -1837,7 +1837,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, bias = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [age, language], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -1877,7 +1877,7 @@ class WeightedSumTest(test.TestCase):
features, [country, language], num_outputs=1))
# Assert that only a single weight is created.
self.assertEqual(len(variables), 1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -1941,7 +1941,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, bias = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [weighted_language], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -1969,7 +1969,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, bias = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [language], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -1992,7 +1992,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [movies], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2026,7 +2026,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [country_language], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2050,7 +2050,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [language_language], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2083,7 +2083,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [country_language], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2124,7 +2124,7 @@ class WeightedSumTest(test.TestCase):
features, [country, language, country_language],
num_outputs=1,
scope=scope))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2161,7 +2161,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [country, age, incomes], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2197,7 +2197,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [country, age, height, incomes], num_outputs=5))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2228,7 +2228,7 @@ class WeightedSumTest(test.TestCase):
feature_column_ops.weighted_sum_from_feature_columns(
features, [bucket], num_outputs=1))
output_core = fc_core.linear_model(features, [bucket])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
# Cross compatibility: Core builder output should equal to contrib.
@@ -2259,7 +2259,7 @@ class WeightedSumTest(test.TestCase):
feature_column_ops.weighted_sum_from_feature_columns(
features, [bucket, country], num_outputs=1))
output_core = fc_core.linear_model(features, [bucket, country])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
# Cross compatibility: Core builder output should equal to contrib.
@@ -2290,7 +2290,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [bucket, country], num_outputs=5))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2326,7 +2326,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [country_price], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2365,7 +2365,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [country_language_price], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2389,7 +2389,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [product], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0]
@@ -2404,7 +2404,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [product], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0]
@@ -2419,7 +2419,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [product], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0]
@@ -2440,7 +2440,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [product], num_outputs=1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0]
@@ -2452,7 +2452,7 @@ class WeightedSumTest(test.TestCase):
features = {"age": constant_op.constant([[10.], [20.], [30.], [40.]])}
output, _, bias = feature_column_ops.weighted_sum_from_feature_columns(
features, [feature_column.real_valued_column("age")], num_outputs=3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
sess.run(bias.assign([0.1, 0.2, 0.3]))
@@ -2466,7 +2466,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [column], num_outputs=3))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
weights = column_to_variable[column][0]
@@ -2490,7 +2490,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [column], num_outputs=3))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
weights = column_to_variable[column][0]
@@ -2516,7 +2516,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [column], num_outputs=3))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2556,7 +2556,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [column], num_outputs=3))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2585,7 +2585,7 @@ class WeightedSumTest(test.TestCase):
output, column_to_variable, _ = (
feature_column_ops.weighted_sum_from_feature_columns(
features, [column], num_outputs=3))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
lookup_ops.tables_initializer().run()
@@ -2651,7 +2651,7 @@ class ParseExampleTest(test.TestCase):
feature_columns=[bucket, wire_cast])
self.assertIn(bucket, output)
self.assertIn(wire_cast, output)
- with self.test_session():
+ with self.cached_session():
lookup_ops.tables_initializer().run()
self.assertAllEqual(output[bucket].eval(), [[2, 3, 0]])
self.assertAllEqual(output[wire_cast].indices.eval(), [[0, 0], [0, 1]])
@@ -2713,7 +2713,7 @@ class ParseExampleTest(test.TestCase):
self.assertIn("measurements", seq)
self.assertIsInstance(seq["measurements"], ops.Tensor)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
location_val, wire_cast_val, measurement_val = sess.run(
[ctx["location"], seq["wire_cast"], seq["measurements"]])
diff --git a/tensorflow/contrib/layers/python/layers/feature_column_test.py b/tensorflow/contrib/layers/python/layers/feature_column_test.py
index eaaf9f8d5f..d90d6ecf7f 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column_test.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column_test.py
@@ -201,7 +201,7 @@ class FeatureColumnTest(test.TestCase):
b2 = feature_column_ops.input_from_feature_columns({
b[1]: input_tensor_c2
}, [b[1]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
b1_value = b1.eval()
b2_value = b2.eval()
@@ -230,7 +230,7 @@ class FeatureColumnTest(test.TestCase):
e1 = feature_column_ops.input_from_feature_columns({
e[0]: input_tensor_c1
}, [e[0]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
d1_value = d1.eval()
e1_value = e1.eval()
@@ -340,7 +340,7 @@ class FeatureColumnTest(test.TestCase):
with variable_scope.variable_scope("output_rank_{}".format(output_rank)):
one_hot_output = one_hot._to_dnn_input_layer(
id_tensor, output_rank=output_rank)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
one_hot_value = sess.run(one_hot_output)
expected_shape = (id_tensor_shape[:output_rank - 1] + [vocab_size])
self.assertEquals(expected_shape, list(one_hot_value.shape))
@@ -376,7 +376,7 @@ class FeatureColumnTest(test.TestCase):
one_hot_output_shape = one_hot_output.get_shape().as_list()
expected_shape = id_tensor_shape[:-1] + [vocab_size]
self.assertEquals(expected_shape, one_hot_output_shape)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
one_hot_value = sess.run(one_hot_output)
self.assertEquals(expected_shape, list(one_hot_value.shape))
@@ -399,7 +399,7 @@ class FeatureColumnTest(test.TestCase):
expected = np.array([[0., 1., 0., 0., 0., 0., 0., 1., 0.,
0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 1.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
one_hot_value = sess.run(one_hot_output)
self.assertTrue(np.array_equal(one_hot_value, expected))
@@ -440,7 +440,7 @@ class FeatureColumnTest(test.TestCase):
}
one_hot_tensor = feature_column_ops.input_from_feature_columns(
features, [one_hot])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(lookup_ops.tables_initializer())
self.assertAllEqual([[2., 6., 0.]], one_hot_tensor.eval())
@@ -451,7 +451,7 @@ class FeatureColumnTest(test.TestCase):
features = {"ids": constant_op.constant([["marlo", "unknown", "omar"]])}
one_hot_tensor = feature_column_ops.input_from_feature_columns(
features, [one_hot])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(lookup_ops.tables_initializer())
self.assertAllEqual([[1., 1., 0.]], one_hot_tensor.eval())
@@ -603,7 +603,7 @@ class FeatureColumnTest(test.TestCase):
real_valued_output = real_valued_column._to_dnn_input_layer(
constant_op.constant(real_valued_input, dtype=dtypes.float32),
output_rank=output_rank)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
real_valued_eval = sess.run(real_valued_output)
expected_shape = (
input_shape[:output_rank - 1] +
@@ -797,7 +797,7 @@ class FeatureColumnTest(test.TestCase):
sparse_column.insert_transformed_feature(features)
sparse_output = features[sparse_column]
expected_shape = [batch_size, 1]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sparse_result = sess.run(sparse_output)
self.assertEquals(expected_shape, list(sparse_result.dense_shape))
@@ -1110,7 +1110,7 @@ class FeatureColumnTest(test.TestCase):
ckpt_dir = tempfile.mkdtemp(prefix=ckpt_dir_prefix)
checkpoint_path = os.path.join(ckpt_dir, "model.ckpt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
saved_embedding = embeddings.eval()
save.save(sess, checkpoint_path)
@@ -1131,7 +1131,7 @@ class FeatureColumnTest(test.TestCase):
embedding_col_initialized: input_tensor
}, [embedding_col_initialized])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
loaded_embedding = pretrained_embeddings.eval()
@@ -1176,7 +1176,7 @@ class FeatureColumnTest(test.TestCase):
ckpt_dir = tempfile.mkdtemp(prefix=ckpt_dir_prefix)
checkpoint_path = os.path.join(ckpt_dir, "model.ckpt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(assign_op)
saved_col_weights = col_weights[crossed_col][0].eval()
@@ -1201,7 +1201,7 @@ class FeatureColumnTest(test.TestCase):
}, [crossed_col_initialized], 1))
col_weights_from_ckpt = col_weights[crossed_col_initialized][0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
loaded_col_weights = col_weights_from_ckpt.eval()
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 04668f112d..a82d4c1951 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -3109,7 +3109,7 @@ def maxout(inputs, num_units, axis=-1, scope=None):
inputs: Tensor input
num_units: Specifies how many features will remain after maxout
in the `axis` dimension (usually channel).
- This must be multiple of number of `axis`.
+ This must be a factor of number of features.
axis: The dimension where max pooling will be performed. Default is the
last dimension.
scope: Optional scope for variable_scope.
@@ -3128,7 +3128,7 @@ def maxout(inputs, num_units, axis=-1, scope=None):
raise ValueError('number of features({}) is not '
'a multiple of num_units({})'.format(
num_channels, num_units))
- shape[axis] = -1
+ shape[axis] = num_units
shape += [num_channels // num_units]
# Dealing with batches with arbitrary sizes
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index eee90864b4..85af9de4e4 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -281,7 +281,7 @@ class BiasAddTest(test.TestCase):
def testCreate(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, height, width, 3))
output = _layers.bias_add(images)
self.assertEqual(output.op.name, 'BiasAdd/BiasAdd')
@@ -289,7 +289,7 @@ class BiasAddTest(test.TestCase):
def testCreateWithActivation(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = _layers.bias_add(images, activation_fn=nn_ops.relu)
self.assertEqual(output.op.name, 'BiasAdd/Relu')
@@ -298,7 +298,7 @@ class BiasAddTest(test.TestCase):
def testCreateDimensions(self):
dims = (2, 3, 4)
shape = [5, 2, 3, 4]
- with self.test_session():
+ with self.cached_session():
for d in dims:
input_shape = shape[:d]
inputs = random_ops.random_uniform(input_shape, seed=1)
@@ -311,7 +311,7 @@ class BiasAddTest(test.TestCase):
class ConvolutionTest(test.TestCase):
def testInvalidShape(self):
- with self.test_session():
+ with self.cached_session():
images_2d = random_ops.random_uniform((5, 7, 9, 3), seed=1)
with self.assertRaisesRegexp(
ValueError, 'Convolution expects input with rank 5, got 4'):
@@ -323,14 +323,14 @@ class ConvolutionTest(test.TestCase):
def testInvalidDataFormat(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
with self.assertRaisesRegexp(ValueError, 'data_format'):
layers_lib.convolution2d(images, 32, 3, data_format='CHWN')
def testCreateConv(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, height, width, 4)).astype(np.float32)
output = layers_lib.convolution2d(images, 32, [3, 3])
self.assertEqual(output.op.name, 'Conv/Relu')
@@ -342,7 +342,7 @@ class ConvolutionTest(test.TestCase):
def testCreateConvNCHW(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, 4, height, width)).astype(np.float32)
output = layers_lib.convolution2d(images, 32, [3, 3], data_format='NCHW')
self.assertEqual(output.op.name, 'Conv/Relu')
@@ -354,7 +354,7 @@ class ConvolutionTest(test.TestCase):
def testCreateSquareConv(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.convolution2d(images, 32, 3)
self.assertEqual(output.op.name, 'Conv/Relu')
@@ -362,7 +362,7 @@ class ConvolutionTest(test.TestCase):
def testCreateConvWithTensorShape(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.convolution2d(images, 32, images.get_shape()[1:3])
self.assertEqual(output.op.name, 'Conv/Relu')
@@ -370,7 +370,7 @@ class ConvolutionTest(test.TestCase):
def testCreateFullyConv(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 32), seed=1)
output = layers_lib.convolution2d(
images, 64, images.get_shape()[1:3], padding='VALID')
@@ -381,7 +381,7 @@ class ConvolutionTest(test.TestCase):
def testFullyConvWithCustomGetter(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
called = [0]
def custom_getter(getter, *args, **kwargs):
@@ -395,7 +395,7 @@ class ConvolutionTest(test.TestCase):
def testCreateVerticalConv(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 4), seed=1)
output = layers_lib.convolution2d(images, 32, [3, 1])
self.assertEqual(output.op.name, 'Conv/Relu')
@@ -407,7 +407,7 @@ class ConvolutionTest(test.TestCase):
def testCreateHorizontalConv(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 4), seed=1)
output = layers_lib.convolution2d(images, 32, [1, 3])
self.assertEqual(output.op.name, 'Conv/Relu')
@@ -417,7 +417,7 @@ class ConvolutionTest(test.TestCase):
def testCreateConvWithStride(self):
height, width = 6, 8
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.convolution2d(images, 32, [3, 3], stride=2)
self.assertEqual(output.op.name, 'Conv/Relu')
@@ -427,7 +427,7 @@ class ConvolutionTest(test.TestCase):
def testCreateConvCreatesWeightsAndBiasesVars(self):
height, width = 7, 9
images = random_ops.random_uniform((5, height, width, 3), seed=1)
- with self.test_session():
+ with self.cached_session():
self.assertFalse(variables.get_variables('conv1/weights'))
self.assertFalse(variables.get_variables('conv1/biases'))
layers_lib.convolution2d(images, 32, [3, 3], scope='conv1')
@@ -436,7 +436,7 @@ class ConvolutionTest(test.TestCase):
def testCreateConvWithScope(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.convolution2d(images, 32, [3, 3], scope='conv1')
self.assertEqual(output.op.name, 'conv1/Relu')
@@ -453,14 +453,14 @@ class ConvolutionTest(test.TestCase):
def testCreateConvWithoutActivation(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.convolution2d(images, 32, [3, 3], activation_fn=None)
self.assertEqual(output.op.name, 'Conv/BiasAdd')
def testCreateConvValid(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.convolution2d(images, 32, [3, 3], padding='VALID')
self.assertListEqual(output.get_shape().as_list(), [5, 5, 7, 32])
@@ -468,7 +468,7 @@ class ConvolutionTest(test.TestCase):
def testCreateConvWithWD(self):
height, width = 7, 9
weight_decay = 0.01
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform((5, height, width, 3), seed=1)
regularizer = regularizers.l2_regularizer(weight_decay)
layers_lib.convolution2d(
@@ -481,7 +481,7 @@ class ConvolutionTest(test.TestCase):
def testCreateConvNoRegularizers(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
layers_lib.convolution2d(images, 32, [3, 3])
self.assertEqual(
@@ -489,7 +489,7 @@ class ConvolutionTest(test.TestCase):
def testReuseVars(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
layers_lib.convolution2d(images, 32, [3, 3], scope='conv1')
self.assertEqual(len(variables.get_variables()), 2)
@@ -498,7 +498,7 @@ class ConvolutionTest(test.TestCase):
def testNonReuseVars(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
layers_lib.convolution2d(images, 32, [3, 3])
self.assertEqual(len(variables.get_variables()), 2)
@@ -507,7 +507,7 @@ class ConvolutionTest(test.TestCase):
def testReuseConvWithWD(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
weight_decay = regularizers.l2_regularizer(0.01)
with arg_scope(
@@ -523,7 +523,7 @@ class ConvolutionTest(test.TestCase):
def testConvWithBatchNorm(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 32), seed=1)
with arg_scope(
[layers_lib.convolution2d],
@@ -539,7 +539,7 @@ class ConvolutionTest(test.TestCase):
def testReuseConvWithBatchNorm(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 32), seed=1)
with arg_scope(
[layers_lib.convolution2d],
@@ -557,7 +557,7 @@ class ConvolutionTest(test.TestCase):
def testCreateConvCreatesWeightsAndBiasesVarsWithRateTwo(self):
height, width = 7, 9
images = random_ops.random_uniform((5, height, width, 3), seed=1)
- with self.test_session():
+ with self.cached_session():
self.assertFalse(variables.get_variables('conv1/weights'))
self.assertFalse(variables.get_variables('conv1/biases'))
layers_lib.convolution2d(images, 32, [3, 3], rate=2, scope='conv1')
@@ -573,7 +573,7 @@ class ConvolutionTest(test.TestCase):
output = layers_lib.convolution2d(
images, num_filters, [3, 3], rate=2, padding='SAME')
self.assertListEqual(list(output.get_shape().as_list()), expected_size)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -587,7 +587,7 @@ class ConvolutionTest(test.TestCase):
output = layers_lib.convolution2d(
images, num_filters, [3, 3], rate=2, padding='VALID')
self.assertListEqual(list(output.get_shape().as_list()), expected_size)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -601,7 +601,7 @@ class ConvolutionTest(test.TestCase):
output = layers_lib.convolution2d(
images, num_filters, [3, 3], rate=[2, 3], padding='VALID')
self.assertListEqual(list(output.get_shape().as_list()), expected_size)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEquals(output.op.name, 'Conv/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -612,7 +612,7 @@ class ConvolutionTest(test.TestCase):
expected_size = [None, None, None, num_filters]
expected_size_dynamic = [5, 7, 9, num_filters]
- with self.test_session():
+ with self.cached_session():
images = array_ops.placeholder(np.float32,
[None, None, None, input_size[3]])
output = layers_lib.convolution2d(
@@ -651,7 +651,7 @@ class ConvolutionTest(test.TestCase):
expected_size = [None, None, None, num_filters]
expected_size_dynamic = [5, 5, 7, num_filters]
- with self.test_session():
+ with self.cached_session():
images = array_ops.placeholder(np.float32,
[None, None, None, input_size[3]])
output = layers_lib.convolution2d(
@@ -670,7 +670,7 @@ class ConvolutionTest(test.TestCase):
images = random_ops.random_uniform(input_size, seed=1)
output = layers_lib.convolution2d(
images, num_filters, [3, 3], rate=2, padding='VALID', scope='conv7')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'conv7/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -688,7 +688,7 @@ class ConvolutionTest(test.TestCase):
padding='VALID',
activation_fn=None,
scope='conv7')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'conv7/BiasAdd')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -712,7 +712,7 @@ class Convolution2dTransposeTests(test.TestCase):
def testInvalidDataFormat(self):
height, width = 7, 9
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
with self.assertRaisesRegexp(
ValueError, 'data_format has to be either NCHW or NHWC.'):
@@ -915,7 +915,7 @@ class Convolution2dTransposeTests(test.TestCase):
images, num_filters, [3, 3], stride=1, padding='SAME')
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -929,7 +929,7 @@ class Convolution2dTransposeTests(test.TestCase):
images, num_filters, [3, 3], stride=1, padding='VALID')
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -944,7 +944,7 @@ class Convolution2dTransposeTests(test.TestCase):
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
self.assertListEqual(list(output.get_shape().as_list()), expected_size)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -958,7 +958,7 @@ class Convolution2dTransposeTests(test.TestCase):
images, num_filters, [2, 2], stride=[2, 2], padding='SAME')
self.assertListEqual(list(output.get_shape().as_list()), expected_size)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -971,7 +971,7 @@ class Convolution2dTransposeTests(test.TestCase):
images = random_ops.random_uniform(input_size, seed=1)
output = layers_lib.conv2d_transpose(
images, num_filters, [2, 2], stride=[2, 2], padding='VALID')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -984,7 +984,7 @@ class Convolution2dTransposeTests(test.TestCase):
images = random_ops.random_uniform(input_size, seed=1)
output = layers_lib.conv2d_transpose(
images, num_filters, [2, 2], stride=[2, 2], padding='SAME')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -997,7 +997,7 @@ class Convolution2dTransposeTests(test.TestCase):
images = random_ops.random_uniform(input_size, seed=1)
output = layers_lib.conv2d_transpose(
images, num_filters, [2, 2], stride=[2, 2], padding='VALID')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -1010,7 +1010,7 @@ class Convolution2dTransposeTests(test.TestCase):
images = random_ops.random_uniform(input_size, seed=1)
output = layers_lib.conv2d_transpose(
images, num_filters, [2, 4], stride=[2, 1], padding='VALID')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -1023,7 +1023,7 @@ class Convolution2dTransposeTests(test.TestCase):
images = random_ops.random_uniform(input_size, seed=1)
output = layers_lib.conv2d_transpose(
images, num_filters, [2, 4], stride=[2, 4], padding='VALID')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -1036,7 +1036,7 @@ class Convolution2dTransposeTests(test.TestCase):
images = random_ops.random_uniform(input_size, seed=1)
output = layers_lib.conv2d_transpose(
images, num_filters, [2, 4], stride=[2, 5], padding='VALID')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -1083,7 +1083,7 @@ class Convolution2dTransposeTests(test.TestCase):
images, num_filters, [3, 3], stride=[2, 2], padding='VALID')
self.assertListEqual(output.get_shape().as_list(), expected_size)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertEqual(output.op.name, 'Conv2d_transpose/Relu')
eval_output = output.eval({images: np.zeros(input_size, np.float32)})
@@ -1095,7 +1095,7 @@ class Convolution2dTransposeTests(test.TestCase):
expected_size = [None, None, None, num_filters]
expected_size_dynamic = [5, 18, 22, num_filters]
- with self.test_session():
+ with self.cached_session():
images = array_ops.placeholder(np.float32,
[None, None, None, input_size[3]])
output = layers_lib.conv2d_transpose(
@@ -1116,7 +1116,7 @@ class Convolution2dTransposeTests(test.TestCase):
images, num_filters, [3, 3], stride=2, padding='VALID', scope='conv7')
self.assertEqual(output.op.name, 'conv7/Relu')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -1135,7 +1135,7 @@ class Convolution2dTransposeTests(test.TestCase):
scope='conv7')
self.assertEqual(output.op.name, 'conv7/BiasAdd')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertListEqual(list(output.eval().shape), expected_size)
@@ -1146,7 +1146,7 @@ class Convolution2dTransposeTests(test.TestCase):
stride = 2
padding = 'VALID'
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform(input_size, seed=1)
output_deconv = layers_lib.conv2d_transpose(
images,
@@ -1184,7 +1184,7 @@ class ConvolutionInPlaneTest(test.TestCase):
activation_fn=None)
init_op = variables_lib.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
result = sess.run(horz_gradients)
expected = np.zeros((1, 10, 9, 1))
@@ -1201,7 +1201,7 @@ class ConvolutionInPlaneTest(test.TestCase):
activation_fn=None)
init_op = variables_lib.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
result = sess.run(
horz_gradients, feed_dict={
@@ -1225,7 +1225,7 @@ class ConvolutionInPlaneTest(test.TestCase):
activation_fn=None)
init_op = variables_lib.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
result = sess.run(horz_gradients)
@@ -1245,7 +1245,7 @@ class ConvolutionInPlaneTest(test.TestCase):
activation_fn=None)
init_op = variables_lib.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
result = sess.run(horz_gradients)
@@ -1267,7 +1267,7 @@ class ConvolutionInPlaneTest(test.TestCase):
activation_fn=None)
init_op = variables_lib.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
result = sess.run(horz_gradients)
@@ -1283,12 +1283,12 @@ class ConvolutionInPlaneTest(test.TestCase):
activation_fn=None)
init_op = variables_lib.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
result = sess.run(vert_gradients)
expected = np.zeros((1, 9, 10, 1))
- self.assertAllEqual(result, expected)
+ self.assertAllClose(result, expected, rtol=1e-5, atol=1e-5)
def testVertConvWithVaryingImage(self):
image = np.asmatrix(('1.0 2.0 3.0;' '1.1 2.0 4.0;' '-4.3 0.0 8.9'))
@@ -1306,7 +1306,7 @@ class ConvolutionInPlaneTest(test.TestCase):
activation_fn=None)
init_op = variables_lib.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
result = sess.run(vert_gradients)
@@ -1314,7 +1314,7 @@ class ConvolutionInPlaneTest(test.TestCase):
def testConv1dShape(self):
width = 7
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, width, 3), seed=1)
output = layers_lib.convolution1d(images, 32, 3)
self.assertEqual(output.op.name, 'Conv/Relu')
@@ -1322,7 +1322,7 @@ class ConvolutionInPlaneTest(test.TestCase):
def testConvInferSpatialDims(self):
depth, height, width = 7, 9, 11
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, width, 4)).astype(np.float32)
output = layers_lib.convolution(images, 32, [3])
self.assertListEqual(output.get_shape().as_list(), [5, width, 32])
@@ -1344,7 +1344,7 @@ class DenseToSparseTest(test.TestCase):
sparse = _layers.dense_to_sparse(tensor)
dense = sparse_ops.sparse_to_dense(sparse.indices, sparse.dense_shape,
sparse.values)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
constant = sess.run(dense)
self.assertAllEqual(expected_constant, constant)
@@ -1353,7 +1353,7 @@ class DropoutTest(test.TestCase):
def testCreateDropout(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, height, width, 3))
output = _layers.dropout(images)
self.assertEqual(output.op.name, 'Dropout/dropout_1/mul')
@@ -1362,7 +1362,7 @@ class DropoutTest(test.TestCase):
def testCreateDropoutWithConstantTrue(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
is_training = constant_op.constant(True)
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = _layers.dropout(images, is_training=is_training)
@@ -1370,7 +1370,7 @@ class DropoutTest(test.TestCase):
def testCreateDropoutWithConstantFalse(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
is_training = constant_op.constant(False)
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = _layers.dropout(images, is_training=is_training)
@@ -1378,7 +1378,7 @@ class DropoutTest(test.TestCase):
def testCreateDropoutWithPlaceholder(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
is_training = array_ops.placeholder(dtype=dtypes.bool, shape=[])
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = _layers.dropout(images, is_training=is_training)
@@ -1387,7 +1387,7 @@ class DropoutTest(test.TestCase):
def testCollectOutputs(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = _layers.dropout(images, outputs_collections='outputs')
c_output = ops.get_collection('outputs')[0]
@@ -1396,7 +1396,7 @@ class DropoutTest(test.TestCase):
def testDropout(self):
height, width = 10, 10
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
num_elem_initial = math_ops.reduce_mean(math_ops.to_float(images > 0))
@@ -1409,7 +1409,7 @@ class DropoutTest(test.TestCase):
def testDropoutSeed(self):
"""Test that providing the same seed produces the same result."""
height, width = 10, 10
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
output1 = _layers.dropout(images, seed=1)
@@ -1418,7 +1418,7 @@ class DropoutTest(test.TestCase):
def testCreateDropoutNoTraining(self):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
num_elem_initial = math_ops.reduce_mean(math_ops.to_float(images > 0))
@@ -1431,7 +1431,7 @@ class DropoutTest(test.TestCase):
def testCreateFCFollowByDropout(self):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
output = _layers.fully_connected(images, 50)
@@ -1445,7 +1445,7 @@ class DropoutTest(test.TestCase):
def testCreateFCWithDropout(self):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
output = _layers.fully_connected(
@@ -1475,7 +1475,7 @@ class FlattenTest(test.TestCase):
def testCollectOutputs(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, height, width, 3))
output = _layers.flatten(images, outputs_collections='outputs')
c_output = ops.get_collection('outputs')[0]
@@ -1484,7 +1484,7 @@ class FlattenTest(test.TestCase):
def testFlatten4D(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
output = _layers.flatten(images)
@@ -1494,7 +1494,7 @@ class FlattenTest(test.TestCase):
def testFlatten3D(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height, width), seed=1, name='images')
output = _layers.flatten(images)
@@ -1504,7 +1504,7 @@ class FlattenTest(test.TestCase):
def testFlattenBatchSize(self):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
inputs = array_ops.placeholder(dtypes.int32, (None, height, width, 3))
@@ -1516,7 +1516,7 @@ class FlattenTest(test.TestCase):
def testUnknownDims(self):
height = width = depth = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform(
(5, height, width, depth), seed=1, name='images')
inputs = array_ops.placeholder(dtypes.int32, (None, None, None, None))
@@ -1551,7 +1551,7 @@ class PartialFlattenTest(test.TestCase):
flattened_t = _layers._inner_flatten(inputs, new_rank)
static_shape = flattened_t.get_shape().as_list()
self.assertEqual(static_shape, expected_new_shape)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
flattened = sess.run(flattened_t)
np.testing.assert_array_equal(expected_flattened, flattened)
@@ -1571,7 +1571,7 @@ class PartialFlattenTest(test.TestCase):
flattened_t = _layers._inner_flatten(inputs_t, new_rank)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
flattened = sess.run(flattened_t)
np.testing.assert_array_equal(expected_indices, flattened.indices)
@@ -1641,7 +1641,7 @@ class FCTest(test.TestCase):
def testCreateFCWithScope(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
output = _layers.fully_connected(inputs, 32, scope='fc1')
self.assertEqual(output.op.name, 'fc1/Relu')
@@ -1659,7 +1659,7 @@ class FCTest(test.TestCase):
def testCreateFcCreatesWeightsAndBiasesVars(self):
height, width = 3, 3
inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
- with self.test_session():
+ with self.cached_session():
self.assertFalse(variables.get_variables('fc1/weights'))
self.assertFalse(variables.get_variables('fc1/biases'))
_layers.fully_connected(inputs, 32, scope='fc1')
@@ -1669,7 +1669,7 @@ class FCTest(test.TestCase):
def testReuseVars(self):
height, width = 3, 3
inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
- with self.test_session():
+ with self.cached_session():
_layers.fully_connected(inputs, 32, scope='fc1')
self.assertEqual(len(variables.get_variables('fc1')), 2)
_layers.fully_connected(inputs, 32, scope='fc1', reuse=True)
@@ -1678,7 +1678,7 @@ class FCTest(test.TestCase):
def testNonReuseVars(self):
height, width = 3, 3
inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
- with self.test_session():
+ with self.cached_session():
_layers.fully_connected(inputs, 32)
self.assertEqual(len(variables.get_variables('fully_connected')), 2)
_layers.fully_connected(inputs, 32)
@@ -1713,14 +1713,14 @@ class FCTest(test.TestCase):
def testCreateFCWithoutActivation(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
output = _layers.fully_connected(inputs, 32, activation_fn=None)
self.assertEqual(output.op.name, 'fully_connected/BiasAdd')
def testCreateFCWithWD(self):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
weight_decay = regularizers.l2_regularizer(0.01)
_layers.fully_connected(inputs, 32, weights_regularizer=weight_decay)
@@ -1732,7 +1732,7 @@ class FCTest(test.TestCase):
def testCreateFCWithBD(self):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
bias_decay = regularizers.l2_regularizer(0.01)
_layers.fully_connected(inputs, 32, biases_regularizer=bias_decay)
@@ -1744,7 +1744,7 @@ class FCTest(test.TestCase):
def testCreateNoRegularizers(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
_layers.fully_connected(inputs, 32)
self.assertEqual(
@@ -1752,7 +1752,7 @@ class FCTest(test.TestCase):
def testReuseFCWithWD(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
inputs = random_ops.random_uniform((5, height * width * 3), seed=1)
weight_decay = regularizers.l2_regularizer(0.01)
_layers.fully_connected(
@@ -1768,7 +1768,7 @@ class FCTest(test.TestCase):
def testFCWithBatchNorm(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height * width * 3), seed=1)
with arg_scope(
[_layers.fully_connected],
@@ -1786,7 +1786,7 @@ class FCTest(test.TestCase):
def testReuseFCWithBatchNorm(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height * width * 3), seed=1)
with arg_scope(
[_layers.fully_connected],
@@ -1844,7 +1844,7 @@ class BatchNormTest(test.TestCase):
if dtype is None:
dtype = dtypes.float32
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, height, width, 3)).astype(
dtype.as_numpy_dtype)
output = _layers.batch_norm(images, fused=fused)
@@ -1866,7 +1866,7 @@ class BatchNormTest(test.TestCase):
def _testCreateOpBetaRegularizer(self, fused=True):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
reg = lambda x: 0.1 * math_ops.reduce_sum(x)
images = np.random.uniform(size=(5, height, width, 3)).astype('f')
_layers.batch_norm(images, param_regularizers={'beta': reg}, fused=fused)
@@ -1883,7 +1883,7 @@ class BatchNormTest(test.TestCase):
def _testCreateOpGammaRegularizer(self, fused=True):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
reg = lambda x: 0.1 * math_ops.reduce_sum(x)
images = np.random.uniform(size=(5, height, width, 3)).astype('f')
_layers.batch_norm(
@@ -1901,7 +1901,7 @@ class BatchNormTest(test.TestCase):
def testCreateVariables(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
_layers.batch_norm(images, scale=True)
beta = variables.get_variables_by_name('beta')[0]
@@ -1915,7 +1915,7 @@ class BatchNormTest(test.TestCase):
def testMovingAverageVariables(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
_layers.batch_norm(images, scale=True)
self.assertEqual(len(variables.get_model_variables()), 4)
@@ -1926,7 +1926,7 @@ class BatchNormTest(test.TestCase):
def testMovingAverageVariablesZeroDebias(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
_layers.batch_norm(
images, scale=True, zero_debias_moving_mean=True, fused=False)
@@ -1943,7 +1943,7 @@ class BatchNormTest(test.TestCase):
def testUpdatesCollection(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
_layers.batch_norm(images, updates_collections='my_update_ops')
update_layers = ops.get_collection('my_update_ops')
@@ -1971,7 +1971,7 @@ class BatchNormTest(test.TestCase):
def testReuseVariables(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
_layers.batch_norm(images, scale=True, scope='bn')
_layers.batch_norm(images, scale=True, scope='bn', reuse=True)
@@ -1986,7 +1986,7 @@ class BatchNormTest(test.TestCase):
def testReuseUpdateOps(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
with arg_scope([_layers.batch_norm], updates_collections='update_ops'):
_layers.batch_norm(images, scope='bn')
@@ -1996,7 +1996,7 @@ class BatchNormTest(test.TestCase):
def testCreateMovingVars(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
_ = _layers.batch_norm(images)
moving_mean = variables.get_variables('BatchNorm/moving_mean')
@@ -2029,7 +2029,7 @@ class BatchNormTest(test.TestCase):
moving_variance = variables.get_variables_by_name('moving_variance')[0]
biased = variables.get_variables_by_name('biased')[0]
local_step = variables.get_variables_by_name('local_step')[0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
self.assertAllClose(local_step.eval(), 0)
self.assertAllClose(moving_mean.eval(), [0] * channels)
@@ -2213,7 +2213,7 @@ class BatchNormTest(test.TestCase):
def _testEvalMovingVars(self, zero_debias_moving_mean=False):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
image_shape = (10, height, width, 3)
image_values = np.random.rand(*image_shape)
expected_mean = np.mean(image_values, axis=(0, 1, 2))
@@ -2264,7 +2264,7 @@ class BatchNormTest(test.TestCase):
height, width = 3, 3
batch_size = 10
channels = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
image_shape = (batch_size, height, width, channels)
image_values = np.random.rand(*image_shape)
expected_mean = np.mean(image_values, axis=(0, 1, 2))
@@ -2435,7 +2435,7 @@ class BatchNormTest(test.TestCase):
def testNoUpdatesWhenIsTrainingFalse(self):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
image_shape = (10, height, width, 3)
image_values = np.random.rand(*image_shape)
images = constant_op.constant(
@@ -2460,7 +2460,7 @@ class BatchNormTest(test.TestCase):
def testNoneUpdatesCollectionNoTraining(self):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
image_shape = (10, height, width, 3)
image_values = np.random.rand(*image_shape)
images = constant_op.constant(
@@ -2647,7 +2647,7 @@ class BatchNormTest(test.TestCase):
def testCustomInitializer(self):
height, width = 3, 3
channels = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = (np.ones((5, height, width, channels)) * 9.0).astype('f')
beta = init_ops.constant_initializer(
(np.ones(channels) * 5.0).astype('f'))
@@ -2728,7 +2728,7 @@ class BatchNormTest(test.TestCase):
def testBatchNormBeta(self):
# Test case for 11673
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a_32 = array_ops.placeholder(dtypes.float32, shape=(10, 10, 10, 10))
_layers.batch_norm(
a_32, center=False, data_format='NCHW', zero_debias_moving_mean=True)
@@ -2739,7 +2739,7 @@ class BatchNormTest(test.TestCase):
def testVariablesAreFloat32(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, dtype=dtypes.float16)
_layers.batch_norm(images, scale=True)
@@ -2824,7 +2824,7 @@ class LayerNormTest(test.TestCase):
def testCreateOp(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, height, width, 3))
output = _layers.layer_norm(images)
self.assertTrue(output.op.name.startswith('LayerNorm/batchnorm'))
@@ -2832,7 +2832,7 @@ class LayerNormTest(test.TestCase):
def testCreateVariables(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
_layers.layer_norm(images)
beta = variables.get_variables_by_name('beta')[0]
@@ -2842,7 +2842,7 @@ class LayerNormTest(test.TestCase):
def testReuseVariables(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
_layers.layer_norm(images, scope='ln')
_layers.layer_norm(images, scope='ln', reuse=True)
@@ -2853,7 +2853,7 @@ class LayerNormTest(test.TestCase):
def testReuseVars(self):
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
image_shape = (10, height, width, 3)
image_values = np.random.rand(*image_shape)
images = constant_op.constant(
@@ -2940,7 +2940,7 @@ class GDNTest(test.TestCase):
def _runGDN(self, x, shape, inverse, data_format):
inputs = array_ops.placeholder(dtypes.float32, shape)
outputs = _layers.gdn(inputs, inverse=inverse, data_format=data_format)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
y, = sess.run([outputs], {inputs: x})
return y
@@ -3152,14 +3152,14 @@ class MaxPool3DTest(test.TestCase):
class OneHotEncodingTest(test.TestCase):
def testOneHotEncodingCreate(self):
- with self.test_session():
+ with self.cached_session():
labels = np.array([0, 1, 2])
output = _layers.one_hot_encoding(labels, num_classes=3)
self.assertEqual(output.op.name, 'OneHotEncoding/one_hot')
self.assertListEqual(output.get_shape().as_list(), [3, 3])
def testCollectOutputs(self):
- with self.test_session():
+ with self.cached_session():
labels = constant_op.constant([0, 1, 2])
output = _layers.one_hot_encoding(
labels, num_classes=3, outputs_collections='outputs')
@@ -3168,14 +3168,14 @@ class OneHotEncodingTest(test.TestCase):
self.assertEqual(c_output, output)
def testOneHotEncoding(self):
- with self.test_session():
+ with self.cached_session():
labels = constant_op.constant([0, 1, 2])
one_hot_labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
output = _layers.one_hot_encoding(labels, num_classes=3)
self.assertAllClose(output.eval(), one_hot_labels.eval())
def testOneHotEncodingInt32(self):
- with self.test_session():
+ with self.cached_session():
labels = constant_op.constant([0, 1, 2], dtype=dtypes.int32)
one_hot_labels = constant_op.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
output = _layers.one_hot_encoding(labels, num_classes=3)
@@ -3186,7 +3186,7 @@ class RepeatTests(test.TestCase):
def testRepeat(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, height, width, 3)).astype(np.float32)
output = _layers.repeat(images, 3, layers_lib.conv2d, 32, [3, 3])
self.assertEqual(output.op.name, 'Repeat/convolution2d_3/Relu')
@@ -3194,7 +3194,7 @@ class RepeatTests(test.TestCase):
def testRepeatWithScope(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
output = _layers.repeat(
@@ -3207,7 +3207,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateConvInt32(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, dtype=dtypes.int32, maxval=12345)
with self.assertRaisesRegexp(TypeError, 'non-floating point type'):
@@ -3215,7 +3215,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateConvFloat32(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, dtype=dtypes.float32)
output = layers_lib.separable_conv2d(images, 32, [3, 3], 2)
@@ -3224,7 +3224,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateDepthwiseConv(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.separable_conv2d(images, None, [3, 3], 2)
self.assertEqual(output.op.name, 'SeparableConv2d/Relu')
@@ -3233,7 +3233,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateConvCreatesWeightsAndBiasesVars(self):
height, width = 3, 3
images = random_ops.random_uniform((5, height, width, 3), seed=1)
- with self.test_session():
+ with self.cached_session():
self.assertFalse(variables.get_variables('conv1/depthwise_weights'))
self.assertFalse(variables.get_variables('conv1/pointwise_weights'))
self.assertFalse(variables.get_variables('conv1/biases'))
@@ -3245,7 +3245,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateAtrousConvCreatesWeightsAndBiasesVars(self):
height, width = 3, 3
images = random_ops.random_uniform((5, height, width, 3), seed=1)
- with self.test_session():
+ with self.cached_session():
self.assertFalse(variables.get_variables('conv1/depthwise_weights'))
self.assertFalse(variables.get_variables('conv1/pointwise_weights'))
self.assertFalse(variables.get_variables('conv1/biases'))
@@ -3257,7 +3257,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateDepthwiseConvCreatesWeightsAndBiasesVars(self):
height, width = 3, 3
images = random_ops.random_uniform((5, height, width, 3), seed=1)
- with self.test_session():
+ with self.cached_session():
self.assertFalse(variables.get_variables('conv1/depthwise_weights'))
self.assertFalse(variables.get_variables('conv1/pointwise_weights'))
self.assertFalse(variables.get_variables('conv1/biases'))
@@ -3268,14 +3268,14 @@ class SeparableConv2dTest(test.TestCase):
def testCreateConvWithScope(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.separable_conv2d(images, 32, [3, 3], 6, scope='conv1')
self.assertEqual(output.op.name, 'conv1/Relu')
def testCreateConvWithoutActivation(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.separable_conv2d(
images, 32, [3, 3], 8, activation_fn=None)
@@ -3283,7 +3283,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateConvValid(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.separable_conv2d(
images, 32, [3, 3], 2, padding='VALID')
@@ -3291,7 +3291,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateAtrousConvValid(self):
height, width = 5, 5
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.separable_conv2d(
images, 32, [3, 3], 2, padding='VALID', rate=2)
@@ -3299,7 +3299,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateDepthwiseConvValid(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.separable_conv2d(
images, None, [3, 3], 2, padding='VALID')
@@ -3307,7 +3307,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateAtrousDepthwiseConvValid(self):
height, width = 5, 5
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
output = layers_lib.separable_conv2d(
images, None, [3, 3], 2, padding='VALID', rate=2)
@@ -3316,7 +3316,7 @@ class SeparableConv2dTest(test.TestCase):
def testCreateConvWithWeightDecay(self):
random_seed.set_random_seed(0)
height, width = 3, 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = random_ops.random_uniform((5, height, width, 3), seed=1)
regularizer = regularizers.l2_regularizer(0.01)
layers_lib.separable_conv2d(
@@ -3360,7 +3360,7 @@ class SeparableConv2dTest(test.TestCase):
def testReuseConvWithWeightDecay(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform((5, height, width, 3), seed=1)
regularizer = regularizers.l2_regularizer(0.01)
layers_lib.separable_conv2d(
@@ -3419,7 +3419,7 @@ class SeparableConv2dTest(test.TestCase):
normalizer_params={},
scope='conv1')
init_op = variables_lib.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
images = np.random.rand(5, height, width, 3)
sess.run(init_op)
sess.run(net, feed_dict={images_placeholder: images})
@@ -3440,7 +3440,7 @@ class SeparableConv2dTest(test.TestCase):
def testSepConvNCHW(self):
for num_filters, correct_output_filters in zip((None, 5), (6, 5)):
- with self.test_session():
+ with self.cached_session():
batch, height, width = 4, 10, 12
kernel_dim, stride = 3, 2
images = random_ops.random_uniform((batch, 3, height, width), seed=1)
@@ -3462,7 +3462,7 @@ class ScaleGradientTests(test.TestCase):
"""Simple tests of the scale_gradient function."""
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
x = np.array([42], np.float32)
gradient_scale = np.array([2], np.float32)
@@ -3513,7 +3513,7 @@ class SoftmaxTests(test.TestCase):
exp_prediction = np.array([[self.low, self.high], [0.5, 0.5],
[self.high, self.low]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
prediction = sess.run(prediction)
self.assertAllClose(exp_prediction, prediction)
@@ -3529,7 +3529,7 @@ class SoftmaxTests(test.TestCase):
exp_prediction[1, 1, 1] = self.low
prediction = _layers.softmax(logits)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
prediction = sess.run(prediction)
self.assertAllClose(exp_prediction, prediction)
@@ -3547,7 +3547,7 @@ class SoftmaxTests(test.TestCase):
exp_prediction[1, 1, 1] = self.low
prediction = _layers.softmax(logit_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
prediction = sess.run(prediction, feed_dict=feed_dict)
self.assertAllClose(exp_prediction, prediction)
@@ -3575,7 +3575,7 @@ class SpatialSoftmaxTests(test.TestCase):
features = array_ops.placeholder(dtypes.float32, shape=batch_shape)
np_features = np.zeros(batch_shape, dtype=np.float32)
spatial_softmax = _layers.spatial_softmax(features)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
feed_dict = {features: np_features}
keypoints = sess.run(spatial_softmax, feed_dict)
@@ -3586,7 +3586,7 @@ class SpatialSoftmaxTests(test.TestCase):
features = array_ops.placeholder(dtypes.float32, shape=batch_shape)
np_features = np.zeros(batch_shape, dtype=np.float32)
spatial_softmax = _layers.spatial_softmax(features, data_format='NCHW')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
feed_dict = {features: np_features}
keypoints = sess.run(spatial_softmax, feed_dict)
@@ -3613,7 +3613,7 @@ class SpatialSoftmaxTests(test.TestCase):
nchannels)
# Make sure expected location keypoints matches actual location keypoints.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
feed_dict = {features: np_features}
keypoints = sess.run(spatial_softmax, feed_dict)
@@ -3637,7 +3637,7 @@ class SpatialSoftmaxTests(test.TestCase):
nchannels)
# Make sure expected location keypoints matches actual location keypoints.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
feed_dict = {features: np_features}
keypoints = sess.run(spatial_softmax, feed_dict)
@@ -3669,7 +3669,7 @@ class SpatialSoftmaxTests(test.TestCase):
batch_size, nchannels)
# Make sure expected location keypoints matches actual location keypoints.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
feed_dict = {features: np_features1}
tf_keypoints1 = sess.run(spatial_softmax, feed_dict)
@@ -3696,7 +3696,7 @@ class SpatialSoftmaxTests(test.TestCase):
nchannels)
# Make sure expected location keypoints matches actual location keypoints.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
feed_dict = {features: np_features}
keypoints = sess.run(spatial_softmax, feed_dict)
@@ -3719,7 +3719,7 @@ class SpatialSoftmaxTests(test.TestCase):
nchannels)
# Make sure expected location keypoints matches actual location keypoints.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
feed_dict = {features: np_features}
keypoints = sess.run(spatial_softmax, feed_dict)
@@ -3731,7 +3731,7 @@ class SpatialSoftmaxTests(test.TestCase):
spatial_softmax = _layers.spatial_softmax(features)
net = _layers.fully_connected(spatial_softmax, 10)
np_features = np.zeros(batch_shape, dtype=np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
feed_dict = {features: np_features}
sess.run(net, feed_dict)
@@ -3741,7 +3741,7 @@ class StackTests(test.TestCase):
def testStackFullyConnected(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = np.random.uniform(size=(5, height * width * 3))
output = _layers.stack(images, _layers.fully_connected, [10, 20, 30])
self.assertEqual(output.op.name, 'Stack/fully_connected_3/Relu')
@@ -3749,7 +3749,7 @@ class StackTests(test.TestCase):
def testStackFullyConnectedFailOnReuse(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope('test', reuse=True):
images = np.random.uniform(size=(5, height * width * 3))
with self.assertRaises(ValueError):
@@ -3757,7 +3757,7 @@ class StackTests(test.TestCase):
def testStackRelu(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height * width * 3), seed=1, name='images')
output = _layers.stack(images, layers_lib.relu, [10, 20, 30])
@@ -3766,7 +3766,7 @@ class StackTests(test.TestCase):
def testStackElu(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height * width * 3), seed=1, name='images')
output = _layers.stack(images, layers_lib.elu, [10, 20, 30])
@@ -3775,7 +3775,7 @@ class StackTests(test.TestCase):
def testStackConvolution2d(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
output = _layers.stack(
@@ -3788,7 +3788,7 @@ class StackTests(test.TestCase):
def testStackWithScope(self):
height, width = 3, 3
- with self.test_session():
+ with self.cached_session():
images = random_ops.random_uniform(
(5, height, width, 3), seed=1, name='images')
output = _layers.stack(
@@ -3817,7 +3817,7 @@ class UnitNormTests(test.TestCase):
del shape[dim]
expected = np.ones(shape)
- with self.test_session():
+ with self.cached_session():
actual = norms.eval()
self.assertAllClose(expected, actual, 1e-4, 1e-4)
@@ -3849,7 +3849,7 @@ class UnitNormTests(test.TestCase):
norms = math_ops.sqrt(
math_ops.reduce_sum(math_ops.square(output), reduction_indices=dim))
- with self.test_session():
+ with self.cached_session():
actual = norms.eval({image: placeholder_value})
self.assertAllClose(expected, actual, 1e-4, 1e-4)
@@ -3875,7 +3875,7 @@ class PoincareNormalizeTest(test.TestCase):
x_np = np.random.random_sample(x_shape).astype(np.float32)
for dim in range(len(x_shape)):
y_np = self._PoincareNormalize(x_np, dim, epsilon)
- with self.test_session():
+ with self.cached_session():
x_tf = constant_op.constant(x_np, name='x')
y_tf = _layers.poincare_normalize(x_tf, dim, epsilon)
y_tf_eval = y_tf.eval()
@@ -3893,7 +3893,7 @@ class PoincareNormalizeTest(test.TestCase):
x_np = np.random.random_sample(x_shape).astype(np.float32)
dim = [1, 2]
y_np = self._PoincareNormalize(x_np, dim, epsilon)
- with self.test_session():
+ with self.cached_session():
x_tf = constant_op.constant(x_np, name='x')
y_tf = _layers.poincare_normalize(x_tf, dim, epsilon)
y_tf_eval = y_tf.eval()
@@ -3908,7 +3908,7 @@ class PoincareNormalizeTest(test.TestCase):
np.random.seed(1)
x_np = np.random.random_sample(x_shape).astype(np.float64)
for dim in range(len(x_shape)):
- with self.test_session():
+ with self.cached_session():
x_tf = constant_op.constant(x_np, name='x')
y_tf = _layers.poincare_normalize(x_tf, dim)
err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf,
@@ -4117,7 +4117,7 @@ class LegacyFullyConnectedTest(test.TestCase):
# Empty x is common if someone masks their input with tf.boolean_mask in
# order to drop missing entries, and in a particular batch all entries are
# missing.
- with self.test_session():
+ with self.cached_session():
x = np.array([]).reshape(0, 3)
self.assertEqual(0, array_ops.size(x).eval())
y = _layers.legacy_fully_connected(x, 2, activation_fn=nn_ops.softmax)
@@ -4131,7 +4131,7 @@ class LegacyFullyConnectedTest(test.TestCase):
y = _layers.legacy_fully_connected(x, 1)
# in the output we still only know the 2nd and 3rd dimensions statically.
self.assertEqual(y.get_shape().as_list(), [None, 4, 1])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables_lib.global_variables_initializer().run()
# we can feed in input with first dimension 2
shape_value = sess.run(
@@ -4162,7 +4162,7 @@ class LegacyFullyConnectedTest(test.TestCase):
self._unknown_dim_invalid_input(last_dim=None)
def test_1d_invalid_input(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError,
'rank of x must be at least 2 not: 1'):
x = constant_op.constant([[]], shape=[0])
diff --git a/tensorflow/contrib/layers/python/layers/normalization_test.py b/tensorflow/contrib/layers/python/layers/normalization_test.py
index 55272e5fd1..c8d3c91b10 100644
--- a/tensorflow/contrib/layers/python/layers/normalization_test.py
+++ b/tensorflow/contrib/layers/python/layers/normalization_test.py
@@ -106,7 +106,7 @@ class InstanceNormTest(test.TestCase):
images = random_ops.random_uniform(image_shape, seed=1)
output_train = normalization.instance_norm(images, scope='IN')
output_eval = normalization.instance_norm(images, scope='IN', reuse=True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
# output_train and output_eval should be the same.
train_np, eval_np = sess.run([output_train, output_eval])
@@ -130,7 +130,7 @@ class InstanceNormTest(test.TestCase):
inputs = random_ops.random_uniform(input_shape, seed=0) * sigma + mu
output_op = normalization.instance_norm(
inputs, center=False, scale=False, data_format=data_format)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
outputs = sess.run(output_op)
# Make sure that there are no NaNs
@@ -287,7 +287,7 @@ class GroupNormTest(test.TestCase):
output_train = normalization.group_norm(images, groups=2, scope='IN')
output_eval = normalization.group_norm(images, groups=2, scope='IN',
reuse=True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
# output_train and output_eval should be the same.
train_np, eval_np = sess.run([output_train, output_eval])
@@ -349,7 +349,7 @@ class GroupNormTest(test.TestCase):
channels_axis=channels_axis,
reduction_axes=reduction_axes,
mean_close_to_zero=mean_close_to_zero)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
outputs = sess.run(output_op)
# Make sure that there are no NaNs
diff --git a/tensorflow/contrib/layers/python/layers/optimizers_test.py b/tensorflow/contrib/layers/python/layers/optimizers_test.py
index 0f037e24ad..29dede2a49 100644
--- a/tensorflow/contrib/layers/python/layers/optimizers_test.py
+++ b/tensorflow/contrib/layers/python/layers/optimizers_test.py
@@ -165,7 +165,7 @@ class OptimizersTest(test.TestCase):
def testGradientNoise(self):
random_seed.set_random_seed(42)
- with self.test_session() as session:
+ with self.cached_session() as session:
x, var, loss, global_step = _setup_model()
train = optimizers_lib.optimize_loss(
loss,
@@ -182,7 +182,7 @@ class OptimizersTest(test.TestCase):
def testGradientNoiseWithClipping(self):
random_seed.set_random_seed(42)
- with self.test_session() as session:
+ with self.cached_session() as session:
x, var, loss, global_step = _setup_model()
train = optimizers_lib.optimize_loss(
loss,
@@ -198,7 +198,7 @@ class OptimizersTest(test.TestCase):
self.assertEqual(global_step_value, 1)
def testGradientClip(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
x, var, loss, global_step = _setup_model()
train = optimizers_lib.optimize_loss(
loss,
@@ -213,7 +213,7 @@ class OptimizersTest(test.TestCase):
self.assertEqual(global_step_value, 1)
def testAdaptiveGradientClip(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
x, var, loss, global_step = _setup_model()
clip_gradients = optimizers_lib.adaptive_clipping_fn()
train = optimizers_lib.optimize_loss(
@@ -234,7 +234,7 @@ class OptimizersTest(test.TestCase):
self.assertEqual(2, var_count)
def testGradientMultiply(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
x, var, loss, global_step = _setup_model()
train = optimizers_lib.optimize_loss(
loss,
@@ -433,7 +433,7 @@ class OptimizersTest(test.TestCase):
class AdaptiveClipping(test.TestCase):
def testAverages(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
scale = 2.
grad = array_ops.ones([3, 4]) * scale
log_norm = np.log(np.sqrt(scale**2 * grad.get_shape().num_elements()))
@@ -463,7 +463,7 @@ class AdaptiveClipping(test.TestCase):
self.assertAlmostEqual(float(sq_mean), log_norm**2, places=4)
def testClip(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
spike = 1000.
multiplier = array_ops.placeholder(dtypes.float32, [], "multiplier")
step = array_ops.placeholder(dtypes.int32, [], "step")
diff --git a/tensorflow/contrib/layers/python/layers/regularizers_test.py b/tensorflow/contrib/layers/python/layers/regularizers_test.py
index 07191eeda7..51faba30c7 100644
--- a/tensorflow/contrib/layers/python/layers/regularizers_test.py
+++ b/tensorflow/contrib/layers/python/layers/regularizers_test.py
@@ -71,7 +71,7 @@ class RegularizerTest(test.TestCase):
with self.assertRaises(ValueError):
regularizers.l1_l2_regularizer(0.5, 0)
- with self.test_session():
+ with self.cached_session():
shape = [5, 5, 5]
num_elem = 5 * 5 * 5
tensor = constant_op.constant(1.0, shape=shape)
@@ -84,7 +84,7 @@ class RegularizerTest(test.TestCase):
num_elem = 5 * 5 * 5
tensor = constant_op.constant(1.0, shape=shape)
loss = regularizers.l1_l2_regularizer(0.0, 1.0)(tensor)
- with self.test_session():
+ with self.cached_session():
self.assertEquals(loss.op.name, 'l1_l2_regularizer')
self.assertAlmostEqual(loss.eval(), num_elem / 2, 5)
@@ -93,7 +93,7 @@ class RegularizerTest(test.TestCase):
num_elem = 5 * 5 * 5
tensor = constant_op.constant(1.0, shape=shape)
loss = regularizers.l1_l2_regularizer(1.0, 0.0)(tensor)
- with self.test_session():
+ with self.cached_session():
self.assertEquals(loss.op.name, 'l1_l2_regularizer')
self.assertAlmostEqual(loss.eval(), num_elem, 5)
@@ -104,7 +104,7 @@ class RegularizerTest(test.TestCase):
self.assertEquals(loss, None)
def testL1L2RegularizerWithScope(self):
- with self.test_session():
+ with self.cached_session():
shape = [5, 5, 5]
num_elem = 5 * 5 * 5
tensor = constant_op.constant(1.0, shape=shape)
@@ -142,7 +142,7 @@ class RegularizerTest(test.TestCase):
array_weights_list = [[1.5], [2, 3, 4.2], [10, 42, 666.6]]
tensor_weights_list = [constant_op.constant(x) for x in array_weights_list]
expected = sum([2 * x for l in array_weights_list for x in l])
- with self.test_session():
+ with self.cached_session():
result = regularizers.apply_regularization(dummy_regularizer,
tensor_weights_list)
self.assertAllClose(expected, result.eval())
@@ -151,7 +151,7 @@ class RegularizerTest(test.TestCase):
regularizer = regularizers.l2_regularizer(0.0)
array_weights_list = [[1.5], [2, 3, 4.2], [10, 42, 666.6]]
tensor_weights_list = [constant_op.constant(x) for x in array_weights_list]
- with self.test_session():
+ with self.cached_session():
result = regularizers.apply_regularization(regularizer,
tensor_weights_list)
self.assertAllClose(0.0, result.eval())
@@ -161,7 +161,7 @@ class RegularizerTest(test.TestCase):
tensor_weights_list = [
constant_op.constant(x) for x in [[1.5], [2, 3, 4.2], [10, 42, 666.6]]
]
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
regularizers.apply_regularization(non_scalar_regularizer,
tensor_weights_list)
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
index b25f11b5a6..06da32072f 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py
@@ -30,6 +30,7 @@ import functools
import re
import numpy as np
+import six
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.framework.python import ops as contrib_framework_ops
@@ -44,6 +45,7 @@ from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
@@ -471,7 +473,8 @@ def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
Args:
fn: a function that takes Tensors (all as positional arguments) and returns
- a tuple of Tensors.
+ a tuple of Tensors. Note that `fn` should not close over any other
+ Tensors or Variables.
use_data_dep: `bool`, if `True` will use a dummy data dependency to force
the recompute to happen. If `False` will use a control dependency. By
default will be `True` if in an XLA context and `False` otherwise. XLA
@@ -485,7 +488,22 @@ def recompute_grad(fn, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
A wrapped fn that is identical to fn when called, but its activations will
be discarded and recomputed on the backwards pass (i.e. on a call to
tf.gradients).
+
+ Raises:
+ ValueError: if `fn` closes over any Tensors or Variables.
"""
+ # Check for closed-over Tensors/Variables
+ if fn.__code__.co_freevars:
+ closed_over_vars = dict(zip(fn.__code__.co_freevars,
+ [c.cell_contents for c in fn.__closure__]))
+ for var_name, value in six.iteritems(closed_over_vars):
+ if isinstance(value, (framework_ops.Tensor, variables_lib.Variable)):
+ raise ValueError(
+ "fn decorated with @recompute_grad closes over Tensor %s "
+ "(local variable name: %s). The decorated fn must not close over "
+ "Tensors or Variables because gradients will NOT be computed for "
+ "them through fn. To ensure correct gradients, make the "
+ "Tensor an input to fn." % (value.name, var_name))
@_safe_wraps(fn)
def wrapped(*args):
@@ -500,6 +518,62 @@ def _is_on_tpu():
return control_flow_util.GetContainingXLAContext(ctxt) is not None
+def _recomputing_grad_fn(compute_fn,
+ original_args,
+ original_vars,
+ output_grads,
+ grad_fn_variables,
+ use_data_dep,
+ tupleize_grads,
+ arg_scope,
+ var_scope,
+ has_is_recompute_kwarg):
+ """Grad fn for recompute_grad."""
+ variables = grad_fn_variables or []
+
+ # Identity ops around the inputs ensures correct gradient graph-walking.
+ inputs = [array_ops.identity(x) for x in list(original_args)]
+
+ # Recompute outputs
+ # Use a control dependency to ensure that the recompute is not eliminated by
+ # CSE and that it happens on the backwards pass.
+ ctrl_dep_grads = [g for g in output_grads if g is not None]
+ with framework_ops.control_dependencies(ctrl_dep_grads):
+ if use_data_dep:
+ inputs = _force_data_dependency(output_grads, inputs)
+ # Re-enter scopes
+ with contrib_framework_ops.arg_scope(arg_scope):
+ with variable_scope.variable_scope(var_scope, reuse=True):
+ # Re-call the function and ensure that the touched variables are the
+ # same as in the first call.
+ with backprop.GradientTape() as tape:
+ fn_kwargs = {}
+ if has_is_recompute_kwarg:
+ fn_kwargs["is_recomputing"] = True
+ outputs = compute_fn(*inputs, **fn_kwargs)
+ recompute_vars = set(tape.watched_variables())
+ if original_vars != recompute_vars:
+ raise ValueError(_WRONG_VARS_ERR)
+
+ if not isinstance(outputs, (list, tuple)):
+ outputs = [outputs]
+ outputs = list(outputs)
+
+ # Compute gradients
+ grads = gradients_impl.gradients(outputs, inputs + variables,
+ output_grads)
+
+ if tupleize_grads:
+ if use_data_dep:
+ grads = _tuple_with_data_dep(grads)
+ else:
+ grads = control_flow_ops.tuple(grads)
+
+ grad_inputs = grads[:len(inputs)]
+ grad_vars = grads[len(inputs):]
+ return grad_inputs, grad_vars
+
+
def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
"""See recompute_grad."""
has_is_recompute_kwarg = "is_recomputing" in tf_inspect.getargspec(fn).args
@@ -510,12 +584,16 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
if use_data_dep_ == _USE_DEFAULT:
use_data_dep_ = _is_on_tpu()
+ # Use custom_gradient and return a grad_fn that recomputes on the backwards
+ # pass.
@custom_gradient.custom_gradient
def fn_with_recompute(*args):
"""Wrapper for fn."""
- # Forward pass
+ # Capture the variable and arg scopes so we can re-enter them when
+ # recomputing.
vs = variable_scope.get_variable_scope()
arg_scope = contrib_framework_ops.current_arg_scope()
+ # Track all variables touched in the function.
with backprop.GradientTape() as tape:
fn_kwargs = {}
if has_is_recompute_kwarg:
@@ -523,46 +601,25 @@ def _recompute_grad(fn, args, use_data_dep=_USE_DEFAULT, tupleize_grads=False):
outputs = fn(*args, **fn_kwargs)
original_vars = set(tape.watched_variables())
- # Backward pass
def _grad_fn(output_grads, variables=None):
- """Recompute outputs for gradient computation."""
- variables = variables or []
+ # Validate that custom_gradient passes the right variables into grad_fn.
if original_vars:
assert variables, ("Fn created variables but the variables were not "
"passed to the gradient fn.")
if set(variables) != original_vars:
raise ValueError(_WRONG_VARS_ERR)
- inputs = [array_ops.identity(x) for x in list(args)]
- # Recompute outputs
- with framework_ops.control_dependencies(output_grads):
- if use_data_dep_:
- inputs = _force_data_dependency(output_grads, inputs)
- with contrib_framework_ops.arg_scope(arg_scope):
- with variable_scope.variable_scope(vs, reuse=True):
- with backprop.GradientTape() as tape:
- fn_kwargs = {}
- if has_is_recompute_kwarg:
- fn_kwargs["is_recomputing"] = True
- outputs = fn(*inputs, **fn_kwargs)
- recompute_vars = set(tape.watched_variables())
- if original_vars != recompute_vars:
- raise ValueError(_WRONG_VARS_ERR)
-
- if not isinstance(outputs, (list, tuple)):
- outputs = [outputs]
- outputs = list(outputs)
- grads = gradients_impl.gradients(outputs, inputs + variables,
- output_grads)
-
- if tupleize_grads:
- if use_data_dep_:
- grads = _tuple_with_data_dep(grads)
- else:
- grads = control_flow_ops.tuple(grads)
- grad_inputs = grads[:len(inputs)]
- grad_vars = grads[len(inputs):]
- return grad_inputs, grad_vars
+ return _recomputing_grad_fn(
+ compute_fn=fn,
+ original_args=args,
+ original_vars=original_vars,
+ output_grads=output_grads,
+ grad_fn_variables=variables,
+ use_data_dep=use_data_dep_,
+ tupleize_grads=tupleize_grads,
+ arg_scope=arg_scope,
+ var_scope=vs,
+ has_is_recompute_kwarg=has_is_recompute_kwarg)
# custom_gradient inspects the signature of the function to determine
# whether the user expects variables passed in the grad_fn. If the function
diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
index d5971fb9d8..2c7463acc0 100644
--- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
+++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py
@@ -58,7 +58,7 @@ class RevBlockTest(test.TestCase):
y1, y2 = block.forward(x1, x2)
x1_inv, x2_inv = block.backward(y1, y2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
x1, x2, x1_inv, x2_inv = sess.run([x1, x2, x1_inv, x2_inv])
@@ -81,7 +81,7 @@ class RevBlockTest(test.TestCase):
x1, x2 = block.backward(y1, y2)
y1_inv, y2_inv = block.forward(x1, x2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
y1, y2, y1_inv, y2_inv = sess.run([y1, y2, y1_inv, y2_inv])
@@ -151,7 +151,7 @@ class RevBlockTest(test.TestCase):
grads_rev = gradients_impl.gradients(loss_rev, wrt)
grads = gradients_impl.gradients(loss, wrt)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
y_val, yd_val, gd_val, g_val = sess.run([y, y_rev, grads_rev, grads])
self.assertAllClose(y_val, yd_val)
@@ -286,7 +286,7 @@ class RecomputeTest(test.TestCase):
for out, scope_vars in outputs_and_vars:
all_grads.append(gradients_impl.gradients(out, scope_vars))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
outputs = list(zip(*outputs_and_vars))[0]
outs, all_grads_val = sess.run([outputs, all_grads])
@@ -389,9 +389,19 @@ class RecomputeTest(test.TestCase):
layer_list.append(math_ops.sqrt(concat_n_wrap(*layer_list)))
grads = gradients_impl.gradients(layer_list[-1], layer_list[0])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(grads)
+ def testErrorOnClosedOverTensor(self):
+ x = random_ops.random_uniform((4, 8))
+ y = random_ops.random_uniform((4, 8))
+ z = x * y
+
+ with self.assertRaisesWithPredicateMatch(ValueError, "closes over"):
+ @rev_block_lib.recompute_grad
+ def fn_with_capture(a): # pylint: disable=unused-variable
+ return a * z
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/layers/python/layers/summaries_test.py b/tensorflow/contrib/layers/python/layers/summaries_test.py
index a1ef06feec..2ec2af9d44 100644
--- a/tensorflow/contrib/layers/python/layers/summaries_test.py
+++ b/tensorflow/contrib/layers/python/layers/summaries_test.py
@@ -29,19 +29,19 @@ from tensorflow.python.platform import test
class SummariesTest(test.TestCase):
def test_summarize_scalar_tensor(self):
- with self.test_session():
+ with self.cached_session():
scalar_var = variables.Variable(1)
summary_op = summaries_lib.summarize_tensor(scalar_var)
self.assertEquals(summary_op.op.type, 'ScalarSummary')
def test_summarize_multidim_tensor(self):
- with self.test_session():
+ with self.cached_session():
tensor_var = variables.Variable([1, 2, 3])
summary_op = summaries_lib.summarize_tensor(tensor_var)
self.assertEquals(summary_op.op.type, 'HistogramSummary')
def test_summarize_activation(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable(1)
op = array_ops.identity(var, name='SummaryTest')
summary_op = summaries_lib.summarize_activation(op)
@@ -52,7 +52,7 @@ class SummariesTest(test.TestCase):
self.assertIn(u'SummaryTest/activation', names)
def test_summarize_activation_relu(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable(1)
op = nn_ops.relu(var, name='SummaryTest')
summary_op = summaries_lib.summarize_activation(op)
@@ -64,7 +64,7 @@ class SummariesTest(test.TestCase):
self.assertIn(u'SummaryTest/activation', names)
def test_summarize_activation_relu6(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable(1)
op = nn_ops.relu6(var, name='SummaryTest')
summary_op = summaries_lib.summarize_activation(op)
@@ -77,7 +77,7 @@ class SummariesTest(test.TestCase):
self.assertIn(u'SummaryTest/activation', names)
def test_summarize_collection_regex(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable(1)
array_ops.identity(var, name='Test1')
ops.add_to_collection('foo', array_ops.identity(var, name='Test2'))
diff --git a/tensorflow/contrib/layers/python/layers/utils_test.py b/tensorflow/contrib/layers/python/layers/utils_test.py
index a9bd89532a..34f63f5d86 100644
--- a/tensorflow/contrib/layers/python/layers/utils_test.py
+++ b/tensorflow/contrib/layers/python/layers/utils_test.py
@@ -42,7 +42,7 @@ class ConstantValueTest(test.TestCase):
c = constant_op.constant(v)
value = utils.constant_value(c)
self.assertEqual(value, v)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(c.eval(), v)
def test_variable(self):
@@ -60,7 +60,7 @@ class ConstantValueTest(test.TestCase):
x = array_ops.identity(p)
value = utils.constant_value(p)
self.assertEqual(value, None)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(x.eval(feed_dict={p: v}), v)
@@ -80,7 +80,7 @@ class StaticCondTest(test.TestCase):
expected = lambda v: b'fn1' if v else b'fn2'
for v in [True, False, 1, 0]:
o = utils.static_cond(v, fn1, fn2)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(o.eval(), expected(v))
def test_variable(self):
@@ -89,7 +89,7 @@ class StaticCondTest(test.TestCase):
expected = lambda v: b'fn1' if v else b'fn2'
for v in [True, False, 1, 0]:
o = utils.static_cond(v, fn1, fn2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
self.assertEqual(o.eval(), expected(v))
@@ -99,7 +99,7 @@ class StaticCondTest(test.TestCase):
expected = lambda v: -1 if v else -2
for v in [True, False, 1, 0]:
o = utils.static_cond(v, fn1, fn2)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(o.eval(), expected(v))
@@ -119,7 +119,7 @@ class SmartCondStaticTest(test.TestCase):
expected = lambda v: b'fn1' if v else b'fn2'
for v in [True, False, 1, 0]:
o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(o.eval(), expected(v))
def test_variable(self):
@@ -128,7 +128,7 @@ class SmartCondStaticTest(test.TestCase):
expected = lambda v: b'fn1' if v else b'fn2'
for v in [True, False, 1, 0]:
o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
self.assertEqual(o.eval(), expected(v))
@@ -138,7 +138,7 @@ class SmartCondStaticTest(test.TestCase):
expected = lambda v: -1 if v else -2
for v in [True, False, 1, 0]:
o = utils.smart_cond(constant_op.constant(v), fn1, fn2)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(o.eval(), expected(v))
@@ -151,7 +151,7 @@ class SmartCondDynamicTest(test.TestCase):
p = array_ops.placeholder(dtypes.bool, [])
for v in [True, False, 1, 0]:
o = utils.smart_cond(p, fn1, fn2)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
def test_constant(self):
@@ -161,7 +161,7 @@ class SmartCondDynamicTest(test.TestCase):
p = array_ops.placeholder(dtypes.bool, [])
for v in [True, False, 1, 0]:
o = utils.smart_cond(p, fn1, fn2)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
def test_variable(self):
@@ -171,7 +171,7 @@ class SmartCondDynamicTest(test.TestCase):
p = array_ops.placeholder(dtypes.bool, [])
for v in [True, False, 1, 0]:
o = utils.smart_cond(p, fn1, fn2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
@@ -182,7 +182,7 @@ class SmartCondDynamicTest(test.TestCase):
p = array_ops.placeholder(dtypes.bool, [])
for v in [True, False, 1, 0]:
o = utils.smart_cond(p, fn1, fn2)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(o.eval(feed_dict={p: v}), expected(v))
diff --git a/tensorflow/contrib/layers/python/ops/sparse_ops_test.py b/tensorflow/contrib/layers/python/ops/sparse_ops_test.py
index d50750001e..b6c2cab64a 100644
--- a/tensorflow/contrib/layers/python/ops/sparse_ops_test.py
+++ b/tensorflow/contrib/layers/python/ops/sparse_ops_test.py
@@ -42,7 +42,7 @@ def _assert_sparse_tensor_value(test_case, expected, actual):
class DenseToSparseTensorTest(test.TestCase):
def test_dense_to_sparse_tensor_1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
st = sparse_ops.dense_to_sparse_tensor([1, 0, 2, 0])
result = sess.run(st)
self.assertEqual(result.indices.dtype, np.int64)
@@ -53,7 +53,7 @@ class DenseToSparseTensorTest(test.TestCase):
self.assertAllEqual([4], result.dense_shape)
def test_dense_to_sparse_tensor_1d_float(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
st = sparse_ops.dense_to_sparse_tensor([1.5, 0.0, 2.3, 0.0])
result = sess.run(st)
self.assertEqual(result.indices.dtype, np.int64)
@@ -64,7 +64,7 @@ class DenseToSparseTensorTest(test.TestCase):
self.assertAllEqual([4], result.dense_shape)
def test_dense_to_sparse_tensor_1d_bool(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
st = sparse_ops.dense_to_sparse_tensor([True, False, True, False])
result = sess.run(st)
self.assertEqual(result.indices.dtype, np.int64)
@@ -75,7 +75,7 @@ class DenseToSparseTensorTest(test.TestCase):
self.assertAllEqual([4], result.dense_shape)
def test_dense_to_sparse_tensor_1d_str(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
st = sparse_ops.dense_to_sparse_tensor([b'qwe', b'', b'ewq', b''])
result = sess.run(st)
self.assertEqual(result.indices.dtype, np.int64)
@@ -86,7 +86,7 @@ class DenseToSparseTensorTest(test.TestCase):
self.assertAllEqual([4], result.dense_shape)
def test_dense_to_sparse_tensor_1d_str_special_ignore(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
st = sparse_ops.dense_to_sparse_tensor(
[b'qwe', b'', b'ewq', b''], ignore_value=b'qwe')
result = sess.run(st)
@@ -98,7 +98,7 @@ class DenseToSparseTensorTest(test.TestCase):
self.assertAllEqual([4], result.dense_shape)
def test_dense_to_sparse_tensor_2d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
st = sparse_ops.dense_to_sparse_tensor([[1, 2, 0, 0], [3, 4, 5, 0]])
result = sess.run(st)
self.assertAllEqual([[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]],
@@ -107,7 +107,7 @@ class DenseToSparseTensorTest(test.TestCase):
self.assertAllEqual([2, 4], result.dense_shape)
def test_dense_to_sparse_tensor_3d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
st = sparse_ops.dense_to_sparse_tensor([[[1, 2, 0, 0], [3, 4, 5, 0]],
[[7, 8, 0, 0], [9, 0, 0, 0]]])
result = sess.run(st)
@@ -117,7 +117,7 @@ class DenseToSparseTensorTest(test.TestCase):
self.assertAllEqual([2, 2, 4], result.dense_shape)
def test_dense_to_sparse_tensor_unknown_1d_shape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tensor = array_ops.placeholder(shape=[None], dtype=dtypes.int32)
st = sparse_ops.dense_to_sparse_tensor(tensor)
result = sess.run(st, feed_dict={tensor: [0, 100, 0, 3]})
@@ -126,7 +126,7 @@ class DenseToSparseTensorTest(test.TestCase):
self.assertAllEqual([4], result.dense_shape)
def test_dense_to_sparse_tensor_unknown_3d_shape(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
tensor = array_ops.placeholder(
shape=[None, None, None], dtype=dtypes.int32)
st = sparse_ops.dense_to_sparse_tensor(tensor)
@@ -142,7 +142,7 @@ class DenseToSparseTensorTest(test.TestCase):
def test_dense_to_sparse_unknown_rank(self):
ph = array_ops.placeholder(dtype=dtypes.int32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
st = sparse_ops.dense_to_sparse_tensor(ph)
result = sess.run(st, feed_dict={ph: [[1, 2, 0, 0], [3, 4, 5, 0]]})
self.assertAllEqual([[0, 0], [0, 1], [1, 0], [1, 1], [1, 2]],
@@ -155,7 +155,7 @@ class SparseRowEnvelopeTest(test.TestCase):
def test_sparse_row_envelope(self):
expected_sparse_row_envelope = [1, 0, 3]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sparse_input = sparse_tensor.SparseTensor(
indices=[[0, 0], [2, 0], [2, 1], [2, 2]],
values=[0, 1, 2, 3],
@@ -167,7 +167,7 @@ class SparseRowEnvelopeTest(test.TestCase):
def test_sparse_row_envelope_unsorted_indices(self):
expected_sparse_row_envelope = [1, 0, 3]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sparse_input = sparse_tensor.SparseTensor(
indices=[[2, 0], [2, 2], [2, 1], [0, 0]],
values=[0, 1, 2, 3],
@@ -179,7 +179,7 @@ class SparseRowEnvelopeTest(test.TestCase):
def test_sparse_row_envelope_empty_in_the_end(self):
expected_sparse_row_envelope = [1, 0, 3, 0, 0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sparse_input = sparse_tensor.SparseTensor(
indices=[[0, 0], [2, 0], [2, 1], [2, 2]],
values=[0, 1, 2, 3],
@@ -191,7 +191,7 @@ class SparseRowEnvelopeTest(test.TestCase):
def test_sparse_row_envelope_empty_3d(self):
expected_sparse_row_envelope = [1, 0, 3, 0, 0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sparse_input = sparse_tensor.SparseTensor(
indices=[[0, 0, 0], [0, 2, 0], [0, 2, 1], [0, 2, 2]],
values=[0, 1, 2, 3],
@@ -207,7 +207,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
def test_indicators_to_sparse_ids_1d(self):
indicators = (0, 0, 1, 0)
sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
indices=((0,),),
values=(2,),
@@ -220,7 +220,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
(1, 0, 0, 1),
)
sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=(2, 0, 3),
@@ -235,7 +235,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
((1, 0, 0, 1, 1), (0, 0, 1, 0, 0)),
)
sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
indices=(
(0, 0, 0),
@@ -255,7 +255,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
)
sparse_ids = sparse_ops.indicators_to_sparse_ids(
indicators, dtype=dtypes.int16)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=np.array((2, 0, 3), dtype=np.int16),
@@ -269,7 +269,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
)
sparse_ids = sparse_ops.indicators_to_sparse_ids(
indicators, ignore_value=-1)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
values=(2, 0, 3, 2),
@@ -282,7 +282,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
(('B', '', '', 'C'), ('', '', 'D', '')),
)
sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
values=(2, 0, 3, 2),
@@ -296,7 +296,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
)
sparse_ids = sparse_ops.indicators_to_sparse_ids(
indicators, ignore_value='x')
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
values=(2, 0, 3, 2),
@@ -311,7 +311,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
indicators = array_ops.placeholder(
dtype=dtypes.int32, shape=(None, None, None))
sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
values=(2, 0, 3, 2),
@@ -325,7 +325,7 @@ class IndicatorToSparseIdsTest(test.TestCase):
)
indicators = array_ops.placeholder(dtype=dtypes.int32)
sparse_ids = sparse_ops.indicators_to_sparse_ids(indicators)
- with self.test_session():
+ with self.cached_session():
_assert_sparse_tensor_value(self, sparse_tensor.SparseTensorValue(
indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
values=(2, 0, 3, 2),
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index 418b0cf392..61185f65a9 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -403,6 +403,7 @@ py_test(
srcs = ["python/learn/estimators/dnn_test.py"],
shard_count = 4,
srcs_version = "PY2AND3",
+ tags = ["notap"],
deps = [
":learn",
"//tensorflow/contrib/layers:layers_py",
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py
index 5e07b9313f..284a4f45f6 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py
@@ -147,7 +147,7 @@ class DataFeederTest(test.TestCase):
def test_unsupervised(self):
def func(feeder):
- with self.test_session():
+ with self.cached_session():
inp, _ = feeder.input_builder()
feed_dict_fn = feeder.get_feed_dict_fn()
feed_dict = feed_dict_fn()
@@ -181,7 +181,7 @@ class DataFeederTest(test.TestCase):
def test_epoch(self):
def func(feeder):
- with self.test_session():
+ with self.cached_session():
feeder.input_builder()
epoch = feeder.make_epoch_variable()
feed_dict_fn = feeder.get_feed_dict_fn()
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py
index 7e81f2b7d9..5e90d1fa20 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py
@@ -38,7 +38,7 @@ class GeneratorIoTest(test.TestCase):
'label': np.ones(1) * index - 32
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator,
target_key='label',
@@ -68,7 +68,7 @@ class GeneratorIoTest(test.TestCase):
for index in range(2):
yield {'a': np.ones(1) * index}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator, target_key=None, batch_size=2, shuffle=False, num_epochs=1)
features = input_fn()
@@ -97,7 +97,7 @@ class GeneratorIoTest(test.TestCase):
'label2': np.ones(1) * index - 64,
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator,
target_key=['label', 'label2'],
@@ -134,7 +134,7 @@ class GeneratorIoTest(test.TestCase):
'label': np.ones((3, 3)) * index - 32
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator,
target_key='label',
@@ -162,7 +162,7 @@ class GeneratorIoTest(test.TestCase):
def testGeneratorInputFnWithXAsNonGeneratorFunction(self):
x = np.arange(32, 36)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, 'x must be generator function'):
failing_input_fn = generator_io.generator_input_fn(
x, batch_size=2, shuffle=False, num_epochs=1)
@@ -173,7 +173,7 @@ class GeneratorIoTest(test.TestCase):
def generator():
return np.arange(32, 36)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, 'x\(\) must be generator'):
failing_input_fn = generator_io.generator_input_fn(
generator, batch_size=2, shuffle=False, num_epochs=1)
@@ -184,7 +184,7 @@ class GeneratorIoTest(test.TestCase):
def generator():
yield np.arange(32, 36)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, 'x\(\) must yield dict'):
failing_input_fn = generator_io.generator_input_fn(
generator, batch_size=2, shuffle=False, num_epochs=1)
@@ -201,7 +201,7 @@ class GeneratorIoTest(test.TestCase):
}
y = np.arange(32, 36)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, 'target_key must be str or'
' Container of str'):
failing_input_fn = generator_io.generator_input_fn(
@@ -219,7 +219,7 @@ class GeneratorIoTest(test.TestCase):
}
y = ['label', np.arange(10)]
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, 'target_key must be str or'
' Container of str'):
failing_input_fn = generator_io.generator_input_fn(
@@ -237,7 +237,7 @@ class GeneratorIoTest(test.TestCase):
}
y = ['label', 'target']
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(KeyError, 'target_key not in yielded dict'):
failing_input_fn = generator_io.generator_input_fn(
generator, target_key=y, batch_size=2, shuffle=False, num_epochs=1)
@@ -253,7 +253,7 @@ class GeneratorIoTest(test.TestCase):
'label': np.ones(1) * index - 32
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator, target_key=None, batch_size=2, shuffle=False, num_epochs=1)
features = input_fn()
@@ -283,7 +283,7 @@ class GeneratorIoTest(test.TestCase):
'label': np.ones(1) * index - 32
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator, target_key=None, batch_size=4, shuffle=False, num_epochs=1)
features = input_fn()
@@ -319,7 +319,7 @@ class GeneratorIoTest(test.TestCase):
'label': np.ones(1) * index - 32
}
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = generator_io.generator_input_fn(
generator, target_key=None, batch_size=2, shuffle=False, num_epochs=1)
features = input_fn()
diff --git a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py
index c738f0e8f3..396539a76a 100644
--- a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py
+++ b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io_test.py
@@ -65,7 +65,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ProducesExpectedOutputs(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
@@ -79,7 +79,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ProducesOutputsForLargeBatchAndMultipleEpochs(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
index = np.arange(100, 102)
a = np.arange(2)
b = np.arange(32, 34)
@@ -107,7 +107,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ProducesOutputsWhenDataSizeNotDividedByBatchSize(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
index = np.arange(100, 105)
a = np.arange(5)
b = np.arange(32, 37)
@@ -146,7 +146,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_OnlyX(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, _ = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y=None, batch_size=2, shuffle=False, num_epochs=1)
@@ -159,7 +159,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_ExcludesIndex(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=2, shuffle=False, num_epochs=1)
@@ -182,7 +182,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_RespectsEpoch_NoShuffle(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=4, shuffle=False, num_epochs=1)
@@ -192,7 +192,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_RespectsEpoch_WithShuffle(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=4, shuffle=True, num_epochs=1)
@@ -202,7 +202,7 @@ class PandasIoTest(test.TestCase):
def testPandasInputFn_RespectsEpoch_WithShuffleAutosize(self):
if not HAS_PANDAS:
return
- with self.test_session() as session:
+ with self.cached_session() as session:
x, y = self.makeTestDataFrame()
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=2, shuffle=True, queue_capacity=None, num_epochs=2)
@@ -213,7 +213,7 @@ class PandasIoTest(test.TestCase):
if not HAS_PANDAS:
return
x, y = self.makeTestDataFrame()
- with self.test_session() as session:
+ with self.cached_session() as session:
input_fn = pandas_io.pandas_input_fn(
x, y, batch_size=3, shuffle=False, num_epochs=1)
diff --git a/tensorflow/contrib/learn/python/learn/ops/ops_test.py b/tensorflow/contrib/learn/python/learn/ops/ops_test.py
index 80d4923db3..ff190110c1 100644
--- a/tensorflow/contrib/learn/python/learn/ops/ops_test.py
+++ b/tensorflow/contrib/learn/python/learn/ops/ops_test.py
@@ -33,7 +33,7 @@ class OpsTest(test.TestCase):
"""Ops tests."""
def test_softmax_classifier(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
features = array_ops.placeholder(dtypes.float32, [None, 3])
labels = array_ops.placeholder(dtypes.float32, [None, 2])
weights = constant_op.constant([[0.1, 0.1], [0.1, 0.1], [0.1, 0.1]])
@@ -52,7 +52,7 @@ class OpsTest(test.TestCase):
ids_shape = (2, 3, 4)
embeds = np.random.randn(n_embed, d_embed)
ids = np.random.randint(0, n_embed, ids_shape)
- with self.test_session():
+ with self.cached_session():
embed_np = embeds[ids]
embed_tf = ops.embedding_lookup(embeds, ids).eval()
self.assertEqual(embed_np.shape, embed_tf.shape)
@@ -60,7 +60,7 @@ class OpsTest(test.TestCase):
def test_categorical_variable(self):
random_seed.set_random_seed(42)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
cat_var_idx = array_ops.placeholder(dtypes.int64, [2, 2])
embeddings = ops.categorical_variable(
cat_var_idx, n_classes=5, embedding_size=10, name="my_cat_var")
diff --git a/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py b/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py
index 95aec61955..5a7e4ebfea 100644
--- a/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py
+++ b/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops_test.py
@@ -31,7 +31,7 @@ class Seq2SeqOpsTest(test.TestCase):
"""Sequence-to-sequence tests."""
def test_sequence_classifier(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
decoding = [
array_ops.placeholder(dtypes.float32, [2, 2]) for _ in range(3)
]
@@ -60,7 +60,7 @@ class Seq2SeqOpsTest(test.TestCase):
def test_seq2seq_inputs(self):
inp = np.array([[[1, 0], [0, 1], [1, 0]], [[0, 1], [1, 0], [0, 1]]])
out = np.array([[[0, 1, 0], [1, 0, 0]], [[1, 0, 0], [0, 1, 0]]])
- with self.test_session() as session:
+ with self.cached_session() as session:
x = array_ops.placeholder(dtypes.float32, [2, 3, 2])
y = array_ops.placeholder(dtypes.float32, [2, 2, 3])
in_x, in_y, out_y = ops.seq2seq_inputs(x, y, 3, 2)
@@ -77,7 +77,7 @@ class Seq2SeqOpsTest(test.TestCase):
[[0, 0, 0], [0, 0, 0]]])
def test_rnn_decoder(self):
- with self.test_session():
+ with self.cached_session():
decoder_inputs = [
array_ops.placeholder(dtypes.float32, [2, 2]) for _ in range(3)
]
diff --git a/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py b/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py
index 423dcce8de..8390ddda90 100644
--- a/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py
+++ b/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py
@@ -29,7 +29,7 @@ from tensorflow.python.platform import test
class DecodeLibsvmOpTest(test.TestCase):
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
content = [
"1 1:3.4 2:0.5 4:0.231", "1 2:2.5 3:inf 5:0.503",
"2 3:2.5 2:nan 1:0.105"
@@ -48,7 +48,7 @@ class DecodeLibsvmOpTest(test.TestCase):
[0, 0.105, np.nan, 2.5, 0, 0]])
def testNDimension(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
content = [["1 1:3.4 2:0.5 4:0.231", "1 1:3.4 2:0.5 4:0.231"],
["1 2:2.5 3:inf 5:0.503", "1 2:2.5 3:inf 5:0.503"],
["2 3:2.5 2:nan 1:0.105", "2 3:2.5 2:nan 1:0.105"]]
diff --git a/tensorflow/contrib/linear_optimizer/kernels/g3doc/readme.md b/tensorflow/contrib/linear_optimizer/kernels/g3doc/readme.md
index a4f5086dde..5fe883d647 100644
--- a/tensorflow/contrib/linear_optimizer/kernels/g3doc/readme.md
+++ b/tensorflow/contrib/linear_optimizer/kernels/g3doc/readme.md
@@ -199,6 +199,46 @@ does.
However, in practice, convergence with $$x_0 = 0$$ always happens (tested for a
sample of generic values for the parameters).
+### Poisson log loss
+
+Poisson log loss is defined as $$ \l(u) = e^u - uy $$ for label $$y \geq 0.$$
+Its dual is
+
+$$ \l^\star(v) = (y+v) (\log(y+v) - 1) $$
+
+and is only defined for $$ y+v > 0 $$. We then have the constraint
+
+$$ y > \a+\d. $$
+
+The dual is
+
+$$ D(\d) = -(y-\a-\d) (\log(y-\a-\d) - 1) - \bar{y} \d - \frac{A}{2} \d^2 $$
+
+and its derivative is,
+
+$$ D'(\d) = \log(y-\a-\d) - \bar{y} - A\d $$
+
+Similar to the logistic loss, we perform a change of variable to handle the
+constraint on $$ \d $$
+
+$$ y - (\a+\d) = e^x $$
+
+After this change of variable, the goal is to find the zero of this function
+
+$$ H(x) = x - \bar{y} -A(y-\a-e^x) $$
+
+whose first derivative is
+
+$$ H'(x) = 1+Ae^x $$
+
+Since this function is always positive, $$H$$ is increasing and has a unique
+zero.
+
+We can start Newton algorithm at $$\d=0$$ which corresponds to $$ x =
+\log(y-\a)$$. As before the Newton step is given by
+
+$$x_{k+1} = x_k - \frac{H(x_k)}{H'(x_k)}. $$
+
### References
[1] C. Ma et al., Adding vs. Averaging in Distributed Primal-Dual Optimization,
diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
index ef0e08a777..1d2db1cec8 100644
--- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
@@ -1192,6 +1192,57 @@ class SdcaWithSmoothHingeLossTest(SdcaModelTest):
self.assertAllClose(0.33, unregularized_loss.eval(), atol=0.02)
self.assertAllClose(0.44, regularized_loss.eval(), atol=0.02)
+class SdcaWithPoissonLossTest(SdcaModelTest):
+ """SDCA optimizer test class for poisson loss."""
+
+ def testSimple(self):
+ # Setup test data
+ example_protos = [
+ make_example_proto({
+ 'age': [0],
+ 'gender': [0]
+ }, 0),
+ make_example_proto({
+ 'age': [1],
+ 'gender': [1]
+ }, 2),
+ ]
+ example_weights = [100.0, 100.0]
+ with self._single_threaded_test_session():
+ examples = make_example_dict(example_protos, example_weights)
+ variables = make_variable_dict(1, 1)
+ options = dict(
+ symmetric_l2_regularization=1.0,
+ symmetric_l1_regularization=0,
+ loss_type='poisson_loss')
+ model = SdcaModel(examples, variables, options)
+ variables_lib.global_variables_initializer().run()
+
+ # Before minimization, the weights default to zero. There is no loss due
+ # to regularization, only unregularized loss which is 1 for each example.
+ predictions = model.predictions(examples)
+ self.assertAllClose([1.0, 1.0], predictions.eval())
+ unregularized_loss = model.unregularized_loss(examples)
+ regularized_loss = model.regularized_loss(examples)
+ approximate_duality_gap = model.approximate_duality_gap()
+ self.assertAllClose(1.0, unregularized_loss.eval())
+ self.assertAllClose(1.0, regularized_loss.eval())
+
+ # There are 4 sparse weights: 2 for age (say w1, w2) and 2 for gender
+ # (say w3 and w4). The minimization leads to:
+ # w1=w3=-1.96487, argmin of 100*(exp(2*w)-2*w*0)+w**2.
+ # w2=w4=0.345708, argmin of 100*(exp(2*w)-2*w*2)+w**2.
+ # This gives an unregularized loss of .3167 and .3366 with regularization.
+ train_op = model.minimize()
+ for _ in range(_MAX_ITERATIONS):
+ train_op.run()
+ model.update_weights(train_op).run()
+
+ self.assertAllClose([0.0196, 1.9965], predictions.eval(), atol=1e-4)
+ self.assertAllClose(0.3167, unregularized_loss.eval(), atol=1e-4)
+ self.assertAllClose(0.3366, regularized_loss.eval(), atol=1e-4)
+ self.assertAllClose(0., approximate_duality_gap.eval(), atol=1e-6)
+
class SdcaFprintTest(SdcaModelTest):
"""Tests for the SdcaFprint op.
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
index 0047d5753a..14f59a3f64 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py
@@ -35,6 +35,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables as var_ops
+from tensorflow.python.ops.nn import log_poisson_loss
from tensorflow.python.ops.nn import sigmoid_cross_entropy_with_logits
from tensorflow.python.summary import summary
@@ -51,6 +52,7 @@ class SdcaModel(object):
* Squared loss
* Hinge loss
* Smooth hinge loss
+ * Poisson log loss
This class defines an optimizer API to train a linear model.
@@ -112,7 +114,7 @@ class SdcaModel(object):
raise ValueError('examples, variables and options must all be specified.')
supported_losses = ('logistic_loss', 'squared_loss', 'hinge_loss',
- 'smooth_hinge_loss')
+ 'smooth_hinge_loss', 'poisson_loss')
if options['loss_type'] not in supported_losses:
raise ValueError('Unsupported loss_type: ', options['loss_type'])
@@ -315,6 +317,7 @@ class SdcaModel(object):
"""Add operations to compute predictions by the model.
If logistic_loss is being used, predicted probabilities are returned.
+ If poisson_loss is being used, predictions are exponentiated.
Otherwise, (raw) linear predictions (w*x) are returned.
Args:
@@ -335,6 +338,10 @@ class SdcaModel(object):
# Convert logits to probability for logistic loss predictions.
with name_scope('sdca/logistic_prediction'):
result = math_ops.sigmoid(result)
+ elif self._options['loss_type'] == 'poisson_loss':
+ # Exponeniate the prediction for poisson loss predictions.
+ with name_scope('sdca/poisson_prediction'):
+ result = math_ops.exp(result)
return result
def _get_partitioned_update_ops(self,
@@ -624,6 +631,11 @@ class SdcaModel(object):
logits=predictions),
weights)) / math_ops.reduce_sum(weights)
+ if self._options['loss_type'] == 'poisson_loss':
+ return math_ops.reduce_sum(math_ops.multiply(
+ log_poisson_loss(targets=labels, log_input=predictions),
+ weights)) / math_ops.reduce_sum(weights)
+
if self._options['loss_type'] in ['hinge_loss', 'smooth_hinge_loss']:
# hinge_loss = max{0, 1 - y_i w*x} where y_i \in {-1, 1}. So, we need to
# first convert 0/1 labels into -1/1 labels.
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py
index a2d82cf800..553b116a3b 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable_test.py
@@ -30,7 +30,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
def testShardedMutableHashTable(self):
for num_shards in [1, 3, 10]:
- with self.test_session():
+ with self.cached_session():
default_val = -1
empty_key = 0
keys = constant_op.constant([11, 12, 13], dtypes.int64)
@@ -53,7 +53,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
def testShardedMutableHashTableVectors(self):
for num_shards in [1, 3, 10]:
- with self.test_session():
+ with self.cached_session():
default_val = [-0.1, 0.2]
empty_key = [0, 1]
keys = constant_op.constant([[11, 12], [13, 14], [15, 16]],
@@ -79,7 +79,7 @@ class ShardedMutableDenseHashTableTest(TensorFlowTestCase):
output.eval())
def testExportSharded(self):
- with self.test_session():
+ with self.cached_session():
empty_key = -2
default_val = -1
num_shards = 2
diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py
index 237a6812b7..51c4f68543 100644
--- a/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/ops/sparse_feature_column_test.py
@@ -36,13 +36,13 @@ class SparseFeatureColumnTest(TensorFlowTestCase):
self.assertTrue(isinstance(sfc.example_indices, ops.Tensor))
self.assertTrue(isinstance(sfc.feature_indices, ops.Tensor))
self.assertEqual(sfc.feature_values, None)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_example_indices, sfc.example_indices.eval())
self.assertAllEqual(expected_feature_indices, sfc.feature_indices.eval())
expected_feature_values = [1.0, 2.0, 3.0, 4.0]
sfc = SparseFeatureColumn([1, 1, 1, 2], [0, 1, 2, 0],
expected_feature_values)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(expected_feature_values, sfc.feature_values.eval())
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD
index 0091587bf7..f320b53d94 100644
--- a/tensorflow/contrib/lite/BUILD
+++ b/tensorflow/contrib/lite/BUILD
@@ -36,10 +36,10 @@ cc_library(
srcs = ["arena_planner.cc"],
hdrs = ["arena_planner.h"],
deps = [
- ":context",
":graph_info",
":memory_planner",
":simple_memory_arena",
+ "//tensorflow/contrib/lite/c:c_api_internal",
],
)
@@ -54,6 +54,7 @@ cc_test(
deps = [
":arena_planner",
"//tensorflow/contrib/lite/testing:util",
+ "//tensorflow/core:framework",
"//tensorflow/core:lib",
"@com_google_googletest//:gtest",
],
@@ -63,27 +64,27 @@ cc_test(
# TODO(aselle): Resolve problems preventing C99 usage.
cc_library(
name = "context",
- srcs = ["context.c"],
hdrs = ["context.h"],
+ deps = ["//tensorflow/contrib/lite/c:c_api_internal"],
)
cc_library(
name = "graph_info",
hdrs = ["graph_info.h"],
- deps = [":context"],
+ deps = ["//tensorflow/contrib/lite/c:c_api_internal"],
)
cc_library(
name = "memory_planner",
hdrs = ["memory_planner.h"],
- deps = [":context"],
+ deps = ["//tensorflow/contrib/lite/c:c_api_internal"],
)
cc_library(
name = "simple_memory_arena",
srcs = ["simple_memory_arena.cc"],
hdrs = ["simple_memory_arena.h"],
- deps = [":context"],
+ deps = ["//tensorflow/contrib/lite/c:c_api_internal"],
)
cc_library(
@@ -91,7 +92,7 @@ cc_library(
hdrs = [
"builtin_op_data.h",
],
- deps = [":context"],
+ deps = ["//tensorflow/contrib/lite/c:c_api_internal"],
)
cc_library(
@@ -121,12 +122,12 @@ cc_library(
name = "framework",
srcs = [
"allocation.cc",
- "error_reporter.cc",
"graph_info.cc",
"interpreter.cc",
"model.cc",
- "op_resolver.cc",
+ "mutable_op_resolver.cc",
"optional_debug_tools.cc",
+ "stderr_reporter.cc",
] + select({
"//tensorflow:android": [
"nnapi_delegate.cc",
@@ -149,9 +150,11 @@ cc_library(
"graph_info.h",
"interpreter.h",
"model.h",
+ "mutable_op_resolver.h",
"nnapi_delegate.h",
"op_resolver.h",
"optional_debug_tools.h",
+ "stderr_reporter.h",
],
copts = tflite_copts(),
linkopts = [
@@ -164,14 +167,14 @@ cc_library(
}),
deps = [
":arena_planner",
- ":builtin_op_data",
- ":context",
":graph_info",
":memory_planner",
":schema_fbs_version",
":simple_memory_arena",
":string",
":util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/core/api",
"//tensorflow/contrib/lite/kernels:eigen_support",
"//tensorflow/contrib/lite/kernels:gemm_support",
"//tensorflow/contrib/lite/nnapi:nnapi_lib",
@@ -210,6 +213,8 @@ cc_test(
deps = [
":framework",
":string_util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/core/api",
"//tensorflow/contrib/lite/kernels:builtin_ops",
"//tensorflow/contrib/lite/kernels:kernel_util",
"//tensorflow/contrib/lite/kernels/internal:tensor_utils",
@@ -259,6 +264,8 @@ cc_test(
],
deps = [
":framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/core/api",
"//tensorflow/contrib/lite/testing:util",
"@com_google_googletest//:gtest",
],
@@ -266,9 +273,9 @@ cc_test(
# Test OpResolver.
cc_test(
- name = "op_resolver_test",
+ name = "mutable_op_resolver_test",
size = "small",
- srcs = ["op_resolver_test.cc"],
+ srcs = ["mutable_op_resolver_test.cc"],
tags = ["no_oss"],
deps = [
":framework",
@@ -277,24 +284,12 @@ cc_test(
],
)
-# Test the C extension API code.
-cc_test(
- name = "context_test",
- size = "small",
- srcs = ["context_test.cc"],
- deps = [
- ":framework",
- "//tensorflow/contrib/lite/testing:util",
- "@com_google_googletest//:gtest",
- ],
-)
-
cc_library(
name = "util",
srcs = ["util.cc"],
hdrs = ["util.h"],
deps = [
- ":context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
],
)
@@ -304,7 +299,6 @@ cc_test(
srcs = ["util_test.cc"],
tags = ["no_oss"],
deps = [
- ":context",
":util",
"//tensorflow/contrib/lite/testing:util",
"@com_google_googletest//:gtest",
diff --git a/tensorflow/contrib/lite/RELEASE.md b/tensorflow/contrib/lite/RELEASE.md
deleted file mode 100644
index 8fd63d5cee..0000000000
--- a/tensorflow/contrib/lite/RELEASE.md
+++ /dev/null
@@ -1,8 +0,0 @@
-# Release 0.1.7
-
-* TensorFlow Lite 0.1.7 is based on tag `tflite-v0.1.7` (git commit
- fa1db5eb0da85b5baccc2a46d534fdeb3bb473d0).
-* To reproduce the iOS library, it's required to cherry pick git commit
- f1f1d5172fe5bfeaeb2cf657ffc43ba744187bee to fix a dependency issue.
-* The code is based on TensorFlow 1.8.0 release candidate and it's very close
- to TensorFlow 1.8.0 release.
diff --git a/tensorflow/contrib/lite/allocation.cc b/tensorflow/contrib/lite/allocation.cc
index 8946261814..21cb1832a7 100644
--- a/tensorflow/contrib/lite/allocation.cc
+++ b/tensorflow/contrib/lite/allocation.cc
@@ -23,8 +23,8 @@ limitations under the License.
#include <cstring>
#include <utility>
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/allocation.h b/tensorflow/contrib/lite/allocation.h
index 121f3d2646..182bc0977f 100644
--- a/tensorflow/contrib/lite/allocation.h
+++ b/tensorflow/contrib/lite/allocation.h
@@ -20,8 +20,8 @@ limitations under the License.
#include <cstdio>
#include <cstdlib>
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/simple_memory_arena.h"
#include "tensorflow/contrib/lite/string.h"
diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/contrib/lite/arena_planner.h
index 55003cf4e9..382577045b 100644
--- a/tensorflow/contrib/lite/arena_planner.h
+++ b/tensorflow/contrib/lite/arena_planner.h
@@ -18,7 +18,7 @@ limitations under the License.
#include <memory>
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/graph_info.h"
#include "tensorflow/contrib/lite/memory_planner.h"
#include "tensorflow/contrib/lite/simple_memory_arena.h"
@@ -37,8 +37,8 @@ struct AllocationInfo;
// each tensor needs to be allocated and deallocated, and preallocates all the
// necessary memory (the PlanAllocations phase). It then assigns portions of
// this memory buffer to each tensor (the ExecuteAllocations phase). Tensors may
-// share some of the buffer if a tensor B is to be allocated after another tensor
-// A has been deallocated.
+// share some of the buffer if a tensor B is to be allocated after another
+// tensor A has been deallocated.
//
// If dynamic tensors are used the planning steps can be repeated during model
// execution. Since dynamic tensors don't have sizes until after the
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index 45a0ded7eb..9317e2bb6e 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -49,6 +49,9 @@ def tflite_linkopts_unstripped():
Returns:
a select object with proper linkopts
"""
+
+ # In case you wonder why there's no --icf is because the gains were
+ # negligible, and created potential compatibility problems.
return select({
"//tensorflow:android": [
"-Wl,--no-export-dynamic", # Only inc syms referenced by dynamic obj.
@@ -56,12 +59,7 @@ def tflite_linkopts_unstripped():
"-Wl,--gc-sections", # Eliminate unused code and data.
"-Wl,--as-needed", # Don't link unused libs.
],
- "//tensorflow:darwin": [],
- "//tensorflow/contrib/lite:mips": [],
- "//tensorflow/contrib/lite:mips64": [],
- "//conditions:default": [
- "-Wl,--icf=all", # Identical code folding.
- ],
+ "//conditions:default": [],
})
def tflite_jni_linkopts_unstripped():
@@ -73,17 +71,15 @@ def tflite_jni_linkopts_unstripped():
Returns:
a select object with proper linkopts
"""
+
+ # In case you wonder why there's no --icf is because the gains were
+ # negligible, and created potential compatibility problems.
return select({
"//tensorflow:android": [
"-Wl,--gc-sections", # Eliminate unused code and data.
"-Wl,--as-needed", # Don't link unused libs.
],
- "//tensorflow:darwin": [],
- "//tensorflow/contrib/lite:mips": [],
- "//tensorflow/contrib/lite:mips64": [],
- "//conditions:default": [
- "-Wl,--icf=all", # Identical code folding.
- ],
+ "//conditions:default": [],
})
def tflite_linkopts():
@@ -235,6 +231,7 @@ def generated_test_models():
"exp",
"expand_dims",
"floor",
+ "floor_div",
"fully_connected",
"fused_batch_norm",
"gather",
@@ -266,6 +263,7 @@ def generated_test_models():
"padv2",
"prelu",
"pow",
+ "reduce_any",
"reduce_max",
"reduce_min",
"reduce_prod",
@@ -293,6 +291,7 @@ def generated_test_models():
"topk",
"transpose",
#"transpose_conv", # disabled due to b/111213074
+ "unpack",
"where",
]
diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h
index 70178b2faa..30901bd0fa 100644
--- a/tensorflow/contrib/lite/builtin_op_data.h
+++ b/tensorflow/contrib/lite/builtin_op_data.h
@@ -12,282 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// Compatibility shim for new location of interface definitions.
+
#ifndef TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_
#define TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_
-#include <stdint.h>
-
-#include "tensorflow/contrib/lite/context.h"
-
-#ifdef __cplusplus
-extern "C" {
-#endif // __cplusplus
-
-// TODO(aselle): Consider using "if this then that" for testing.
-
-// Possible padding types (for convolutions)
-typedef enum {
- kTfLitePaddingUnknown = 0,
- kTfLitePaddingSame,
- kTfLitePaddingValid,
-} TfLitePadding;
-
-typedef struct {
- int width;
- int height;
-} TfLitePaddingValues;
-
-// Possible fused activation functions.
-// TODO(aselle): rename to TfLiteActivation
-typedef enum {
- kTfLiteActNone = 0,
- kTfLiteActRelu,
- kTfLiteActRelu1,
- kTfLiteActRelu6,
- kTfLiteActTanh,
- kTfLiteActSignBit,
- kTfLiteActSigmoid,
-} TfLiteFusedActivation;
-
-typedef struct {
- TfLitePadding padding;
- int stride_width;
- int stride_height;
- int dilation_width_factor;
- int dilation_height_factor;
- TfLiteFusedActivation activation;
-} TfLiteConvParams;
-
-typedef struct {
- TfLitePadding padding;
- int stride_width;
- int stride_height;
- int filter_width;
- int filter_height;
- TfLiteFusedActivation activation;
- struct {
- TfLitePaddingValues padding;
- } computed;
-} TfLitePoolParams;
-
-typedef struct {
- TfLitePadding padding;
- int stride_width;
- int stride_height;
- int depth_multiplier;
- TfLiteFusedActivation activation;
-} TfLiteDepthwiseConvParams;
-
-typedef struct {
- int rank;
- TfLiteFusedActivation activation;
-} TfLiteSVDFParams;
-
-typedef struct {
- TfLiteFusedActivation activation;
-} TfLiteRNNParams;
-
-typedef struct {
- bool time_major;
- TfLiteFusedActivation activation;
-} TfLiteSequenceRNNParams;
-
-typedef enum {
- kTfLiteFullyConnectedWeightsFormatDefault = 0,
- kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1,
-} TfLiteFullyConnectedWeightsFormat;
-
-typedef struct {
- // Parameters for FullyConnected version 1 or above.
- TfLiteFusedActivation activation;
-
- // Parameters for FullyConnected version 2 or above.
- TfLiteFullyConnectedWeightsFormat weights_format;
-} TfLiteFullyConnectedParams;
-
-typedef enum {
- kTfLiteLshProjectionUnknown = 0,
- kTfLiteLshProjectionSparse = 1,
- kTfLiteLshProjectionDense = 2,
-} TfLiteLSHProjectionType;
-
-typedef struct {
- TfLiteLSHProjectionType type;
-} TfLiteLSHProjectionParams;
-
-typedef struct {
- float beta;
-} TfLiteSoftmaxParams;
-
-typedef struct {
- int axis;
- TfLiteFusedActivation activation;
-} TfLiteConcatenationParams;
-
-typedef struct {
- TfLiteFusedActivation activation;
-} TfLiteAddParams;
-
-typedef struct {
-} TfLiteSpaceToBatchNDParams;
-
-typedef struct {
-} TfLiteBatchToSpaceNDParams;
-
-typedef struct {
- TfLiteFusedActivation activation;
-} TfLiteMulParams;
-
-typedef struct {
- TfLiteFusedActivation activation;
-} TfLiteSubParams;
-
-typedef struct {
- TfLiteFusedActivation activation;
-} TfLiteDivParams;
-
-typedef struct {
- TfLiteFusedActivation activation;
-} TfLiteL2NormParams;
-
-typedef struct {
- int radius;
- float bias;
- float alpha;
- float beta;
-} TfLiteLocalResponseNormParams;
-
-typedef enum {
- kTfLiteLSTMFullKernel = 0,
- kTfLiteLSTMBasicKernel
-} TfLiteLSTMKernelType;
-
-typedef struct {
- // Parameters for LSTM version 1.
- TfLiteFusedActivation activation;
- float cell_clip;
- float proj_clip;
-
- // Parameters for LSTM version 2.
- // kTfLiteLSTMBasicKernel is only supported in version 2 or above.
- TfLiteLSTMKernelType kernel_type;
-} TfLiteLSTMParams;
-
-typedef struct {
- bool align_corners;
-} TfLiteResizeBilinearParams;
-
-typedef struct {
-} TfLitePadParams;
-
-typedef struct {
-} TfLitePadV2Params;
-
-typedef struct {
- // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
- // For now we will fix the maximum possible number of dimensions.
- int shape[8];
- int num_dimensions;
-} TfLiteReshapeParams;
-
-typedef struct {
- int ngram_size;
- int max_skip_size;
- bool include_all_ngrams;
-} TfLiteSkipGramParams;
-
-typedef struct {
- int block_size;
-} TfLiteSpaceToDepthParams;
-
-typedef struct {
- TfLiteType in_data_type;
- TfLiteType out_data_type;
-} TfLiteCastParams;
-
-typedef enum {
- kTfLiteCombinerTypeSum = 0,
- kTfLiteCombinerTypeMean = 1,
- kTfLiteCombinerTypeSqrtn = 2,
-} TfLiteCombinerType;
-
-typedef struct {
- TfLiteCombinerType combiner;
-} TfLiteEmbeddingLookupSparseParams;
-
-typedef struct {
- int axis;
-} TfLiteGatherParams;
-
-typedef struct {
-} TfLiteTransposeParams;
-
-typedef struct {
- bool keep_dims;
-} TfLiteReducerParams;
-
-typedef struct {
- int num_splits;
-} TfLiteSplitParams;
-
-typedef struct {
- // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
- // For now we will fix the maximum possible number of dimensions.
- int squeeze_dims[8];
- int num_squeeze_dims;
-} TfLiteSqueezeParams;
-
-typedef struct {
- int begin_mask;
- int end_mask;
- int ellipsis_mask;
- int new_axis_mask;
- int shrink_axis_mask;
-} TfLiteStridedSliceParams;
-
-typedef struct {
- TfLiteType output_type;
-} TfLiteArgMaxParams;
-
-typedef struct {
- TfLiteType output_type;
-} TfLiteArgMinParams;
-
-typedef struct {
- TfLitePadding padding;
- int stride_width;
- int stride_height;
-} TfLiteTransposeConvParams;
-
-typedef struct {
- bool validate_indices;
-} TfLiteSparseToDenseParams;
-
-typedef struct {
- TfLiteType out_type;
-} TfLiteShapeParams;
-
-typedef struct {
- // Parameters supported by version 1:
- float min;
- float max;
- int num_bits;
-
- // Parameters supported by version 2:
- bool narrow_range;
-} TfLiteFakeQuantParams;
-
-typedef struct {
- int values_count;
- int axis;
-} TfLitePackParams;
-
-typedef struct {
- int axis;
-} TfLiteOneHotParams;
-
-#ifdef __cplusplus
-} // extern "C"
-#endif // __cplusplus
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#endif // TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_
diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h
index 706f64a84a..9cf4bea73e 100644
--- a/tensorflow/contrib/lite/builtin_ops.h
+++ b/tensorflow/contrib/lite/builtin_ops.h
@@ -115,6 +115,8 @@ typedef enum {
kTfLiteBuiltinLogicalNot = 87,
kTfLiteBuiltinUnpack = 88,
kTfLiteBuiltinReduceMin = 89,
+ kTfLiteBuiltinFloorDiv = 90,
+ kTfLiteBuiltinReduceAny = 91,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/contrib/lite/c/BUILD b/tensorflow/contrib/lite/c/BUILD
new file mode 100644
index 0000000000..663eb63cad
--- /dev/null
+++ b/tensorflow/contrib/lite/c/BUILD
@@ -0,0 +1,39 @@
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+cc_library(
+ name = "c_api_internal",
+ srcs = ["c_api_internal.c"],
+ hdrs = [
+ "builtin_op_data.h",
+ "c_api_internal.h",
+ ],
+ visibility = [
+ "//tensorflow/contrib/lite:__subpackages__",
+ ],
+)
+
+# Test the C extension API code.
+cc_test(
+ name = "c_api_internal_test",
+ size = "small",
+ srcs = ["c_api_internal_test.cc"],
+ deps = [
+ ":c_api_internal",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_test(
+ name = "builtin_op_data_test",
+ size = "small",
+ srcs = ["builtin_op_data_test.cc"],
+ copts = ["-Wno-unused-variable"],
+ deps = [
+ ":c_api_internal",
+ "@com_google_googletest//:gtest",
+ ],
+)
diff --git a/tensorflow/contrib/lite/c/builtin_op_data.h b/tensorflow/contrib/lite/c/builtin_op_data.h
new file mode 100644
index 0000000000..fa43e6a024
--- /dev/null
+++ b/tensorflow/contrib/lite/c/builtin_op_data.h
@@ -0,0 +1,298 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_
+#define TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_
+
+#include <stdint.h>
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+// TODO(aselle): Consider using "if this then that" for testing.
+
+// Possible padding types (for convolutions)
+typedef enum {
+ kTfLitePaddingUnknown = 0,
+ kTfLitePaddingSame,
+ kTfLitePaddingValid,
+} TfLitePadding;
+
+typedef struct {
+ int width;
+ int height;
+} TfLitePaddingValues;
+
+// Possible fused activation functions.
+// TODO(aselle): rename to TfLiteActivation
+typedef enum {
+ kTfLiteActNone = 0,
+ kTfLiteActRelu,
+ kTfLiteActRelu1,
+ kTfLiteActRelu6,
+ kTfLiteActTanh,
+ kTfLiteActSignBit,
+ kTfLiteActSigmoid,
+} TfLiteFusedActivation;
+
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+ int dilation_width_factor;
+ int dilation_height_factor;
+ TfLiteFusedActivation activation;
+} TfLiteConvParams;
+
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+ int filter_width;
+ int filter_height;
+ TfLiteFusedActivation activation;
+ struct {
+ TfLitePaddingValues padding;
+ } computed;
+} TfLitePoolParams;
+
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+ int depth_multiplier;
+ TfLiteFusedActivation activation;
+} TfLiteDepthwiseConvParams;
+
+typedef struct {
+ int rank;
+ TfLiteFusedActivation activation;
+} TfLiteSVDFParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteRNNParams;
+
+typedef struct {
+ bool time_major;
+ TfLiteFusedActivation activation;
+} TfLiteSequenceRNNParams;
+
+typedef enum {
+ kTfLiteFullyConnectedWeightsFormatDefault = 0,
+ kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1,
+} TfLiteFullyConnectedWeightsFormat;
+
+typedef struct {
+ // Parameters for FullyConnected version 1 or above.
+ TfLiteFusedActivation activation;
+
+ // Parameters for FullyConnected version 2 or above.
+ TfLiteFullyConnectedWeightsFormat weights_format;
+} TfLiteFullyConnectedParams;
+
+typedef enum {
+ kTfLiteLshProjectionUnknown = 0,
+ kTfLiteLshProjectionSparse = 1,
+ kTfLiteLshProjectionDense = 2,
+} TfLiteLSHProjectionType;
+
+typedef struct {
+ TfLiteLSHProjectionType type;
+} TfLiteLSHProjectionParams;
+
+typedef struct {
+ float beta;
+} TfLiteSoftmaxParams;
+
+typedef struct {
+ int axis;
+ TfLiteFusedActivation activation;
+} TfLiteConcatenationParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteAddParams;
+
+typedef struct {
+} TfLiteSpaceToBatchNDParams;
+
+typedef struct {
+} TfLiteBatchToSpaceNDParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteMulParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteSubParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteDivParams;
+
+typedef struct {
+ TfLiteFusedActivation activation;
+} TfLiteL2NormParams;
+
+typedef struct {
+ int radius;
+ float bias;
+ float alpha;
+ float beta;
+} TfLiteLocalResponseNormParams;
+
+typedef enum {
+ kTfLiteLSTMFullKernel = 0,
+ kTfLiteLSTMBasicKernel
+} TfLiteLSTMKernelType;
+
+typedef struct {
+ // Parameters for LSTM version 1.
+ TfLiteFusedActivation activation;
+ float cell_clip;
+ float proj_clip;
+
+ // Parameters for LSTM version 2.
+ // kTfLiteLSTMBasicKernel is only supported in version 2 or above.
+ TfLiteLSTMKernelType kernel_type;
+} TfLiteLSTMParams;
+
+typedef struct {
+ bool align_corners;
+} TfLiteResizeBilinearParams;
+
+typedef struct {
+} TfLitePadParams;
+
+typedef struct {
+} TfLitePadV2Params;
+
+typedef struct {
+ // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
+ // For now we will fix the maximum possible number of dimensions.
+ int shape[8];
+ int num_dimensions;
+} TfLiteReshapeParams;
+
+typedef struct {
+ int ngram_size;
+ int max_skip_size;
+ bool include_all_ngrams;
+} TfLiteSkipGramParams;
+
+typedef struct {
+ int block_size;
+} TfLiteSpaceToDepthParams;
+
+typedef struct {
+ TfLiteType in_data_type;
+ TfLiteType out_data_type;
+} TfLiteCastParams;
+
+typedef enum {
+ kTfLiteCombinerTypeSum = 0,
+ kTfLiteCombinerTypeMean = 1,
+ kTfLiteCombinerTypeSqrtn = 2,
+} TfLiteCombinerType;
+
+typedef struct {
+ TfLiteCombinerType combiner;
+} TfLiteEmbeddingLookupSparseParams;
+
+typedef struct {
+ int axis;
+} TfLiteGatherParams;
+
+typedef struct {
+} TfLiteTransposeParams;
+
+typedef struct {
+ bool keep_dims;
+} TfLiteReducerParams;
+
+typedef struct {
+ int num_splits;
+} TfLiteSplitParams;
+
+typedef struct {
+ // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
+ // For now we will fix the maximum possible number of dimensions.
+ int squeeze_dims[8];
+ int num_squeeze_dims;
+} TfLiteSqueezeParams;
+
+typedef struct {
+ int begin_mask;
+ int end_mask;
+ int ellipsis_mask;
+ int new_axis_mask;
+ int shrink_axis_mask;
+} TfLiteStridedSliceParams;
+
+typedef struct {
+ TfLiteType output_type;
+} TfLiteArgMaxParams;
+
+typedef struct {
+ TfLiteType output_type;
+} TfLiteArgMinParams;
+
+typedef struct {
+ TfLitePadding padding;
+ int stride_width;
+ int stride_height;
+} TfLiteTransposeConvParams;
+
+typedef struct {
+ bool validate_indices;
+} TfLiteSparseToDenseParams;
+
+typedef struct {
+ TfLiteType out_type;
+} TfLiteShapeParams;
+
+typedef struct {
+ // Parameters supported by version 1:
+ float min;
+ float max;
+ int num_bits;
+
+ // Parameters supported by version 2:
+ bool narrow_range;
+} TfLiteFakeQuantParams;
+
+typedef struct {
+ int values_count;
+ int axis;
+} TfLitePackParams;
+
+typedef struct {
+ int axis;
+} TfLiteOneHotParams;
+
+typedef struct {
+ int num;
+ int axis;
+} TfLiteUnpackParams;
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+
+#endif // TENSORFLOW_CONTRIB_LITE_C_BUILTIN_OP_DATA_H_
diff --git a/tensorflow/contrib/lite/c/builtin_op_data_test.cc b/tensorflow/contrib/lite/c/builtin_op_data_test.cc
new file mode 100644
index 0000000000..4d0ba75e68
--- /dev/null
+++ b/tensorflow/contrib/lite/c/builtin_op_data_test.cc
@@ -0,0 +1,83 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include <gtest/gtest.h>
+
+namespace tflite {
+
+// Builtin op data is just a set of data definitions, so the only meaningful
+// test we can run is whether we can create the structs we expect to find.
+// Testing each struct's members might be possible, but it seems unnecessary
+// until we've locked down the API. The build rule has copts set to ignore the
+// unused variable warning, since this is just a compilation test.
+TEST(IntArray, CanCompileStructs) {
+ TfLitePadding padding = kTfLitePaddingSame;
+ TfLitePaddingValues padding_values;
+ TfLiteFusedActivation fused_activation = kTfLiteActRelu;
+ TfLiteConvParams conv_params;
+ TfLitePoolParams pool_params;
+ TfLiteDepthwiseConvParams depthwise_conv_params;
+ TfLiteSVDFParams svdf_params;
+ TfLiteRNNParams rnn_params;
+ TfLiteSequenceRNNParams sequence_rnn_params;
+ TfLiteFullyConnectedWeightsFormat fully_connected_weights_format =
+ kTfLiteFullyConnectedWeightsFormatDefault;
+ TfLiteFullyConnectedParams fully_connected_params;
+ TfLiteLSHProjectionType projection_type = kTfLiteLshProjectionDense;
+ TfLiteLSHProjectionParams projection_params;
+ TfLiteSoftmaxParams softmax_params;
+ TfLiteConcatenationParams concatenation_params;
+ TfLiteAddParams add_params;
+ TfLiteSpaceToBatchNDParams space_to_batch_nd_params;
+ TfLiteBatchToSpaceNDParams batch_to_space_nd_params;
+ TfLiteMulParams mul_params;
+ TfLiteSubParams sub_params;
+ TfLiteDivParams div_params;
+ TfLiteL2NormParams l2_norm_params;
+ TfLiteLocalResponseNormParams local_response_norm_params;
+ TfLiteLSTMKernelType lstm_kernel_type = kTfLiteLSTMBasicKernel;
+ TfLiteLSTMParams lstm_params;
+ TfLiteResizeBilinearParams resize_bilinear_params;
+ TfLitePadParams pad_params;
+ TfLitePadV2Params pad_v2_params;
+ TfLiteReshapeParams reshape_params;
+ TfLiteSkipGramParams skip_gram_params;
+ TfLiteSpaceToDepthParams space_to_depth_params;
+ TfLiteCastParams cast_params;
+ TfLiteCombinerType combiner_type = kTfLiteCombinerTypeSqrtn;
+ TfLiteEmbeddingLookupSparseParams lookup_sparse_params;
+ TfLiteGatherParams gather_params;
+ TfLiteTransposeParams transpose_params;
+ TfLiteReducerParams reducer_params;
+ TfLiteSplitParams split_params;
+ TfLiteSqueezeParams squeeze_params;
+ TfLiteStridedSliceParams strided_slice_params;
+ TfLiteArgMaxParams arg_max_params;
+ TfLiteArgMinParams arg_min_params;
+ TfLiteTransposeConvParams transpose_conv_params;
+ TfLiteSparseToDenseParams sparse_to_dense_params;
+ TfLiteShapeParams shape_params;
+ TfLiteFakeQuantParams fake_quant_params;
+ TfLitePackParams pack_params;
+ TfLiteOneHotParams one_hot_params;
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/context.c b/tensorflow/contrib/lite/c/c_api_internal.c
index 7f2aa316f4..1846bad4b7 100644
--- a/tensorflow/contrib/lite/context.c
+++ b/tensorflow/contrib/lite/c/c_api_internal.c
@@ -13,8 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include <stdio.h>
+#include <stdlib.h>
#include <string.h>
int TfLiteIntArrayGetSizeInBytes(int size) {
@@ -76,7 +77,8 @@ void TfLiteTensorFree(TfLiteTensor* t) {
void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
TfLiteQuantizationParams quantization, char* buffer,
size_t size, TfLiteAllocationType allocation_type,
- const void* allocation, bool is_variable, TfLiteTensor* tensor) {
+ const void* allocation, bool is_variable,
+ TfLiteTensor* tensor) {
TfLiteTensorFree(tensor);
tensor->type = type;
tensor->name = name;
diff --git a/tensorflow/contrib/lite/c/c_api_internal.h b/tensorflow/contrib/lite/c/c_api_internal.h
new file mode 100644
index 0000000000..48df68a654
--- /dev/null
+++ b/tensorflow/contrib/lite/c/c_api_internal.h
@@ -0,0 +1,491 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// This file defines a C API for implementing operations in tflite.
+// These operations can be defined using c++ but the interface between
+// the interpreter and the operations are C.
+//
+// Summary of abstractions
+// TF_LITE_ENSURE - Self-sufficient error checking
+// TfLiteStatus - Status reporting
+// TfLiteIntArray - stores tensor shapes (dims),
+// TfLiteContext - allows an op to access the tensors
+// TfLiteTensor - tensor (a multidimensional array)
+// TfLiteNode - a single node or operation
+// TfLiteRegistration - the implementation of a conceptual operation.
+//
+// Some abstractions in this file are created and managed by Interpreter.
+#ifndef TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_
+#define TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_
+
+#include <stdbool.h>
+#include <stddef.h>
+#include <stdint.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus;
+
+// The list of external context types known to TF Lite. This list exists solely
+// to avoid conflicts and to ensure ops can share the external contexts they
+// need. Access to the external contexts is controled by one of the
+// corresponding support files.
+typedef enum {
+ kTfLiteEigenContext = 0, // include eigen_support.h to use.
+ kTfLiteGemmLowpContext = 1, // include gemm_support.h to use.
+ kTfLiteEdgeTpuContext = 2, // Placeholder for Edge TPU support.
+ kTfLiteMaxExternalContexts = 3
+} TfLiteExternalContextType;
+
+// An external context is a collection of information unrelated to the TF Lite
+// framework, but useful to a subset of the ops. TF Lite knows very little
+// about about the actual contexts, but it keeps a list of them, and is able to
+// refresh them if configurations like the number of recommended threads
+// change.
+typedef struct {
+ TfLiteExternalContextType type;
+ TfLiteStatus (*Refresh)(struct TfLiteContext* context);
+} TfLiteExternalContext;
+
+// Forward declare so GetNode can use this is in Context.
+typedef struct _TfLiteRegistration TfLiteRegistration;
+typedef struct _TfLiteDelegate TfLiteDelegate;
+
+#define kOptionalTensor (-1)
+
+// Fixed size list of integers. Used for dimensions and inputs/outputs tensor
+// indices
+typedef struct {
+ int size;
+// gcc 6.1+ have a bug where flexible members aren't properly handled
+// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c
+#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \
+ __GNUC_MINOR__ >= 1
+ int data[0];
+#else
+ int data[];
+#endif
+} TfLiteIntArray;
+
+// Given the size (number of elements) in a TfLiteIntArray, calculate its size
+// in bytes.
+int TfLiteIntArrayGetSizeInBytes(int size);
+
+// Create a array of a given `size` (uninitialized entries).
+// This returns a pointer, that you must free using TfLiteIntArrayFree().
+TfLiteIntArray* TfLiteIntArrayCreate(int size);
+
+// Check if two tensors are equal. Returns 1 if they are equal, 0 otherwise.
+int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b);
+
+// Create a copy of an array passed as `src`.
+// You are expected to free memory with TfLiteIntArrayFree
+TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src);
+
+// Free memory of array `v`.
+void TfLiteIntArrayFree(TfLiteIntArray* v);
+
+// Since we must not depend on any libraries, define a minimal subset of
+// error macros while avoiding names that have pre-conceived meanings like
+// assert and check.
+
+// Check whether value is true, and if not return kTfLiteError from
+// the current function (and report the error string msg).
+#define TF_LITE_ENSURE_MSG(context, value, msg) \
+ do { \
+ if (!(value)) { \
+ (context)->ReportError((context), __FILE__ " " msg); \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+// Check whether the value `a` is true, and if not return kTfLiteError from
+// the current function, while also reporting the location of the error.
+#define TF_LITE_ENSURE(context, a) \
+ do { \
+ if (!(a)) { \
+ (context)->ReportError((context), "%s:%d %s was not true.", __FILE__, \
+ __LINE__, #a); \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+#define TF_LITE_ENSURE_STATUS(a) \
+ do { \
+ if ((a) != kTfLiteOk) { \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+// Check whether the value `a == b` is true, and if not return kTfLiteError from
+// the current function, while also reporting the location of the error.
+// `a` and `b` may be evaluated more than once, so no side effects or
+// extremely expensive computations should be done.
+#define TF_LITE_ENSURE_EQ(context, a, b) \
+ do { \
+ if ((a) != (b)) { \
+ (context)->ReportError((context), "%s:%d %s != %s (%d != %d)", __FILE__, \
+ __LINE__, #a, #b, (a), (b)); \
+ return kTfLiteError; \
+ } \
+ } while (0)
+
+#define TF_LITE_ENSURE_OK(context, status) \
+ do { \
+ if ((status) != kTfLiteOk) { \
+ return status; \
+ } \
+ } while (0)
+
+// Single-precision complex data type compatible with the C99 definition.
+typedef struct {
+ float re, im; // real and imaginary parts, respectively.
+} TfLiteComplex64;
+
+// Types supported by tensor
+typedef enum {
+ kTfLiteNoType = 0,
+ kTfLiteFloat32 = 1,
+ kTfLiteInt32 = 2,
+ kTfLiteUInt8 = 3,
+ kTfLiteInt64 = 4,
+ kTfLiteString = 5,
+ kTfLiteBool = 6,
+ kTfLiteInt16 = 7,
+ kTfLiteComplex64 = 8,
+} TfLiteType;
+
+// Parameters for asymmetric quantization. Quantized values can be converted
+// back to float using:
+// real_value = scale * (quantized_value - zero_point);
+typedef struct {
+ float scale;
+ int32_t zero_point;
+} TfLiteQuantizationParams;
+
+// A union of pointers that points to memory for a given tensor.
+typedef union {
+ int* i32;
+ int64_t* i64;
+ float* f;
+ char* raw;
+ const char* raw_const;
+ uint8_t* uint8;
+ bool* b;
+ int16_t* i16;
+ TfLiteComplex64* c64;
+} TfLitePtrUnion;
+
+// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped
+// data (or data externally allocated). kTfLiteArenaRw is arena allocated
+// data. kTfLiteDynamic is for tensors that are allocated during evaluation.
+typedef enum {
+ kTfLiteMemNone = 0,
+ kTfLiteMmapRo,
+ kTfLiteArenaRw,
+ kTfLiteArenaRwPersistent,
+ kTfLiteDynamic,
+} TfLiteAllocationType;
+
+// The delegates should use zero or positive integers to represent handles.
+// -1 is reserved from unallocated status.
+typedef int TfLiteBufferHandle;
+const TfLiteBufferHandle kTfLiteNullBufferHandle = -1;
+
+// An tensor in the interpreter system which is a wrapper around a buffer of
+// data including a dimensionality (or NULL if not currently defined).
+typedef struct {
+ // The data type specification for data stored in `data`. This affects
+ // what member of `data` union should be used.
+ TfLiteType type;
+ // A union of data pointers. The appropriate type should be used for a typed
+ // tensor based on `type`.
+ TfLitePtrUnion data;
+ // A pointer to a structure representing the dimensionality interpretation
+ // that the buffer should have. NOTE: the product of elements of `dims`
+ // and the element datatype size should be equal to `bytes` below.
+ TfLiteIntArray* dims;
+ // Quantization information.
+ TfLiteQuantizationParams params;
+ // How memory is mapped
+ // kTfLiteMmapRo: Memory mapped read only.
+ // i.e. weights
+ // kTfLiteArenaRw: Arena allocated read write memory
+ // (i.e. temporaries, outputs).
+ TfLiteAllocationType allocation_type;
+ // The number of bytes required to store the data of this Tensor. I.e.
+ // (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if
+ // type is kTfLiteFloat32 and dims = {3, 2} then
+ // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
+ size_t bytes;
+
+ // An opaque pointer to a tflite::MMapAllocation
+ const void* allocation;
+
+ // Null-terminated name of this tensor.
+ const char* name;
+
+ // The delegate which knows how to handle `buffer_handle`.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteDelegate* delegate;
+
+ // An integer buffer handle that can be handled by `delegate`.
+ // The value is valid only when delegate is not null.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteBufferHandle buffer_handle;
+
+ // If the delegate uses its own buffer (e.g. GPU memory), the delegate is
+ // responsible to set data_is_stale to true.
+ // `delegate->CopyFromBufferHandle` can be called to copy the data from
+ // delegate buffer.
+ // WARNING: This is an // experimental interface that is subject to change.
+ bool data_is_stale;
+
+ // True if the tensor is a variable.
+ bool is_variable;
+} TfLiteTensor;
+
+// Free data memory of tensor `t`;
+void TfLiteTensorDataFree(TfLiteTensor* t);
+
+// Free memory of tensor `t`;
+void TfLiteTensorFree(TfLiteTensor* t);
+
+// Set all of a tensor's fields (and free any previously allocated data).
+void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
+ TfLiteQuantizationParams quantization, char* buffer,
+ size_t size, TfLiteAllocationType allocation_type,
+ const void* allocation, bool is_variable,
+ TfLiteTensor* tensor);
+
+// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
+// types other than kTfLiteDynamic will be ignored.
+void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
+
+// A structure representing an instance of a node.
+// This structure only exhibits the inputs, outputs and user defined data, not
+// other features like the type.
+typedef struct {
+ // Inputs to this node expressed as indices into the simulator's tensors.
+ TfLiteIntArray* inputs;
+
+ // Outputs to this node expressed as indices into the simulator's tensors.
+ TfLiteIntArray* outputs;
+
+ // Temporary tensors uses during the computations. This usually contains no
+ // tensors, but ops are allowed to change that if they need scratch space of
+ // any sort.
+ TfLiteIntArray* temporaries;
+
+ // Opaque data provided by the node implementer through `Registration.init`.
+ void* user_data;
+
+ // Opaque data provided to the node if the node is a builtin. This is usually
+ // a structure defined in builtin_op_data.h
+ void* builtin_data;
+
+ // Custom initial data. This is the opaque data provided in the flatbuffer.
+ // WARNING: This is an experimental interface that is subject to change.
+ const void* custom_initial_data;
+ int custom_initial_data_size;
+
+ // The pointer to the delegate. This is non-null only when the node is
+ // created by calling `interpreter.ModifyGraphWithDelegate`.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteDelegate* delegate;
+} TfLiteNode;
+
+typedef struct TfLiteContext {
+ // Number of tensors in the context.
+ size_t tensors_size;
+
+ // The execution plan contains a list of the node indices in execution
+ // order. execution_plan->size is the current number of nodes. And,
+ // execution_plan->data[0] is the first node that needs to be run.
+ // TfLiteDelegates can traverse the current execution plan by iterating
+ // through each member of this array and using GetNodeAndRegistration() to
+ // access details about a node. i.e.
+ // TfLiteIntArray* execution_plan;
+ // TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan));
+ // for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) {
+ // int node_index = execution_plan->data[exec_index];
+ // TfLiteNode* node;
+ // TfLiteRegistration* reg;
+ // context->GetNodeAndRegistration(context, node_index, &node, &reg);
+ // }
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context,
+ TfLiteIntArray** execution_plan);
+
+ // An array of tensors in the interpreter context (of length `tensors_size`)
+ TfLiteTensor* tensors;
+
+ // opaque full context ptr (an opaque c++ data structure)
+ void* impl_;
+
+ // Request memory pointer be resized. Updates dimensions on the tensor.
+ // NOTE: ResizeTensor takes ownership of newSize.
+ TfLiteStatus (*ResizeTensor)(struct TfLiteContext*, TfLiteTensor* tensor,
+ TfLiteIntArray* new_size);
+ // Request that a error be reported with format string msg.
+ void (*ReportError)(struct TfLiteContext*, const char* msg, ...);
+
+ // Add `tensors_to_add` tensors, preserving pre-existing Tensor entries. If
+ // non-null, the value pointed to by `first_new_tensor_index` will be set to
+ // the index of the first new tensor.
+ TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add,
+ int* first_new_tensor_index);
+
+ // Get a Tensor node by node_index.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteStatus (*GetNodeAndRegistration)(struct TfLiteContext*, int node_index,
+ TfLiteNode** node,
+ TfLiteRegistration** registration);
+
+ // Replace ops with one or more stub delegate operations. This function
+ // does not take ownership of `nodes_to_replace`.
+ TfLiteStatus (*ReplaceSubgraphsWithDelegateKernels)(
+ struct TfLiteContext*, TfLiteRegistration registration,
+ const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate);
+
+ // Number of threads that are recommended to subsystems like gemmlowp and
+ // eigen.
+ int recommended_num_threads;
+
+ // Access external contexts by type.
+ // WARNING: This is an experimental interface that is subject to change.
+ TfLiteExternalContext* (*GetExternalContext)(struct TfLiteContext*,
+ TfLiteExternalContextType);
+ // Set the value of a external context. Does not take ownership of the
+ // pointer.
+ // WARNING: This is an experimental interface that is subject to change.
+ void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType,
+ TfLiteExternalContext*);
+} TfLiteContext;
+
+typedef struct _TfLiteRegistration {
+ // Initializes the op from serialized data.
+ // If a built-in op:
+ // `buffer` is the op's params data (TfLiteLSTMParams*).
+ // `length` is zero.
+ // If custom op:
+ // `buffer` is the op's `custom_options`.
+ // `length` is the size of the buffer.
+ //
+ // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer
+ // or an instance of a struct).
+ //
+ // The returned pointer will be stored with the node in the `user_data` field,
+ // accessible within prepare and invoke functions below.
+ // NOTE: if the data is already in the desired format, simply implement this
+ // function to return `nullptr` and implement the free function to be a no-op.
+ void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
+
+ // The pointer `buffer` is the data previously returned by an init invocation.
+ void (*free)(TfLiteContext* context, void* buffer);
+
+ // prepare is called when the inputs this node depends on have been resized.
+ // context->ResizeTensor() can be called to request output tensors to be
+ // resized.
+ //
+ // Returns kTfLiteOk on success.
+ TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);
+
+ // Execute the node (should read node->inputs and output to node->outputs).
+ // Returns kTfLiteOk on success.
+ TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
+
+ // profiling_string is called during summarization of profiling information
+ // in order to group executions together. Providing a value here will cause a
+ // given op to appear multiple times is the profiling report. This is
+ // particularly useful for custom ops that can perform significantly
+ // different calculations depending on their `user-data`.
+ const char* (*profiling_string)(const TfLiteContext* context,
+ const TfLiteNode* node);
+
+ // Builtin codes. If this kernel refers to a builtin this is the code
+ // of the builtin. This is so we can do marshaling to other frameworks like
+ // NN API.
+ // Note: It is the responsibility of the registration binder to set this
+ // properly.
+ int32_t builtin_code;
+
+ // Custom op name. If the op is a builtin, this will be null.
+ // Note: It is the responsibility of the registration binder to set this
+ // properly.
+ // WARNING: This is an experimental interface that is subject to change.
+ const char* custom_name;
+
+ // The version of the op.
+ // Note: It is the responsibility of the registration binder to set this
+ // properly.
+ int version;
+} TfLiteRegistration;
+
+// WARNING: This is an experimental interface that is subject to change.
+typedef struct _TfLiteDelegate {
+ // Data that delegate needs to identify itself. This data is owned by the
+ // delegate. The delegate is owned in the user code, so the delegate is
+ // responsible for doing this when it is destroyed.
+ void* data_;
+
+ // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the
+ // delegate a view of the current graph through TfLiteContext*. It typically
+ // will look at the nodes and call ReplaceSubgraphsWithDelegateKernels()
+ // to ask the TensorFlow lite runtime to create macro-nodes to represent
+ // delegated subgraphs of the original graph.
+ TfLiteStatus (*Prepare)(TfLiteContext* context, TfLiteDelegate* delegate);
+
+ // Copy the data from delegate buffer handle to raw memory.
+ // This can be null if the delegate doesn't use its own buffer.
+ TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context,
+ TfLiteDelegate* delegate,
+ TfLiteBufferHandle buffer_handle,
+ void* data, size_t size);
+
+ // Copy the data from raw memory to delegate buffer handle.
+ // This can be null if the delegate doesn't use its own buffer.
+ TfLiteStatus (*CopyToBufferHandle)(TfLiteContext* context,
+ TfLiteDelegate* delegate,
+ TfLiteBufferHandle buffer_handle,
+ void* data, size_t size);
+
+ // Free the Delegate Buffer Handle. Note: This only frees the handle, but
+ // this doesn't release the underlying resource (e.g. textures). The
+ // resources are either owned by application layer or the delegate.
+ // This can be null if the delegate doesn't use its own buffer.
+ void (*FreeBufferHandle)(TfLiteContext* context, TfLiteDelegate* delegate,
+ TfLiteBufferHandle* handle);
+} TfLiteDelegate;
+
+// WARNING: This is an experimental interface that is subject to change.
+//
+// Currently, TfLiteDelegateParams has to be allocated in a way that it's
+// trivially destructable. It will be stored as `builtin_data` field in
+// `TfLiteNode` of the delegate node.
+//
+// See also the `CreateDelegateParams` function in `interpreter.cc` details.
+typedef struct {
+ TfLiteDelegate* delegate;
+ TfLiteIntArray* nodes_to_replace;
+ TfLiteIntArray* input_tensors;
+ TfLiteIntArray* output_tensors;
+} TfLiteDelegateParams;
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+#endif // TENSORFLOW_CONTRIB_LITE_C_C_API_INTERNAL_H_
diff --git a/tensorflow/contrib/lite/context_test.cc b/tensorflow/contrib/lite/c/c_api_internal_test.cc
index 20d6f69a25..af398f3207 100644
--- a/tensorflow/contrib/lite/context_test.cc
+++ b/tensorflow/contrib/lite/c/c_api_internal_test.cc
@@ -13,16 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/testing/util.h"
namespace tflite {
// NOTE: this tests only the TfLiteIntArray part of context.
-// most of context.h is provided in the context of using it with interpreter.h
-// and interpreter.cc, so interpreter_test.cc tests context structures more
-// thoroughly.
+// most of c_api_internal.h is provided in the context of using it with
+// interpreter.h and interpreter.cc, so interpreter_test.cc tests context
+// structures more thoroughly.
TEST(IntArray, TestIntArrayCreate) {
TfLiteIntArray* a = TfLiteIntArrayCreate(0);
@@ -69,7 +68,6 @@ TEST(IntArray, TestIntArrayEqual) {
} // namespace tflite
int main(int argc, char** argv) {
- ::tflite::LogToStderr();
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h
index c7f4df3cdc..b86c2819b8 100644
--- a/tensorflow/contrib/lite/context.h
+++ b/tensorflow/contrib/lite/context.h
@@ -12,480 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// This file defines a C API for implementing operations in tflite.
-// These operations can be defined using c++ but the interface between
-// the interpreter and the operations are C.
-//
-// Summary of abstractions
-// TF_LITE_ENSURE - Self-sufficient error checking
-// TfLiteStatus - Status reporting
-// TfLiteIntArray - stores tensor shapes (dims),
-// TfLiteContext - allows an op to access the tensors
-// TfLiteTensor - tensor (a multidimensional array)
-// TfLiteNode - a single node or operation
-// TfLiteRegistration - the implementation of a conceptual operation.
-//
-// Some abstractions in this file are created and managed by Interpreter.
+// Compatibility shim for moved header location.
#ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
#define TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
-#include <stdbool.h>
-#include <stdint.h>
-#include <stdlib.h>
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
-#ifdef __cplusplus
-extern "C" {
-#endif // __cplusplus
-
-typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus;
-
-// The list of external context types known to TF Lite. This list exists solely
-// to avoid conflicts and to ensure ops can share the external contexts they
-// need. Access to the external contexts is controled by one of the
-// corresponding support files.
-typedef enum {
- kTfLiteEigenContext = 0, // include eigen_support.h to use.
- kTfLiteGemmLowpContext = 1, // include gemm_support.h to use.
- kTfLiteEdgeTpuContext = 2, // Placeholder for Edge TPU support.
- kTfLiteMaxExternalContexts = 3
-} TfLiteExternalContextType;
-
-// An external context is a collection of information unrelated to the TF Lite
-// framework, but useful to a subset of the ops. TF Lite knows very little
-// about about the actual contexts, but it keeps a list of them, and is able to
-// refresh them if configurations like the number of recommended threads
-// change.
-typedef struct {
- TfLiteExternalContextType type;
- TfLiteStatus (*Refresh)(struct TfLiteContext* context);
-} TfLiteExternalContext;
-
-// Forward declare so GetNode can use this is in Context.
-typedef struct _TfLiteRegistration TfLiteRegistration;
-typedef struct _TfLiteDelegate TfLiteDelegate;
-
-#define kOptionalTensor (-1)
-
-// Fixed size list of integers. Used for dimensions and inputs/outputs tensor
-// indices
-typedef struct {
- int size;
-// gcc 6.1+ have a bug where flexible members aren't properly handled
-// https://github.com/google/re2/commit/b94b7cd42e9f02673cd748c1ac1d16db4052514c
-#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ == 6 && \
- __GNUC_MINOR__ >= 1
- int data[0];
-#else
- int data[];
-#endif
-} TfLiteIntArray;
-
-// Given the size (number of elements) in a TfLiteIntArray, calculate its size
-// in bytes.
-int TfLiteIntArrayGetSizeInBytes(int size);
-
-// Create a array of a given `size` (uninitialized entries).
-// This returns a pointer, that you must free using TfLiteIntArrayFree().
-TfLiteIntArray* TfLiteIntArrayCreate(int size);
-
-// Check if two tensors are equal. Returns 1 if they are equal, 0 otherwise.
-int TfLiteIntArrayEqual(TfLiteIntArray* a, TfLiteIntArray* b);
-
-// Create a copy of an array passed as `src`.
-// You are expected to free memory with TfLiteIntArrayFree
-TfLiteIntArray* TfLiteIntArrayCopy(TfLiteIntArray* src);
-
-// Free memory of array `v`.
-void TfLiteIntArrayFree(TfLiteIntArray* v);
-
-// Since we must not depend on any libraries, define a minimal subset of
-// error macros while avoiding names that have pre-conceived meanings like
-// assert and check.
-
-// Check whether value is true, and if not return kTfLiteError from
-// the current function (and report the error string msg).
-#define TF_LITE_ENSURE_MSG(context, value, msg) \
- do { \
- if (!(value)) { \
- (context)->ReportError((context), __FILE__ " " msg); \
- return kTfLiteError; \
- } \
- } while (0)
-
-// Check whether the value `a` is true, and if not return kTfLiteError from
-// the current function, while also reporting the location of the error.
-#define TF_LITE_ENSURE(context, a) \
- do { \
- if (!(a)) { \
- (context)->ReportError((context), "%s:%d %s was not true.", __FILE__, \
- __LINE__, #a); \
- return kTfLiteError; \
- } \
- } while (0)
-
-#define TF_LITE_ENSURE_STATUS(a) \
- do { \
- if ((a) != kTfLiteOk) { \
- return kTfLiteError; \
- } \
- } while (0)
-
-// Check whether the value `a == b` is true, and if not return kTfLiteError from
-// the current function, while also reporting the location of the error.
-// `a` and `b` may be evaluated more than once, so no side effects or
-// extremely expensive computations should be done.
-#define TF_LITE_ENSURE_EQ(context, a, b) \
- do { \
- if ((a) != (b)) { \
- (context)->ReportError((context), "%s:%d %s != %s (%d != %d)", __FILE__, \
- __LINE__, #a, #b, (a), (b)); \
- return kTfLiteError; \
- } \
- } while (0)
-
-#define TF_LITE_ENSURE_OK(context, status) \
- do { \
- if ((status) != kTfLiteOk) { \
- return status; \
- } \
- } while (0)
-
-// Single-precision complex data type compatible with the C99 definition.
-typedef struct {
- float re, im; // real and imaginary parts, respectively.
-} TfLiteComplex64;
-
-// Types supported by tensor
-typedef enum {
- kTfLiteNoType = 0,
- kTfLiteFloat32 = 1,
- kTfLiteInt32 = 2,
- kTfLiteUInt8 = 3,
- kTfLiteInt64 = 4,
- kTfLiteString = 5,
- kTfLiteBool = 6,
- kTfLiteInt16 = 7,
- kTfLiteComplex64 = 8,
-} TfLiteType;
-
-// Parameters for asymmetric quantization. Quantized values can be converted
-// back to float using:
-// real_value = scale * (quantized_value - zero_point);
-typedef struct {
- float scale;
- int32_t zero_point;
-} TfLiteQuantizationParams;
-
-// A union of pointers that points to memory for a given tensor.
-typedef union {
- int* i32;
- int64_t* i64;
- float* f;
- char* raw;
- const char* raw_const;
- uint8_t* uint8;
- bool* b;
- int16_t* i16;
- TfLiteComplex64* c64;
-} TfLitePtrUnion;
-
-// Memory allocation strategies. kTfLiteMmapRo is for read-only memory-mapped
-// data (or data externally allocated). kTfLiteArenaRw is arena allocated
-// data. kTfLiteDynamic is for tensors that are allocated during evaluation.
-typedef enum {
- kTfLiteMemNone = 0,
- kTfLiteMmapRo,
- kTfLiteArenaRw,
- kTfLiteArenaRwPersistent,
- kTfLiteDynamic,
-} TfLiteAllocationType;
-
-// The delegates should use zero or positive integers to represent handles.
-// -1 is reserved from unallocated status.
-typedef int TfLiteBufferHandle;
-const TfLiteBufferHandle kTfLiteNullBufferHandle = -1;
-
-// An tensor in the interpreter system which is a wrapper around a buffer of
-// data including a dimensionality (or NULL if not currently defined).
-typedef struct {
- // The data type specification for data stored in `data`. This affects
- // what member of `data` union should be used.
- TfLiteType type;
- // A union of data pointers. The appropriate type should be used for a typed
- // tensor based on `type`.
- TfLitePtrUnion data;
- // A pointer to a structure representing the dimensionality interpretation
- // that the buffer should have. NOTE: the product of elements of `dims`
- // and the element datatype size should be equal to `bytes` below.
- TfLiteIntArray* dims;
- // Quantization information.
- TfLiteQuantizationParams params;
- // How memory is mapped
- // kTfLiteMmapRo: Memory mapped read only.
- // i.e. weights
- // kTfLiteArenaRw: Arena allocated read write memory
- // (i.e. temporaries, outputs).
- TfLiteAllocationType allocation_type;
- // The number of bytes required to store the data of this Tensor. I.e.
- // (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if
- // type is kTfLiteFloat32 and dims = {3, 2} then
- // bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
- size_t bytes;
-
- // An opaque pointer to a tflite::MMapAllocation
- const void* allocation;
-
- // Null-terminated name of this tensor.
- const char* name;
-
- // The delegate which knows how to handle `buffer_handle`.
- // WARNING: This is an experimental interface that is subject to change.
- TfLiteDelegate* delegate;
-
- // An integer buffer handle that can be handled by `delegate`.
- // The value is valid only when delegate is not null.
- // WARNING: This is an experimental interface that is subject to change.
- TfLiteBufferHandle buffer_handle;
-
- // If the delegate uses its own buffer (e.g. GPU memory), the delegate is
- // responsible to set data_is_stale to true.
- // `delegate->CopyFromBufferHandle` can be called to copy the data from
- // delegate buffer.
- // WARNING: This is an // experimental interface that is subject to change.
- bool data_is_stale;
-
- // True if the tensor is a variable.
- bool is_variable;
-} TfLiteTensor;
-
-// Free data memory of tensor `t`;
-void TfLiteTensorDataFree(TfLiteTensor* t);
-
-// Free memory of tensor `t`;
-void TfLiteTensorFree(TfLiteTensor* t);
-
-// Set all of a tensor's fields (and free any previously allocated data).
-void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
- TfLiteQuantizationParams quantization, char* buffer,
- size_t size, TfLiteAllocationType allocation_type,
- const void* allocation, bool is_variable,
- TfLiteTensor* tensor);
-
-// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
-// types other than kTfLiteDynamic will be ignored.
-void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
-
-// A structure representing an instance of a node.
-// This structure only exhibits the inputs, outputs and user defined data, not
-// other features like the type.
-typedef struct {
- // Inputs to this node expressed as indices into the simulator's tensors.
- TfLiteIntArray* inputs;
-
- // Outputs to this node expressed as indices into the simulator's tensors.
- TfLiteIntArray* outputs;
-
- // Temporary tensors uses during the computations. This usually contains no
- // tensors, but ops are allowed to change that if they need scratch space of
- // any sort.
- TfLiteIntArray* temporaries;
-
- // Opaque data provided by the node implementer through `Registration.init`.
- void* user_data;
-
- // Opaque data provided to the node if the node is a builtin. This is usually
- // a structure defined in builtin_op_data.h
- void* builtin_data;
-
- // Custom initial data. This is the opaque data provided in the flatbuffer.
- // WARNING: This is an experimental interface that is subject to change.
- const void* custom_initial_data;
- int custom_initial_data_size;
-
- // The pointer to the delegate. This is non-null only when the node is
- // created by calling `interpreter.ModifyGraphWithDelegate`.
- // WARNING: This is an experimental interface that is subject to change.
- TfLiteDelegate* delegate;
-} TfLiteNode;
-
-typedef struct TfLiteContext {
- // Number of tensors in the context.
- size_t tensors_size;
-
- // The execution plan contains a list of the node indices in execution
- // order. execution_plan->size is the current number of nodes. And,
- // execution_plan->data[0] is the first node that needs to be run.
- // TfLiteDelegates can traverse the current execution plan by iterating
- // through each member of this array and using GetNodeAndRegistration() to
- // access details about a node. i.e.
- // TfLiteIntArray* execution_plan;
- // TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan));
- // for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) {
- // int node_index = execution_plan->data[exec_index];
- // TfLiteNode* node;
- // TfLiteRegistration* reg;
- // context->GetNodeAndRegistration(context, node_index, &node, &reg);
- // }
- // WARNING: This is an experimental interface that is subject to change.
- TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context,
- TfLiteIntArray** execution_plan);
-
- // An array of tensors in the interpreter context (of length `tensors_size`)
- TfLiteTensor* tensors;
-
- // opaque full context ptr (an opaque c++ data structure)
- void* impl_;
-
- // Request memory pointer be resized. Updates dimensions on the tensor.
- // NOTE: ResizeTensor takes ownership of newSize.
- TfLiteStatus (*ResizeTensor)(struct TfLiteContext*, TfLiteTensor* tensor,
- TfLiteIntArray* new_size);
- // Request that a error be reported with format string msg.
- void (*ReportError)(struct TfLiteContext*, const char* msg, ...);
-
- // Add `tensors_to_add` tensors, preserving pre-existing Tensor entries. If
- // non-null, the value pointed to by `first_new_tensor_index` will be set to
- // the index of the first new tensor.
- TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add,
- int* first_new_tensor_index);
-
- // Get a Tensor node by node_index.
- // WARNING: This is an experimental interface that is subject to change.
- TfLiteStatus (*GetNodeAndRegistration)(struct TfLiteContext*, int node_index,
- TfLiteNode** node,
- TfLiteRegistration** registration);
-
- // Replace ops with one or more stub delegate operations. This function
- // does not take ownership of `nodes_to_replace`.
- TfLiteStatus (*ReplaceSubgraphsWithDelegateKernels)(
- struct TfLiteContext*, TfLiteRegistration registration,
- const TfLiteIntArray* nodes_to_replace, TfLiteDelegate* delegate);
-
- // Number of threads that are recommended to subsystems like gemmlowp and
- // eigen.
- int recommended_num_threads;
-
- // Access external contexts by type.
- // WARNING: This is an experimental interface that is subject to change.
- TfLiteExternalContext* (*GetExternalContext)(struct TfLiteContext*,
- TfLiteExternalContextType);
- // Set the value of a external context. Does not take ownership of the
- // pointer.
- // WARNING: This is an experimental interface that is subject to change.
- void (*SetExternalContext)(struct TfLiteContext*, TfLiteExternalContextType,
- TfLiteExternalContext*);
-} TfLiteContext;
-
-typedef struct _TfLiteRegistration {
- // Initializes the op from serialized data.
- // If a built-in op:
- // `buffer` is the op's params data (TfLiteLSTMParams*).
- // `length` is zero.
- // If custom op:
- // `buffer` is the op's `custom_options`.
- // `length` is the size of the buffer.
- //
- // Returns a type-punned (i.e. void*) opaque data (e.g. a primitive pointer
- // or an instance of a struct).
- //
- // The returned pointer will be stored with the node in the `user_data` field,
- // accessible within prepare and invoke functions below.
- // NOTE: if the data is already in the desired format, simply implement this
- // function to return `nullptr` and implement the free function to be a no-op.
- void* (*init)(TfLiteContext* context, const char* buffer, size_t length);
-
- // The pointer `buffer` is the data previously returned by an init invocation.
- void (*free)(TfLiteContext* context, void* buffer);
-
- // prepare is called when the inputs this node depends on have been resized.
- // context->ResizeTensor() can be called to request output tensors to be
- // resized.
- //
- // Returns kTfLiteOk on success.
- TfLiteStatus (*prepare)(TfLiteContext* context, TfLiteNode* node);
-
- // Execute the node (should read node->inputs and output to node->outputs).
- // Returns kTfLiteOk on success.
- TfLiteStatus (*invoke)(TfLiteContext* context, TfLiteNode* node);
-
- // profiling_string is called during summarization of profiling information
- // in order to group executions together. Providing a value here will cause a
- // given op to appear multiple times is the profiling report. This is
- // particularly useful for custom ops that can perform significantly
- // different calculations depending on their `user-data`.
- const char* (*profiling_string)(const TfLiteContext* context,
- const TfLiteNode* node);
-
- // Builtin codes. If this kernel refers to a builtin this is the code
- // of the builtin. This is so we can do marshaling to other frameworks like
- // NN API.
- // Note: It is the responsibility of the registration binder to set this
- // properly.
- int32_t builtin_code;
-
- // Custom op name. If the op is a builtin, this will be null.
- // Note: It is the responsibility of the registration binder to set this
- // properly.
- // WARNING: This is an experimental interface that is subject to change.
- const char* custom_name;
-
- // The version of the op.
- // Note: It is the responsibility of the registration binder to set this
- // properly.
- int version;
-} TfLiteRegistration;
-
-// WARNING: This is an experimental interface that is subject to change.
-typedef struct _TfLiteDelegate {
- // Data that delegate needs to identify itself. This data is owned by the
- // delegate. The delegate is owned in the user code, so the delegate is
- // responsible for doing this when it is destroyed.
- void* data_;
-
- // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the
- // delegate a view of the current graph through TfLiteContext*. It typically
- // will look at the nodes and call ReplaceSubgraphsWithDelegateKernels()
- // to ask the TensorFlow lite runtime to create macro-nodes to represent
- // delegated subgraphs of the original graph.
- TfLiteStatus (*Prepare)(TfLiteContext* context, TfLiteDelegate* delegate);
-
- // Copy the data from delegate buffer handle to raw memory.
- // This can be null if the delegate doesn't use its own buffer.
- TfLiteStatus (*CopyFromBufferHandle)(TfLiteContext* context,
- TfLiteDelegate* delegate,
- TfLiteBufferHandle buffer_handle,
- void* data, size_t size);
-
- // Copy the data from raw memory to delegate buffer handle.
- // This can be null if the delegate doesn't use its own buffer.
- TfLiteStatus (*CopyToBufferHandle)(TfLiteContext* context,
- TfLiteDelegate* delegate,
- TfLiteBufferHandle buffer_handle,
- void* data, size_t size);
-
- // Free the Delegate Buffer Handle. Note: This only frees the handle, but
- // this doesn't release the underlying resource (e.g. textures). The
- // resources are either owned by application layer or the delegate.
- // This can be null if the delegate doesn't use its own buffer.
- void (*FreeBufferHandle)(TfLiteContext* context, TfLiteDelegate* delegate,
- TfLiteBufferHandle* handle);
-} TfLiteDelegate;
-
-// WARNING: This is an experimental interface that is subject to change.
-//
-// Currently, TfLiteDelegateParams has to be allocated in a way that it's
-// trivially destructable. It will be stored as `builtin_data` field in
-// `TfLiteNode` of the delegate node.
-//
-// See also the `CreateDelegateParams` function in `interpreter.cc` details.
-typedef struct {
- TfLiteDelegate* delegate;
- TfLiteIntArray* nodes_to_replace;
- TfLiteIntArray* input_tensors;
- TfLiteIntArray* output_tensors;
-} TfLiteDelegateParams;
-
-#ifdef __cplusplus
-} // extern "C"
-#endif // __cplusplus
#endif // TENSORFLOW_CONTRIB_LITE_CONTEXT_H_
diff --git a/tensorflow/contrib/lite/context_util.h b/tensorflow/contrib/lite/context_util.h
index abe802e342..ccda4c7393 100644
--- a/tensorflow/contrib/lite/context_util.h
+++ b/tensorflow/contrib/lite/context_util.h
@@ -17,7 +17,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_
#define TENSORFLOW_CONTRIB_LITE_CONTEXT_UTIL_H_
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/core/api/BUILD b/tensorflow/contrib/lite/core/api/BUILD
new file mode 100644
index 0000000000..e4500534f3
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/BUILD
@@ -0,0 +1,57 @@
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
+
+cc_library(
+ name = "api",
+ srcs = [
+ "error_reporter.cc",
+ "flatbuffer_conversions.cc",
+ "op_resolver.cc",
+ ],
+ hdrs = [
+ "error_reporter.h",
+ "flatbuffer_conversions.h",
+ "op_resolver.h",
+ ],
+ copts = tflite_copts(),
+ deps = [
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ ],
+)
+
+cc_test(
+ name = "error_reporter_test",
+ size = "small",
+ srcs = ["error_reporter_test.cc"],
+ deps = [
+ ":api",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_test(
+ name = "op_resolver_test",
+ size = "small",
+ srcs = ["op_resolver_test.cc"],
+ deps = [
+ ":api",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+cc_test(
+ name = "flatbuffer_conversions_test",
+ size = "small",
+ srcs = ["flatbuffer_conversions_test.cc"],
+ deps = [
+ ":api",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "@com_google_googletest//:gtest",
+ ],
+)
diff --git a/tensorflow/contrib/lite/core/api/error_reporter.cc b/tensorflow/contrib/lite/core/api/error_reporter.cc
new file mode 100644
index 0000000000..423f83b1a9
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/error_reporter.cc
@@ -0,0 +1,38 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include <cstdarg>
+
+namespace tflite {
+
+int ErrorReporter::Report(const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ int code = Report(format, args);
+ va_end(args);
+ return code;
+}
+
+// TODO(aselle): Make the name of ReportError on context the same, so
+// we can use the ensure functions w/o a context and w/ a reporter.
+int ErrorReporter::ReportError(void*, const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ int code = Report(format, args);
+ va_end(args);
+ return code;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/core/api/error_reporter.h b/tensorflow/contrib/lite/core/api/error_reporter.h
new file mode 100644
index 0000000000..a2f780b003
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/error_reporter.h
@@ -0,0 +1,45 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_CORE_API_ERROR_REPORTER_H_
+#define TENSORFLOW_CONTRIB_LITE_CORE_API_ERROR_REPORTER_H_
+
+#include <cstdarg>
+
+namespace tflite {
+
+// A functor that reports error to supporting system. Invoked similar to
+// printf.
+//
+// Usage:
+// ErrorReporter foo;
+// foo.Report("test %d", 5);
+// or
+// va_list args;
+// foo.Report("test %d", args); // where args is va_list
+//
+// Subclass ErrorReporter to provide another reporting destination.
+// For example, if you have a GUI program, you might redirect to a buffer
+// that drives a GUI error log box.
+class ErrorReporter {
+ public:
+ virtual ~ErrorReporter() {}
+ virtual int Report(const char* format, va_list args) = 0;
+ int Report(const char* format, ...);
+ int ReportError(void*, const char* format, ...);
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_CORE_API_ERROR_REPORTER_H_
diff --git a/tensorflow/contrib/lite/core/api/error_reporter_test.cc b/tensorflow/contrib/lite/core/api/error_reporter_test.cc
new file mode 100644
index 0000000000..0463eee6be
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/error_reporter_test.cc
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+
+#include <cstdio>
+
+#include <gtest/gtest.h>
+
+namespace tflite {
+
+class MockErrorReporter : public ErrorReporter {
+ public:
+ int Report(const char* format, va_list args) override {
+ vsnprintf(buffer_, kBufferSize, format, args);
+ return 0;
+ }
+ char* GetBuffer() { return buffer_; }
+
+ private:
+ static constexpr int kBufferSize = 256;
+ char buffer_[kBufferSize];
+};
+
+TEST(ErrorReporter, TestReport) {
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+ reporter->Report("Error: %d", 23);
+ EXPECT_EQ(0, strcmp(mock_reporter.GetBuffer(), "Error: 23"));
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
new file mode 100644
index 0000000000..1420fbcdc6
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
@@ -0,0 +1,622 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h"
+
+#include <cstdlib>
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+
+namespace tflite {
+
+namespace {
+
+// Copies the contents from the flatbuffer int vector `flatbuffer` into the
+// int array `buffer`. `flat_vector` and `buffer` represent the same
+// configuration operation for a given operation.
+void FlatBufferIntVectorToArray(int max_size_of_buffer,
+ const flatbuffers::Vector<int32_t>* flat_vector,
+ int* buffer, ErrorReporter* error_reporter) {
+ if (!flat_vector) {
+ error_reporter->Report("Input array not provided for operation.\n");
+ } else {
+ int num_dimensions = flat_vector->Length();
+ if (num_dimensions > max_size_of_buffer / sizeof(int)) {
+ error_reporter->Report(
+ "Found too many dimensions in the operation's input array.\n");
+ } else {
+ for (int i = 0; i < num_dimensions; ++i) {
+ buffer[i] = flat_vector->Get(i);
+ }
+ }
+ }
+}
+
+// Allocate a structure using malloc, but make sure the structure is a POD
+// structure that doesn't require constructors to run. The reason we do this,
+// is that Interpreter's C extension part will take ownership so destructors
+// will not be run during deallocation.
+template <class T>
+T* MallocPOD() {
+ static_assert(std::is_pod<T>::value, "Builtin data structure must be POD.");
+ return static_cast<T*>(malloc(sizeof(T)));
+}
+
+} // namespace
+
+TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
+ ErrorReporter* error_reporter) {
+ switch (tensor_type) {
+ case TensorType_FLOAT32:
+ *type = kTfLiteFloat32;
+ break;
+ case TensorType_INT16:
+ *type = kTfLiteInt16;
+ break;
+ case TensorType_INT32:
+ *type = kTfLiteInt32;
+ break;
+ case TensorType_UINT8:
+ *type = kTfLiteUInt8;
+ break;
+ case TensorType_INT64:
+ *type = kTfLiteInt64;
+ break;
+ case TensorType_STRING:
+ *type = kTfLiteString;
+ break;
+ case TensorType_BOOL:
+ *type = kTfLiteBool;
+ break;
+ case TensorType_COMPLEX64:
+ *type = kTfLiteComplex64;
+ break;
+ default:
+ error_reporter->Report("Unimplemented data type %s (%d) in tensor\n",
+ EnumNameTensorType(tensor_type), tensor_type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+// Parse the appropriate data out of the op.
+//
+// This handles builtin data explicitly as there are flatbuffer schemas.
+// If it returns kTfLiteOk, it passes the data out with `builtin_data`, which
+// need to be released by calling `free`.`
+// If it returns kTfLiteError, `builtin_data` will be `nullptr`.
+TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
+ ErrorReporter* error_reporter, void** builtin_data) {
+ auto parse_padding = [](Padding padding) {
+ switch (padding) {
+ case Padding_SAME:
+ return kTfLitePaddingSame;
+ case Padding_VALID:
+ return kTfLitePaddingValid;
+ }
+ return kTfLitePaddingUnknown;
+ };
+ auto parse_activation = [](ActivationFunctionType activation) {
+ switch (activation) {
+ case ActivationFunctionType_NONE:
+ return kTfLiteActNone;
+ case ActivationFunctionType_RELU:
+ return kTfLiteActRelu;
+ case ActivationFunctionType_RELU_N1_TO_1:
+ return kTfLiteActRelu1;
+ case ActivationFunctionType_RELU6:
+ return kTfLiteActRelu6;
+ case ActivationFunctionType_TANH:
+ return kTfLiteActTanh;
+ case ActivationFunctionType_SIGN_BIT:
+ return kTfLiteActSignBit;
+ }
+ return kTfLiteActNone;
+ };
+ auto parseLSHProjectionType = [](LSHProjectionType type) {
+ switch (type) {
+ case LSHProjectionType_SPARSE:
+ return kTfLiteLshProjectionSparse;
+ case LSHProjectionType_DENSE:
+ return kTfLiteLshProjectionDense;
+ default:
+ return kTfLiteLshProjectionUnknown;
+ }
+ };
+ auto parseCombinerType = [](CombinerType type) {
+ switch (type) {
+ case CombinerType_MEAN:
+ return kTfLiteCombinerTypeMean;
+ case CombinerType_SQRTN:
+ return kTfLiteCombinerTypeSqrtn;
+ case CombinerType_SUM:
+ default:
+ return kTfLiteCombinerTypeSum;
+ }
+ };
+
+ *builtin_data = nullptr;
+ switch (op_type) {
+ case BuiltinOperator_CONV_2D: {
+ TfLiteConvParams* params = MallocPOD<TfLiteConvParams>();
+ if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
+ params->padding = parse_padding(conv_params->padding());
+ params->stride_width = conv_params->stride_w();
+ params->stride_height = conv_params->stride_h();
+ params->activation =
+ parse_activation(conv_params->fused_activation_function());
+
+ params->dilation_width_factor = conv_params->dilation_w_factor();
+ params->dilation_height_factor = conv_params->dilation_h_factor();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_CAST: {
+ TfLiteCastParams* params = MallocPOD<TfLiteCastParams>();
+ if (auto* schema_params = op->builtin_options_as_CastOptions()) {
+ auto in_status =
+ ConvertTensorType(schema_params->in_data_type(),
+ &params->in_data_type, error_reporter);
+ auto out_status =
+ ConvertTensorType(schema_params->out_data_type(),
+ &params->out_data_type, error_reporter);
+ if (in_status != kTfLiteOk || out_status != kTfLiteOk) {
+ free(params);
+ return kTfLiteError;
+ }
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_LSH_PROJECTION: {
+ TfLiteLSHProjectionParams* params =
+ MallocPOD<TfLiteLSHProjectionParams>();
+ if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) {
+ params->type = parseLSHProjectionType(lshParams->type());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_AVERAGE_POOL_2D:
+ case BuiltinOperator_MAX_POOL_2D:
+ case BuiltinOperator_L2_POOL_2D: {
+ TfLitePoolParams* params = MallocPOD<TfLitePoolParams>();
+ if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) {
+ params->padding = parse_padding(pool_params->padding());
+ params->stride_width = pool_params->stride_w();
+ params->stride_height = pool_params->stride_h();
+ params->filter_width = pool_params->filter_width();
+ params->filter_height = pool_params->filter_height();
+ params->activation =
+ parse_activation(pool_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_DEPTHWISE_CONV_2D: {
+ TfLiteDepthwiseConvParams* params =
+ MallocPOD<TfLiteDepthwiseConvParams>();
+ if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) {
+ params->padding = parse_padding(conv_params->padding());
+ params->stride_width = conv_params->stride_w();
+ params->stride_height = conv_params->stride_h();
+ params->depth_multiplier = conv_params->depth_multiplier();
+ params->activation =
+ parse_activation(conv_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SVDF: {
+ TfLiteSVDFParams* params = MallocPOD<TfLiteSVDFParams>();
+ if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) {
+ params->rank = svdf_params->rank();
+ params->activation =
+ parse_activation(svdf_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
+ case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
+ TfLiteSequenceRNNParams* params = MallocPOD<TfLiteSequenceRNNParams>();
+ if (auto* sequence_rnn_params =
+ op->builtin_options_as_SequenceRNNOptions()) {
+ params->activation =
+ parse_activation(sequence_rnn_params->fused_activation_function());
+ params->time_major = sequence_rnn_params->time_major();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_RNN: {
+ TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>();
+ if (auto* rnn_params = op->builtin_options_as_RNNOptions()) {
+ params->activation =
+ parse_activation(rnn_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: {
+ TfLiteEmbeddingLookupSparseParams* params =
+ MallocPOD<TfLiteEmbeddingLookupSparseParams>();
+ if (auto* embedding_params =
+ op->builtin_options_as_EmbeddingLookupSparseOptions()) {
+ params->combiner = parseCombinerType(embedding_params->combiner());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_FULLY_CONNECTED: {
+ TfLiteFullyConnectedParams* params =
+ MallocPOD<TfLiteFullyConnectedParams>();
+ if (auto* fully_connected_params =
+ op->builtin_options_as_FullyConnectedOptions()) {
+ params->activation = parse_activation(
+ fully_connected_params->fused_activation_function());
+ switch (fully_connected_params->weights_format()) {
+ case FullyConnectedOptionsWeightsFormat_DEFAULT:
+ params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault;
+ break;
+ case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
+ params->weights_format =
+ kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8;
+ break;
+ default:
+ error_reporter->Report("Unhandled fully-connected weights format.");
+ return kTfLiteError;
+ }
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_HASHTABLE_LOOKUP:
+ // no-op.
+ break;
+ case BuiltinOperator_SOFTMAX: {
+ TfLiteSoftmaxParams* params = MallocPOD<TfLiteSoftmaxParams>();
+ if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) {
+ params->beta = softmax_params->beta();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_CONCATENATION: {
+ TfLiteConcatenationParams* params =
+ MallocPOD<TfLiteConcatenationParams>();
+ if (auto* concatenation_params =
+ op->builtin_options_as_ConcatenationOptions()) {
+ params->activation =
+ parse_activation(concatenation_params->fused_activation_function());
+ params->axis = concatenation_params->axis();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_MUL: {
+ auto* params = MallocPOD<TfLiteMulParams>();
+ if (auto* schema_params = op->builtin_options_as_MulOptions()) {
+ params->activation =
+ parse_activation(schema_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_ADD: {
+ auto* params = MallocPOD<TfLiteAddParams>();
+ if (auto* schema_params = op->builtin_options_as_AddOptions()) {
+ params->activation =
+ parse_activation(schema_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_DIV: {
+ auto* params = MallocPOD<TfLiteDivParams>();
+ if (auto* schema_params = op->builtin_options_as_DivOptions()) {
+ params->activation =
+ parse_activation(schema_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SUB: {
+ auto* params = MallocPOD<TfLiteSubParams>();
+ if (auto* schema_params = op->builtin_options_as_SubOptions()) {
+ params->activation =
+ parse_activation(schema_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_L2_NORMALIZATION: {
+ auto* params = MallocPOD<TfLiteL2NormParams>();
+ if (auto* schema_params = op->builtin_options_as_L2NormOptions()) {
+ params->activation =
+ parse_activation(schema_params->fused_activation_function());
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: {
+ auto* params = MallocPOD<TfLiteLocalResponseNormParams>();
+ if (auto* schema_params =
+ op->builtin_options_as_LocalResponseNormalizationOptions()) {
+ params->radius = schema_params->radius();
+ params->bias = schema_params->bias();
+ params->alpha = schema_params->alpha();
+ params->beta = schema_params->beta();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM:
+ case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
+ case BuiltinOperator_LSTM: {
+ TfLiteLSTMParams* params = MallocPOD<TfLiteLSTMParams>();
+ if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
+ params->activation =
+ parse_activation(lstm_params->fused_activation_function());
+ params->cell_clip = lstm_params->cell_clip();
+ params->proj_clip = lstm_params->proj_clip();
+ switch (lstm_params->kernel_type()) {
+ case LSTMKernelType_FULL:
+ params->kernel_type = kTfLiteLSTMFullKernel;
+ break;
+ case LSTMKernelType_BASIC:
+ params->kernel_type = kTfLiteLSTMBasicKernel;
+ break;
+ }
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_RESIZE_BILINEAR: {
+ auto* params = MallocPOD<TfLiteResizeBilinearParams>();
+ if (auto* schema_params =
+ op->builtin_options_as_ResizeBilinearOptions()) {
+ params->align_corners = schema_params->align_corners();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_RESHAPE: {
+ auto* params = MallocPOD<TfLiteReshapeParams>();
+ if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) {
+ auto* new_shape = schema_params->new_shape();
+ FlatBufferIntVectorToArray(sizeof(params->shape), new_shape,
+ params->shape, error_reporter);
+ params->num_dimensions = new_shape->Length();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SKIP_GRAM: {
+ TfLiteSkipGramParams* params = MallocPOD<TfLiteSkipGramParams>();
+ if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) {
+ params->ngram_size = skip_gram_params->ngram_size();
+ params->max_skip_size = skip_gram_params->max_skip_size();
+ params->include_all_ngrams = skip_gram_params->include_all_ngrams();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SPACE_TO_DEPTH: {
+ auto* params = MallocPOD<TfLiteSpaceToDepthParams>();
+ if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) {
+ params->block_size = schema_params->block_size();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_GATHER: {
+ TfLiteGatherParams* params = MallocPOD<TfLiteGatherParams>();
+ params->axis = 0;
+ if (auto* gather_params = op->builtin_options_as_GatherOptions()) {
+ params->axis = gather_params->axis();
+ }
+
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_MEAN:
+ case BuiltinOperator_REDUCE_MAX:
+ case BuiltinOperator_REDUCE_MIN:
+ case BuiltinOperator_REDUCE_PROD:
+ case BuiltinOperator_REDUCE_ANY:
+ case BuiltinOperator_SUM: {
+ auto* params = MallocPOD<TfLiteReducerParams>();
+ if (auto* schema_params = op->builtin_options_as_ReducerOptions()) {
+ params->keep_dims = schema_params->keep_dims();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SPLIT: {
+ auto* params = MallocPOD<TfLiteSplitParams>();
+ if (auto* schema_params = op->builtin_options_as_SplitOptions()) {
+ params->num_splits = schema_params->num_splits();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SQUEEZE: {
+ auto* params = MallocPOD<TfLiteSqueezeParams>();
+ if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) {
+ const auto& squeeze_dims = schema_params->squeeze_dims();
+ FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims,
+ params->squeeze_dims, error_reporter);
+ params->num_squeeze_dims = squeeze_dims->Length();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_STRIDED_SLICE: {
+ auto* params = MallocPOD<TfLiteStridedSliceParams>();
+ if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) {
+ params->begin_mask = schema_params->begin_mask();
+ params->end_mask = schema_params->end_mask();
+ params->ellipsis_mask = schema_params->ellipsis_mask();
+ params->new_axis_mask = schema_params->new_axis_mask();
+ params->shrink_axis_mask = schema_params->shrink_axis_mask();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_ARG_MAX: {
+ auto* params = MallocPOD<TfLiteArgMaxParams>();
+ if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) {
+ ConvertTensorType(schema_params->output_type(), &params->output_type,
+ error_reporter);
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_ARG_MIN: {
+ auto* params = MallocPOD<TfLiteArgMinParams>();
+ if (const auto* schema_params = op->builtin_options_as_ArgMinOptions()) {
+ ConvertTensorType(schema_params->output_type(), &params->output_type,
+ error_reporter);
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_TRANSPOSE_CONV: {
+ TfLiteTransposeConvParams* params =
+ MallocPOD<TfLiteTransposeConvParams>();
+ if (auto* transpose_conv_params =
+ op->builtin_options_as_TransposeConvOptions()) {
+ params->padding = parse_padding(transpose_conv_params->padding());
+ params->stride_width = transpose_conv_params->stride_w();
+ params->stride_height = transpose_conv_params->stride_h();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SPARSE_TO_DENSE: {
+ TfLiteSparseToDenseParams* params =
+ MallocPOD<TfLiteSparseToDenseParams>();
+ if (auto* sparse_to_dense_params =
+ op->builtin_options_as_SparseToDenseOptions()) {
+ params->validate_indices = sparse_to_dense_params->validate_indices();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_SHAPE: {
+ auto* params = MallocPOD<TfLiteShapeParams>();
+ if (auto* schema_params = op->builtin_options_as_ShapeOptions()) {
+ ConvertTensorType(schema_params->out_type(), &params->out_type,
+ error_reporter);
+ }
+ *builtin_data = static_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_PACK: {
+ TfLitePackParams* params = MallocPOD<TfLitePackParams>();
+ if (auto* pack_params = op->builtin_options_as_PackOptions()) {
+ params->values_count = pack_params->values_count();
+ params->axis = pack_params->axis();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_DELEGATE: {
+ // TODO(ycling): Revisit when supporting saving delegated models.
+ error_reporter->Report("DELEGATE op shouldn't exist in model.");
+ return kTfLiteError;
+ }
+ case BuiltinOperator_FAKE_QUANT: {
+ auto* params = MallocPOD<TfLiteFakeQuantParams>();
+ if (auto* schema_params = op->builtin_options_as_FakeQuantOptions()) {
+ params->min = schema_params->min();
+ params->max = schema_params->max();
+ params->num_bits = schema_params->num_bits();
+ params->narrow_range = schema_params->narrow_range();
+ }
+ *builtin_data = static_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_ONE_HOT: {
+ auto* params = MallocPOD<TfLiteOneHotParams>();
+ if (auto* schema_params = op->builtin_options_as_OneHotOptions()) {
+ params->axis = schema_params->axis();
+ }
+ *builtin_data = static_cast<void*>(params);
+ break;
+ }
+ case BuiltinOperator_UNPACK: {
+ TfLiteUnpackParams* params = MallocPOD<TfLiteUnpackParams>();
+ if (auto* unpack_params = op->builtin_options_as_UnpackOptions()) {
+ params->num = unpack_params->num();
+ params->axis = unpack_params->axis();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+
+ // Below are the ops with no builtin_data strcture.
+ case BuiltinOperator_BATCH_TO_SPACE_ND:
+ // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
+ // ok for now, since there is no call implementation either.
+ case BuiltinOperator_CALL:
+ case BuiltinOperator_CONCAT_EMBEDDINGS:
+ case BuiltinOperator_CUSTOM:
+ case BuiltinOperator_DEQUANTIZE:
+ case BuiltinOperator_EMBEDDING_LOOKUP:
+ case BuiltinOperator_EQUAL:
+ case BuiltinOperator_EXP:
+ case BuiltinOperator_EXPAND_DIMS:
+ case BuiltinOperator_FLOOR:
+ case BuiltinOperator_GREATER:
+ case BuiltinOperator_GREATER_EQUAL:
+ case BuiltinOperator_LESS:
+ case BuiltinOperator_LESS_EQUAL:
+ case BuiltinOperator_LOG:
+ case BuiltinOperator_LOGISTIC:
+ case BuiltinOperator_LOG_SOFTMAX:
+ case BuiltinOperator_MAXIMUM:
+ case BuiltinOperator_MINIMUM:
+ case BuiltinOperator_NEG:
+ case BuiltinOperator_NOT_EQUAL:
+ case BuiltinOperator_PAD:
+ case BuiltinOperator_PADV2:
+ case BuiltinOperator_PRELU:
+ case BuiltinOperator_RELU:
+ case BuiltinOperator_RELU6:
+ case BuiltinOperator_RELU_N1_TO_1:
+ case BuiltinOperator_RSQRT:
+ case BuiltinOperator_SELECT:
+ case BuiltinOperator_SIN:
+ case BuiltinOperator_SLICE:
+ case BuiltinOperator_SPACE_TO_BATCH_ND:
+ case BuiltinOperator_SQRT:
+ case BuiltinOperator_TANH:
+ case BuiltinOperator_TILE:
+ case BuiltinOperator_TOPK_V2:
+ case BuiltinOperator_TRANSPOSE:
+ case BuiltinOperator_POW:
+ case BuiltinOperator_LOGICAL_OR:
+ case BuiltinOperator_LOGICAL_AND:
+ case BuiltinOperator_LOGICAL_NOT:
+ case BuiltinOperator_FLOOR_DIV:
+ break;
+ }
+ return kTfLiteOk;
+} // NOLINT[readability/fn_size]
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h
new file mode 100644
index 0000000000..4dec6f9cfc
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.h
@@ -0,0 +1,48 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_
+#define TENSORFLOW_CONTRIB_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_
+
+// These functions transform codes and data structures that are defined in the
+// flatbuffer serialization format into in-memory values that are used by the
+// runtime API and interpreter.
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+
+namespace tflite {
+
+// Parse the appropriate data out of the op.
+//
+// This handles builtin data explicitly as there are flatbuffer schemas.
+// If it returns kTfLiteOk, it passes the data out with `builtin_data`. The
+// calling function has to pass in an allocator object, and this allocator
+// will be called to reserve space for the output data. If the calling
+// function's allocator reserves memory on the heap, then it's the calling
+// function's responsibility to free it.
+// If it returns kTfLiteError, `builtin_data` will be `nullptr`.
+TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
+ ErrorReporter* error_reporter, void** builtin_data);
+
+// Converts the tensor data type used in the flat buffer to the representation
+// used by the runtime.
+TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
+ ErrorReporter* error_reporter);
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_CORE_API_FLATBUFFER_CONVERSIONS_H_
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc
new file mode 100644
index 0000000000..b12bdf43b2
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions_test.cc
@@ -0,0 +1,104 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h"
+
+#include <cstring>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+
+namespace tflite {
+namespace {
+
+class MockErrorReporter : public ErrorReporter {
+ public:
+ MockErrorReporter() : buffer_size_(0) {}
+ int Report(const char* format, va_list args) override {
+ buffer_size_ = vsnprintf(buffer_, kBufferSize, format, args);
+ return buffer_size_;
+ }
+ char* GetBuffer() { return buffer_; }
+ int GetBufferSize() { return buffer_size_; }
+
+ private:
+ static constexpr int kBufferSize = 256;
+ char buffer_[kBufferSize];
+ int buffer_size_;
+};
+
+} // namespace
+
+TEST(FlatbufferConversions, TestParseOpDataConv) {
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+
+ flatbuffers::FlatBufferBuilder builder;
+ flatbuffers::Offset<void> conv_options =
+ CreateConv2DOptions(builder, Padding_SAME, 1, 2,
+ ActivationFunctionType_RELU, 3, 4)
+ .Union();
+ flatbuffers::Offset<Operator> conv_offset = CreateOperatorDirect(
+ builder, 0, nullptr, nullptr, BuiltinOptions_Conv2DOptions, conv_options,
+ nullptr, CustomOptionsFormat_FLEXBUFFERS, nullptr);
+ builder.Finish(conv_offset);
+ void* conv_pointer = builder.GetBufferPointer();
+ const Operator* conv_op = flatbuffers::GetRoot<Operator>(conv_pointer);
+ void* output_data = nullptr;
+ EXPECT_EQ(kTfLiteOk, ParseOpData(conv_op, BuiltinOperator_CONV_2D, reporter,
+ &output_data));
+ EXPECT_NE(nullptr, output_data);
+ TfLiteConvParams* params = reinterpret_cast<TfLiteConvParams*>(output_data);
+ EXPECT_EQ(kTfLitePaddingSame, params->padding);
+ EXPECT_EQ(1, params->stride_width);
+ EXPECT_EQ(2, params->stride_height);
+ EXPECT_EQ(kTfLiteActRelu, params->activation);
+ EXPECT_EQ(3, params->dilation_width_factor);
+ EXPECT_EQ(4, params->dilation_height_factor);
+ free(output_data);
+}
+
+TEST(FlatbufferConversions, TestParseOpDataCustom) {
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+
+ flatbuffers::FlatBufferBuilder builder;
+ flatbuffers::Offset<void> null_options;
+ flatbuffers::Offset<Operator> custom_offset = CreateOperatorDirect(
+ builder, 0, nullptr, nullptr, BuiltinOptions_NONE, null_options, nullptr,
+ CustomOptionsFormat_FLEXBUFFERS, nullptr);
+ builder.Finish(custom_offset);
+ void* custom_pointer = builder.GetBufferPointer();
+ const Operator* custom_op = flatbuffers::GetRoot<Operator>(custom_pointer);
+ void* output_data = nullptr;
+ EXPECT_EQ(kTfLiteOk, ParseOpData(custom_op, BuiltinOperator_CUSTOM, reporter,
+ &output_data));
+ EXPECT_EQ(nullptr, output_data);
+}
+
+TEST(FlatbufferConversions, TestConvertTensorType) {
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+ TfLiteType type;
+ EXPECT_EQ(kTfLiteOk, ConvertTensorType(TensorType_FLOAT32, &type, reporter));
+ EXPECT_EQ(kTfLiteFloat32, type);
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/core/api/op_resolver.cc b/tensorflow/contrib/lite/core/api/op_resolver.cc
new file mode 100644
index 0000000000..55ee924843
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/op_resolver.cc
@@ -0,0 +1,60 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+
+namespace tflite {
+
+TfLiteStatus GetRegistrationFromOpCode(
+ const OperatorCode* opcode, const OpResolver& op_resolver,
+ ErrorReporter* error_reporter, const TfLiteRegistration** registration) {
+ TfLiteStatus status = kTfLiteOk;
+ *registration = nullptr;
+ auto builtin_code = opcode->builtin_code();
+ int version = opcode->version();
+
+ if (builtin_code > BuiltinOperator_MAX ||
+ builtin_code < BuiltinOperator_MIN) {
+ error_reporter->Report(
+ "Op builtin_code out of range: %d. Are you using old TFLite binary "
+ "with newer model?",
+ builtin_code);
+ status = kTfLiteError;
+ } else if (builtin_code != BuiltinOperator_CUSTOM) {
+ *registration = op_resolver.FindOp(builtin_code, version);
+ if (*registration == nullptr) {
+ error_reporter->Report(
+ "Didn't find op for builtin opcode '%s' version '%d'\n",
+ EnumNameBuiltinOperator(builtin_code), version);
+ status = kTfLiteError;
+ }
+ } else if (!opcode->custom_code()) {
+ error_reporter->Report(
+ "Operator with CUSTOM builtin_code has no custom_code.\n");
+ status = kTfLiteError;
+ } else {
+ const char* name = opcode->custom_code()->c_str();
+ *registration = op_resolver.FindOp(name, version);
+ if (*registration == nullptr) {
+ error_reporter->Report(
+ "Didn't find custom op for name '%s' with version %d\n", name,
+ version);
+ status = kTfLiteError;
+ }
+ }
+ return status;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/core/api/op_resolver.h b/tensorflow/contrib/lite/core/api/op_resolver.h
new file mode 100644
index 0000000000..5f5e6b2736
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/op_resolver.h
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_CORE_API_OP_RESOLVER_H_
+#define TENSORFLOW_CONTRIB_LITE_CORE_API_OP_RESOLVER_H_
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+
+namespace tflite {
+
+// Abstract interface that returns TfLiteRegistrations given op codes or custom
+// op names. This is the mechanism that ops being referenced in the flatbuffer
+// model are mapped to executable function pointers (TfLiteRegistrations).
+class OpResolver {
+ public:
+ // Finds the op registration for a builtin operator by enum code.
+ virtual const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
+ int version) const = 0;
+ // Finds the op registration of a custom operator by op name.
+ virtual const TfLiteRegistration* FindOp(const char* op,
+ int version) const = 0;
+ virtual ~OpResolver() {}
+};
+
+// Handles the logic for converting between an OperatorCode structure extracted
+// from a flatbuffer and information about a registered operator implementation.
+TfLiteStatus GetRegistrationFromOpCode(const OperatorCode* opcode,
+ const OpResolver& op_resolver,
+ ErrorReporter* error_reporter,
+ const TfLiteRegistration** registration);
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_CORE_API_OP_RESOLVER_H_
diff --git a/tensorflow/contrib/lite/core/api/op_resolver_test.cc b/tensorflow/contrib/lite/core/api/op_resolver_test.cc
new file mode 100644
index 0000000000..167463110e
--- /dev/null
+++ b/tensorflow/contrib/lite/core/api/op_resolver_test.cc
@@ -0,0 +1,197 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+
+#include <cstring>
+
+#include <gtest/gtest.h>
+
+namespace tflite {
+namespace {
+void* MockInit(TfLiteContext* context, const char* buffer, size_t length) {
+ // Do nothing.
+ return nullptr;
+}
+
+void MockFree(TfLiteContext* context, void* buffer) {
+ // Do nothing.
+}
+
+TfLiteStatus MockPrepare(TfLiteContext* context, TfLiteNode* node) {
+ return kTfLiteOk;
+}
+
+TfLiteStatus MockInvoke(TfLiteContext* context, TfLiteNode* node) {
+ return kTfLiteOk;
+}
+
+class MockOpResolver : public OpResolver {
+ public:
+ const TfLiteRegistration* FindOp(BuiltinOperator op,
+ int version) const override {
+ if (op == BuiltinOperator_CONV_2D) {
+ static TfLiteRegistration r = {MockInit, MockFree, MockPrepare,
+ MockInvoke};
+ return &r;
+ } else {
+ return nullptr;
+ }
+ }
+ const TfLiteRegistration* FindOp(const char* op, int version) const override {
+ if (strcmp(op, "mock_custom") == 0) {
+ static TfLiteRegistration r = {MockInit, MockFree, MockPrepare,
+ MockInvoke};
+ return &r;
+ } else {
+ return nullptr;
+ }
+ }
+};
+
+class MockErrorReporter : public ErrorReporter {
+ public:
+ MockErrorReporter() : buffer_size_(0) {}
+ int Report(const char* format, va_list args) override {
+ buffer_size_ = vsnprintf(buffer_, kBufferSize, format, args);
+ return buffer_size_;
+ }
+ char* GetBuffer() { return buffer_; }
+ int GetBufferSize() { return buffer_size_; }
+
+ private:
+ static constexpr int kBufferSize = 256;
+ char buffer_[kBufferSize];
+ int buffer_size_;
+};
+
+} // namespace
+
+TEST(OpResolver, TestResolver) {
+ MockOpResolver mock_resolver;
+ OpResolver* resolver = &mock_resolver;
+
+ const TfLiteRegistration* registration =
+ resolver->FindOp(BuiltinOperator_CONV_2D, 0);
+ EXPECT_NE(nullptr, registration);
+ EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
+ EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
+ EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
+
+ registration = resolver->FindOp(BuiltinOperator_CAST, 0);
+ EXPECT_EQ(nullptr, registration);
+
+ registration = resolver->FindOp("mock_custom", 0);
+ EXPECT_NE(nullptr, registration);
+ EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
+ EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
+ EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
+
+ registration = resolver->FindOp("nonexistent_custom", 0);
+ EXPECT_EQ(nullptr, registration);
+}
+
+TEST(OpResolver, TestGetRegistrationFromOpCodeConv) {
+ MockOpResolver mock_resolver;
+ OpResolver* resolver = &mock_resolver;
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+
+ flatbuffers::FlatBufferBuilder builder;
+ flatbuffers::Offset<OperatorCode> conv_offset =
+ CreateOperatorCodeDirect(builder, BuiltinOperator_CONV_2D, nullptr, 0);
+ builder.Finish(conv_offset);
+ void* conv_pointer = builder.GetBufferPointer();
+ const OperatorCode* conv_code =
+ flatbuffers::GetRoot<OperatorCode>(conv_pointer);
+ const TfLiteRegistration* registration = nullptr;
+ EXPECT_EQ(kTfLiteOk, GetRegistrationFromOpCode(conv_code, *resolver, reporter,
+ &registration));
+ EXPECT_NE(nullptr, registration);
+ EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
+ EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
+ EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
+ EXPECT_EQ(0, mock_reporter.GetBufferSize());
+}
+
+TEST(OpResolver, TestGetRegistrationFromOpCodeCast) {
+ MockOpResolver mock_resolver;
+ OpResolver* resolver = &mock_resolver;
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+
+ flatbuffers::FlatBufferBuilder builder;
+ flatbuffers::Offset<OperatorCode> conv_offset =
+ CreateOperatorCodeDirect(builder, BuiltinOperator_CAST, nullptr, 0);
+ builder.Finish(conv_offset);
+ void* conv_pointer = builder.GetBufferPointer();
+ const OperatorCode* conv_code =
+ flatbuffers::GetRoot<OperatorCode>(conv_pointer);
+ const TfLiteRegistration* registration = nullptr;
+ EXPECT_EQ(kTfLiteError, GetRegistrationFromOpCode(conv_code, *resolver,
+ reporter, &registration));
+ EXPECT_EQ(nullptr, registration);
+ EXPECT_NE(0, mock_reporter.GetBufferSize());
+}
+
+TEST(OpResolver, TestGetRegistrationFromOpCodeCustom) {
+ MockOpResolver mock_resolver;
+ OpResolver* resolver = &mock_resolver;
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+
+ flatbuffers::FlatBufferBuilder builder;
+ flatbuffers::Offset<OperatorCode> conv_offset = CreateOperatorCodeDirect(
+ builder, BuiltinOperator_CUSTOM, "mock_custom", 0);
+ builder.Finish(conv_offset);
+ void* conv_pointer = builder.GetBufferPointer();
+ const OperatorCode* conv_code =
+ flatbuffers::GetRoot<OperatorCode>(conv_pointer);
+ const TfLiteRegistration* registration = nullptr;
+ EXPECT_EQ(kTfLiteOk, GetRegistrationFromOpCode(conv_code, *resolver, reporter,
+ &registration));
+ EXPECT_NE(nullptr, registration);
+ EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
+ EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
+ EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
+ EXPECT_EQ(0, mock_reporter.GetBufferSize());
+}
+
+TEST(OpResolver, TestGetRegistrationFromOpCodeNonexistentCustom) {
+ MockOpResolver mock_resolver;
+ OpResolver* resolver = &mock_resolver;
+ MockErrorReporter mock_reporter;
+ ErrorReporter* reporter = &mock_reporter;
+
+ flatbuffers::FlatBufferBuilder builder;
+ flatbuffers::Offset<OperatorCode> conv_offset = CreateOperatorCodeDirect(
+ builder, BuiltinOperator_CUSTOM, "nonexistent_custom", 0);
+ builder.Finish(conv_offset);
+ void* conv_pointer = builder.GetBufferPointer();
+ const OperatorCode* conv_code =
+ flatbuffers::GetRoot<OperatorCode>(conv_pointer);
+ const TfLiteRegistration* registration = nullptr;
+ EXPECT_EQ(kTfLiteError, GetRegistrationFromOpCode(conv_code, *resolver,
+ reporter, &registration));
+ EXPECT_EQ(nullptr, registration);
+ EXPECT_NE(0, mock_reporter.GetBufferSize());
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/delegates/eager/BUILD b/tensorflow/contrib/lite/delegates/eager/BUILD
index 8abc828578..bf5d91899c 100644
--- a/tensorflow/contrib/lite/delegates/eager/BUILD
+++ b/tensorflow/contrib/lite/delegates/eager/BUILD
@@ -16,6 +16,7 @@ cc_library(
deps = [
":util",
"//tensorflow/c:c_api_internal",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite:kernel_api",
] + select({
"//tensorflow:android": [
@@ -54,6 +55,7 @@ cc_library(
":delegate_data",
":kernel",
":util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite:kernel_api",
"//tensorflow/contrib/lite:util",
] + select({
@@ -104,6 +106,7 @@ tf_cc_test(
":delegate_data",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/testing:util",
"@com_google_googletest//:gtest",
],
@@ -117,6 +120,7 @@ cc_library(
":delegate_data",
":util",
"@flatbuffers",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite:kernel_api",
"//tensorflow/contrib/lite:string",
"//tensorflow/contrib/lite/kernels:kernel_util",
@@ -132,6 +136,8 @@ cc_library(
],
"//conditions:default": [
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:tensorflow",
],
}),
)
@@ -168,6 +174,7 @@ cc_library(
hdrs = ["util.h"],
deps = [
"//tensorflow/c:c_api_internal",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite:kernel_api",
] + select({
"//tensorflow:android": [
diff --git a/tensorflow/contrib/lite/delegates/eager/buffer_map.h b/tensorflow/contrib/lite/delegates/eager/buffer_map.h
index a28329ae7d..aaaa045840 100644
--- a/tensorflow/contrib/lite/delegates/eager/buffer_map.h
+++ b/tensorflow/contrib/lite/delegates/eager/buffer_map.h
@@ -17,7 +17,7 @@ limitations under the License.
#include <map>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/core/framework/tensor.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate.h b/tensorflow/contrib/lite/delegates/eager/delegate.h
index 6d15ba47dc..70f3c15af4 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate.h
+++ b/tensorflow/contrib/lite/delegates/eager/delegate.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_
#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_DELEGATE_H_
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc
index b3a0ffcec1..def063309f 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/delegate_data_test.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/testing/util.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
index eb47f46c0b..984f8bbc98 100644
--- a/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/delegate_test.cc
@@ -72,6 +72,26 @@ TEST_F(DelegateTest, FullGraph) {
ASSERT_THAT(GetShape(8), ElementsAre(2, 1));
ASSERT_THAT(GetValues(8), ElementsAre(14.52f, 38.72f));
+ ASSERT_EQ(GetType(8), kTfLiteFloat32);
+}
+
+TEST_F(DelegateTest, NonFloatTypeInference) {
+ AddTensors(3, {0, 1}, {2}, kTfLiteInt32, {2});
+
+ AddTfOp(testing::kAdd, {0, 1}, {2});
+
+ ConfigureDelegate();
+
+ SetShape(0, {2, 2});
+ SetTypedValues<int>(0, {1, 2, 3, 4});
+ SetShape(1, {2, 2});
+ SetTypedValues<int>(1, {4, 3, 2, 1});
+
+ ASSERT_TRUE(Invoke());
+
+ ASSERT_THAT(GetShape(2), ElementsAre(2, 2));
+ ASSERT_THAT(GetTypedValues<int>(2), ElementsAre(5, 5, 5, 5));
+ ASSERT_EQ(GetType(2), kTfLiteInt32);
}
TEST_F(DelegateTest, MixedGraph) {
diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.cc b/tensorflow/contrib/lite/delegates/eager/kernel.cc
index 1082b78725..274c3c082a 100644
--- a/tensorflow/contrib/lite/delegates/eager/kernel.cc
+++ b/tensorflow/contrib/lite/delegates/eager/kernel.cc
@@ -14,9 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/delegates/eager/kernel.h"
-#include "flatbuffers/flexbuffers.h"
+#include "flatbuffers/flexbuffers.h" // flatbuffers
#include "tensorflow/contrib/lite/builtin_ops.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/context_util.h"
#include "tensorflow/contrib/lite/delegates/eager/delegate_data.h"
#include "tensorflow/contrib/lite/delegates/eager/util.h"
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/execute.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
// Note: this is part of TF Lite's Eager delegation code which is to be
// completed soon.
@@ -189,6 +190,14 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
}
}
+ // Fill NodeDef with defaults if it's a valid op.
+ const tensorflow::OpRegistrationData* op_reg_data;
+ auto tf_status = tensorflow::OpRegistry::Global()->LookUp(
+ node_data.nodedef.op(), &op_reg_data);
+ if (tf_status.ok()) {
+ AddDefaultsToNodeDef(op_reg_data->op_def, &node_data.nodedef);
+ }
+
for (auto input_index : TfLiteIntArrayView(node->inputs)) {
node_data.inputs.push_back(input_index);
}
@@ -269,7 +278,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* tensor = &context->tensors[tensor_index];
TF_LITE_ENSURE_OK(
context,
- CopyShape(context, buffer_map->GetTensor(tensor_index), tensor));
+ CopyShapeAndType(context, buffer_map->GetTensor(tensor_index), tensor));
tensor->buffer_handle = tensor_index;
tensor->data_is_stale = true;
}
diff --git a/tensorflow/contrib/lite/delegates/eager/kernel.h b/tensorflow/contrib/lite/delegates/eager/kernel.h
index 100672c82d..2478abccaa 100644
--- a/tensorflow/contrib/lite/delegates/eager/kernel.h
+++ b/tensorflow/contrib/lite/delegates/eager/kernel.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_
#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_KERNEL_H_
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
namespace eager {
diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.cc b/tensorflow/contrib/lite/delegates/eager/test_util.cc
index 26d96acc82..8584999ace 100644
--- a/tensorflow/contrib/lite/delegates/eager/test_util.cc
+++ b/tensorflow/contrib/lite/delegates/eager/test_util.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/delegates/eager/test_util.h"
#include "absl/memory/memory.h"
-#include "flatbuffers/flexbuffers.h"
+#include "flatbuffers/flexbuffers.h" // flatbuffers
#include "tensorflow/contrib/lite/string.h"
namespace tflite {
@@ -25,19 +25,6 @@ namespace testing {
bool EagerModelTest::Invoke() { return interpreter_->Invoke() == kTfLiteOk; }
-void EagerModelTest::SetValues(int tensor_index,
- const std::vector<float>& values) {
- float* v = interpreter_->typed_tensor<float>(tensor_index);
- for (float f : values) {
- *v++ = f;
- }
-}
-
-std::vector<float> EagerModelTest::GetValues(int tensor_index) {
- TfLiteTensor* o = interpreter_->tensor(tensor_index);
- return std::vector<float>(o->data.f, o->data.f + o->bytes / sizeof(float));
-}
-
void EagerModelTest::SetShape(int tensor_index,
const std::vector<int>& values) {
ASSERT_EQ(interpreter_->ResizeInputTensor(tensor_index, values), kTfLiteOk);
@@ -54,13 +41,21 @@ std::vector<int> EagerModelTest::GetShape(int tensor_index) {
return result;
}
+TfLiteType EagerModelTest::GetType(int tensor_index) {
+ return interpreter_->tensor(tensor_index)->type;
+}
+
void EagerModelTest::AddTensors(int num_tensors, const std::vector<int>& inputs,
const std::vector<int>& outputs,
- const TfLiteType& type,
- const std::vector<int>& dims) {
+ TfLiteType type, const std::vector<int>& dims) {
interpreter_->AddTensors(num_tensors);
for (int i = 0; i < num_tensors; ++i) {
TfLiteQuantizationParams quant;
+ // Suppress explicit output type specification to ensure type inference
+ // works properly.
+ if (std::find(outputs.begin(), outputs.end(), i) != outputs.end()) {
+ type = kTfLiteFloat32;
+ }
CHECK_EQ(interpreter_->SetTensorParametersReadWrite(i, type,
/*name=*/"",
/*dims=*/dims, quant),
@@ -101,18 +96,26 @@ void EagerModelTest::AddTfOp(TfOpType op, const std::vector<int>& inputs,
return " attr{ key: '" + key + "' value {" + value + "}}";
};
+ // Crude type attribution, will need fleshing out as more tests are added.
+ // TODO(b/113613439): Use nodedef string utilities to properly handle
+ // all types.
+ string type_attribute = attr("T", "type: DT_FLOAT");
+ if (interpreter_->tensor(inputs[0])->type == kTfLiteInt32) {
+ type_attribute = attr("T", "type: DT_INT32");
+ }
+
if (op == kUnpack) {
- string attributes = attr("T", "type: DT_FLOAT") + attr("num", "i: 2") +
- attr("axis", "i: 0");
+ string attributes =
+ type_attribute + attr("num", "i: 2") + attr("axis", "i: 0");
AddTfOp("EagerUnpack", "Unpack", attributes, inputs, outputs);
} else if (op == kIdentity) {
- string attributes = attr("T", "type: DT_FLOAT");
+ string attributes = type_attribute;
AddTfOp("EagerIdentity", "Identity", attributes, inputs, outputs);
} else if (op == kAdd) {
- string attributes = attr("T", "type: DT_FLOAT");
+ string attributes = type_attribute;
AddTfOp("EagerAdd", "Add", attributes, inputs, outputs);
} else if (op == kMul) {
- string attributes = attr("T", "type: DT_FLOAT");
+ string attributes = type_attribute;
AddTfOp("EagerMul", "Mul", attributes, inputs, outputs);
} else if (op == kNonExistent) {
AddTfOp("NonExistentOp", "NonExistentOp", "", inputs, outputs);
diff --git a/tensorflow/contrib/lite/delegates/eager/test_util.h b/tensorflow/contrib/lite/delegates/eager/test_util.h
index 0eab9e1135..816db41931 100644
--- a/tensorflow/contrib/lite/delegates/eager/test_util.h
+++ b/tensorflow/contrib/lite/delegates/eager/test_util.h
@@ -44,11 +44,30 @@ class EagerModelTest : public ::testing::Test {
bool Invoke();
+ // Sets the (typed) tensor's values at the given index.
+ template <typename T>
+ void SetTypedValues(int tensor_index, const std::vector<T>& values) {
+ memcpy(interpreter_->typed_tensor<T>(tensor_index), values.data(),
+ values.size() * sizeof(T));
+ }
+
+ // Returns the (typed) tensor's values at the given index.
+ template <typename T>
+ std::vector<T> GetTypedValues(int tensor_index) {
+ const TfLiteTensor* t = interpreter_->tensor(tensor_index);
+ const T* tdata = interpreter_->typed_tensor<T>(tensor_index);
+ return std::vector<T>(tdata, tdata + t->bytes / sizeof(T));
+ }
+
// Sets the tensor's values at the given index.
- void SetValues(int tensor_index, const std::vector<float>& values);
+ void SetValues(int tensor_index, const std::vector<float>& values) {
+ SetTypedValues<float>(tensor_index, values);
+ }
// Returns the tensor's values at the given index.
- std::vector<float> GetValues(int tensor_index);
+ std::vector<float> GetValues(int tensor_index) {
+ return GetTypedValues<float>(tensor_index);
+ }
// Sets the tensor's shape at the given index.
void SetShape(int tensor_index, const std::vector<int>& values);
@@ -56,13 +75,16 @@ class EagerModelTest : public ::testing::Test {
// Returns the tensor's shape at the given index.
std::vector<int> GetShape(int tensor_index);
+ // Returns the tensor's type at the given index.
+ TfLiteType GetType(int tensor_index);
+
const TestErrorReporter& error_reporter() const { return error_reporter_; }
// Adds `num_tensor` tensors to the model. `inputs` contains the indices of
// the input tensors and `outputs` contains the indices of the output
// tensors. All tensors are set to have `type` and `dims`.
void AddTensors(int num_tensors, const std::vector<int>& inputs,
- const std::vector<int>& outputs, const TfLiteType& type,
+ const std::vector<int>& outputs, TfLiteType type,
const std::vector<int>& dims);
// Adds a TFLite Mul op. `inputs` contains the indices of the input tensors
diff --git a/tensorflow/contrib/lite/delegates/eager/util.cc b/tensorflow/contrib/lite/delegates/eager/util.cc
index 4426c653e6..051246bf86 100644
--- a/tensorflow/contrib/lite/delegates/eager/util.cc
+++ b/tensorflow/contrib/lite/delegates/eager/util.cc
@@ -26,8 +26,17 @@ TfLiteStatus ConvertStatus(TfLiteContext* context,
return kTfLiteOk;
}
-TfLiteStatus CopyShape(TfLiteContext* context, const tensorflow::Tensor& src,
- TfLiteTensor* tensor) {
+TfLiteStatus CopyShapeAndType(TfLiteContext* context,
+ const tensorflow::Tensor& src,
+ TfLiteTensor* tensor) {
+ tensor->type = GetTensorFlowLiteType(static_cast<TF_DataType>(src.dtype()));
+ if (tensor->type == kTfLiteNoType) {
+ context->ReportError(context,
+ "TF Lite does not support TensorFlow data type: %s",
+ DataTypeString(src.dtype()).c_str());
+ return kTfLiteError;
+ }
+
int num_dims = src.dims();
TfLiteIntArray* shape = TfLiteIntArrayCreate(num_dims);
for (int j = 0; j < num_dims; ++j) {
@@ -68,5 +77,28 @@ TF_DataType GetTensorFlowDataType(TfLiteType type) {
}
}
+TfLiteType GetTensorFlowLiteType(TF_DataType type) {
+ switch (type) {
+ case TF_FLOAT:
+ return kTfLiteFloat32;
+ case TF_INT16:
+ return kTfLiteInt16;
+ case TF_INT32:
+ return kTfLiteInt32;
+ case TF_UINT8:
+ return kTfLiteUInt8;
+ case TF_INT64:
+ return kTfLiteInt64;
+ case TF_COMPLEX64:
+ return kTfLiteComplex64;
+ case TF_STRING:
+ return kTfLiteString;
+ case TF_BOOL:
+ return kTfLiteBool;
+ default:
+ return kTfLiteNoType;
+ }
+}
+
} // namespace eager
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/util.h b/tensorflow/contrib/lite/delegates/eager/util.h
index a9407be071..930cb99cb9 100644
--- a/tensorflow/contrib/lite/delegates/eager/util.h
+++ b/tensorflow/contrib/lite/delegates/eager/util.h
@@ -16,7 +16,7 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_DELEGATES_EAGER_UTIL_H_
#include "tensorflow/c/c_api_internal.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
@@ -28,14 +28,19 @@ namespace eager {
TfLiteStatus ConvertStatus(TfLiteContext* context,
const tensorflow::Status& status);
-// Copies the given shape of the given 'src' into a TF Lite 'tensor'. Logs an
-// error and returns kTfLiteError if the shape can't be converted.
-TfLiteStatus CopyShape(TfLiteContext* context, const tensorflow::Tensor& src,
- TfLiteTensor* tensor);
+// Copies the given shape and type of the TensorFlow 'src' tensor into a TF Lite
+// 'tensor'. Logs an error and returns kTfLiteError if the shape or type can't
+// be converted.
+TfLiteStatus CopyShapeAndType(TfLiteContext* context,
+ const tensorflow::Tensor& src,
+ TfLiteTensor* tensor);
// Returns the TF C API Data type that corresponds to the given TfLiteType.
TF_DataType GetTensorFlowDataType(TfLiteType type);
+// Returns the TfLiteType that corresponds to the given TF C API Data type.
+TfLiteType GetTensorFlowLiteType(TF_DataType);
+
} // namespace eager
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/eager/util_test.cc b/tensorflow/contrib/lite/delegates/eager/util_test.cc
index 53378a1eaf..aebc91149c 100644
--- a/tensorflow/contrib/lite/delegates/eager/util_test.cc
+++ b/tensorflow/contrib/lite/delegates/eager/util_test.cc
@@ -26,6 +26,7 @@ namespace eager {
namespace {
using tensorflow::DT_FLOAT;
+using tensorflow::DT_INT32;
using tensorflow::Tensor;
using ::testing::ElementsAre;
@@ -71,27 +72,41 @@ TEST(UtilTest, ConvertStatus) {
EXPECT_TRUE(context.error.empty());
}
-TEST(UtilTest, CopyShape) {
+TEST(UtilTest, CopyShapeAndType) {
TestContext context;
context.ReportError = ReportError;
context.ResizeTensor = ResizeTensor;
TfLiteTensor dst;
- EXPECT_EQ(CopyShape(&context, Tensor(), &dst), kTfLiteOk);
+ EXPECT_EQ(CopyShapeAndType(&context, Tensor(), &dst), kTfLiteOk);
EXPECT_THAT(context.new_size, ElementsAre(0));
+ EXPECT_EQ(dst.type, kTfLiteFloat32);
- EXPECT_EQ(CopyShape(&context, Tensor(DT_FLOAT, {1, 2}), &dst), kTfLiteOk);
+ EXPECT_EQ(CopyShapeAndType(&context, Tensor(DT_FLOAT, {1, 2}), &dst),
+ kTfLiteOk);
EXPECT_THAT(context.new_size, ElementsAre(1, 2));
+ EXPECT_EQ(dst.type, kTfLiteFloat32);
- EXPECT_EQ(CopyShape(&context, Tensor(DT_FLOAT, {1LL << 44, 2}), &dst),
+ EXPECT_EQ(CopyShapeAndType(&context, Tensor(DT_INT32, {1, 2}), &dst),
+ kTfLiteOk);
+ EXPECT_THAT(context.new_size, ElementsAre(1, 2));
+ EXPECT_EQ(dst.type, kTfLiteInt32);
+
+ EXPECT_EQ(CopyShapeAndType(&context, Tensor(DT_FLOAT, {1LL << 44, 2}), &dst),
kTfLiteError);
EXPECT_EQ(context.error,
"Dimension value in TensorFlow shape is larger than supported by "
"TF Lite");
+
+ EXPECT_EQ(
+ CopyShapeAndType(&context, Tensor(tensorflow::DT_HALF, {1, 2}), &dst),
+ kTfLiteError);
+ EXPECT_EQ(context.error,
+ "TF Lite does not support TensorFlow data type: half");
}
-TEST(UtilTest, TypeConversions) {
+TEST(UtilTest, TypeConversionsFromTFLite) {
EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteNoType));
EXPECT_EQ(TF_FLOAT, GetTensorFlowDataType(kTfLiteFloat32));
EXPECT_EQ(TF_INT16, GetTensorFlowDataType(kTfLiteInt16));
@@ -103,6 +118,19 @@ TEST(UtilTest, TypeConversions) {
EXPECT_EQ(TF_BOOL, GetTensorFlowDataType(kTfLiteBool));
}
+TEST(UtilTest, TypeConversionsFromTensorFlow) {
+ EXPECT_EQ(kTfLiteFloat32, GetTensorFlowLiteType(TF_FLOAT));
+ EXPECT_EQ(kTfLiteInt16, GetTensorFlowLiteType(TF_INT16));
+ EXPECT_EQ(kTfLiteInt32, GetTensorFlowLiteType(TF_INT32));
+ EXPECT_EQ(kTfLiteUInt8, GetTensorFlowLiteType(TF_UINT8));
+ EXPECT_EQ(kTfLiteInt64, GetTensorFlowLiteType(TF_INT64));
+ EXPECT_EQ(kTfLiteComplex64, GetTensorFlowLiteType(TF_COMPLEX64));
+ EXPECT_EQ(kTfLiteString, GetTensorFlowLiteType(TF_STRING));
+ EXPECT_EQ(kTfLiteBool, GetTensorFlowLiteType(TF_BOOL));
+ EXPECT_EQ(kTfLiteNoType, GetTensorFlowLiteType(TF_RESOURCE));
+ EXPECT_EQ(kTfLiteNoType, GetTensorFlowLiteType(TF_VARIANT));
+}
+
} // namespace
} // namespace eager
} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/nnapi/BUILD b/tensorflow/contrib/lite/delegates/nnapi/BUILD
index 954955f24b..4e7b2948fb 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/BUILD
+++ b/tensorflow/contrib/lite/delegates/nnapi/BUILD
@@ -13,6 +13,7 @@ cc_library(
deps = [
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:kernel_api",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:kernel_util",
"//tensorflow/contrib/lite/nnapi:nnapi_lib",
],
@@ -29,6 +30,7 @@ tf_cc_test(
deps = [
":nnapi_delegate",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
index e6cc3dd99c..e3eebac4da 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/allocation.h"
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/builtin_ops.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/context_util.h"
#include "tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -238,7 +238,7 @@ class NNAPIOpBuilder {
tensor->params.zero_point};
CHECK_NN(context_,
ANeuralNetworksModel_addOperand(nn_model_, &operand_type));
- augmented_inputs_.push_back(ann_index);
+ augmented_outputs_.push_back(ann_index);
*ann_tensor_index_out = ann_index;
return kTfLiteOk;
@@ -370,8 +370,8 @@ struct NNAPIOpMappingArgs {
TfLiteContext* context;
NNAPIOpBuilder* builder;
TfLiteNode* node;
- std::vector<int>* model_state_inputs;
- std::vector<int>* model_state_tfl_outputs;
+ std::vector<int>* model_state_outputs;
+ std::vector<int>* model_state_tfl_inputs;
};
// The kernel that represents the subgraph of TF Lite being run on NN API.
@@ -781,8 +781,7 @@ class NNAPIDelegateKernel {
break;
case kTfLiteBuiltinRnn:
// NNAPI only support float32 weights.
- // TODO(miaowang): check the number of inputs before accessing it.
- if (version == 1 &&
+ if (version == 1 && node->inputs->size == 5 &&
context->tensors[node->inputs->data[/*kWeightsTensor*/ 1]].type ==
kTfLiteFloat32) {
return [](const NNAPIOpMappingArgs& mapping_args)
@@ -790,11 +789,11 @@ class NNAPIDelegateKernel {
// NNAPI need both state_in and state_out.
int ann_index;
mapping_args.builder->AddStateFloat32Tensor(
- mapping_args.node->outputs->data[/*kHiddenStateTensor*/ 0],
+ mapping_args.node->inputs->data[/*kHiddenStateTensor*/ 4],
&ann_index);
- mapping_args.model_state_inputs->push_back(ann_index);
- mapping_args.model_state_tfl_outputs->push_back(
- mapping_args.node->outputs->data[/*kHiddenStateTensor*/ 0]);
+ mapping_args.model_state_outputs->push_back(ann_index);
+ mapping_args.model_state_tfl_inputs->push_back(
+ mapping_args.node->inputs->data[/*kHiddenStateTensor*/ 4]);
auto builtin = reinterpret_cast<TfLiteRNNParams*>(
mapping_args.node->builtin_data);
mapping_args.builder->AddScalarInt32Operand(builtin->activation);
@@ -806,7 +805,7 @@ class NNAPIDelegateKernel {
break;
case kTfLiteBuiltinSvdf:
// NNAPI only support float32 weights.
- if (version == 1 &&
+ if (version == 1 && node->inputs->size == 5 &&
context->tensors[node->inputs->data[/*kWeightsFeatureTensor*/ 1]]
.type == kTfLiteFloat32) {
return [](const NNAPIOpMappingArgs& mapping_args)
@@ -814,11 +813,13 @@ class NNAPIDelegateKernel {
// NNAPI need both state_in and state_out.
int ann_index;
mapping_args.builder->AddStateFloat32Tensor(
- mapping_args.node->outputs->data[/*kStateTensor*/ 0],
+ mapping_args.node->inputs
+ ->data[/*kInputActivationStateTensor*/ 4],
&ann_index);
- mapping_args.model_state_inputs->push_back(ann_index);
- mapping_args.model_state_tfl_outputs->push_back(
- mapping_args.node->outputs->data[/*kStateTensor*/ 0]);
+ mapping_args.model_state_outputs->push_back(ann_index);
+ mapping_args.model_state_tfl_inputs->push_back(
+ mapping_args.node->inputs
+ ->data[/*kInputActivationStateTensor*/ 4]);
auto builtin = reinterpret_cast<TfLiteSVDFParams*>(
mapping_args.node->builtin_data);
@@ -833,28 +834,12 @@ class NNAPIDelegateKernel {
case kTfLiteBuiltinLstm:
// NNAPI only support float32 weights.
// TODO(miaowang): add loggings to indicate why the op is rejected.
- if (version == 1 && node->inputs->size == 18 &&
+ if (version == 1 && node->inputs->size == 20 &&
context->tensors[node->inputs
->data[/*kInputToOutputWeightsTensor*/ 4]]
.type == kTfLiteFloat32) {
return [](const NNAPIOpMappingArgs& mapping_args)
-> ANeuralNetworksOperationType {
- // NNAPI need both state_in and state_out for cell_state and
- // output_state.
- int ann_index;
- mapping_args.builder->AddStateFloat32Tensor(
- mapping_args.node->outputs->data[/*kOutputStateTensor*/ 0],
- &ann_index);
- mapping_args.model_state_inputs->push_back(ann_index);
- mapping_args.model_state_tfl_outputs->push_back(
- mapping_args.node->outputs->data[/*kOutputStateTensor*/ 0]);
- mapping_args.builder->AddStateFloat32Tensor(
- mapping_args.node->outputs->data[/*kCellStateTensor*/ 1],
- &ann_index);
- mapping_args.model_state_inputs->push_back(ann_index);
- mapping_args.model_state_tfl_outputs->push_back(
- mapping_args.node->outputs->data[/*kCellStateTensor*/ 1]);
-
auto builtin = reinterpret_cast<TfLiteLSTMParams*>(
mapping_args.node->builtin_data);
mapping_args.builder->AddScalarInt32Operand(builtin->activation);
@@ -864,6 +849,25 @@ class NNAPIDelegateKernel {
// Current NNAPI implementation requires the sratch_buffer as
// output.
mapping_args.builder->AddAdditionalFloat32OutputTensor(2);
+
+ // NNAPI need both state_in and state_out for cell_state and
+ // output_state.
+ int ann_index;
+ mapping_args.builder->AddStateFloat32Tensor(
+ mapping_args.node->inputs
+ ->data[/*kInputActivationStateTensor*/ 18],
+ &ann_index);
+ mapping_args.model_state_outputs->push_back(ann_index);
+ mapping_args.model_state_tfl_inputs->push_back(
+ mapping_args.node->inputs
+ ->data[/*kInputActivationStateTensor*/ 18]);
+ mapping_args.builder->AddStateFloat32Tensor(
+ mapping_args.node->inputs->data[/*kInputCellStateTensor*/ 19],
+ &ann_index);
+ mapping_args.model_state_outputs->push_back(ann_index);
+ mapping_args.model_state_tfl_inputs->push_back(
+ mapping_args.node->inputs->data[/*kInputCellStateTensor*/ 19]);
+
return ANEURALNETWORKS_LSTM;
};
} else {
@@ -950,12 +954,10 @@ class NNAPIDelegateKernel {
// Set the input tensor buffers. Note: we access tflite tensors using
// absolute indices but NN api indices inputs by relative indices.
int relative_input_index = 0;
- int num_optional_tensors = 0;
size_t input_offset = 0;
for (auto absolute_input_index : TfLiteIntArrayView(node->inputs)) {
if (absolute_input_index == kOptionalTensor) {
- num_optional_tensors++;
continue;
}
TfLiteTensor* tensor = &context->tensors[absolute_input_index];
@@ -989,16 +991,16 @@ class NNAPIDelegateKernel {
// The state_out of previous invocation need to be mapped to state_in of
// current invocation.
- for (size_t i = 0; i < model_state_tfl_outputs_.size(); i++) {
- int state_tensor_idx = model_state_tfl_outputs_[i];
+ for (size_t i = 0; i < model_state_tfl_inputs_.size(); i++) {
+ int state_tensor_idx = model_state_tfl_inputs_[i];
TfLiteTensor* tensor = &context->tensors[state_tensor_idx];
// Here we are using a deep copy for state_in tensors so that we are not
// reading and writing into the same buffer during a invocation.
// TODO(110369471): using double shared buffer to minimize the copies.
- CHECK_NN(context,
- ANeuralNetworksExecution_setInput(
- execution, i + node->inputs->size - num_optional_tensors,
- nullptr, tensor->data.raw, tensor->bytes));
+ CHECK_NN(context, ANeuralNetworksExecution_setOutput(
+ execution, relative_output_index, nullptr,
+ tensor->data.raw, tensor->bytes));
+ relative_output_index++;
}
// Invoke ANN in blocking fashion.
ANeuralNetworksEvent* event = nullptr;
@@ -1030,8 +1032,8 @@ class NNAPIDelegateKernel {
// Track indices we use
OperandMapping operand_mapping_;
- std::vector<int> model_state_inputs_;
- std::vector<int> model_state_tfl_outputs_;
+ std::vector<int> model_state_outputs_;
+ std::vector<int> model_state_tfl_inputs_;
std::unique_ptr<NNMemory> nn_input_memory_;
std::unique_ptr<NNMemory> nn_output_memory_;
@@ -1063,9 +1065,9 @@ class NNAPIDelegateKernel {
}
}
// Get op type and operands
- int nn_op_type = Map(context, reg->builtin_code, reg->version,
- node)({context, &builder, node, &model_state_inputs_,
- &model_state_tfl_outputs_});
+ int nn_op_type = Map(context, reg->builtin_code, reg->version, node)(
+ {context, &builder, node, &model_state_outputs_,
+ &model_state_tfl_inputs_});
// Map outputs to NN API tensor indices.
for (auto output_index : TfLiteIntArrayView(node->outputs)) {
TF_LITE_ENSURE_STATUS(builder.AddTensorOutput(output_index));
@@ -1098,17 +1100,17 @@ class NNAPIDelegateKernel {
}
}
- // Add state input tensors as model inputs
- for (int i : model_state_inputs_) {
- inputs.push_back(i);
- }
-
size_t total_output_byte_size = 0;
for (int i : TfLiteIntArrayView(output_tensors)) {
outputs.push_back(operand_mapping_.lite_index_to_ann(i));
total_output_byte_size += context->tensors[i].bytes;
}
+ // Add state output tensors as model inputs
+ for (int i : model_state_outputs_) {
+ outputs.push_back(i);
+ }
+
// Tell ANN to declare inputs/outputs
CHECK_NN(context, ANeuralNetworksModel_identifyInputsAndOutputs(
nn_model_.get(), inputs.size(), inputs.data(),
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h
index 44cca2fd28..4852b76974 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_
#define TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
index 3224b23a0c..4b01aefd6a 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
@@ -1773,15 +1773,16 @@ class RNNOpModel : public SingleOpModelWithNNAPI {
weights_ = AddInput(weights);
recurrent_weights_ = AddInput(recurrent_weights);
bias_ = AddInput(TensorType_FLOAT32);
- hidden_state_ = AddOutput(TensorType_FLOAT32);
+ hidden_state_ = AddInput(TensorType_FLOAT32, true);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(
BuiltinOperator_RNN, BuiltinOptions_RNNOptions,
CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union());
- BuildInterpreter({{batches_, input_size_},
- {units_, input_size_},
- {units_, units_},
- {units_}});
+ BuildInterpreter({{batches_, input_size_}, // input tensor
+ {units_, input_size_}, // weights tensor
+ {units_, units_}, // recurrent weights tensor
+ {units_}, // bias tensor
+ {batches_, units_}}); // hidden state tensor
}
void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
@@ -1802,14 +1803,6 @@ class RNNOpModel : public SingleOpModelWithNNAPI {
PopulateTensor(input_, offset, begin, end);
}
- void ResetHiddenState() {
- const int zero_buffer_size = units_ * batches_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(hidden_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
int input_size() { return input_size_; }
@@ -1835,7 +1828,6 @@ TEST(NNAPIDelegate, RnnBlackBoxTest) {
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
(rnn.input_size() * rnn.num_batches());
@@ -1968,16 +1960,20 @@ class BaseSVDFOpModel : public SingleOpModelWithNNAPI {
weights_feature_ = AddInput(weights_feature_type);
weights_time_ = AddInput(weights_time_type);
bias_ = AddNullInput();
- state_ = AddOutput(TensorType_FLOAT32);
+ const int num_filters = units * rank;
+ activation_state_ = AddInput(
+ TensorData{TensorType_FLOAT32, {batches, memory_size * num_filters}},
+ /*is_variable=*/true);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(
BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions,
CreateSVDFOptions(builder_, rank, ActivationFunctionType_NONE).Union());
BuildInterpreter({
- {batches_, input_size_}, // Input tensor
- {units_ * rank, input_size_}, // weights_feature tensor
- {units_ * rank, memory_size_}, // weights_time tensor
- {units_} // bias tensor
+ {batches_, input_size_}, // input tensor
+ {units_ * rank, input_size_}, // weights_feature tensor
+ {units_ * rank, memory_size_}, // weights_time tensor
+ {units_}, // bias tensor
+ {batches, memory_size * num_filters} // activation_state tensor
});
}
@@ -1996,15 +1992,6 @@ class BaseSVDFOpModel : public SingleOpModelWithNNAPI {
PopulateTensor(input_, offset, begin, end);
}
- // Resets the state of SVDF op by filling it with 0's.
- void ResetState() {
- const int zero_buffer_size = rank_ * units_ * batches_ * memory_size_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
// Extracts the output tensor from the SVDF op.
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
@@ -2017,7 +2004,7 @@ class BaseSVDFOpModel : public SingleOpModelWithNNAPI {
int weights_feature_;
int weights_time_;
int bias_;
- int state_;
+ int activation_state_;
int output_;
int batches_;
@@ -2081,7 +2068,6 @@ TEST(NNAPIDelegate, SVDFBlackBoxTestRank1) {
-0.10781813, 0.27201805, 0.14324132, -0.23681851, -0.27115166,
-0.01580888, -0.14943552, 0.15465137, 0.09784451, -0.0337657});
- svdf.ResetState();
svdf.VerifyGoldens(svdf_input, svdf_golden_output_rank_1, sizeof(svdf_input));
}
@@ -2120,7 +2106,6 @@ TEST(NNAPIDelegate, SVDFBlackBoxTestRank2) {
0.27179423, -0.04710215, 0.31069002, 0.22672787, 0.09580326,
0.08682203, 0.1258215, 0.1851041, 0.29228821, 0.12366763});
- svdf.ResetState();
svdf.VerifyGoldens(svdf_input, svdf_golden_output_rank_2, sizeof(svdf_input));
}
@@ -2192,8 +2177,12 @@ class LSTMOpModel : public SingleOpModelWithNNAPI {
projection_bias_ = AddNullInput();
}
- output_state_ = AddOutput(TensorType_FLOAT32);
- cell_state_ = AddOutput(TensorType_FLOAT32);
+ // Adding the 2 input state tensors.
+ input_activation_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_output_}}, true);
+ input_cell_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_batch_, n_cell_}}, true);
+
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
@@ -2271,22 +2260,6 @@ class LSTMOpModel : public SingleOpModelWithNNAPI {
PopulateTensor(projection_bias_, f);
}
- void ResetOutputState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(output_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
- void ResetCellState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(cell_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
void SetInput(int offset, const float* begin, const float* end) {
PopulateTensor(input_, offset, const_cast<float*>(begin),
const_cast<float*>(end));
@@ -2495,10 +2468,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
@@ -2602,10 +2571,6 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
lstm.SetCellToOutputWeights(cell_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
@@ -3266,10 +3231,6 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
lstm.SetProjectionWeights(projection_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
diff --git a/tensorflow/contrib/lite/error_reporter.h b/tensorflow/contrib/lite/error_reporter.h
index 3c5f805f12..5c20eedc25 100644
--- a/tensorflow/contrib/lite/error_reporter.h
+++ b/tensorflow/contrib/lite/error_reporter.h
@@ -12,43 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// Compatibility shim for moved header location.
#ifndef TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_
#define TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_
-#include <cstdarg>
-#include "tensorflow/contrib/lite/context.h"
-
-namespace tflite {
-
-// A functor that reports error to supporting system. Invoked similar to
-// printf.
-//
-// Usage:
-// ErrorReporter foo;
-// foo.Report("test %d", 5);
-// or
-// va_list args;
-// foo.Report("test %d", args); // where args is va_list
-//
-// Subclass ErrorReporter to provide another reporting destination.
-// For example, if you have a GUI program, you might redirect to a buffer
-// that drives a GUI error log box.
-class ErrorReporter {
- public:
- virtual ~ErrorReporter();
- virtual int Report(const char* format, va_list args) = 0;
- int Report(const char* format, ...);
- int ReportError(void*, const char* format, ...);
-};
-
-// An error reporter that simplify writes the message to stderr.
-struct StderrReporter : public ErrorReporter {
- int Report(const char* format, va_list args) override;
-};
-
-// Return the default error reporter (output to stderr).
-ErrorReporter* DefaultErrorReporter();
-
-} // namespace tflite
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/stderr_reporter.h"
#endif // TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_
diff --git a/tensorflow/contrib/lite/examples/android/app/build.gradle b/tensorflow/contrib/lite/examples/android/app/build.gradle
index eb7fd705e1..35e7887852 100644
--- a/tensorflow/contrib/lite/examples/android/app/build.gradle
+++ b/tensorflow/contrib/lite/examples/android/app/build.gradle
@@ -9,7 +9,6 @@ android {
targetSdkVersion 26
versionCode 1
versionName "1.0"
- testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
// Remove this block.
jackOptions {
@@ -51,10 +50,5 @@ apply from: "download-models.gradle"
dependencies {
compile fileTree(dir: 'libs', include: ['*.jar'])
- androidTestCompile('androidx.test.espresso:espresso-core:3.1.0-alpha3', {
- exclude group: 'com.android.support', module: 'support-annotations'
- })
compile 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
-
- testCompile 'junit:junit:4.12'
}
diff --git a/tensorflow/contrib/lite/examples/ios/camera/Podfile b/tensorflow/contrib/lite/examples/ios/camera/Podfile
index 8084307ac7..f460693122 100644
--- a/tensorflow/contrib/lite/examples/ios/camera/Podfile
+++ b/tensorflow/contrib/lite/examples/ios/camera/Podfile
@@ -2,4 +2,4 @@ platform :ios, '8.0'
inhibit_all_warnings!
target 'tflite_camera_example'
- pod 'TensorFlowLite', '1.10.0'
+ pod 'TensorFlowLite', '1.10.1'
diff --git a/tensorflow/contrib/lite/examples/ios/simple/Podfile b/tensorflow/contrib/lite/examples/ios/simple/Podfile
index eea7ecb759..ddb77088d9 100644
--- a/tensorflow/contrib/lite/examples/ios/simple/Podfile
+++ b/tensorflow/contrib/lite/examples/ios/simple/Podfile
@@ -2,4 +2,4 @@ platform :ios, '8.0'
inhibit_all_warnings!
target 'tflite_simple_example'
- pod 'TensorFlowLite', '1.10.0'
+ pod 'TensorFlowLite', '1.10.1'
diff --git a/tensorflow/contrib/lite/experimental/c/BUILD b/tensorflow/contrib/lite/experimental/c/BUILD
index 8fc07e8eb7..ea4a543252 100644
--- a/tensorflow/contrib/lite/experimental/c/BUILD
+++ b/tensorflow/contrib/lite/experimental/c/BUILD
@@ -78,6 +78,7 @@ cc_test(
data = ["//tensorflow/contrib/lite:testdata/add.bin"],
deps = [
":c_api",
+ "//tensorflow/contrib/lite:context",
"//tensorflow/contrib/lite:kernel_api",
"//tensorflow/contrib/lite/testing:util",
"@com_google_googletest//:gtest",
diff --git a/tensorflow/contrib/lite/experimental/c/c_api.cc b/tensorflow/contrib/lite/experimental/c/c_api.cc
index a4ab0e8c30..c589cf71ea 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api.cc
@@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/experimental/c/c_api.h"
+#include <memory>
+
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/experimental/c/c_api_internal.h"
#include "tensorflow/contrib/lite/interpreter.h"
@@ -29,12 +31,14 @@ extern "C" {
TFL_Model* TFL_NewModel(const void* model_data, size_t model_size) {
auto model = tflite::FlatBufferModel::BuildFromBuffer(
static_cast<const char*>(model_data), model_size);
- return model ? new TFL_Model{std::move(model)} : nullptr;
+ std::shared_ptr<const tflite::FlatBufferModel> shared_model(model.release());
+ return shared_model ? new TFL_Model{std::move(shared_model)} : nullptr;
}
TFL_Model* TFL_NewModelFromFile(const char* model_path) {
auto model = tflite::FlatBufferModel::BuildFromFile(model_path);
- return model ? new TFL_Model{std::move(model)} : nullptr;
+ std::shared_ptr<const tflite::FlatBufferModel> shared_model(model.release());
+ return shared_model ? new TFL_Model{std::move(shared_model)} : nullptr;
}
void TFL_DeleteModel(TFL_Model* model) { delete model; }
@@ -72,7 +76,7 @@ TFL_Interpreter* TFL_NewInterpreter(
}
}
- return new TFL_Interpreter{std::move(interpreter)};
+ return new TFL_Interpreter{model->impl, std::move(interpreter)};
}
void TFL_DeleteInterpreter(TFL_Interpreter* interpreter) { delete interpreter; }
@@ -129,6 +133,8 @@ void* TFL_TensorData(const TFL_Tensor* tensor) {
return static_cast<void*>(tensor->data.raw);
}
+const char* TFL_TensorName(const TFL_Tensor* tensor) { return tensor->name; }
+
TFL_Status TFL_TensorCopyFromBuffer(TFL_Tensor* tensor, const void* input_data,
size_t input_data_size) {
if (tensor->bytes != input_data_size) {
diff --git a/tensorflow/contrib/lite/experimental/c/c_api.h b/tensorflow/contrib/lite/experimental/c/c_api.h
index 3757349b55..b429e76870 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api.h
+++ b/tensorflow/contrib/lite/experimental/c/c_api.h
@@ -93,7 +93,8 @@ typedef struct TFL_Interpreter TFL_Interpreter;
// failure.
//
// * `model` must be a valid model instance. The caller retains ownership of the
-// object, and can destroy it immediately after creating the interpreter.
+// object, and can destroy it immediately after creating the interpreter; the
+// interpreter will maintain its own reference to the underlying model data.
// * `optional_options` may be null. The caller retains ownership of the object,
// and can safely destroy it immediately after creating the interpreter.
//
@@ -145,6 +146,11 @@ TFL_CAPI_EXPORT extern int32_t TFL_InterpreterGetOutputTensorCount(
// Returns the tensor associated with the output index.
// REQUIRES: 0 <= input_index < TFL_InterpreterGetOutputTensorCount(tensor)
+//
+// NOTE: The shape and underlying data buffer for output tensors may be not
+// be available until after the output tensor has been both sized and allocated.
+// In general, best practice is to interact with the output tensor *after*
+// calling TFL_InterpreterInvoke().
TFL_CAPI_EXPORT extern const TFL_Tensor* TFL_InterpreterGetOutputTensor(
const TFL_Interpreter* interpreter, int32_t output_index);
@@ -172,12 +178,15 @@ TFL_CAPI_EXPORT extern size_t TFL_TensorByteSize(const TFL_Tensor* tensor);
// Returns a pointer to the underlying data buffer.
//
-// Note: The result may be null if tensors have not yet been allocated, e.g.,
+// NOTE: The result may be null if tensors have not yet been allocated, e.g.,
// if the Tensor has just been created or resized and `TFL_AllocateTensors()`
// has yet to be called, or if the output tensor is dynamically sized and the
// interpreter hasn't been invoked.
TFL_CAPI_EXPORT extern void* TFL_TensorData(const TFL_Tensor* tensor);
+// Returns the (null-terminated) name of the tensor.
+TFL_CAPI_EXPORT extern const char* TFL_TensorName(const TFL_Tensor* tensor);
+
// Copies from the provided input buffer into the tensor's buffer.
// REQUIRES: input_data_size == TFL_TensorByteSize(tensor)
TFL_CAPI_EXPORT extern TFL_Status TFL_TensorCopyFromBuffer(
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_internal.h b/tensorflow/contrib/lite/experimental/c/c_api_internal.h
index c5c612a4c6..60c2e4e2cd 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_internal.h
+++ b/tensorflow/contrib/lite/experimental/c/c_api_internal.h
@@ -24,7 +24,8 @@ limitations under the License.
// not be depended on.
struct TFL_Model {
- std::unique_ptr<tflite::FlatBufferModel> impl;
+ // Sharing is safe as FlatBufferModel is const.
+ std::shared_ptr<const tflite::FlatBufferModel> impl;
};
struct TFL_InterpreterOptions {
@@ -35,6 +36,9 @@ struct TFL_InterpreterOptions {
};
struct TFL_Interpreter {
+ // Taking a reference to the (const) model data avoids lifetime-related issues
+ // and complexity with the TFL_Model's existence.
+ std::shared_ptr<const tflite::FlatBufferModel> model;
std::unique_ptr<tflite::Interpreter> impl;
};
diff --git a/tensorflow/contrib/lite/experimental/c/c_api_test.cc b/tensorflow/contrib/lite/experimental/c/c_api_test.cc
index a631dae890..649dac8d1a 100644
--- a/tensorflow/contrib/lite/experimental/c/c_api_test.cc
+++ b/tensorflow/contrib/lite/experimental/c/c_api_test.cc
@@ -55,6 +55,8 @@ TEST(CApiSimple, Smoke) {
EXPECT_EQ(TFL_TensorNumDims(input_tensor), 1);
EXPECT_EQ(TFL_TensorDim(input_tensor, 0), 2);
EXPECT_EQ(TFL_TensorByteSize(input_tensor), sizeof(float) * 2);
+ EXPECT_NE(TFL_TensorData(input_tensor), nullptr);
+ EXPECT_STREQ(TFL_TensorName(input_tensor), "input");
std::array<float, 2> input = {1.f, 3.f};
ASSERT_EQ(TFL_TensorCopyFromBuffer(input_tensor, input.data(),
@@ -70,6 +72,8 @@ TEST(CApiSimple, Smoke) {
EXPECT_EQ(TFL_TensorNumDims(output_tensor), 1);
EXPECT_EQ(TFL_TensorDim(output_tensor, 0), 2);
EXPECT_EQ(TFL_TensorByteSize(output_tensor), sizeof(float) * 2);
+ EXPECT_NE(TFL_TensorData(output_tensor), nullptr);
+ EXPECT_STREQ(TFL_TensorName(output_tensor), "output");
std::array<float, 2> output;
ASSERT_EQ(TFL_TensorCopyToBuffer(output_tensor, output.data(),
diff --git a/tensorflow/contrib/lite/experimental/kernels/BUILD b/tensorflow/contrib/lite/experimental/kernels/BUILD
index 9c06c4ebd9..4786cc62f9 100644
--- a/tensorflow/contrib/lite/experimental/kernels/BUILD
+++ b/tensorflow/contrib/lite/experimental/kernels/BUILD
@@ -53,6 +53,7 @@ cc_library(
"//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string_util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:builtin_ops",
"//tensorflow/contrib/lite/kernels:gemm_support",
"//tensorflow/contrib/lite/kernels:kernel_util",
@@ -61,8 +62,8 @@ cc_library(
"//tensorflow/contrib/lite/kernels/internal:optimized",
"//tensorflow/contrib/lite/kernels/internal:optimized_base",
"//tensorflow/contrib/lite/kernels/internal:quantization_util",
- "//tensorflow/contrib/lite/kernels/internal:reference",
"//tensorflow/contrib/lite/kernels/internal:reference_base",
+ "//tensorflow/contrib/lite/kernels/internal:tensor",
"//tensorflow/contrib/lite/kernels/internal:tensor_utils",
"@flatbuffers",
],
diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h
index c658e43092..7c5099235a 100644
--- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h
+++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h
@@ -257,6 +257,16 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
} else {
max_coeff = raw_input.maxCoeff();
}
+
+ // Get normalization term of softmax: log(sum(exp(logit[j]-max_coeff))).
+ float logsumexp = 0.0;
+ for (int j = 0; j < raw_input.size(); ++j) {
+ logsumexp += Eigen::numext::exp(raw_input(j) - max_coeff);
+ }
+ logsumexp = Eigen::numext::log(logsumexp);
+ // Final normalization offset to get correct log probabilities.
+ float norm_offset = max_coeff + logsumexp;
+
const float label_selection_input_min =
(label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_)
: -std::numeric_limits<float>::infinity();
@@ -288,10 +298,10 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
beam_scorer_->GetStateExpansionScore(b->state, previous));
}
// Plabel(l=abc @ t=6) *= P(c @ 6)
- b->newp.label += raw_input(b->label) - max_coeff;
+ b->newp.label += raw_input(b->label) - norm_offset;
}
// Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6)
- b->newp.blank = b->oldp.total + raw_input(blank_index_) - max_coeff;
+ b->newp.blank = b->oldp.total + raw_input(blank_index_) - norm_offset;
// P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6)
b->newp.total = LogSumExp(b->newp.blank, b->newp.label);
@@ -326,6 +336,8 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
const float logit = top_k ? top_k_logits[ind] : raw_input(ind);
// Perform label selection: if input for this label looks very
// unpromising, never evaluate it with a scorer.
+ // We may compare logits instead of log probabilities,
+ // since the difference is the same in both cases.
if (logit < label_selection_input_min) {
continue;
}
@@ -339,7 +351,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
// Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6)
beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label);
float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total;
- c.newp.label = logit - max_coeff +
+ c.newp.label = logit - norm_offset +
beam_scorer_->GetStateExpansionScore(c.state, previous);
// P(l=abcd @ t=6) = Plabel(l=abcd @ t=6)
c.newp.total = c.newp.label;
diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc
index 834d1ebd66..8442c4d46c 100644
--- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc
+++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder.cc
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <vector>
-#include "flatbuffers/flexbuffers.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/experimental/kernels/ctc_beam_search.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc
index 9d1e6a562f..aa42b495bd 100644
--- a/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc
+++ b/tensorflow/contrib/lite/experimental/kernels/ctc_beam_search_decoder_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
-#include "flatbuffers/flexbuffers.h"
+#include "flatbuffers/flexbuffers.h" // flatbuffers
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
@@ -117,7 +117,7 @@ TEST(CTCBeamSearchTest, SimpleTest) {
EXPECT_THAT(decoded_outputs[2], ElementsAre(1, 1));
// Check log probabilities output.
EXPECT_THAT(m.GetLogProbabilitiesOutput(),
- ElementsAreArray(ArrayFloatNear({0.32134813})));
+ ElementsAreArray(ArrayFloatNear({-0.357094})));
}
TEST(CTCBeamSearchTest, MultiBatchTest) {
@@ -148,9 +148,8 @@ TEST(CTCBeamSearchTest, MultiBatchTest) {
EXPECT_THAT(decoded_outputs[1], ElementsAre(1, 0, 0, 0));
EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 2));
// Check log probabilities output.
- EXPECT_THAT(
- m.GetLogProbabilitiesOutput(),
- ElementsAreArray(ArrayFloatNear({0.46403232, 0.49500442, 0.40443572})));
+ EXPECT_THAT(m.GetLogProbabilitiesOutput(),
+ ElementsAreArray(ArrayFloatNear({-1.88343, -1.41188, -1.20958})));
}
TEST(CTCBeamSearchTest, MultiPathsTest) {
@@ -188,8 +187,8 @@ TEST(CTCBeamSearchTest, MultiPathsTest) {
EXPECT_THAT(decoded_outputs[5], ElementsAre(2, 2));
// Check log probabilities output.
EXPECT_THAT(m.GetLogProbabilitiesOutput(),
- ElementsAreArray(ArrayFloatNear(
- {0.91318405, 0.9060272, 1.0780245, 0.64358956})));
+ ElementsAreArray(
+ ArrayFloatNear({-2.65148, -2.65864, -2.17914, -2.61357})));
}
TEST(CTCBeamSearchTest, NonEqualSequencesTest) {
@@ -223,7 +222,7 @@ TEST(CTCBeamSearchTest, NonEqualSequencesTest) {
EXPECT_THAT(decoded_outputs[2], ElementsAre(3, 1));
// Check log probabilities output.
EXPECT_THAT(m.GetLogProbabilitiesOutput(),
- ElementsAreArray(ArrayFloatNear({0., 1.0347567, 0.7833005})));
+ ElementsAreArray(ArrayFloatNear({-0.97322, -1.16334, -2.15553})));
}
} // namespace
diff --git a/tensorflow/contrib/lite/experimental/writer/BUILD b/tensorflow/contrib/lite/experimental/writer/BUILD
new file mode 100644
index 0000000000..82d39c00ab
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/BUILD
@@ -0,0 +1,66 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+cc_binary(
+ name = "option_writer_generator",
+ srcs = ["option_writer_generator.cc"],
+ deps = [
+ "//tensorflow/contrib/lite/schema:schema_fbs_with_reflection",
+ "@flatbuffers",
+ ],
+)
+
+cc_library(
+ name = "writer_lib",
+ srcs = [
+ "enum_mapping.h",
+ "writer_lib.cc",
+ ],
+ hdrs = [
+ "writer_lib.h",
+ ],
+ data = [
+ ":option_writer_gen",
+ ],
+ textual_hdrs = ["option_writer_generated.h"],
+ deps = [
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:schema_fbs_version",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/schema:schema_fbs_with_reflection",
+ ],
+)
+
+cc_binary(
+ name = "writer",
+ srcs = ["writer.cc"],
+ deps = [
+ ":writer_lib",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ ],
+)
+
+cc_test(
+ name = "writer_lib_test",
+ size = "small",
+ srcs = ["writer_lib_test.cc"],
+ deps = [
+ ":writer_lib",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/testing:util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+genrule(
+ name = "option_writer_gen",
+ outs = ["option_writer_generated.h"],
+ cmd = "$(location :option_writer_generator) $(@)",
+ tools = [":option_writer_generator"],
+)
diff --git a/tensorflow/contrib/lite/experimental/writer/enum_mapping.h b/tensorflow/contrib/lite/experimental/writer/enum_mapping.h
new file mode 100644
index 0000000000..8bc464fd71
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/enum_mapping.h
@@ -0,0 +1,116 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_
+
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
+
+// TODO(aselle): Ideally extract this from the schema.
+
+namespace tflite {
+
+inline ActivationFunctionType TfLiteActivationToSchemaActivation(
+ TfLiteFusedActivation act) {
+ switch (act) {
+ case kTfLiteActNone:
+ return ActivationFunctionType_NONE;
+ case kTfLiteActRelu:
+ return ActivationFunctionType_RELU;
+ case kTfLiteActRelu1:
+ return ActivationFunctionType_RELU_N1_TO_1;
+ case kTfLiteActRelu6:
+ return ActivationFunctionType_RELU6;
+ case kTfLiteActTanh:
+ return ActivationFunctionType_TANH;
+ case kTfLiteActSignBit:
+ return ActivationFunctionType_SIGN_BIT;
+ case kTfLiteActSigmoid:
+ return ActivationFunctionType_NONE; // TODO(aselle): Add to schema
+ }
+ return ActivationFunctionType_NONE;
+}
+
+inline Padding TfLitePaddingToSchemaPadding(TfLitePadding padding) {
+ switch (padding) {
+ case kTfLitePaddingUnknown:
+ return Padding_SAME; // TODO(aselle): Consider an error.
+ case kTfLitePaddingSame:
+ return Padding_SAME;
+ case kTfLitePaddingValid:
+ return Padding_VALID;
+ }
+ return Padding_SAME; // TODO(aselle): Consider an error.
+}
+
+inline TensorType TfLiteTypeToSchemaType(TfLiteType type) {
+ switch (type) {
+ // case kTfLiteNoType: return TensorType_NONE;
+ case kTfLiteNoType:
+ return TensorType_FLOAT32; // TODO(aselle): Consider an error.
+ case kTfLiteFloat32:
+ return TensorType_FLOAT32;
+ case kTfLiteInt32:
+ return TensorType_INT32;
+ case kTfLiteUInt8:
+ return TensorType_UINT8;
+ case kTfLiteInt64:
+ return TensorType_INT64;
+ case kTfLiteString:
+ return TensorType_STRING;
+ case kTfLiteBool:
+ return TensorType_BOOL;
+ case kTfLiteInt16:
+ return TensorType_INT16;
+ case kTfLiteComplex64:
+ return TensorType_COMPLEX64;
+ }
+ // TODO(aselle): consider an error
+}
+
+inline FullyConnectedOptionsWeightsFormat
+FullyConnectedOptionsWeightsFormatToSchema(
+ TfLiteFullyConnectedWeightsFormat format) {
+ switch (format) {
+ case kTfLiteFullyConnectedWeightsFormatDefault:
+ return FullyConnectedOptionsWeightsFormat_DEFAULT;
+ case kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8:
+ return FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8;
+ }
+}
+
+inline LSTMKernelType LSTMKernelTypeToSchema(TfLiteLSTMKernelType type) {
+ switch (type) {
+ case kTfLiteLSTMFullKernel:
+ return LSTMKernelType_FULL;
+ case kTfLiteLSTMBasicKernel:
+ return LSTMKernelType_BASIC;
+ }
+}
+
+inline LSHProjectionType LSHProjectionTypeToSchema(
+ TfLiteLSHProjectionType type) {
+ switch (type) {
+ case kTfLiteLshProjectionUnknown:
+ return LSHProjectionType_UNKNOWN;
+ case kTfLiteLshProjectionSparse:
+ return LSHProjectionType_SPARSE;
+ case kTfLiteLshProjectionDense:
+ return LSHProjectionType_DENSE;
+ }
+}
+
+} // namespace tflite
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_ENUM_MAPPING_H_
diff --git a/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc b/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc
new file mode 100644
index 0000000000..e6d5a776b3
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/option_writer_generator.cc
@@ -0,0 +1,370 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <ctype.h>
+#include <iostream>
+#include <unordered_map>
+#include <unordered_set>
+#include "flatbuffers/minireflect.h" // flatbuffers
+#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
+
+namespace tflite {
+namespace {
+// This is generated by grepping
+// cat third_party/tensorflow/contrib/lite/builtin_op_data.h
+//| grep "^} TfLite" | sed 's/^} TfLite\(.*\)Params;/\1Params/g' | grep -v "^}"
+static const char* param_structs[] = {"TfLiteConvParams",
+ "TfLitePoolParams",
+ "TfLiteDepthwiseConvParams",
+ "TfLiteSVDFParams",
+ "TfLiteRNNParams",
+ "TfLiteSequenceRNNParams",
+ "TfLiteFullyConnectedParams",
+ "TfLiteLSHProjectionParams",
+ "TfLiteSoftmaxParams",
+ "TfLiteConcatenationParams",
+ "TfLiteAddParams",
+ "TfLiteSpaceToBatchNDParams",
+ "TfLiteBatchToSpaceNDParams",
+ "TfLiteMulParams",
+ "TfLiteSubParams",
+ "TfLiteDivParams",
+ "TfLiteL2NormParams",
+ "TfLiteLocalResponseNormParams",
+ "TfLiteLSTMParams",
+ "TfLiteResizeBilinearParams",
+ "TfLitePadParams",
+ "TfLitePadV2Params",
+ "TfLiteReshapeParams",
+ "TfLiteSkipGramParams",
+ "TfLiteSpaceToDepthParams",
+ "TfLiteCastParams",
+ "TfLiteEmbeddingLookupSparseParams",
+ "TfLiteGatherParams",
+ "TfLiteTransposeParams",
+ "TfLiteReducerParams",
+ "TfLiteSplitParams",
+ "TfLiteSqueezeParams",
+ "TfLiteStridedSliceParams",
+ "TfLiteArgMaxParams",
+ "TfLiteArgMinParams",
+ "TfLiteTransposeConvParams",
+ "TfLiteSparseToDenseParams",
+ "TfLiteShapeParams",
+ "TfLiteFakeQuantParams",
+ "TfLitePackParams",
+ "TfLiteOneHotParams",
+ nullptr};
+} // namespace
+
+// Get rid of all underscores and make everything lower case to make name
+// matching work for stuff like 3D vs 3d or RNN vs Rnn.
+std::string ToCollapsed(const std::string& in) {
+ const char* s = in.c_str();
+ bool first = true;
+ std::string out;
+ while (*s != '\0') {
+ if (*s == '_') {
+ first = true;
+ } else if (first) {
+ out.push_back(tolower(*s));
+ first = false;
+ } else {
+ out.push_back(tolower(*s));
+ }
+ s++;
+ }
+ return out;
+}
+
+// A collection of information about builtin ops.
+class OpOptionData {
+ public:
+ OpOptionData() {
+ BuildOpList();
+ BuildOptionToTypeFunctionMap();
+ BuildOpToOptionMap();
+ }
+
+ // A list of builtin operations
+ const std::vector<std::string>& ops() const { return ops_; }
+ // Maps from operation name to option name (i.e. 'ADD' to 'AddOptions')
+ const std::unordered_map<std::string, std::string>& op_to_option() {
+ return op_to_option_;
+ }
+ // Maps from option to to C struct i.e. 'AddOptions' -> 'TfLiteAddOptions'
+ const std::unordered_map<std::string, std::string>& option_to_struct() {
+ return option_to_struct_;
+ }
+ // Maps from option to a flatbuffer type function that describes that option.
+ const std::unordered_map<std::string, flatbuffers::TypeFunction>&
+ option_to_type_function() {
+ return option_to_type_function_;
+ }
+
+ private:
+ void BuildOpList() {
+ for (const char* const* curr = EnumNamesBuiltinOperator(); *curr != nullptr;
+ ++curr) {
+ if (strlen(*curr) != 0) ops_.push_back(*curr);
+ }
+ }
+
+ void BuildOptionToTypeFunctionMap() {
+ auto d = tflite::BuiltinOptionsTypeTable();
+ for (int i = 0; i < d->num_elems; i++) {
+ flatbuffers::TypeCode code = d->type_codes[i];
+ if (code.sequence_ref != -1) {
+ option_to_type_function_.insert(
+ std::make_pair(d->names[i], d->type_refs[code.sequence_ref]));
+ }
+ }
+ }
+
+ void BuildOpToOptionMap() {
+ // Manually specified mappings between ops and options
+ op_to_option_["REDUCE_MAX"] = "ReducerOptions";
+ op_to_option_["REDUCE_MIN"] = "ReducerOptions";
+ op_to_option_["REDUCE_ANY"] = "ReducerOptions";
+ op_to_option_["UNPACK"] = "";
+ op_to_option_["SUM"] = "ReducerOptions";
+ op_to_option_["REDUCE_MAX"] = "ReducerOptions";
+ op_to_option_["REDUCE_PROD"] = "ReducerOptions";
+ op_to_option_["MEAN"] = "ReducerOptions";
+ op_to_option_["L2_POOL_2D"] = "Pool2DOptions";
+ op_to_option_["AVERAGE_POOL_2D"] = "Pool2DOptions";
+ op_to_option_["MAX_POOL_2D"] = "Pool2DOptions";
+ op_to_option_["L2_NORMALIZATION"] = "L2NormOptions";
+ op_to_option_["BIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions";
+ op_to_option_["UNIDIRECTIONAL_SEQUENCE_LSTM"] = "LSTMOptions";
+ op_to_option_["BIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
+ op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
+ op_to_option_["UNIDIRECTIONAL_SEQUENCE_RNN"] = "SequenceRNNOptions";
+ // Manually specified mappings between ops and options (none)
+ op_to_option_["EMBEDDING_LOOKUP"] =
+ ""; // TODO(aselle): maybe something else.
+ op_to_option_["FLOOR"] = "";
+ op_to_option_["HASHTABLE_LOOKUP"] =
+ ""; // TODO(aselle): maybe something else.
+ op_to_option_["LOGISTIC"] = "";
+ op_to_option_["RELU"] = "";
+ op_to_option_["RELU_N1_TO_1"] = "";
+ op_to_option_["RELU6"] = "";
+ op_to_option_["TANH"] = "";
+ op_to_option_["CUSTOM"] = ""; // TODO(aselle): maybe something else.
+ op_to_option_["DELEGATE"] = ""; // TODO(aselle): maybe something else.
+ op_to_option_["PRELU"] = "";
+ op_to_option_["MAXIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions
+ op_to_option_["MINIMUM"] = ""; // TODO(aselle): MaximumMinimumOptions
+ op_to_option_["SIN"] = "";
+ op_to_option_["LOG"] = "";
+ op_to_option_["SQRT"] = "";
+ op_to_option_["RSQRT"] = "";
+
+ // TODO(aselle): These are undesirable hacks. Consider changing C structs
+ option_to_struct_["Pool2DOptions"] = "TfLitePoolParams";
+ option_to_struct_["Conv2DOptions"] = "TfLiteConvParams";
+ option_to_struct_["DepthwiseConv2DOptions"] = "TfLiteDepthwiseConvParams";
+ option_to_struct_["LocalResponseNormalizationOptions"] =
+ "TfLiteLocalResponseNormParams";
+ // Now for every op, try to find an option.
+ bool fatal = false;
+ for (auto op_name : ops_) {
+ bool found_option = false;
+ auto d = tflite::BuiltinOptionsTypeTable();
+ std::string collapsed_option_name_guess =
+ ToCollapsed(op_name) + "options";
+ // O(n^2) but not that big of n.
+ for (int i = 0; i < d->num_elems; i++) {
+ std::string option_name = d->names[i];
+ std::string collapsed_option_name = ToCollapsed(option_name);
+ if (collapsed_option_name_guess == collapsed_option_name) {
+ op_to_option_.insert(std::make_pair(op_name, option_name));
+ found_option = true;
+ break;
+ }
+ }
+ auto it = op_to_option_.find(op_name);
+ if (it == op_to_option_.end()) {
+ std::cerr << "Didn't find option for " << op_name << std::endl;
+ fatal = true;
+ } else if (!it->second.empty()) {
+ std::string option_name = it->second;
+
+ if (option_to_struct_.find(option_name) == option_to_struct_.end()) {
+ bool param_struct_found = false;
+ std::string params_guess = std::string("TfLite") + option_name;
+ size_t start = params_guess.find("Options");
+ size_t len = strlen("Options");
+ params_guess.replace(start, len, "Params");
+ for (auto* param = param_structs; *param != nullptr; param++) {
+ if (*param == params_guess) {
+ param_struct_found = true;
+ break;
+ }
+ }
+ if (!param_struct_found) {
+ std::cerr << "Failed to get param struct for option " << option_name
+ << std::endl;
+ fatal = true;
+ } else {
+ option_to_struct_.insert(std::make_pair(option_name, params_guess));
+ }
+ }
+ }
+ }
+ }
+
+ private:
+ std::vector<std::string> ops_;
+ std::unordered_map<std::string, std::string> op_to_option_;
+ std::unordered_map<std::string, std::string> option_to_struct_;
+ std::unordered_map<std::string, flatbuffers::TypeFunction>
+ option_to_type_function_;
+};
+
+void GenerateImportForOp(FILE* fp, const std::string& op_name,
+ const std::string& option_name,
+ const std::string& option_type,
+ const flatbuffers::TypeTable* options,
+ const std::string& struct_name) {
+ // Skip tricky ones for now
+ if (struct_name == "TfLiteResizeBilinearParams") return;
+ if (struct_name == "TfLiteSqueezeParams") return;
+ if (struct_name == "TfLiteEmbeddingLookupSparseParams") return;
+ if (struct_name == "TfLiteReshapeParams") return;
+
+ fprintf(fp, " case BuiltinOperator_%s: {\n", op_name.c_str());
+ fprintf(fp,
+ " const auto* params = reinterpret_cast<const "
+ "%s*>(builtin_op_data);\n",
+ struct_name.c_str());
+
+ for (size_t i = 0; i < options->num_elems; i++) {
+ std::string elem_name = options->names[i];
+ // TODO(aselle): Irregular naming in builtins
+ if (elem_name == "fused_activation_function")
+ elem_name = "activation";
+ else if (elem_name == "stride_w")
+ elem_name = "stride_width";
+ else if (elem_name == "stride_h")
+ elem_name = "stride_height";
+ else if (elem_name == "dilation_h_factor")
+ elem_name = "dilation_height_factor";
+ else if (elem_name == "dilation_w_factor")
+ elem_name = "dilation_width_factor";
+ else if (elem_name == "new_shape")
+ elem_name = "shape";
+
+ flatbuffers::TypeCode code = options->type_codes[i];
+ auto contained_type = code.sequence_ref != -1
+ ? options->type_refs[code.sequence_ref]
+ : nullptr;
+ std::string mapper = "";
+ if (contained_type == TensorTypeTypeTable) {
+ mapper = "TfLiteTypeToSchemaType";
+ } else if (contained_type == ActivationFunctionTypeTypeTable) {
+ mapper = "TfLiteActivationToSchemaActivation";
+ } else if (contained_type == PaddingTypeTable) {
+ mapper = "TfLitePaddingToSchemaPadding";
+ } else if (contained_type == FullyConnectedOptionsWeightsFormatTypeTable) {
+ mapper = "FullyConnectedOptionsWeightsFormatToSchema";
+ } else if (contained_type == LSTMKernelTypeTypeTable) {
+ mapper = "LSTMKernelTypeToSchema";
+ } else if (contained_type == LSHProjectionTypeTypeTable) {
+ mapper = "LSHProjectionTypeToSchema";
+ }
+
+ fprintf(fp,
+ " auto val%zu = "
+ "%s(params->%s);\n",
+ i, mapper.c_str(), elem_name.c_str());
+ }
+ fprintf(fp, " auto union_type = Create%s(*fbb", option_name.c_str());
+ for (size_t i = 0; i < options->num_elems; i++) {
+ fprintf(fp, ", val%zu", i);
+ }
+ fprintf(fp, ").Union();\n");
+ fprintf(fp, " return std::make_pair(%s, union_type);\n",
+ option_type.c_str());
+ fprintf(fp, " }\n break;\n");
+}
+
+void GenerateImport(OpOptionData* option, FILE* fp) {
+ std::unordered_set<std::string> ignores;
+ ignores.insert("CONCAT_EMBEDDINGS");
+ ignores.insert("CALL");
+
+ // Allow any op that doesn't have an options struct to be blocked
+ // together
+ for (const auto& op_name : option->ops()) {
+ auto option_it = option->op_to_option().find(op_name);
+ if (!option_it->second.empty() && ignores.find(op_name) == ignores.end())
+ continue;
+ fprintf(fp, " case BuiltinOperator_%s:\n", op_name.c_str());
+ }
+ fprintf(fp,
+ " return std::make_pair(BuiltinOptions_NONE, "
+ "flatbuffers::Offset<void>());\n break;\n");
+
+ // Iterate over each ops
+ for (const auto& op_name : option->ops()) {
+ if (ignores.find(op_name) != ignores.end()) continue;
+ // Get to the option and struct names, continuing if not found.
+ auto option_it = option->op_to_option().find(op_name);
+ if (option_it->second.empty()) continue;
+ std::string option_name = option_it->second;
+ std::string option_type = "BuiltinOptions_" + option_name;
+ auto option_func_it = option->option_to_type_function().find(option_name);
+ if (option_func_it == option->option_to_type_function().end()) continue;
+ auto struct_name_it = option->option_to_struct().find(option_name);
+ if (struct_name_it == option->option_to_struct().end()) {
+ // If no C struct, then it better have no arguments.
+ auto type_info = option_func_it->second();
+ if (type_info->num_elems != 0) {
+ // We have non-zero arguments in the schema, this means there
+ // should be a struct.
+ fprintf(stderr,
+ "Op %s uses option struct %s which has no builtin struct\n",
+ op_name.c_str(), option_name.c_str());
+ exit(1);
+ }
+ fprintf(fp, " case BuiltinOperator_%s:\n", op_name.c_str());
+ fprintf(fp, " return std::make_pair(%s, Create%s(*fbb).Union());",
+ option_type.c_str(), option_name.c_str());
+ } else {
+ // If C struct, then we need to assign all properties
+ auto struct_name = struct_name_it->second;
+ GenerateImportForOp(fp, op_name, option_name, option_type,
+ option_func_it->second(), struct_name);
+ }
+ }
+ // TODO(aselle): Handle unhandled cases more gracefully.
+ fprintf(fp,
+ "default: return std::make_pair(BuiltinOptions_NONE, "
+ "flatbuffers::Offset<void>());\n break;\n");
+}
+
+} // namespace tflite
+
+int main(int argc, char* argv[]) {
+ tflite::OpOptionData option;
+ if (argc != 2) {
+ fprintf(stderr, "Usage: %s <fname out>\n", argv[0]);
+ return 1;
+ }
+ FILE* fp = fopen(argv[1], "w");
+ tflite::GenerateImport(&option, fp);
+ fclose(fp);
+}
diff --git a/tensorflow/contrib/lite/experimental/writer/writer.cc b/tensorflow/contrib/lite/experimental/writer/writer.cc
new file mode 100644
index 0000000000..20ede214fb
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/writer.cc
@@ -0,0 +1,41 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Just does a read/write loop of tflite file format using the interpreter as
+// an intermediate.
+//
+// Usage:
+// writer <input tflite> <output tflite>
+
+#include <iostream>
+
+#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+
+int main(int argc, char* argv[]) {
+ if (argc != 3) {
+ fprintf(stderr, "Usage: %s input_file output_file\n", argv[0]);
+ return 1;
+ }
+ std::unique_ptr<tflite::FlatBufferModel> model =
+ tflite::FlatBufferModel::BuildFromFile(argv[1]);
+ std::unique_ptr<tflite::Interpreter> interpreter;
+ tflite::ops::builtin::BuiltinOpResolver builtin_op_resolver;
+ tflite::InterpreterBuilder(*model, builtin_op_resolver)(&interpreter);
+ tflite::InterpreterWriter writer(interpreter.get());
+ writer.Write(argv[2]);
+
+ return 0;
+}
diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.cc b/tensorflow/contrib/lite/experimental/writer/writer_lib.cc
new file mode 100644
index 0000000000..52b17faf82
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/writer_lib.cc
@@ -0,0 +1,281 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h"
+#include <cstdlib>
+#include <cstring>
+#include <unordered_map>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context_util.h"
+#include "tensorflow/contrib/lite/experimental/writer/enum_mapping.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
+#include "tensorflow/contrib/lite/version.h"
+
+namespace tflite {
+template <class T>
+using Offset = flatbuffers::Offset<T>;
+template <class T>
+using Vector = flatbuffers::Vector<T>;
+using FlatBufferBuilder = flatbuffers::FlatBufferBuilder;
+
+std::pair<BuiltinOptions, Offset<void>> CreateBuiltinUnion(
+ FlatBufferBuilder* fbb, enum BuiltinOperator op, void* builtin_op_data) {
+ switch (op) {
+#include "tensorflow/contrib/lite/experimental/writer/option_writer_generated.h"
+ }
+ return std::make_pair(BuiltinOptions_NONE, Offset<void>());
+}
+
+template <class T_OUTPUT, class T_INPUT>
+Offset<Vector<T_OUTPUT>> InterpreterWriter::ExportVector(FlatBufferBuilder* fbb,
+ const T_INPUT& v) {
+ std::vector<T_OUTPUT> inputs(v.begin(), v.end());
+ return fbb->template CreateVector<T_OUTPUT>(inputs);
+}
+
+Offset<Vector<Offset<Operator>>> InterpreterWriter::ExportOperators(
+ FlatBufferBuilder* fbb) {
+ std::vector<Offset<Operator>> operators;
+
+ std::vector<int> operator_to_opcode;
+ // TODO(aselle): Augment this once we put execution plan in schema.
+ operator_to_opcode.resize(interpreter_->nodes_size(), -1);
+ for (int op_index : interpreter_->execution_plan()) {
+ const auto* node_and_registration =
+ interpreter_->node_and_registration(op_index);
+ const TfLiteRegistration* registration = &node_and_registration->second;
+ if (!registration->custom_name) {
+ operator_to_opcode[op_index] =
+ GetOpCodeForBuiltin(registration->builtin_code);
+ } else {
+ operator_to_opcode[op_index] =
+ GetOpCodeForCustom(registration->custom_name);
+ }
+ }
+ // second pass serialize operators
+ for (int op_index : interpreter_->execution_plan()) {
+ const auto* node_and_registration =
+ interpreter_->node_and_registration(op_index);
+ const TfLiteNode& node = node_and_registration->first;
+ const TfLiteRegistration& registration = node_and_registration->second;
+ Offset<void> builtin_options;
+ BuiltinOptions builtin_options_type = BuiltinOptions_NONE;
+ // Custom data
+ // TODO(aselle): Custom options format is not known by default. Just assume
+ // for now.
+ auto custom_options_format = CustomOptionsFormat_FLEXBUFFERS;
+ Offset<Vector<uint8_t>> custom_options = 0;
+
+ if (!registration.custom_name) {
+ // builtin
+ auto builtin_options_and_type = CreateBuiltinUnion(
+ fbb, static_cast<enum BuiltinOperator>(registration.builtin_code),
+ node.builtin_data);
+ builtin_options = builtin_options_and_type.second;
+ builtin_options_type = builtin_options_and_type.first;
+ } else {
+ auto custom_writer = custom_op_to_writer_.find(registration.custom_name);
+ if (custom_writer != custom_op_to_writer_.end() &&
+ custom_writer->second) {
+ // delegate to custom writer if it exists
+ custom_writer->second(fbb, interpreter_, op_index, &custom_options,
+ &custom_options_format);
+ } else {
+ // use the custom data as fact
+ custom_options = fbb->CreateVector(
+ reinterpret_cast<const uint8_t*>(node.custom_initial_data),
+ node.custom_initial_data_size);
+ }
+ }
+
+ int opcode_index = operator_to_opcode[op_index];
+ std::vector<int> written_inputs =
+ RemapTensorIndicesToWritten(TfLiteIntArrayView(node.inputs));
+ std::vector<int> written_outputs =
+ RemapTensorIndicesToWritten(TfLiteIntArrayView(node.outputs));
+ auto inputs = ExportVector<int32_t>(fbb, written_inputs);
+ auto outputs = ExportVector<int32_t>(fbb, written_outputs);
+ operators.push_back(CreateOperator(*fbb, opcode_index, inputs, outputs,
+ builtin_options_type, builtin_options,
+ custom_options, custom_options_format));
+ }
+
+ return fbb->template CreateVector<Offset<Operator>>(operators);
+}
+
+Offset<Vector<Offset<Tensor>>> InterpreterWriter::ExportTensors(
+ FlatBufferBuilder* fbb) {
+ tensor_to_written_tensor_.resize(interpreter_->tensors_size(), -1);
+
+ std::vector<Offset<Tensor>> tensors;
+
+ // Make a map from tensor index to whether the tensor is a temporary.
+ std::vector<bool> tensor_is_temporary(interpreter_->tensors_size(), false);
+ for (int op_index = 0; op_index < interpreter_->nodes_size(); ++op_index) {
+ const auto* node_and_registration =
+ interpreter_->node_and_registration(op_index);
+ for (auto tensor_index :
+ TfLiteIntArrayView(node_and_registration->first.temporaries))
+ tensor_is_temporary[tensor_index] = true;
+ }
+
+ // Now we need to remap all used tensor indices
+ int curr_output_index = 0;
+ for (int tensor_index = 0; tensor_index < interpreter_->tensors_size();
+ tensor_index++) {
+ if (!tensor_is_temporary[tensor_index]) {
+ tensor_to_written_tensor_[tensor_index] = curr_output_index++;
+ }
+ }
+
+ for (int tensor_index = 0; tensor_index < interpreter_->tensors_size();
+ ++tensor_index) {
+ // Skip temporaries.
+ if (tensor_is_temporary[tensor_index]) continue;
+
+ if (TfLiteTensor* tensor = interpreter_->tensor(tensor_index)) {
+ // We only need to convert non temporaries
+ if (tensor->allocation_type != kTfLiteArenaRw &&
+ tensor->allocation_type != kTfLiteMmapRo &&
+ tensor->allocation_type != kTfLiteArenaRwPersistent)
+ continue;
+ // Allocate a buffer index
+ int buffer_index = 0; // This is null
+ if (tensor->allocation_type == kTfLiteMmapRo) {
+ buffer_index = buffers_.size();
+ buffers_.push_back(std::make_pair(
+ reinterpret_cast<const uint8_t*>(tensor->data.raw), tensor->bytes));
+ }
+ // Primitive type.
+ TensorType type = TfLiteTypeToSchemaType(tensor->type);
+ // Handle quantization
+ const Offset<Vector<float>> null_array;
+ Offset<Vector<float>> scale_array;
+ Offset<Vector<int64_t>> zero_point_array;
+ if (tensor->params.scale != 0.f) {
+ // We have quantization, make a single arugment array (multi channel
+ // quant needs updating here).
+ scale_array = fbb->CreateVector<float>({tensor->params.scale});
+ zero_point_array =
+ fbb->CreateVector<int64_t>({tensor->params.zero_point});
+ }
+ Offset<QuantizationParameters> quantization_params =
+ CreateQuantizationParameters(*fbb, null_array, null_array,
+ scale_array, zero_point_array);
+ // Shape
+ TfLiteIntArrayView shape_view(tensor->dims);
+ std::vector<int> shape =
+ std::vector<int>(shape_view.begin(), shape_view.end());
+
+ tensors.push_back(CreateTensor(*fbb, ExportVector<int32_t>(fbb, shape),
+ type, buffer_index,
+ fbb->CreateString(tensor->name),
+ quantization_params, tensor->is_variable));
+ }
+ }
+ return fbb->template CreateVector<Offset<Tensor>>(tensors);
+}
+
+Offset<Vector<Offset<Buffer>>> InterpreterWriter::ExportBuffers(
+ FlatBufferBuilder* fbb) {
+ std::vector<Offset<Buffer>> buffer_vector;
+ for (auto buffer : buffers_) {
+ auto data_offset = fbb->CreateVector(buffer.first, buffer.second);
+ buffer_vector.push_back(CreateBuffer(*fbb, data_offset));
+ }
+ return fbb->template CreateVector<Offset<Buffer>>(buffer_vector);
+}
+
+Offset<Vector<Offset<OperatorCode>>> InterpreterWriter::CreateOpCodeTable(
+ FlatBufferBuilder* fbb) {
+ std::vector<Offset<OperatorCode>> codes;
+ for (auto it : opcodes_) {
+ const char* custom_name = it.custom.empty() ? nullptr : it.custom.c_str();
+ codes.push_back(CreateOperatorCodeDirect(
+ *fbb, static_cast<BuiltinOperator>(it.builtin), custom_name));
+ }
+ return fbb->template CreateVector<Offset<OperatorCode>>(codes);
+}
+
+template <class T>
+std::vector<int> InterpreterWriter::RemapTensorIndicesToWritten(
+ const T& input) {
+ std::vector<int> output;
+ output.reserve(input.size());
+ for (int x : input) {
+ output.push_back(tensor_to_written_tensor_[x]);
+ }
+ return output;
+}
+
+TfLiteStatus InterpreterWriter::GetBuffer(std::unique_ptr<uint8_t[]>* out,
+ size_t* size) {
+ if (!out || !size) return kTfLiteError;
+ FlatBufferBuilder builder(/*initial_size=*/10240);
+
+ std::vector<Offset<SubGraph>> subgraphs_as_vector;
+ { // subgraph specific stuff
+ auto tensors = ExportTensors(&builder);
+ std::vector<int> written_inputs =
+ RemapTensorIndicesToWritten(interpreter_->inputs());
+ std::vector<int> written_outputs =
+ RemapTensorIndicesToWritten(interpreter_->outputs());
+ auto inputs = ExportVector<int32_t>(&builder, written_inputs);
+ auto outputs = ExportVector<int32_t>(&builder, written_outputs);
+
+ auto ops = ExportOperators(&builder);
+ subgraphs_as_vector.push_back(
+ CreateSubGraph(builder, tensors, inputs, outputs, ops, /* name */ 0));
+ }
+ Offset<Vector<Offset<Buffer>>> buffers = ExportBuffers(&builder);
+
+ auto description = builder.CreateString("Exported from Interpreter.");
+
+ auto op_codes = CreateOpCodeTable(&builder);
+ auto model = CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
+ builder.CreateVector(subgraphs_as_vector),
+ description, buffers);
+ ::tflite::FinishModelBuffer(builder, model);
+ const uint8_t* buffer = builder.GetBufferPointer();
+ *size = builder.GetSize();
+ (*out).reset(new uint8_t[*size]);
+ memcpy(out->get(), buffer, *size);
+ return kTfLiteOk;
+}
+
+TfLiteStatus InterpreterWriter::Write(const std::string& filename) {
+ std::unique_ptr<uint8_t[]> buffer;
+ size_t size;
+ TF_LITE_ENSURE_STATUS(GetBuffer(&buffer, &size));
+
+ FILE* fp = fopen(filename.c_str(), "wb");
+ if (!fp) return kTfLiteError;
+
+ if (fwrite(buffer.get(), 1, size, fp) != size) return kTfLiteError;
+ if (fclose(fp)) return kTfLiteError;
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus InterpreterWriter::RegisterCustomWriter(
+ const std::string& custom_name, CustomWriter custom_writer) {
+ if (custom_op_to_writer_.find(custom_name) != custom_op_to_writer_.end()) {
+ return kTfLiteError;
+ }
+ custom_op_to_writer_.insert(std::make_pair(custom_name, custom_writer));
+ return kTfLiteOk;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib.h b/tensorflow/contrib/lite/experimental/writer/writer_lib.h
new file mode 100644
index 0000000000..a98108b496
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/writer_lib.h
@@ -0,0 +1,126 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Writes a flatbuffer of a currently loaded TensorFlow Lite interpreter.
+//
+// Usage:
+// From command line:
+// bazel run third_party/tensorflow/contrib/lite/experimental/writer:writer
+// -- foo.tflite foo.out.tflite
+//
+// From C++
+// std::unique_ptr<Interpreter> interpreter;
+// // Build Interpreter however
+// // ... <omitted>
+// InterpreterWriter(interpreter.get()).Write("output.tflite");
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
+#include <iostream>
+#include <unordered_map>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context_util.h"
+#include "tensorflow/contrib/lite/experimental/writer/enum_mapping.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/schema/reflection/schema_generated.h"
+#include "tensorflow/contrib/lite/version.h"
+
+namespace tflite {
+
+// Handles writing TensorFlow Lite running interpreter to a serialized TF lite
+// file format.
+class InterpreterWriter {
+ public:
+ typedef flatbuffers::Offset<Operator> (*CustomWriter)(
+ flatbuffers::FlatBufferBuilder* fbb, Interpreter* interpreter,
+ int node_index,
+ flatbuffers::Offset<flatbuffers::Vector<uint8_t>>* output_options,
+ CustomOptionsFormat* custom_options_format);
+
+ // Construct an interpreter writer for the specified `interpreter`. Then,
+ // a uses .Write() or .GetBuffer(...) to extract the data.
+ explicit InterpreterWriter(Interpreter* interpreter)
+ : interpreter_(interpreter) {
+ buffers_.push_back(std::make_pair(nullptr, 0));
+ }
+
+ // Get a buffer and size of a serialized flatbuffer.
+ TfLiteStatus GetBuffer(std::unique_ptr<uint8_t[]>* out, size_t* size);
+ // Write the serialized flatbuffer to the prescribed `filename`.
+ TfLiteStatus Write(const std::string& filename);
+ // Registers a custom writer for a custom op. The customization allows the
+ // caller to change the custom data.
+ TfLiteStatus RegisterCustomWriter(const std::string& custom_name,
+ CustomWriter custom_writer);
+
+ private:
+ template <class T>
+ using Offset = flatbuffers::Offset<T>;
+ template <class T_OUTPUT, class T_INPUT>
+ Offset<flatbuffers::Vector<T_OUTPUT>> ExportVector(
+ flatbuffers::FlatBufferBuilder* fbb, const T_INPUT& v);
+ Offset<flatbuffers::Vector<Offset<Tensor>>> ExportTensors(
+ flatbuffers::FlatBufferBuilder* fbb);
+ Offset<flatbuffers::Vector<Offset<Operator>>> ExportOperators(
+ flatbuffers::FlatBufferBuilder* fbb);
+ Offset<flatbuffers::Vector<Offset<OperatorCode>>> CreateOpCodeTable(
+ flatbuffers::FlatBufferBuilder* fbb);
+ Offset<flatbuffers::Vector<Offset<Buffer>>> ExportBuffers(
+ flatbuffers::FlatBufferBuilder* fbb);
+
+ template <class T>
+ std::vector<int> RemapTensorIndicesToWritten(const T& input);
+
+ int GetOpCodeForBuiltin(int builtin_op_index) {
+ // auto it = builtin_op_to_opcode_.find(builtin_op_index);
+ std::pair<decltype(builtin_op_to_opcode_)::iterator, bool> result =
+ builtin_op_to_opcode_.insert(
+ std::make_pair(builtin_op_index, opcodes_.size()));
+ if (result.second) {
+ opcodes_.push_back({builtin_op_index, ""});
+ }
+ return result.first->second;
+ }
+
+ int GetOpCodeForCustom(const std::string& custom_name) {
+ std::pair<decltype(custom_op_to_opcode_)::iterator, bool> result =
+ custom_op_to_opcode_.insert(
+ std::make_pair(custom_name, opcodes_.size()));
+ if (result.second) {
+ opcodes_.push_back({BuiltinOperator_CUSTOM, custom_name});
+ }
+ return result.first->second;
+ }
+
+ // The interpreter we are writing
+ Interpreter* interpreter_;
+ // Keep track of byte buffers
+ std::vector<std::pair<const uint8_t*, size_t>> buffers_;
+ // List of op codes and mappings from builtin or custom op to opcode
+ struct OpCode {
+ int builtin;
+ std::string custom;
+ };
+ // For every tensor index in the interpreter, the index in the written.
+ // This is different due to temporary tensors not being written.
+ std::vector<int> tensor_to_written_tensor_;
+ // List of used opcodes
+ std::vector<OpCode> opcodes_;
+ std::unordered_map<int, int> builtin_op_to_opcode_;
+ std::unordered_map<std::string, int> custom_op_to_opcode_;
+ std::unordered_map<std::string, CustomWriter> custom_op_to_writer_;
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_WRITER_WRITER_LIB_H_
diff --git a/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc b/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc
new file mode 100644
index 0000000000..49194a76c8
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/writer/writer_lib_test.cc
@@ -0,0 +1,62 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/experimental/writer/writer_lib.h"
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/testing/util.h"
+
+namespace tflite {
+// Make an interpreter that has no tensors and no nodes
+// TODO(b/113731921): add more tests.
+TEST(Writer, BasicTest) {
+ Interpreter interpreter;
+ interpreter.AddTensors(3);
+ float foo[] = {1, 2, 3};
+ interpreter.SetTensorParametersReadWrite(0, kTfLiteFloat32, "a", {3},
+ TfLiteQuantizationParams());
+ interpreter.SetTensorParametersReadOnly(
+ 1, kTfLiteFloat32, "b", {3}, TfLiteQuantizationParams(),
+ reinterpret_cast<char*>(foo), sizeof(foo));
+ interpreter.SetTensorParametersReadWrite(2, kTfLiteFloat32, "c", {3},
+ TfLiteQuantizationParams());
+ interpreter.SetInputs({0, 1});
+ interpreter.SetOutputs({2});
+ const char* initial_data = "";
+ tflite::ops::builtin::BuiltinOpResolver resolver;
+ TfLiteAddParams* builtin_data =
+ reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
+ builtin_data->activation = kTfLiteActNone;
+ const TfLiteRegistration* reg = resolver.FindOp(BuiltinOperator_ADD, 1);
+ interpreter.AddNodeWithParameters({0, 1}, {2}, initial_data, 0,
+ reinterpret_cast<void*>(builtin_data), reg);
+
+ InterpreterWriter writer(&interpreter);
+ writer.Write("/tmp/test.tflite");
+ std::unique_ptr<FlatBufferModel> model =
+ FlatBufferModel::BuildFromFile("/tmp/test.tflite");
+ InterpreterBuilder builder(*model, resolver);
+ std::unique_ptr<Interpreter> new_interpreter;
+ builder(&new_interpreter);
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/g3doc/README.md b/tensorflow/contrib/lite/g3doc/README.md
deleted file mode 100644
index e3db478481..0000000000
--- a/tensorflow/contrib/lite/g3doc/README.md
+++ /dev/null
@@ -1,4 +0,0 @@
-This is a *work-in-progress* TF Lite subsite for:
-https://www.tensorflow.org/mobile
-
-DO NOT PUBLISH
diff --git a/tensorflow/contrib/lite/g3doc/api_docs/python/index.md b/tensorflow/contrib/lite/g3doc/api_docs/python/index.md
deleted file mode 100644
index 70031a3c3d..0000000000
--- a/tensorflow/contrib/lite/g3doc/api_docs/python/index.md
+++ /dev/null
@@ -1,10 +0,0 @@
-Project: /mobile/_project.yaml
-Book: /mobile/_book.yaml
-page_type: reference
-<style> table img { max-width: 100%; } </style>
-<script src="/_static/js/managed/mathjax/MathJax.js?config=TeX-AMS-MML_SVG"></script>
-
-<!-- DO NOT EDIT! Automatically generated file. -->
-# All symbols in TensorFlow Lite
-
-TEMP PAGE
diff --git a/tensorflow/contrib/lite/g3doc/apis.md b/tensorflow/contrib/lite/g3doc/apis.md
index 776803da8c..69616c7b8a 100644
--- a/tensorflow/contrib/lite/g3doc/apis.md
+++ b/tensorflow/contrib/lite/g3doc/apis.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# TensorFlow Lite APIs
@@ -39,7 +37,7 @@ float* output = interpreter->typed_output_tensor<float>(0);
```
### Data Alignment
-TensorFlow Lite data is usually aligned to 32-bit boundaries. It is recommended
+TensorFlow Lite data is usually aligned to 16-byte boundaries. It is recommended
that all data provided to TensorFlow Lite be aligned that way.
### Error Reporting
@@ -114,7 +112,7 @@ below. It should be noted that:
* Tensors are represented by integers, in order to avoid string comparisons
(and any fixed dependency on string libraries).
- * An interpreter must not be accessed from concurrent threads
+ * An interpreter must not be accessed from concurrent threads.
* Memory allocation for input and output tensors must be triggered
by calling AllocateTensors() right after resizing tensors.
@@ -171,7 +169,7 @@ former provides error reporting facilities and access to global objects,
including all the tensors. The latter allows implementations to access their
inputs and outputs.
-When the interpreter loads a model, it calls init() once for each node in the
+When the interpreter loads a model, it calls `init()` once for each node in the
graph. A given `init()` will be called more than once if the op is used
multiple times in the graph. For custom ops a configuration buffer will be
provided, containing a flexbuffer that maps parameter names to their values.
@@ -212,8 +210,9 @@ namespace custom {
Note that registration is not automatic and an explicit call to
`Register_MY_CUSTOM_OP` should be made somewhere. While the standard
-`:builtin_ops` takes care of the registration of builtins, custom ops will have
-to be collected in separated custom libraries.
+`BuiltinOpResolver` (available from the `:builtin_ops` target) takes care of the
+registration of builtins, custom ops will have to be collected in separate
+custom libraries.
### Customizing the kernel library
@@ -234,7 +233,7 @@ class OpResolver {
};
```
-The regular usage will require the developer to use the `BuiltinOpResolver` and
+Regular usage will require the developer to use the `BuiltinOpResolver` and
write:
```c++
@@ -310,18 +309,25 @@ an `IllegalArgumentException` will be thrown.
#### Inputs
-Each input should be an array, a multi-dimensional array, or a `ByteBuffer` of
-the supported primitive types.
+Each input should be an array or multi-dimensional array of the supported
+primitive types, or a raw `ByteBuffer` of the appropriate size. If the input is
+an array or multi-dimensional array, the associated input tensor will be
+implicitly resized to the array's dimensions at inference time. If the input is
+a ByteBuffer, the caller should first manually resize the associated input
+tensor (via `Interpreter.resizeInput()`) before running inference.
-The use of `ByteBuffer` is preferred since it allows the `Interpreter` to avoid
-unnecessary copies. Each `ByteBuffer` needs to be a direct byte buffer, and its
-order must be `ByteOrder.nativeOrder()`. After it is used for a model inference,
-it must remain unchanged until the model inference is finished.
+When using 'ByteBuffer', prefer using direct byte buffers, as this allows the
+`Interpreter` to avoid unnecessary copies. If the `ByteBuffer` is a direct byte
+buffer, its order must be `ByteOrder.nativeOrder()`. After it is used for a
+model inference, it must remain unchanged until the model inference is finished.
#### Outputs
-Each output should be an array, or a multi-dimensional array of the supported
-primitive types.
+Each output should be an array or multi-dimensional array of the supported
+primitive types, or a ByteBuffer of the appropriate size. Note that some models
+have dynamic outputs, where the shape of output tensors can vary depending on
+the input. There's no straightforward way of handling this with the existing
+Java inference API, but planned extensions will make this possible.
#### Running Model Inference
@@ -341,9 +347,10 @@ interpreter.runForMultipleInputsOutputs(inputs, map_of_indices_to_outputs);
where each entry in `inputs` corresponds to an input tensor and
`map_of_indices_to_outputs` maps indices of output tensors to the
corresponding output data. In both cases the tensor indices should correspond to
-the values given to the `TensorFlow Lite Optimized Converter` when the model was
-created. Be aware that the order of tensors in `input` must match the order
-given to the `TensorFlow Lite Optimized Converter`.
+the values given to the [TensorFlow Lite Optimized Converter](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md)
+when the model was created. Be aware that the order of tensors in `input` must
+match the order given to the `TensorFlow Lite Optimized Converter`.
+
The Java API also provides convenient functions for app developers to get the
index of any model input or output using a tensor name:
diff --git a/tensorflow/contrib/lite/g3doc/custom_operators.md b/tensorflow/contrib/lite/g3doc/custom_operators.md
index d979353bb3..ee6150b60e 100644
--- a/tensorflow/contrib/lite/g3doc/custom_operators.md
+++ b/tensorflow/contrib/lite/g3doc/custom_operators.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# How to use custom operators
diff --git a/tensorflow/contrib/lite/g3doc/demo_android.md b/tensorflow/contrib/lite/g3doc/demo_android.md
index d79a2696b4..c38b928684 100644
--- a/tensorflow/contrib/lite/g3doc/demo_android.md
+++ b/tensorflow/contrib/lite/g3doc/demo_android.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Android Demo App
diff --git a/tensorflow/contrib/lite/g3doc/demo_ios.md b/tensorflow/contrib/lite/g3doc/demo_ios.md
index a554898899..7579ad84a0 100644
--- a/tensorflow/contrib/lite/g3doc/demo_ios.md
+++ b/tensorflow/contrib/lite/g3doc/demo_ios.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# iOS Demo App
diff --git a/tensorflow/contrib/lite/g3doc/devguide.md b/tensorflow/contrib/lite/g3doc/devguide.md
index dc9cc98c08..90e7915c52 100644
--- a/tensorflow/contrib/lite/g3doc/devguide.md
+++ b/tensorflow/contrib/lite/g3doc/devguide.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Developer Guide
diff --git a/tensorflow/contrib/lite/g3doc/ios.md b/tensorflow/contrib/lite/g3doc/ios.md
index d78d373ccf..a83d2c8fec 100644
--- a/tensorflow/contrib/lite/g3doc/ios.md
+++ b/tensorflow/contrib/lite/g3doc/ios.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# TensorFlow Lite for iOS
@@ -38,7 +36,7 @@ brew link libtool
Then you need to run a shell script to download the dependencies you need:
```bash
-tensorflow/contrib/lite/download_dependencies.sh
+tensorflow/contrib/lite/tools/make/download_dependencies.sh
```
This will fetch copies of libraries and data from the web and install them in
@@ -48,14 +46,14 @@ With all of the dependencies set up, you can now build the library for all five
supported architectures on iOS:
```bash
-tensorflow/contrib/lite/build_ios_universal_lib.sh
+tensorflow/contrib/lite/tools/make/build_ios_universal_lib.sh
```
Under the hood this uses a makefile in `tensorflow/contrib/lite` to build the
different versions of the library, followed by a call to `lipo` to bundle them
into a universal file containing armv7, armv7s, arm64, i386, and x86_64
architectures. The resulting library is in
-`tensorflow/contrib/lite/gen/lib/libtensorflow-lite.a`.
+`tensorflow/contrib/lite/tools/make/gen/lib/libtensorflow-lite.a`.
If you get an error such as `no such file or directory: 'x86_64'` when running
`build_ios_universal_lib.sh`: open Xcode > Preferences > Locations, and ensure
diff --git a/tensorflow/contrib/lite/g3doc/models.md b/tensorflow/contrib/lite/g3doc/models.md
index 4ceb9a53dc..a4267eee4c 100644
--- a/tensorflow/contrib/lite/g3doc/models.md
+++ b/tensorflow/contrib/lite/g3doc/models.md
@@ -1,66 +1,70 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# List of Hosted Models
## Image classification (Float Models)
-Model Name | Paper_Model_Files^ | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance^^ | Tensorflow Performance
-------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | --------------------: | ---------------------:
-DenseNet | [paper](https://arxiv.org/abs/1608.06993), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/densenet_2018_04_27.tgz) | 43.6 Mb | 64.2% | 85.6% | 894 ms | 1262 ms
-SqueezeNet | [paper](https://arxiv.org/abs/1602.07360), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz) | 5.0 Mb | 49.0% | 72.9% | 224 ms | 255 ms
-NASNet mobile | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz) | 21.4 Mb | 72.2% | 90.6% | 261 ms | 389 ms
-NASNet large | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_large_2018_04_27.tgz) | 355.3 Mb | 82.1% | 95.8% | 6697 ms | 7940 ms
-ResNet_V2_50 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/resnet_v2_50_2018_04_27.tgz) | 102.3 Mb | 68.1% | 88.4% | 942 ms | 1008 ms
-ResNet_V2_101 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/resnet_v2_101_2018_04_27.tgz) | 178.3 Mb | 70.4% | 89.6% | 1880 ms | 1970 ms
-Inception_V3 | [paper](http://arxiv.org/abs/1512.00567), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz) | 95.3 Mb | 76.9% | 93.5% | 1433 ms | 1522 ms
-Inception_V4 | [paper](http://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz) | 170.7 Mb | 79.6% | 94.6% | 2986 ms | 3139 ms
-Inception_ResNet_V2 | [paper](https://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz) | 121.0 Mb | 76.8% | 93.5% | 2731 ms | 2926 ms
-Mobilenet_0.25_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz) | 1.9 Mb | 41.5% | 66.3% | 6.2 ms | 13.0 ms
-Mobilenet_0.25_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160.tgz) | 1.9 Mb | 45.5% | 70.3% | 8.6 ms | 19.5 ms
-Mobilenet_0.25_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192.tgz) | 1.9 Mb | 47.7% | 72.3% | 12.1 ms | 27.8 ms
-Mobilenet_0.25_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224.tgz) | 1.9 Mb | 49.8% | 74.2% | 16.2 ms | 37.3 ms
-Mobilenet_0.50_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128.tgz) | 5.3 Mb | 56.3% | 79.4% | 18.1 ms | 29.9 ms
-Mobilenet_0.50_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160.tgz) | 5.3 Mb | 59.1% | 81.9% | 26.8 ms | 45.9 ms
-Mobilenet_0.50_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192.tgz) | 5.3 Mb | 61.7% | 83.6% | 35.6 ms | 65.3 ms
-Mobilenet_0.50_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224.tgz) | 5.3 Mb | 63.3% | 84.9% | 47.6 ms | 164.2 ms
-Mobilenet_0.75_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128.tgz) | 10.3 Mb | 62.1% | 83.9% | 34.6 ms | 48.7 ms
-Mobilenet_0.75_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160.tgz) | 10.3 Mb | 65.3% | 86.0% | 51.3 ms | 75.2 ms
-Mobilenet_0.75_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192.tgz) | 10.3 Mb | 67.2% | 87.3% | 71.7 ms | 107.0 ms
-Mobilenet_0.75_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224.tgz) | 10.3 Mb | 68.4% | 88.2% | 95.7 ms | 143.4 ms
-Mobilenet_1.0_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128.tgz) | 16.9 Mb | 65.2% | 85.8% | 57.4 ms | 76.8 ms
-Mobilenet_1.0_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160.tgz) | 16.9 Mb | 68.0% | 87.7% | 86.0 ms | 117.7 ms
-Mobilenet_1.0_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192.tgz) | 16.9 Mb | 70.0% | 89.2% | 118.6 ms | 167.3 ms
-Mobilenet_1.0_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz) | 16.9 Mb | 70.9% | 89.9% | 160.1 ms | 224.3 ms
+Model Name | Paper_Model_Files^ | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance^^ | Tensorflow Performance
+--------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | --------------------: | ---------------------:
+DenseNet | [paper](https://arxiv.org/abs/1608.06993), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/densenet_2018_04_27.tgz) | 43.6 Mb | 64.2% | 85.6% | 894 ms | 1262 ms
+SqueezeNet | [paper](https://arxiv.org/abs/1602.07360), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/squeezenet_2018_04_27.tgz) | 5.0 Mb | 49.0% | 72.9% | 224 ms | 255 ms
+NASNet mobile | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_mobile_2018_04_27.tgz) | 21.4 Mb | 73.9% | 91.5% | 261 ms | 389 ms
+NASNet large | [paper](https://arxiv.org/abs/1707.07012), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/nasnet_large_2018_04_27.tgz) | 355.3 Mb | 82.6% | 96.1% | 6697 ms | 7940 ms
+ResNet_V2_101 | [paper](https://arxiv.org/abs/1603.05027), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/resnet_v2_101.tgz) | 178.3 Mb | 76.8% | 93.6% | 1880 ms | 1970 ms
+Inception_V3 | [paper](http://arxiv.org/abs/1512.00567), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v3_2018_04_27.tgz) | 95.3 Mb | 77.9% | 93.8% | 1433 ms | 1522 ms
+Inception_V4 | [paper](http://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_v4_2018_04_27.tgz) | 170.7 Mb | 80.1% | 95.1% | 2986 ms | 3139 ms
+Inception_ResNet_V2 | [paper](https://arxiv.org/abs/1602.07261), [tflite&pb](https://storage.googleapis.com/download.tensorflow.org/models/tflite/model_zoo/upload_20180427/inception_resnet_v2_2018_04_27.tgz) | 121.0 Mb | 77.5% | 94.0% | 2731 ms | 2926 ms
+Mobilenet_V1_0.25_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_128.tgz) | 1.9 Mb | 41.4% | 66.2% | 6.2 ms | 13.0 ms
+Mobilenet_V1_0.25_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_160.tgz) | 1.9 Mb | 45.4% | 70.2% | 8.6 ms | 19.5 ms
+Mobilenet_V1_0.25_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_192.tgz) | 1.9 Mb | 47.1% | 72.0% | 12.1 ms | 27.8 ms
+Mobilenet_V1_0.25_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.25_224.tgz) | 1.9 Mb | 49.7% | 74.1% | 16.2 ms | 37.3 ms
+Mobilenet_V1_0.50_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_128.tgz) | 5.3 Mb | 56.2% | 79.3% | 18.1 ms | 29.9 ms
+Mobilenet_V1_0.50_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_160.tgz) | 5.3 Mb | 59.0% | 81.8% | 26.8 ms | 45.9 ms
+Mobilenet_V1_0.50_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_192.tgz) | 5.3 Mb | 61.7% | 83.5% | 35.6 ms | 65.3 ms
+Mobilenet_V1_0.50_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.5_224.tgz) | 5.3 Mb | 63.2% | 84.9% | 47.6 ms | 164.2 ms
+Mobilenet_V1_0.75_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_128.tgz) | 10.3 Mb | 62.0% | 83.8% | 34.6 ms | 48.7 ms
+Mobilenet_V1_0.75_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_160.tgz) | 10.3 Mb | 65.2% | 85.9% | 51.3 ms | 75.2 ms
+Mobilenet_V1_0.75_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_192.tgz) | 10.3 Mb | 67.1% | 87.2% | 71.7 ms | 107.0 ms
+Mobilenet_V1_0.75_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_0.75_224.tgz) | 10.3 Mb | 68.3% | 88.1% | 95.7 ms | 143.4 ms
+Mobilenet_V1_1.0_128 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_128.tgz) | 16.9 Mb | 65.2% | 85.7% | 57.4 ms | 76.8 ms
+Mobilenet_V1_1.0_160 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_160.tgz) | 16.9 Mb | 68.0% | 87.7% | 86.0 ms | 117.7 ms
+Mobilenet_V1_1.0_192 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_192.tgz) | 16.9 Mb | 69.9% | 89.1% | 118.6 ms | 167.3 ms
+Mobilenet_V1_1.0_224 | [paper](https://arxiv.org/pdf/1704.04861.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/mobilenet_v1_1.0_224.tgz) | 16.9 Mb | 71.0% | 89.9% | 160.1 ms | 224.3 ms
+Mobilenet_V2_1.0_224 | [paper](https://arxiv.org/pdf/1801.04381.pdf), [tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224.tgz) | 14.0 Mb | 71.8% | 90.6% | 117 ms |
^ The model files include both TF Lite FlatBuffer and Tensorflow frozen Graph.
^^ The performance numbers are generated in the benchmark on Pixel-2 using
single thread large core.
+^^ Accuracy numbers were computed using the
+[TFLite accuracy tool](../tools/accuracy/ilsvrc) .
+
## Image classification (Quantized Models)
-Model Name | Paper_Model_Files | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance
------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ------------------:
-Mobilenet_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb | 39.5% | 64.4% | 3.7 ms
-Mobilenet_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb | 43.4% | 68.5% | 5.5 ms
-Mobilenet_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb | 46.0% | 71.2% | 7.9 ms
-Mobilenet_0.25_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_224_quant.tgz) | 0.5 Mb | 48.0% | 72.8% | 10.4 ms
-Mobilenet_0.50_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_128_quant.tgz) | 1.4 Mb | 54.5% | 77.7% | 8.8 ms
-Mobilenet_0.50_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_160_quant.tgz) | 1.4 Mb | 57.7% | 80.4% | 13.0 ms
-Mobilenet_0.50_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_192_quant.tgz) | 1.4 Mb | 60.0% | 82.2% | 18.3 ms
-Mobilenet_0.50_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_224_quant.tgz) | 1.4 Mb | 60.7% | 83.2% | 24.7 ms
-Mobilenet_0.75_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_128_quant.tgz) | 2.6 Mb | 55.8% | 78.8% | 16.2 ms
-Mobilenet_0.75_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_160_quant.tgz) | 2.6 Mb | 62.3% | 83.8% | 24.3 ms
-Mobilenet_0.75_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_192_quant.tgz) | 2.6 Mb | 66.1% | 86.4% | 33.8 ms
-Mobilenet_0.75_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_224_quant.tgz) | 2.6 Mb | 66.8% | 87.0% | 45.4 ms
-Mobilenet_1.0_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_128_quant.tgz) | 4.3 Mb | 63.4% | 84.2% | 24.9 ms
-Mobilenet_1.0_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_160_quant.tgz) | 4.3 Mb | 67.2% | 86.7% | 37.4 ms
-Mobilenet_1.0_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_192_quant.tgz) | 4.3 Mb | 69.2% | 88.3% | 51.9 ms
-Mobilenet_1.0_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz) | 4.3 Mb | 70.1% | 88.9% | 70.2 ms
+Model Name | Paper_Model_Files | Model_Size | Top-1 Accuracy | Top-5 Accuracy | TF Lite Performance
+--------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------: | ---------: | -------------: | -------------: | ------------------:
+Mobilenet_V1_0.25_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_128_quant.tgz) | 0.5 Mb | 39.5% | 64.4% | 3.7 ms
+Mobilenet_V1_0.25_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_160_quant.tgz) | 0.5 Mb | 42.8% | 68.1% | 5.5 ms
+Mobilenet_V1_0.25_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_192_quant.tgz) | 0.5 Mb | 45.7% | 70.8% | 7.9 ms
+Mobilenet_V1_0.25_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.25_224_quant.tgz) | 0.5 Mb | 48.2% | 72.8% | 10.4 ms
+Mobilenet_V1_0.50_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_128_quant.tgz) | 1.4 Mb | 54.9% | 78.1% | 8.8 ms
+Mobilenet_V1_0.50_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_160_quant.tgz) | 1.4 Mb | 57.2% | 80.5% | 13.0 ms
+Mobilenet_V1_0.50_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_192_quant.tgz) | 1.4 Mb | 59.9% | 82.1% | 18.3 ms
+Mobilenet_V1_0.50_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.5_224_quant.tgz) | 1.4 Mb | 61.2% | 83.2% | 24.7 ms
+Mobilenet_V1_0.75_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_128_quant.tgz) | 2.6 Mb | 55.9% | 79.1% | 16.2 ms
+Mobilenet_V1_0.75_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_160_quant.tgz) | 2.6 Mb | 62.4% | 83.7% | 24.3 ms
+Mobilenet_V1_0.75_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_192_quant.tgz) | 2.6 Mb | 66.1% | 86.2% | 33.8 ms
+Mobilenet_V1_0.75_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_0.75_224_quant.tgz) | 2.6 Mb | 66.9% | 86.9% | 45.4 ms
+Mobilenet_V1_1.0_128_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_128_quant.tgz) | 4.3 Mb | 63.3% | 84.1% | 24.9 ms
+Mobilenet_V1_1.0_160_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_160_quant.tgz) | 4.3 Mb | 66.9% | 86.7% | 37.4 ms
+Mobilenet_V1_1.0_192_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_192_quant.tgz) | 4.3 Mb | 69.1% | 88.1% | 51.9 ms
+Mobilenet_V1_1.0_224_quant | [paper](https://arxiv.org/pdf/1712.05877.pdf), [tflite&pb](http://download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz) | 4.3 Mb | 70.0% | 89.0% | 70.2 ms
+Mobilenet_v2_1.0_224_quant | [paper](https://arxiv.org/abs/1806.08342), [tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224_quant.tgz) | 3.4 Mb | 70.8% | 89.9% | 80.3 ms
+Inception_v3_quant | [paper](https://arxiv.org/abs/1806.08342),[tflite&pb](http://download.tensorflow.org/models/tflite_11_05_08/inception_v3_quant.tgz) | 23 Mb | 77.5% | 93.7% | 637 ms
## Other models
Model | TF Lite FlatBuffer
----------------------- | :----------------:
-Smart Reply 1.0 Android | [reference](https://research.googleblog.com/2017/11/on-device-conversational-modeling-with.html), [tflite](https://storage.googleapis.com/download.tensorflow.org/models/smartreply_1.0_2017_11_01.zip)
+[reference](https://research.googleblog.com/2017/11/on-device-conversational-modeling-with.html),
+[tflite](https://storage.googleapis.com/download.tensorflow.org/models/smartreply_1.0_2017_11_01.zip)
diff --git a/tensorflow/contrib/lite/g3doc/ops_versioning.md b/tensorflow/contrib/lite/g3doc/ops_versioning.md
index b06f4fd3b8..0d571ce547 100644
--- a/tensorflow/contrib/lite/g3doc/ops_versioning.md
+++ b/tensorflow/contrib/lite/g3doc/ops_versioning.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# TensorFlow Lite Ops Versioning
diff --git a/tensorflow/contrib/lite/g3doc/overview.md b/tensorflow/contrib/lite/g3doc/overview.md
index be60d7941a..8cf43496df 100644
--- a/tensorflow/contrib/lite/g3doc/overview.md
+++ b/tensorflow/contrib/lite/g3doc/overview.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Introduction to TensorFlow Lite
diff --git a/tensorflow/contrib/lite/g3doc/performance.md b/tensorflow/contrib/lite/g3doc/performance.md
index 5cd0aab44f..28cb6aba6e 100644
--- a/tensorflow/contrib/lite/g3doc/performance.md
+++ b/tensorflow/contrib/lite/g3doc/performance.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Performance
diff --git a/tensorflow/contrib/lite/g3doc/rpi.md b/tensorflow/contrib/lite/g3doc/rpi.md
index 9fcf79ba00..41a1892b6f 100644
--- a/tensorflow/contrib/lite/g3doc/rpi.md
+++ b/tensorflow/contrib/lite/g3doc/rpi.md
@@ -1,30 +1,36 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
-
# TensorFlow Lite for Raspberry Pi
## Cross compiling
-### Installing toolchian
-This has been tested on Ubuntu 16.04.3 64bit and Tensorflow devel docker image [tensorflow/tensorflow:nightly-devel](https://hub.docker.com/r/tensorflow/tensorflow/tags/).
-To cross compiling TensorFlow Lite. First you should install the toolchain and libs.
+### Installing the toolchain
+
+This has been tested on Ubuntu 16.04.3 64bit and Tensorflow devel docker image
+[tensorflow/tensorflow:nightly-devel](https://hub.docker.com/r/tensorflow/tensorflow/tags/).
+
+To cross compile TensorFlow Lite, first install the toolchain and libs.
+
```bash
sudo apt-get update
sudo apt-get install crossbuild-essential-armhf
```
-> If you are using docker, you may not use `sudo`
+
+> If you are using Docker, you may not use `sudo`.
### Building
+
Clone this Tensorflow repository, Run this script at the root of the repository to download all the dependencies:
+
> The Tensorflow repository is in `/tensorflow` if you are using `tensorflow/tensorflow:nightly-devel` docker image, just try it.
+
```bash
-./tensorflow/contrib/lite/download_dependencies.sh
+./tensorflow/contrib/lite/tools/make/download_dependencies.sh
```
Note that you only need to do this once.
You should then be able to compile:
+
```bash
-./tensorflow/contrib/lite/build_rpi_lib.sh
+./tensorflow/contrib/lite/tools/make/build_rpi_lib.sh
```
This should compile a static library in:
@@ -33,21 +39,23 @@ This should compile a static library in:
## Native compiling
This has been tested on Raspberry Pi 3b, Raspbian GNU/Linux 9.1 (stretch), gcc version 6.3.0 20170516 (Raspbian 6.3.0-18+rpi1).
-Log in to you RPI, install the toolchain.
+Log in to you Raspberry Pi, install the toolchain.
+
```bash
sudo apt-get install build-essential
```
-First, clone this TensorFlow repository. Run this at the root of the repository:
+First, clone the TensorFlow repository. Run this at the root of the repository:
+
```bash
-./tensorflow/contrib/lite/download_dependencies.sh
+./tensorflow/contrib/lite/tools/make/download_dependencies.sh
```
Note that you only need to do this once.
You should then be able to compile:
```bash
-./tensorflow/contrib/lite/build_rpi_lib.sh
+./tensorflow/contrib/lite/tools/make/build_rpi_lib.sh
```
This should compile a static library in:
-`tensorflow/contrib/lite/gen/lib/rpi_armv7/libtensorflow-lite.a`.
+`tensorflow/contrib/lite/tools/make/gen/lib/rpi_armv7/libtensorflow-lite.a`.
diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
index aa65ec9988..8660d29855 100644
--- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
+++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# TensorFlow Lite & TensorFlow Compatibility Guide
@@ -843,6 +841,31 @@ Outputs {
}
```
+**UNPACK**
+
+```
+Inputs {
+ 0: a tensor.
+ 1: an integer.
+ 2: an integer.
+}
+Outputs {
+ 0-N: tensors of unpacked tensor.
+}
+```
+
+**FLOOR_DIV**
+
+```
+Inputs {
+ 0: a list of tensors.
+ 1: a list of tensors.
+}
+Outputs {
+ 0: A tensor of floor_div output tensors.
+}
+```
+
And these are TensorFlow Lite operations that are present but not ready for
custom models yet:
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
index 76e16fc9db..c7cdee07de 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Building TensorFlow on Android
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/index.md b/tensorflow/contrib/lite/g3doc/tfmobile/index.md
index bd047bfcec..d003bb2f38 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/index.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/index.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Overview
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md
index 6223707892..be8b4100c8 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Building TensorFlow on iOS
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md b/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md
index 4c2071ed05..4d4bb3bc08 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Integrating TensorFlow libraries
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md b/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md
index a0192c3541..7436594fd8 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Optimizing for mobile
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md b/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md
index 6b4e4a92bd..d1c67d4c61 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md
@@ -1,5 +1,3 @@
-book_path: /mobile/_book.yaml
-project_path: /mobile/_project.yaml
# Preparing models for mobile deployment
diff --git a/tensorflow/contrib/lite/graph_info.h b/tensorflow/contrib/lite/graph_info.h
index 77268d7aeb..8ee83827bb 100644
--- a/tensorflow/contrib/lite/graph_info.h
+++ b/tensorflow/contrib/lite/graph_info.h
@@ -17,7 +17,7 @@ limitations under the License.
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 362e588725..3f8f4d198f 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -21,9 +21,9 @@ limitations under the License.
#include <cstring>
#include "tensorflow/contrib/lite/arena_planner.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/context_util.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/graph_info.h"
#include "tensorflow/contrib/lite/memory_planner.h"
#include "tensorflow/contrib/lite/nnapi_delegate.h"
@@ -476,6 +476,10 @@ TfLiteStatus Interpreter::ResetVariableTensorsToZero() {
return kTfLiteOk;
}
+void Interpreter::ReserveNodes(int count) {
+ nodes_and_registration_.reserve(count);
+}
+
TfLiteStatus Interpreter::AddNodeWithParameters(
const std::vector<int>& inputs, const std::vector<int>& outputs,
const char* init_data, size_t init_data_size, void* builtin_data,
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index 7d69aa2ad3..f0cd178c19 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -23,10 +23,11 @@ limitations under the License.
#include <vector>
#include "tensorflow/contrib/lite/allocation.h"
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/memory_planner.h"
#include "tensorflow/contrib/lite/profiling/profiler.h"
+#include "tensorflow/contrib/lite/stderr_reporter.h"
namespace tflite {
@@ -136,6 +137,11 @@ class Interpreter {
// interpreter.
TfLiteStatus SetVariables(std::vector<int> variables);
+ // Ensure the internal node storage memory allocates at least `count`
+ // spots for node. NOTE, this doesn't actually add operators. This is an
+ // efficiency optimization that is subject to change.
+ void ReserveNodes(int count);
+
// Adds a node with the given parameters and returns the index of the new
// node in `node_index` (optionally). Interpreter will take ownership of
// `builtin_data` and destroy it with `free`. Ownership of 'init_data'
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
index 5bcf0927d8..cdede430e2 100644
--- a/tensorflow/contrib/lite/interpreter_test.cc
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/interpreter.h"
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
diff --git a/tensorflow/contrib/lite/java/demo/README.md b/tensorflow/contrib/lite/java/demo/README.md
index e3cea19e16..6a3f0651d0 100644
--- a/tensorflow/contrib/lite/java/demo/README.md
+++ b/tensorflow/contrib/lite/java/demo/README.md
@@ -20,9 +20,6 @@ code to merge.
- Make sure to install the latest version of Bazel. Some distributions
ship with Bazel 0.5.4, which is too old.
- Bazel requires Android Build Tools `26.0.1` or higher.
- - **Bazel is incompatible with NDK revisions 15 and above,** with revision
- 16 being a compile-breaking change. [Download an older version manually
- instead of using the SDK Manager.](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install-bazel-and-android-prerequisites)
- You also need to install the Android Support Repository, available
through Android Studio under `Android SDK Manager -> SDK Tools ->
Android Support Repository`.
@@ -37,8 +34,7 @@ code to merge.
- Make sure the `api_level` in `WORKSPACE` is set to an SDK version that
you have installed.
- By default, Android Studio will install the SDK to `~/Android/Sdk` and
- the NDK to `~/Android/Sdk/ndk-bundle` (but the NDK should be a manual
- download until Bazel supports NDK 16. See bullet points under (1)).
+ the NDK to `~/Android/Sdk/ndk-bundle`.
2. Build the app with Bazel. The demo needs C++11:
diff --git a/tensorflow/contrib/lite/java/demo/app/build.gradle b/tensorflow/contrib/lite/java/demo/app/build.gradle
index 92f04c651c..05301ebf88 100644
--- a/tensorflow/contrib/lite/java/demo/app/build.gradle
+++ b/tensorflow/contrib/lite/java/demo/app/build.gradle
@@ -10,7 +10,6 @@ android {
targetSdkVersion 26
versionCode 1
versionName "1.0"
- testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
// Remove this block.
jackOptions {
@@ -44,9 +43,6 @@ repositories {
dependencies {
compile fileTree(dir: 'libs', include: ['*.jar'])
- androidTestCompile('androidx.test.espresso:espresso-core:3.1.0-alpha3', {
- exclude group: 'com.android.support', module: 'support-annotations'
- })
compile 'com.android.support:appcompat-v7:25.2.0'
compile 'com.android.support.constraint:constraint-layout:1.0.2'
compile 'com.android.support:design:25.2.0'
@@ -54,8 +50,6 @@ dependencies {
compile 'com.android.support:support-v13:25.2.0'
compile 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
-
- testCompile 'junit:junit:4.12'
}
def modelDownloadUrl = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip"
diff --git a/tensorflow/contrib/lite/java/ovic/BUILD b/tensorflow/contrib/lite/java/ovic/BUILD
index 06f46fb923..781289ceb2 100644
--- a/tensorflow/contrib/lite/java/ovic/BUILD
+++ b/tensorflow/contrib/lite/java/ovic/BUILD
@@ -35,6 +35,7 @@ java_binary(
"//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt",
],
main_class = "org.tensorflow.ovic.OvicValidator",
+ tags = ["no_oss"],
deps = [
"//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib_java",
],
@@ -47,6 +48,7 @@ android_library(
"src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java",
],
manifest = "//tensorflow/contrib/lite/java:AndroidManifest.xml",
+ tags = ["no_oss"],
deps = [
"//tensorflow/contrib/lite/java:tensorflowlite",
"//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper",
@@ -61,6 +63,7 @@ java_library(
"src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java",
],
javacopts = JAVACOPTS,
+ tags = ["no_oss"],
deps = [
"//tensorflow/contrib/lite/java:libtensorflowlite_jni.so",
"//tensorflow/contrib/lite/java:tensorflowlite_java",
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle b/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle
index 2a08608bbb..4f3a6cdb2f 100644
--- a/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/build.gradle
@@ -9,7 +9,6 @@ android {
targetSdkVersion 26
versionCode 1
versionName "1.0"
- testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
// Remove this block.
jackOptions {
@@ -43,9 +42,6 @@ repositories {
dependencies {
compile fileTree(dir: 'libs', include: ['*.jar'])
- androidTestCompile('androidx.test.espresso:espresso-core:3.1.0-alpha3', {
- exclude group: 'com.android.support', module: 'support-annotations'
- })
compile 'com.android.support:appcompat-v7:25.2.0'
compile 'com.android.support.constraint:constraint-layout:1.0.2'
compile 'com.android.support:design:25.2.0'
@@ -53,6 +49,4 @@ dependencies {
compile 'com.android.support:support-v13:25.2.0'
compile 'org.tensorflow:tensorflow-lite:+'
-
- testCompile 'junit:junit:4.12'
}
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
index 55ca47fed7..06b35d77c8 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
@@ -20,7 +20,7 @@ limitations under the License.
#include <stdio.h>
#include <time.h>
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h"
#include "tensorflow/contrib/lite/java/src/main/native/tensor_jni.h"
@@ -124,9 +124,9 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env,
*/
JNIEXPORT void JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
- jclass clazz,
- jlong handle,
- jint num_threads);
+ jclass clazz,
+ jlong handle,
+ jint num_threads);
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
* Method:
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
index c020f13d9c..2f73128bdf 100644
--- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
@@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_JAVA_SRC_MAIN_NATIVE_TENSOR_JNI_H_
#include <jni.h>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#ifdef __cplusplus
extern "C" {
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index 1f528fdab9..40f28aeab4 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -6,7 +6,7 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite")
-load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_opts_nortti_if_android")
# Suppress warnings that are introduced by Eigen Tensor.
EXTRA_EIGEN_COPTS = select({
@@ -66,7 +66,7 @@ cc_library(
deps = [
":op_macros",
"//tensorflow/contrib/lite:arena_planner",
- "//tensorflow/contrib/lite:context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels/internal:optimized",
],
)
@@ -82,7 +82,7 @@ cc_library(
copts = tflite_copts(),
deps = [
":op_macros",
- "//tensorflow/contrib/lite:context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"@gemmlowp",
],
)
@@ -93,7 +93,7 @@ cc_library(
"activation_functor.h",
],
deps = [
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
],
)
@@ -113,9 +113,9 @@ cc_library(
"kernel_util.h",
],
deps = [
- "//tensorflow/contrib/lite:builtin_op_data",
- "//tensorflow/contrib/lite:context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels/internal:round",
+ "//tensorflow/contrib/lite/kernels/internal:types",
],
)
@@ -147,7 +147,16 @@ tf_cc_test(
)
cc_library(
- name = "builtin_ops",
+ name = "padding",
+ srcs = [],
+ hdrs = ["padding.h"],
+ deps = [
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ ],
+)
+
+cc_library(
+ name = "builtin_op_kernels",
srcs = [
"activations.cc",
"add.cc",
@@ -172,10 +181,12 @@ cc_library(
"expand_dims.cc",
"fake_quant.cc",
"floor.cc",
+ "floor_div.cc",
"fully_connected.cc",
"gather.cc",
"hashtable_lookup.cc",
"l2norm.cc",
+ "layer_norm_lstm.cc",
"local_response_norm.cc",
"logical.cc",
"lsh_projection.cc",
@@ -190,7 +201,7 @@ cc_library(
"pooling.cc",
"pow.cc",
"reduce.cc",
- "register.cc",
+ "relu1.cc",
"reshape.cc",
"resize_bilinear.cc",
"select.cc",
@@ -211,35 +222,48 @@ cc_library(
"transpose_conv.cc",
"unidirectional_sequence_lstm.cc",
"unidirectional_sequence_rnn.cc",
+ "unpack.cc",
],
hdrs = [
- "padding.h",
- "register.h",
],
- copts = tflite_copts() + EXTRA_EIGEN_COPTS,
+ copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS,
+ visibility = ["//visibility:private"],
deps = [
":activation_functor",
":eigen_support",
":kernel_util",
":op_macros",
- "//tensorflow/contrib/lite:builtin_op_data",
+ ":padding",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string_util",
"//tensorflow/contrib/lite:util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:gemm_support",
"//tensorflow/contrib/lite/kernels/internal:audio_utils",
"//tensorflow/contrib/lite/kernels/internal:kernel_utils",
"//tensorflow/contrib/lite/kernels/internal:optimized",
"//tensorflow/contrib/lite/kernels/internal:optimized_base",
"//tensorflow/contrib/lite/kernels/internal:quantization_util",
- "//tensorflow/contrib/lite/kernels/internal:reference",
"//tensorflow/contrib/lite/kernels/internal:reference_base",
+ "//tensorflow/contrib/lite/kernels/internal:tensor",
"//tensorflow/contrib/lite/kernels/internal:tensor_utils",
"@farmhash_archive//:farmhash",
"@flatbuffers",
],
)
+cc_library(
+ name = "builtin_ops",
+ srcs = ["register.cc"],
+ hdrs = ["register.h"],
+ deps = [
+ ":builtin_op_kernels",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:util",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ ],
+)
+
tf_cc_test(
name = "audio_spectrogram_test",
size = "small",
@@ -292,6 +316,23 @@ tf_cc_test(
)
tf_cc_test(
+ name = "relu1_test",
+ size = "small",
+ srcs = ["relu1_test.cc"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ "@flatbuffers",
+ ],
+)
+
+tf_cc_test(
name = "activations_test",
size = "small",
srcs = ["activations_test.cc"],
@@ -726,8 +767,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -743,8 +784,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -902,6 +943,20 @@ tf_cc_test(
)
tf_cc_test(
+ name = "layer_norm_lstm_test",
+ size = "small",
+ srcs = ["layer_norm_lstm_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ "@flatbuffers",
+ ],
+)
+
+tf_cc_test(
name = "lstm_test",
size = "small",
srcs = ["lstm_test.cc"],
@@ -999,8 +1054,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -1102,8 +1157,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -1119,8 +1174,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -1136,8 +1191,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -1153,8 +1208,8 @@ tf_cc_test(
],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -1167,8 +1222,8 @@ tf_cc_test(
tags = ["tflite_not_portable_ios"],
deps = [
":builtin_ops",
- "//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest",
],
@@ -1194,6 +1249,34 @@ tf_cc_test(
tags = ["tflite_not_portable_ios"],
deps = [
":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "unpack_test",
+ size = "small",
+ srcs = ["unpack_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
+ name = "floor_div_test",
+ size = "small",
+ srcs = ["floor_div_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
"//tensorflow/contrib/lite:builtin_op_data",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite/kernels:test_util",
diff --git a/tensorflow/contrib/lite/kernels/activation_functor.h b/tensorflow/contrib/lite/kernels/activation_functor.h
index 41ec3cca33..e075dc7054 100644
--- a/tensorflow/contrib/lite/kernels/activation_functor.h
+++ b/tensorflow/contrib/lite/kernels/activation_functor.h
@@ -19,7 +19,7 @@ limitations under the License.
#include <cmath>
#include <cstdlib>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc
index d6d62580e2..b2d9b84979 100644
--- a/tensorflow/contrib/lite/kernels/activations.cc
+++ b/tensorflow/contrib/lite/kernels/activations.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
@@ -200,7 +200,7 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, input->type, output->type);
const int num_dims = NumDimensions(input);
- TF_LITE_ENSURE(context, num_dims == 1 || num_dims == 2 || num_dims == 4);
+ TF_LITE_ENSURE(context, num_dims >= 1 && num_dims <= 4);
if (input->type == kTfLiteUInt8) {
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
@@ -453,6 +453,19 @@ void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output,
Softmax(input->data.f, input_size, batch_size, params->beta, output->data.f);
}
+// Takes a 3D tensor and perform softmax along the last dimension.
+void Softmax3DFloat(const TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params) {
+ const int batch_size = input->dims->data[0];
+ const int intermediate_size = input->dims->data[1];
+ const int input_size = input->dims->data[2];
+ optimized_ops::Softmax(
+ GetTensorData<float>(input),
+ GetTensorShape({batch_size, intermediate_size, 1, input_size}),
+ params->beta, GetTensorData<float>(output),
+ GetTensorShape({batch_size, intermediate_size, 1, input_size}));
+}
+
void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params, OpData* data) {
// TODO(ahentz): this is arguably a dirty trick. Since the implementation
@@ -480,6 +493,19 @@ void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
GetTensorShape({batch_size, 1, 1, input_size}));
}
+void Softmax3DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params, OpData* data) {
+ const int batch_size = input->dims->data[0];
+ const int intermediate_size = input->dims->data[1];
+ const int input_size = input->dims->data[2];
+ optimized_ops::Softmax(
+ GetTensorData<uint8_t>(input),
+ GetTensorShape({batch_size, intermediate_size, 1, input_size}),
+ data->input_multiplier, data->input_left_shift, data->diff_min,
+ GetTensorData<uint8_t>(output),
+ GetTensorShape({batch_size, intermediate_size, 1, input_size}));
+}
+
// Takes a 4D tensor and perform softmax along the forth dimension.
void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output,
TfLiteSoftmaxParams* params) {
@@ -515,6 +541,10 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
Softmax2DFloat(input, output, params);
return kTfLiteOk;
}
+ if (NumDimensions(input) == 3) {
+ Softmax3DFloat(input, output, params);
+ return kTfLiteOk;
+ }
if (NumDimensions(input) == 4) {
Softmax4DFloat(input, output, params);
return kTfLiteOk;
@@ -533,6 +563,10 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
Softmax2DQuantized(input, output, params, data);
return kTfLiteOk;
}
+ if (NumDimensions(input) == 3) {
+ Softmax3DQuantized(input, output, params, data);
+ return kTfLiteOk;
+ }
if (NumDimensions(input) == 4) {
Softmax4DQuantized(input, output, params, data);
return kTfLiteOk;
@@ -590,10 +624,10 @@ TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
input->type);
return kTfLiteError;
}
- reference_ops::BroadcastBinaryFunction<float, float, float>(
- GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(alpha), GetTensorDims(alpha),
- GetTensorData<float>(output), GetTensorDims(output), ApplyPrelu<float>);
+ reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>(
+ GetTensorShape(input), GetTensorData<float>(input), GetTensorShape(alpha),
+ GetTensorData<float>(alpha), GetTensorShape(output),
+ GetTensorData<float>(output), ApplyPrelu<float>);
return kTfLiteOk;
}
diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc
index e577e3a762..9fa47e190a 100644
--- a/tensorflow/contrib/lite/kernels/activations_test.cc
+++ b/tensorflow/contrib/lite/kernels/activations_test.cc
@@ -339,6 +339,76 @@ TEST(QuantizedActivationsOpTest, Softmax4D) {
kQuantizedTolerance)));
}
+TEST(FloatActivationsOpTest, Softmax3D) {
+ FloatActivationsOpModel m(0.1,
+ /*input=*/{TensorType_FLOAT32, {1, 2, 4}});
+ m.SetInput({
+ 0, -6, 2, 4, // depth = 0
+ 3, -2, 10, 1, // depth = 1
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ .23463, .12877, .28658, .35003, //
+ .22528, .13664, .45365, .18443, //
+ })));
+
+ // Same input, but a different shape.
+ FloatActivationsOpModel m2(0.1,
+ /*input=*/{TensorType_FLOAT32, {4, 1, 2}});
+ m2.SetInput({
+ 0, -6, //
+ 2, 4, //
+ 3, -2, //
+ 10, 1, //
+ });
+ m2.Invoke();
+ EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({
+ 0.645656, 0.354344, //
+ 0.450166, 0.549834, //
+ 0.622459, 0.377541, //
+ 0.710949, 0.28905, //
+ })));
+}
+
+TEST(QuantizedActivationsOpTest, Softmax3D) {
+ QuantizedActivationsOpModel m(
+ 0.1,
+ /*input=*/{TensorType_UINT8, {1, 2, 4}, -10, 10});
+ m.SetInput<uint8_t>({
+ 0, -6, 2, 4, // depth = 0
+ 3, -2, 10, 1, // depth = 1
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ .23463, .12877, .28658, .35003, //
+ .22528, .13664, .45365, .18443, //
+ },
+ kQuantizedTolerance)));
+
+ // Same input, but a different shape.
+ QuantizedActivationsOpModel m2(
+ 0.1,
+ /*input=*/{TensorType_UINT8, {4, 1, 2}, -10, 10});
+ m2.SetInput<uint8_t>({
+ 0, -6, //
+ 2, 4, //
+ 3, -2, //
+ 10, 1, //
+ });
+ m2.Invoke();
+ EXPECT_THAT(m2.GetDequantizedOutput<uint8_t>(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 0.645656, 0.354344, //
+ 0.450166, 0.549834, //
+ 0.622459, 0.377541, //
+ 0.710949, 0.28905, //
+ },
+ kQuantizedTolerance)));
+}
+
TEST(FloatActivationsOpTest, Softmax1D) {
FloatActivationsOpModel m(0.1,
/*input=*/{TensorType_FLOAT32, {8}});
diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc
index af9b5c7013..b4393e8097 100644
--- a/tensorflow/contrib/lite/kernels/add.cc
+++ b/tensorflow/contrib/lite/kernels/add.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/arg_min_max.cc b/tensorflow/contrib/lite/kernels/arg_min_max.cc
index 4f30d09030..b91e348c27 100644
--- a/tensorflow/contrib/lite/kernels/arg_min_max.cc
+++ b/tensorflow/contrib/lite/kernels/arg_min_max.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -96,11 +96,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
const TfLiteTensor* axis = GetInput(context, node, kAxis);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-#define TF_LITE_ARG_MIN_MAX(data_type, axis_type, output_type) \
- optimized_ops::ArgMinMax( \
- GetTensorData<axis_type>(axis), GetTensorData<data_type>(input), \
- GetTensorDims(input), GetTensorData<output_type>(output), \
- GetTensorDims(output), GetComparefunction<data_type>(is_arg_max))
+#define TF_LITE_ARG_MIN_MAX(data_type, axis_type, output_type) \
+ optimized_ops::ArgMinMax( \
+ GetTensorShape(input), GetTensorData<data_type>(input), \
+ GetTensorData<axis_type>(axis), GetTensorShape(output), \
+ GetTensorData<output_type>(output), \
+ GetComparefunction<data_type>(is_arg_max))
if (axis->type == kTfLiteInt32) {
switch (output->type) {
case kTfLiteInt32: {
diff --git a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc
index 91d8dd3fa7..44ef587244 100644
--- a/tensorflow/contrib/lite/kernels/audio_spectrogram.cc
+++ b/tensorflow/contrib/lite/kernels/audio_spectrogram.cc
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/spectrogram.h"
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
-#include "flatbuffers/flexbuffers.h"
+#include "flatbuffers/flexbuffers.h" // flatbuffers
namespace tflite {
namespace ops {
diff --git a/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc b/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc
index 8d460fdfc6..7346b9fd80 100644
--- a/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc
+++ b/tensorflow/contrib/lite/kernels/audio_spectrogram_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
-#include "flatbuffers/flexbuffers.h"
+#include "flatbuffers/flexbuffers.h" // flatbuffers
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc
index c09b15b3d2..1aa27602e5 100644
--- a/tensorflow/contrib/lite/kernels/basic_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include <stddef.h>
#include <stdint.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -31,8 +31,10 @@ constexpr int kInputTensor = 0;
constexpr int kWeightsTensor = 1;
constexpr int kRecurrentWeightsTensor = 2;
constexpr int kBiasTensor = 3;
-constexpr int kHiddenStateTensor = 0;
-constexpr int kOutputTensor = 1;
+constexpr int kHiddenStateTensor = 4;
+
+// Output tensor.
+constexpr int kOutputTensor = 0;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* scratch_tensor_index = new int;
@@ -46,14 +48,16 @@ void Free(TfLiteContext* context, void* buffer) {
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* recurrent_weights =
GetInput(context, node, kRecurrentWeightsTensor);
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
+ const TfLiteTensor* hidden_state =
+ GetInput(context, node, kHiddenStateTensor);
// Check all the parameters of tensor match within themselves and match the
// input configuration.
@@ -65,20 +69,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, input_weights->type, recurrent_weights->type);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(hidden_state), 2);
+ TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[0], batch_size);
+ TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[1], num_units);
- TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- // Resize state.
- TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2);
- hidden_state_size_array->data[0] = batch_size;
- hidden_state_size_array->data[1] = num_units;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, hidden_state,
- hidden_state_size_array));
-
- // Mark hidden state as a persistent tensor.
- hidden_state->allocation_type = kTfLiteArenaRwPersistent;
-
// Resize output.
TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
output_size_array->data[0] = batch_size;
@@ -205,7 +201,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* recurrent_weights =
GetInput(context, node, kRecurrentWeightsTensor);
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
- TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor);
+ TfLiteTensor* hidden_state =
+ &context->tensors[node->inputs->data[kHiddenStateTensor]];
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
// We already checked that weight types are consistent, so branch on one.
diff --git a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc
index 96465fcaf0..d179735404 100644
--- a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc
+++ b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc
@@ -181,15 +181,16 @@ class RNNOpModel : public SingleOpModel {
weights_ = AddInput(weights);
recurrent_weights_ = AddInput(recurrent_weights);
bias_ = AddInput(TensorType_FLOAT32);
- hidden_state_ = AddOutput(TensorType_FLOAT32);
+ hidden_state_ = AddInput(TensorType_FLOAT32, true);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(
BuiltinOperator_RNN, BuiltinOptions_RNNOptions,
CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union());
- BuildInterpreter({{batches_, input_size_},
- {units_, input_size_},
- {units_, units_},
- {units_}});
+ BuildInterpreter({{batches_, input_size_}, // input tensor
+ {units_, input_size_}, // weights tensor
+ {units_, units_}, // recurrent weights tensor
+ {units_}, // bias tensor
+ {batches_, units_}}); // hidden state tensor
}
void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
@@ -210,14 +211,6 @@ class RNNOpModel : public SingleOpModel {
PopulateTensor(input_, offset, begin, end);
}
- void ResetHiddenState() {
- const int zero_buffer_size = units_ * batches_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(hidden_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
int input_size() { return input_size_; }
@@ -258,7 +251,6 @@ TEST(RnnOpTest, BlackBoxTest) {
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
(rnn.input_size() * rnn.num_batches());
@@ -286,7 +278,6 @@ TEST(HybridRnnOpTest, BlackBoxTest) {
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
(rnn.input_size() * rnn.num_batches());
diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
index c8cee88edf..fe2865dfb9 100644
--- a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
+++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -125,14 +125,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
#define TF_LITE_BATCH_TO_SPACE_ND(type, scalar) \
- type::BatchToSpaceND(GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), \
+ type::BatchToSpaceND(GetTensorShape(op_context.input), \
+ GetTensorData<scalar>(op_context.input), \
+ GetTensorShape(op_context.block_shape), \
GetTensorData<int32_t>(op_context.block_shape), \
- GetTensorDims(op_context.block_shape), \
+ GetTensorShape(op_context.crops), \
GetTensorData<int32_t>(op_context.crops), \
- GetTensorDims(op_context.crops), \
- GetTensorData<scalar>(op_context.output), \
- GetTensorDims(op_context.output))
+ GetTensorShape(op_context.output), \
+ GetTensorData<scalar>(op_context.output))
switch (op_context.input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
if (kernel_type == kReference) {
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
index a11a59aa05..541f320138 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
@@ -94,18 +94,54 @@ constexpr int kBwProjectionWeightsTensor = 33; // Optional
// Projection bias tensor of size {n_output}
constexpr int kBwProjectionBiasTensor = 34; // Optional
-// Output tensors.
-constexpr int kFwOutputStateTensor = 0;
-constexpr int kFwCellStateTensor = 1;
-constexpr int kFwOutputTensor = 2;
+// Stateful input tensors that are variables and will be modified by the Op.
+// Activation state tensors of size {n_batch, n_output}
+constexpr int kFwInputActivationStateTensor = 35;
+// Cell state tensors of size {n_batch, n_cell}
+constexpr int kFwInputCellStateTensor = 36;
+// Activation state tensors of size {n_batch, n_output}
+constexpr int kBwInputActivationStateTensor = 37;
+// Cell state tensors of size {n_batch, n_cell}
+constexpr int kBwInputCellStateTensor = 38;
+
+// Auxiliary input and weights when stacking.
+constexpr int kAuxInputTensor = 39; // Optional
+// Forward weights.
+constexpr int kFwAuxInputToInputWeightsTensor = 40; // Optional
+constexpr int kFwAuxInputToForgetWeightsTensor = 41; // Optional
+constexpr int kFwAuxInputToCellWeightsTensor = 42; // Optional
+constexpr int kFwAuxInputToOutputWeightsTensor = 43; // Optional
+// Backward weights.
+constexpr int kBwAuxInputToInputWeightsTensor = 44; // Optional
+constexpr int kBwAuxInputToForgetWeightsTensor = 45; // Optional
+constexpr int kBwAuxInputToCellWeightsTensor = 46; // Optional
+constexpr int kBwAuxInputToOutputWeightsTensor = 47; // Optional
-constexpr int kBwOutputStateTensor = 3;
-constexpr int kBwCellStateTensor = 4;
-constexpr int kBwOutputTensor = 5;
+// Output tensors.
+constexpr int kFwOutputTensor = 0;
+constexpr int kBwOutputTensor = 1;
+
+// Temporary tensors.
+enum TemporaryTensor {
+ // Scratch buffers for input, forget, etc. gates
+ kFwScratchBuffer = 0,
+ kBwScratchBuffer = 1,
+ // Quantized tensors needed for the hybrid kernel.
+ kInputQuantized = 2,
+ kAuxInputQuantized = 3, // Quantized tensor needed for auxiliary input.
+ kFwActivationStateQuantized = 4,
+ kBwActivationStateQuantized = 5,
+ kFwCellStateQuantized = 6,
+ kBwCellStateQuantized = 7,
+ kScalingFactors = 8,
+ kProductScalingFactors = 9,
+ kRecoveredCellWeights = 10,
+ kNumTemporaryTensors = 11
+};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* scratch_tensor_index = new int;
- context->AddTensors(context, 2, scratch_tensor_index);
+ context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
return scratch_tensor_index;
}
@@ -126,7 +162,7 @@ TfLiteStatus CheckLstmTensorDimensions(
int input_gate_bias_tensor, int forget_gate_bias_tensor,
int cell_gate_bias_tensor, int output_gate_bias_tensor,
int projection_weights_tensor, int projection_bias_tensor) {
- auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
// Making sure clipping parameters have valid values.
// == 0 means no clipping
@@ -307,19 +343,20 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
return kTfLiteOk;
}
-// Resize the output, state and scratch tensors based on the sizes of the input
+// Resize the output and scratch tensors based on the sizes of the input
// tensors. Also check that the size of the input tensors match each other.
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
// Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 35);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 6);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 48);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TF_LITE_ENSURE(context, input->dims->size > 1);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
const int max_time = input->dims->data[0];
const int n_batch = input->dims->data[1];
const int n_input = input->dims->data[2];
@@ -343,13 +380,63 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context, CheckInputTensorDimensions(context, node, n_input, n_fw_output,
n_fw_cell));
- // Get the pointer to output, state and scratch buffer tensors.
- TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
- TfLiteTensor* fw_output_state =
- GetOutput(context, node, kFwOutputStateTensor);
- TfLiteTensor* fw_cell_state = GetOutput(context, node, kFwCellStateTensor);
+ // Get (optional) auxiliary inputs and weights.
+ const TfLiteTensor* aux_input =
+ GetOptionalInputTensor(context, node, kAuxInputTensor);
+ const TfLiteTensor* fw_aux_input_to_input_weights =
+ GetOptionalInputTensor(context, node, kFwAuxInputToInputWeightsTensor);
+ const TfLiteTensor* fw_aux_input_to_forget_weights =
+ GetOptionalInputTensor(context, node, kFwAuxInputToForgetWeightsTensor);
+ const TfLiteTensor* fw_aux_input_to_cell_weights =
+ GetOptionalInputTensor(context, node, kFwAuxInputToCellWeightsTensor);
+ const TfLiteTensor* fw_aux_input_to_output_weights =
+ GetOptionalInputTensor(context, node, kFwAuxInputToOutputWeightsTensor);
+ const TfLiteTensor* bw_aux_input_to_input_weights =
+ GetOptionalInputTensor(context, node, kBwAuxInputToInputWeightsTensor);
+ const TfLiteTensor* bw_aux_input_to_forget_weights =
+ GetOptionalInputTensor(context, node, kBwAuxInputToForgetWeightsTensor);
+ const TfLiteTensor* bw_aux_input_to_cell_weights =
+ GetOptionalInputTensor(context, node, kBwAuxInputToCellWeightsTensor);
+ const TfLiteTensor* bw_aux_input_to_output_weights =
+ GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor);
+
+ const bool aux_inputs_all_or_none =
+ ((aux_input != nullptr) && (fw_aux_input_to_cell_weights != nullptr) &&
+ (fw_aux_input_to_forget_weights != nullptr) &&
+ (fw_aux_input_to_output_weights != nullptr) &&
+ (bw_aux_input_to_cell_weights != nullptr) &&
+ (bw_aux_input_to_forget_weights != nullptr) &&
+ (bw_aux_input_to_output_weights != nullptr)) ||
+ ((fw_aux_input_to_cell_weights == nullptr) &&
+ (fw_aux_input_to_forget_weights == nullptr) &&
+ (fw_aux_input_to_output_weights == nullptr) &&
+ (bw_aux_input_to_cell_weights == nullptr) &&
+ (bw_aux_input_to_forget_weights == nullptr) &&
+ (bw_aux_input_to_output_weights == nullptr));
+ TF_LITE_ENSURE(context, aux_inputs_all_or_none);
+ const bool has_aux_input = (aux_input != nullptr);
+
+ if (has_aux_input) {
+ // Check that aux_input has the same dimensions (except last) as the input.
+ TF_LITE_ASSERT_EQ(aux_input->dims->data[0], input->dims->data[0]);
+ TF_LITE_ASSERT_EQ(aux_input->dims->data[1], input->dims->data[1]);
+ }
- // Resize the output, output_state and cell_state tensors.
+ // Get the pointer to output, activation_state and cell_state buffer tensors.
+ TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
+ TfLiteTensor* fw_activation_state =
+ GetVariableInput(context, node, kFwInputActivationStateTensor);
+ TfLiteTensor* fw_cell_state =
+ GetVariableInput(context, node, kFwInputCellStateTensor);
+
+ // Check the shape of input state tensors.
+ // These tensor may be 1D or 2D. It's fine as long as the total size is
+ // correct.
+ TF_LITE_ENSURE_EQ(context, NumElements(fw_activation_state),
+ n_batch * n_fw_output);
+ TF_LITE_ENSURE_EQ(context, NumElements(fw_cell_state), n_batch * n_fw_cell);
+
+ // Resize the output tensors.
TfLiteIntArray* fw_output_size = TfLiteIntArrayCreate(3);
fw_output_size->data[0] = max_time;
fw_output_size->data[1] = n_batch;
@@ -357,32 +444,28 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, fw_output, fw_output_size));
- TfLiteIntArray* fw_output_state_size = TfLiteIntArrayCreate(2);
- fw_output_state_size->data[0] = n_batch;
- fw_output_state_size->data[1] = n_fw_output;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_output_state,
- fw_output_state_size));
-
- TfLiteIntArray* fw_cell_size = TfLiteIntArrayCreate(2);
- fw_cell_size->data[0] = n_batch;
- fw_cell_size->data[1] = n_fw_cell;
- TF_LITE_ENSURE_OK(
- context, context->ResizeTensor(context, fw_cell_state, fw_cell_size));
+ // The weights are of consistent type, so it suffices to check one.
+ const bool is_hybrid_op = (fw_input_to_output_weights->type == kTfLiteUInt8);
- // Create a scratch buffer tensor.
TfLiteIntArrayFree(node->temporaries);
- node->temporaries = TfLiteIntArrayCreate(2);
- node->temporaries->data[0] = *scratch_tensor_index;
- TfLiteTensor* fw_scratch_buffer = GetTemporary(context, node, /*index=*/0);
+ if (is_hybrid_op) {
+ node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
+ } else {
+ node->temporaries = TfLiteIntArrayCreate(2); // the two scratch buffers.
+ }
+ // Create a scratch buffer tensor.
+ node->temporaries->data[kFwScratchBuffer] = *scratch_tensor_index;
+ TfLiteTensor* fw_scratch_buffer =
+ GetTemporary(context, node, kFwScratchBuffer);
fw_scratch_buffer->type = input->type;
fw_scratch_buffer->allocation_type = kTfLiteArenaRw;
- // Mark state tensors as persistent tensors.
- fw_output_state->allocation_type = kTfLiteArenaRwPersistent;
- fw_cell_state->allocation_type = kTfLiteArenaRwPersistent;
-
const TfLiteTensor* fw_input_to_input_weights =
GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor);
+ if (has_aux_input) {
+ TF_LITE_ENSURE_EQ(context, fw_aux_input_to_input_weights->dims->data[0],
+ fw_input_to_input_weights->dims->data[0]);
+ }
const bool fw_use_cifg = (fw_input_to_input_weights == nullptr);
TfLiteIntArray* fw_scratch_buffer_size = TfLiteIntArrayCreate(2);
fw_scratch_buffer_size->data[0] = n_batch;
@@ -415,13 +498,14 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context, CheckInputTensorDimensions(context, node, n_input, n_bw_output,
n_bw_cell));
- // Get the pointer to output, output_state and cell_state buffer tensors.
+ // Get the pointer to output, activation_state and cell_state buffer tensors.
TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
- TfLiteTensor* bw_output_state =
- GetOutput(context, node, kBwOutputStateTensor);
- TfLiteTensor* bw_cell_state = GetOutput(context, node, kBwCellStateTensor);
+ TfLiteTensor* bw_activation_state =
+ GetVariableInput(context, node, kBwInputActivationStateTensor);
+ TfLiteTensor* bw_cell_state =
+ GetVariableInput(context, node, kBwInputCellStateTensor);
- // Resize the output, output_state and cell_state tensors.
+ // Resize the output tensors.
TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3);
bw_output_size->data[0] = max_time;
bw_output_size->data[1] = n_batch;
@@ -429,30 +513,27 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, bw_output, bw_output_size));
- TfLiteIntArray* bw_output_state_size = TfLiteIntArrayCreate(2);
- bw_output_state_size->data[0] = n_batch;
- bw_output_state_size->data[1] = n_bw_output;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_output_state,
- bw_output_state_size));
-
- TfLiteIntArray* bw_cell_size = TfLiteIntArrayCreate(2);
- bw_cell_size->data[0] = n_batch;
- bw_cell_size->data[1] = n_bw_cell;
- TF_LITE_ENSURE_OK(
- context, context->ResizeTensor(context, bw_cell_state, bw_cell_size));
+ // Check the shape of input state tensors.
+ // These tensor may be 1D or 2D. It's fine as long as the total size is
+ // correct.
+ TF_LITE_ENSURE_EQ(context, NumElements(bw_activation_state),
+ n_batch * n_bw_output);
+ TF_LITE_ENSURE_EQ(context, NumElements(bw_cell_state), n_batch * n_bw_cell);
// Create a scratch buffer tensor.
- node->temporaries->data[1] = *(scratch_tensor_index) + 1;
- TfLiteTensor* bw_scratch_buffer = GetTemporary(context, node, /*index=*/1);
+ node->temporaries->data[kBwScratchBuffer] =
+ *(scratch_tensor_index) + kBwScratchBuffer;
+ TfLiteTensor* bw_scratch_buffer =
+ GetTemporary(context, node, kBwScratchBuffer);
bw_scratch_buffer->type = input->type;
bw_scratch_buffer->allocation_type = kTfLiteArenaRw;
- // Mark state tensors as persistent tensors.
- bw_output_state->allocation_type = kTfLiteArenaRwPersistent;
- bw_cell_state->allocation_type = kTfLiteArenaRwPersistent;
-
const TfLiteTensor* bw_input_to_input_weights =
GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor);
+ if (has_aux_input) {
+ TF_LITE_ENSURE_EQ(context, bw_aux_input_to_input_weights->dims->data[0],
+ bw_input_to_input_weights->dims->data[0]);
+ }
const bool bw_use_cifg = (bw_input_to_input_weights == nullptr);
TfLiteIntArray* bw_scratch_buffer_size = TfLiteIntArrayCreate(2);
bw_scratch_buffer_size->data[0] = n_batch;
@@ -465,18 +546,528 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_scratch_buffer,
bw_scratch_buffer_size));
+ if (is_hybrid_op) {
+ // Allocate temporary tensors to store quantized values of input, aux_input
+ // (if present), activation_state and cell_state tensors.
+ node->temporaries->data[kInputQuantized] =
+ *scratch_tensor_index + kInputQuantized;
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, kInputQuantized);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
+
+ if (has_aux_input) {
+ node->temporaries->data[kAuxInputQuantized] =
+ *scratch_tensor_index + kAuxInputQuantized;
+ TfLiteTensor* aux_input_quantized =
+ GetTemporary(context, node, kAuxInputQuantized);
+ aux_input_quantized->type = kTfLiteUInt8;
+ aux_input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) {
+ TfLiteIntArray* aux_input_quantized_size =
+ TfLiteIntArrayCopy(aux_input->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, aux_input_quantized,
+ aux_input_quantized_size));
+ }
+ }
+
+ node->temporaries->data[kFwActivationStateQuantized] =
+ *scratch_tensor_index + kFwActivationStateQuantized;
+ TfLiteTensor* fw_activation_state_quantized =
+ GetTemporary(context, node, kFwActivationStateQuantized);
+ fw_activation_state_quantized->type = kTfLiteUInt8;
+ fw_activation_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(fw_activation_state_quantized->dims,
+ fw_activation_state->dims)) {
+ TfLiteIntArray* fw_activation_state_quantized_size =
+ TfLiteIntArrayCopy(fw_activation_state->dims);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, fw_activation_state_quantized,
+ fw_activation_state_quantized_size));
+ }
+ node->temporaries->data[kBwActivationStateQuantized] =
+ *scratch_tensor_index + kBwActivationStateQuantized;
+ TfLiteTensor* bw_activation_state_quantized =
+ GetTemporary(context, node, kBwActivationStateQuantized);
+ bw_activation_state_quantized->type = kTfLiteUInt8;
+ bw_activation_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(bw_activation_state_quantized->dims,
+ bw_activation_state->dims)) {
+ TfLiteIntArray* bw_activation_state_quantized_size =
+ TfLiteIntArrayCopy(bw_activation_state->dims);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, bw_activation_state_quantized,
+ bw_activation_state_quantized_size));
+ }
+ node->temporaries->data[kFwCellStateQuantized] =
+ *scratch_tensor_index + kFwCellStateQuantized;
+ TfLiteTensor* fw_cell_state_quantized =
+ GetTemporary(context, node, kFwCellStateQuantized);
+ fw_cell_state_quantized->type = kTfLiteUInt8;
+ fw_cell_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(fw_cell_state_quantized->dims,
+ fw_cell_state->dims)) {
+ TfLiteIntArray* fw_cell_state_quantized_size =
+ TfLiteIntArrayCopy(fw_cell_state->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, fw_cell_state_quantized,
+ fw_cell_state_quantized_size));
+ }
+ node->temporaries->data[kBwCellStateQuantized] =
+ *scratch_tensor_index + kBwCellStateQuantized;
+ TfLiteTensor* bw_cell_state_quantized =
+ GetTemporary(context, node, kBwCellStateQuantized);
+ bw_cell_state_quantized->type = kTfLiteUInt8;
+ bw_cell_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(bw_cell_state_quantized->dims,
+ bw_cell_state->dims)) {
+ TfLiteIntArray* bw_cell_state_quantized_size =
+ TfLiteIntArrayCopy(bw_cell_state->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, bw_cell_state_quantized,
+ bw_cell_state_quantized_size));
+ }
+
+ // Allocate temporary tensors to store scaling factors and product scaling
+ // factors. The latter is a convenience storage which allows to quantize
+ // a vector once (which produces the scaling factors) and multiply it with
+ // different matrices (which requires multiplying the scaling factors with
+ // the scaling factor of the matrix).
+ node->temporaries->data[kScalingFactors] =
+ *scratch_tensor_index + kScalingFactors;
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, kScalingFactors);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
+ node->temporaries->data[kProductScalingFactors] =
+ *scratch_tensor_index + kProductScalingFactors;
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, kProductScalingFactors);
+ prod_scaling_factors->type = kTfLiteFloat32;
+ prod_scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
+ prod_scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(prod_scaling_factors->dims,
+ prod_scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, prod_scaling_factors,
+ prod_scaling_factors_size));
+ }
+
+ // Allocate a temporary tensor to store the recovered cell weights. Since
+ // this is used for diagonal matrices, only need to store n_cell values.
+ node->temporaries->data[kRecoveredCellWeights] =
+ *scratch_tensor_index + kRecoveredCellWeights;
+ TfLiteTensor* recovered_cell_weights =
+ GetTemporary(context, node, kRecoveredCellWeights);
+ recovered_cell_weights->type = kTfLiteFloat32;
+ recovered_cell_weights->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1);
+ recovered_cell_weights_size->data[0] = n_fw_cell;
+ if (!TfLiteIntArrayEqual(recovered_cell_weights->dims,
+ recovered_cell_weights_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, recovered_cell_weights,
+ recovered_cell_weights_size));
+ }
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalFloat(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
+ const TfLiteTensor* aux_input_to_input_weights,
+ const TfLiteTensor* aux_input_to_forget_weights,
+ const TfLiteTensor* aux_input_to_cell_weights,
+ const TfLiteTensor* aux_input_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, bool forward_sequence,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
+ TfLiteTensor* cell_state, TfLiteTensor* output) {
+ const int max_time = input->dims->data[0];
+ const int n_batch = input->dims->data[1];
+ const int n_input = input->dims->data[2];
+ const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
+
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ // Index the scratch buffers pointers to the global scratch buffer.
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Check optional tensors, the respective pointers can be null.
+ const float* input_to_input_weights_ptr =
+ (use_cifg) ? nullptr : input_to_input_weights->data.f;
+ const float* recurrent_to_input_weights_ptr =
+ (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
+ const float* input_gate_bias_ptr =
+ (use_cifg) ? nullptr : input_gate_bias->data.f;
+ const float* cell_to_input_weights_ptr =
+ (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
+ const float* cell_to_forget_weights_ptr =
+ (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
+ const float* cell_to_output_weights_ptr =
+ (use_peephole) ? cell_to_output_weights->data.f : nullptr;
+ const float* projection_weights_ptr =
+ (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
+ const float* projection_bias_ptr =
+ (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+ float* aux_input_ptr = nullptr;
+ float* aux_input_to_input_weights_ptr = nullptr;
+ float* aux_input_to_forget_weights_ptr = nullptr;
+ float* aux_input_to_cell_weights_ptr = nullptr;
+ float* aux_input_to_output_weights_ptr = nullptr;
+ if (aux_input_size > 0) {
+ aux_input_ptr = aux_input->data.f;
+ aux_input_to_input_weights_ptr = aux_input_to_input_weights->data.f;
+ aux_input_to_forget_weights_ptr = aux_input_to_forget_weights->data.f;
+ aux_input_to_cell_weights_ptr = aux_input_to_cell_weights->data.f;
+ aux_input_to_output_weights_ptr = aux_input_to_output_weights->data.f;
+ }
+
+ // Loop through the sequence.
+ if (forward_sequence) {
+ for (int t = 0; t < max_time; t++) {
+ const float* input_ptr = input->data.f + t * n_batch * n_input;
+ float* output_ptr_time = output->data.f + t * n_batch * n_output;
+
+ kernel_utils::LstmStepWithAuxInput(
+ input_ptr, input_to_input_weights_ptr,
+ input_to_forget_weights->data.f, input_to_cell_weights->data.f,
+ input_to_output_weights->data.f, aux_input_ptr,
+ aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
+ aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
+ recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f,
+ recurrent_to_cell_weights->data.f,
+ recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
+ cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
+ input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
+ output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
+ params, n_batch, n_cell, n_input, aux_input_size, n_output,
+ activation_state->data.f, cell_state->data.f, input_gate_scratch,
+ forget_gate_scratch, cell_scratch, output_gate_scratch,
+ output_ptr_time);
+ }
+ } else {
+ // Loop through the sequence backwards.
+ for (int t = max_time - 1; t >= 0; t--) {
+ const float* input_ptr = input->data.f + t * n_batch * n_input;
+ float* output_ptr_time = output->data.f + t * n_batch * n_output;
+
+ kernel_utils::LstmStepWithAuxInput(
+ input_ptr, input_to_input_weights_ptr,
+ input_to_forget_weights->data.f, input_to_cell_weights->data.f,
+ input_to_output_weights->data.f, aux_input_ptr,
+ aux_input_to_input_weights_ptr, aux_input_to_forget_weights_ptr,
+ aux_input_to_cell_weights_ptr, aux_input_to_output_weights_ptr,
+ recurrent_to_input_weights_ptr, recurrent_to_forget_weights->data.f,
+ recurrent_to_cell_weights->data.f,
+ recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
+ cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
+ input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
+ output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
+ params, n_batch, n_cell, n_input, aux_input_size, n_output,
+ activation_state->data.f, cell_state->data.f, input_gate_scratch,
+ forget_gate_scratch, cell_scratch, output_gate_scratch,
+ output_ptr_time);
+ }
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalHybrid(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
+ const TfLiteTensor* aux_input_to_input_weights,
+ const TfLiteTensor* aux_input_to_forget_weights,
+ const TfLiteTensor* aux_input_to_cell_weights,
+ const TfLiteTensor* aux_input_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, bool forward_sequence,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
+ TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
+ TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
+ TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
+ TfLiteTensor* output_state, TfLiteTensor* cell_state,
+ TfLiteTensor* output) {
+ const int max_time = input->dims->data[0];
+ const int n_batch = input->dims->data[1];
+ const int n_input = input->dims->data[2];
+ const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existence of only one to get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Check optional tensors, the respective pointers can be null.
+ int8_t* input_to_input_weights_ptr = nullptr;
+ float input_to_input_weights_scale = 1.0f;
+ int8_t* recurrent_to_input_weights_ptr = nullptr;
+ float recurrent_to_input_weights_scale = 1.0f;
+ float* input_gate_bias_ptr = nullptr;
+ if (!use_cifg) {
+ input_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
+ recurrent_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
+ input_gate_bias_ptr = input_gate_bias->data.f;
+ input_to_input_weights_scale = input_to_input_weights->params.scale;
+ recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
+ }
+
+ int8_t* cell_to_input_weights_ptr = nullptr;
+ int8_t* cell_to_forget_weights_ptr = nullptr;
+ int8_t* cell_to_output_weights_ptr = nullptr;
+ float cell_to_input_weights_scale = 1.0f;
+ float cell_to_forget_weights_scale = 1.0f;
+ float cell_to_output_weights_scale = 1.0f;
+ if (use_peephole) {
+ if (!use_cifg) {
+ cell_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
+ cell_to_input_weights_scale = cell_to_input_weights->params.scale;
+ }
+ cell_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
+ cell_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
+ cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
+ cell_to_output_weights_scale = cell_to_output_weights->params.scale;
+ }
+
+ const int8_t* projection_weights_ptr =
+ (projection_weights == nullptr)
+ ? nullptr
+ : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
+ const float projection_weights_scale =
+ (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
+ const float* projection_bias_ptr =
+ (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+ // Required tensors, pointers are non-null.
+ const int8_t* input_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
+ const float input_to_forget_weights_scale =
+ input_to_forget_weights->params.scale;
+ const int8_t* input_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
+ const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
+ const int8_t* input_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
+ const float input_to_output_weights_scale =
+ input_to_output_weights->params.scale;
+ const int8_t* recurrent_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
+ const float recurrent_to_forget_weights_scale =
+ recurrent_to_forget_weights->params.scale;
+ const int8_t* recurrent_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
+ const float recurrent_to_cell_weights_scale =
+ recurrent_to_cell_weights->params.scale;
+ const int8_t* recurrent_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
+ const float recurrent_to_output_weights_scale =
+ recurrent_to_output_weights->params.scale;
+ const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
+ const float* cell_bias_ptr = cell_bias->data.f;
+ const float* output_gate_bias_ptr = output_gate_bias->data.f;
+
+ float* output_state_ptr = output_state->data.f;
+ float* cell_state_ptr = cell_state->data.f;
+
+ // Temporary storage for quantized values and scaling factors.
+ int8_t* quantized_input_ptr =
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+ int8_t* quantized_aux_input_ptr =
+ (aux_input_quantized == nullptr)
+ ? nullptr
+ : reinterpret_cast<int8_t*>(aux_input_quantized->data.uint8);
+ int8_t* quantized_output_state_ptr =
+ reinterpret_cast<int8_t*>(output_state_quantized->data.uint8);
+ int8_t* quantized_cell_state_ptr =
+ reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
+ float* scaling_factors_ptr = scaling_factors->data.f;
+ float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
+ float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
+
+ // Auxiliary input and weights.
+ float* aux_input_ptr = nullptr;
+ int8_t* aux_input_to_input_weights_ptr = nullptr;
+ int8_t* aux_input_to_forget_weights_ptr = nullptr;
+ int8_t* aux_input_to_cell_weights_ptr = nullptr;
+ int8_t* aux_input_to_output_weights_ptr = nullptr;
+ float aux_input_to_input_weights_scale = 0.0f;
+ float aux_input_to_forget_weights_scale = 0.0f;
+ float aux_input_to_cell_weights_scale = 0.0f;
+ float aux_input_to_output_weights_scale = 0.0f;
+ if (aux_input_size > 0) {
+ aux_input_ptr = aux_input->data.f;
+ aux_input_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_input_to_input_weights->data.uint8);
+ aux_input_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_input_to_forget_weights->data.uint8);
+ aux_input_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_input_to_cell_weights->data.uint8);
+ aux_input_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_input_to_output_weights->data.uint8);
+ aux_input_to_input_weights_scale = aux_input_to_input_weights->params.scale;
+ aux_input_to_forget_weights_scale =
+ aux_input_to_forget_weights->params.scale;
+ aux_input_to_cell_weights_scale = aux_input_to_cell_weights->params.scale;
+ aux_input_to_output_weights_scale =
+ aux_input_to_output_weights->params.scale;
+ }
+ if (forward_sequence) {
+ // Feed the sequence into the LSTM step-by-step.
+ for (int t = 0; t < max_time; t++) {
+ const float* input_ptr = input->data.f + t * n_batch * n_input;
+ float* output_ptr = output->data.f + t * n_batch * n_output;
+
+ kernel_utils::LstmStepWithAuxInput(
+ input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
+ input_to_forget_weights_ptr, input_to_forget_weights_scale,
+ input_to_cell_weights_ptr, input_to_cell_weights_scale,
+ input_to_output_weights_ptr, input_to_output_weights_scale,
+ aux_input_ptr, aux_input_to_input_weights_ptr,
+ aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
+ aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
+ aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
+ aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
+ recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
+ recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
+ recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
+ recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
+ cell_to_input_weights_scale, cell_to_forget_weights_ptr,
+ cell_to_forget_weights_scale, cell_to_output_weights_ptr,
+ cell_to_output_weights_scale, input_gate_bias_ptr,
+ forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr,
+ projection_weights_ptr, projection_weights_scale, projection_bias_ptr,
+ params, n_batch, n_cell, n_input, aux_input_size, n_output,
+ input_gate_scratch, forget_gate_scratch, cell_scratch,
+ output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
+ recovered_cell_weights_ptr, quantized_input_ptr,
+ quantized_aux_input_ptr, quantized_output_state_ptr,
+ quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
+ output_ptr);
+ }
+ } else {
+ // Loop through the sequence backwards.
+ for (int t = max_time - 1; t >= 0; t--) {
+ const float* input_ptr = input->data.f + t * n_batch * n_input;
+ float* output_ptr = output->data.f + t * n_batch * n_output;
+
+ kernel_utils::LstmStepWithAuxInput(
+ input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
+ input_to_forget_weights_ptr, input_to_forget_weights_scale,
+ input_to_cell_weights_ptr, input_to_cell_weights_scale,
+ input_to_output_weights_ptr, input_to_output_weights_scale,
+ aux_input_ptr, aux_input_to_input_weights_ptr,
+ aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
+ aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
+ aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
+ aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
+ recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
+ recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
+ recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
+ recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
+ cell_to_input_weights_scale, cell_to_forget_weights_ptr,
+ cell_to_forget_weights_scale, cell_to_output_weights_ptr,
+ cell_to_output_weights_scale, input_gate_bias_ptr,
+ forget_gate_bias_ptr, cell_bias_ptr, output_gate_bias_ptr,
+ projection_weights_ptr, projection_weights_scale, projection_bias_ptr,
+ params, n_batch, n_cell, n_input, aux_input_size, n_output,
+ input_gate_scratch, forget_gate_scratch, cell_scratch,
+ output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
+ recovered_cell_weights_ptr, quantized_input_ptr,
+ quantized_aux_input_ptr, quantized_output_state_ptr,
+ quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
+ output_ptr);
+ }
+ }
+
return kTfLiteOk;
}
// The LSTM Op engine.
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
// Input tensor.
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- const int max_time = input->dims->data[0];
- const int n_batch = input->dims->data[1];
- const int n_input = input->dims->data[2];
// Tensors for the forward cell.
const TfLiteTensor* fw_input_to_input_weights =
@@ -518,9 +1109,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* fw_projection_bias =
GetOptionalInputTensor(context, node, kFwProjectionBiasTensor);
- TfLiteTensor* fw_output_state =
- GetOutput(context, node, kFwOutputStateTensor);
- TfLiteTensor* fw_cell_state = GetOutput(context, node, kFwCellStateTensor);
+ TfLiteTensor* fw_activation_state =
+ GetVariableInput(context, node, kFwInputActivationStateTensor);
+ TfLiteTensor* fw_cell_state =
+ GetVariableInput(context, node, kFwInputCellStateTensor);
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
// Tensors for the backward cell.
@@ -563,154 +1155,134 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* bw_projection_bias =
GetOptionalInputTensor(context, node, kBwProjectionBiasTensor);
- TfLiteTensor* bw_output_state =
- GetOutput(context, node, kBwOutputStateTensor);
- TfLiteTensor* bw_cell_state = GetOutput(context, node, kBwCellStateTensor);
+ // State tensors.
+ TfLiteTensor* bw_activation_state =
+ GetVariableInput(context, node, kBwInputActivationStateTensor);
+ TfLiteTensor* bw_cell_state =
+ GetVariableInput(context, node, kBwInputCellStateTensor);
TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
- // n_cell and n_output will be the same size when there is no projection.
- const int n_fw_cell = fw_input_to_output_weights->dims->data[0];
- const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existense of only one to the get the condition.
- const bool fw_use_cifg = (fw_input_to_input_weights == nullptr);
- const bool fw_use_peephole = (fw_cell_to_output_weights != nullptr);
-
- // Index the scratch buffers pointers to the global scratch buffer.
+ // Temporary tensors.
TfLiteTensor* fw_scratch_buffer =
- &context->tensors[node->temporaries->data[0]];
- float* fw_input_gate_scratch = nullptr;
- float* fw_cell_scratch = nullptr;
- float* fw_forget_gate_scratch = nullptr;
- float* fw_output_gate_scratch = nullptr;
- if (fw_use_cifg) {
- fw_cell_scratch = fw_scratch_buffer->data.f;
- fw_forget_gate_scratch = fw_scratch_buffer->data.f + n_fw_cell * n_batch;
- fw_output_gate_scratch =
- fw_scratch_buffer->data.f + 2 * n_fw_cell * n_batch;
- } else {
- fw_input_gate_scratch = fw_scratch_buffer->data.f;
- fw_cell_scratch = fw_scratch_buffer->data.f + n_fw_cell * n_batch;
- fw_forget_gate_scratch =
- fw_scratch_buffer->data.f + 2 * n_fw_cell * n_batch;
- fw_output_gate_scratch =
- fw_scratch_buffer->data.f + 3 * n_fw_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- const float* fw_input_to_input_weights_ptr =
- (fw_use_cifg) ? nullptr : fw_input_to_input_weights->data.f;
- const float* fw_recurrent_to_input_weights_ptr =
- (fw_use_cifg) ? nullptr : fw_recurrent_to_input_weights->data.f;
- const float* fw_input_gate_bias_ptr =
- (fw_use_cifg) ? nullptr : fw_input_gate_bias->data.f;
- const float* fw_cell_to_input_weights_ptr =
- (fw_use_peephole && !fw_use_cifg) ? fw_cell_to_input_weights->data.f
- : nullptr;
- const float* fw_cell_to_forget_weights_ptr =
- (fw_use_peephole) ? fw_cell_to_forget_weights->data.f : nullptr;
- const float* fw_cell_to_output_weights_ptr =
- (fw_use_peephole) ? fw_cell_to_output_weights->data.f : nullptr;
- const float* fw_projection_weights_ptr = (fw_projection_weights == nullptr)
- ? nullptr
- : fw_projection_weights->data.f;
- const float* fw_projection_bias_ptr =
- (fw_projection_bias == nullptr) ? nullptr : fw_projection_bias->data.f;
-
- // Loop through the sequence.
- for (int t = 0; t < max_time; t++) {
- const float* input_ptr_batch = input->data.f + t * n_batch * n_input;
- float* output_ptr_time = fw_output->data.f + t * n_batch * n_fw_output;
-
- kernel_utils::LstmStep(
- input_ptr_batch, fw_input_to_input_weights_ptr,
- fw_input_to_forget_weights->data.f, fw_input_to_cell_weights->data.f,
- fw_input_to_output_weights->data.f, fw_recurrent_to_input_weights_ptr,
- fw_recurrent_to_forget_weights->data.f,
- fw_recurrent_to_cell_weights->data.f,
- fw_recurrent_to_output_weights->data.f, fw_cell_to_input_weights_ptr,
- fw_cell_to_forget_weights_ptr, fw_cell_to_output_weights_ptr,
- fw_input_gate_bias_ptr, fw_forget_gate_bias->data.f,
- fw_cell_bias->data.f, fw_output_gate_bias->data.f,
- fw_projection_weights_ptr, fw_projection_bias_ptr, params, n_batch,
- n_fw_cell, n_input, n_fw_output, fw_output_state->data.f,
- fw_cell_state->data.f, fw_input_gate_scratch, fw_forget_gate_scratch,
- fw_cell_scratch, fw_output_gate_scratch, output_ptr_time);
- }
-
- // n_cell and n_output will be the same size when there is no projection.
- const int n_bw_cell = bw_input_to_output_weights->dims->data[0];
- const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existense of only one to the get the condition.
- const bool bw_use_cifg = (bw_input_to_input_weights == nullptr);
- const bool bw_use_peephole = (bw_cell_to_output_weights != nullptr);
-
- // Index the scratch buffers pointers to the global scratch buffer.
+ GetTemporary(context, node, kFwScratchBuffer);
TfLiteTensor* bw_scratch_buffer =
- &context->tensors[node->temporaries->data[1]];
- float* bw_input_gate_scratch = nullptr;
- float* bw_cell_scratch = nullptr;
- float* bw_forget_gate_scratch = nullptr;
- float* bw_output_gate_scratch = nullptr;
- if (bw_use_cifg) {
- bw_cell_scratch = bw_scratch_buffer->data.f;
- bw_forget_gate_scratch = bw_scratch_buffer->data.f + n_bw_cell * n_batch;
- bw_output_gate_scratch =
- bw_scratch_buffer->data.f + 2 * n_bw_cell * n_batch;
- } else {
- bw_input_gate_scratch = bw_scratch_buffer->data.f;
- bw_cell_scratch = bw_scratch_buffer->data.f + n_bw_cell * n_batch;
- bw_forget_gate_scratch =
- bw_scratch_buffer->data.f + 2 * n_bw_cell * n_batch;
- bw_output_gate_scratch =
- bw_scratch_buffer->data.f + 3 * n_bw_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- const float* bw_input_to_input_weights_ptr =
- (bw_use_cifg) ? nullptr : bw_input_to_input_weights->data.f;
- const float* bw_recurrent_to_input_weights_ptr =
- (bw_use_cifg) ? nullptr : bw_recurrent_to_input_weights->data.f;
- const float* bw_input_gate_bias_ptr =
- (bw_use_cifg) ? nullptr : bw_input_gate_bias->data.f;
- const float* bw_cell_to_input_weights_ptr =
- (bw_use_peephole && !bw_use_cifg) ? bw_cell_to_input_weights->data.f
- : nullptr;
- const float* bw_cell_to_forget_weights_ptr =
- (bw_use_peephole) ? bw_cell_to_forget_weights->data.f : nullptr;
- const float* bw_cell_to_output_weights_ptr =
- (bw_use_peephole) ? bw_cell_to_output_weights->data.f : nullptr;
- const float* bw_projection_weights_ptr = (bw_projection_weights == nullptr)
- ? nullptr
- : bw_projection_weights->data.f;
- const float* bw_projection_bias_ptr =
- (bw_projection_bias == nullptr) ? nullptr : bw_projection_bias->data.f;
-
- // Loop through the sequence backwards.
- for (int t = max_time - 1; t >= 0; t--) {
- const float* input_ptr_batch = input->data.f + t * n_batch * n_input;
- float* output_ptr_time = bw_output->data.f + t * n_batch * n_bw_output;
-
- kernel_utils::LstmStep(
- input_ptr_batch, bw_input_to_input_weights_ptr,
- bw_input_to_forget_weights->data.f, bw_input_to_cell_weights->data.f,
- bw_input_to_output_weights->data.f, bw_recurrent_to_input_weights_ptr,
- bw_recurrent_to_forget_weights->data.f,
- bw_recurrent_to_cell_weights->data.f,
- bw_recurrent_to_output_weights->data.f, bw_cell_to_input_weights_ptr,
- bw_cell_to_forget_weights_ptr, bw_cell_to_output_weights_ptr,
- bw_input_gate_bias_ptr, bw_forget_gate_bias->data.f,
- bw_cell_bias->data.f, bw_output_gate_bias->data.f,
- bw_projection_weights_ptr, bw_projection_bias_ptr, params, n_batch,
- n_bw_cell, n_input, n_bw_output, bw_output_state->data.f,
- bw_cell_state->data.f, bw_input_gate_scratch, bw_forget_gate_scratch,
- bw_cell_scratch, bw_output_gate_scratch, output_ptr_time);
+ GetTemporary(context, node, kBwScratchBuffer);
+
+ // (Optional) auxiliary inputs.
+ const TfLiteTensor* aux_input =
+ GetOptionalInputTensor(context, node, kAuxInputTensor);
+ const TfLiteTensor* fw_aux_input_to_input_weights =
+ GetOptionalInputTensor(context, node, kFwAuxInputToInputWeightsTensor);
+ const TfLiteTensor* fw_aux_input_to_forget_weights =
+ GetOptionalInputTensor(context, node, kFwAuxInputToForgetWeightsTensor);
+ const TfLiteTensor* fw_aux_input_to_cell_weights =
+ GetOptionalInputTensor(context, node, kFwAuxInputToCellWeightsTensor);
+ const TfLiteTensor* fw_aux_input_to_output_weights =
+ GetOptionalInputTensor(context, node, kFwAuxInputToOutputWeightsTensor);
+ const TfLiteTensor* bw_aux_input_to_input_weights =
+ GetOptionalInputTensor(context, node, kBwAuxInputToInputWeightsTensor);
+ const TfLiteTensor* bw_aux_input_to_forget_weights =
+ GetOptionalInputTensor(context, node, kBwAuxInputToForgetWeightsTensor);
+ const TfLiteTensor* bw_aux_input_to_cell_weights =
+ GetOptionalInputTensor(context, node, kBwAuxInputToCellWeightsTensor);
+ const TfLiteTensor* bw_aux_input_to_output_weights =
+ GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor);
+
+ switch (fw_input_to_output_weights->type) {
+ case kTfLiteFloat32: {
+ TfLiteStatus fw_pass_status = EvalFloat(
+ input, fw_input_to_input_weights, fw_input_to_forget_weights,
+ fw_input_to_cell_weights, fw_input_to_output_weights,
+ fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
+ fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights,
+ fw_cell_to_input_weights, fw_cell_to_forget_weights,
+ fw_cell_to_output_weights, aux_input, fw_aux_input_to_input_weights,
+ fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
+ fw_aux_input_to_output_weights, fw_input_gate_bias,
+ fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias,
+ fw_projection_weights, fw_projection_bias, params,
+ /*forward_sequence=*/true, fw_scratch_buffer, fw_activation_state,
+ fw_cell_state, fw_output);
+ TF_LITE_ENSURE_OK(context, fw_pass_status);
+
+ TfLiteStatus bw_pass_status = EvalFloat(
+ input, bw_input_to_input_weights, bw_input_to_forget_weights,
+ bw_input_to_cell_weights, bw_input_to_output_weights,
+ bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
+ bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
+ bw_cell_to_input_weights, bw_cell_to_forget_weights,
+ bw_cell_to_output_weights, aux_input, bw_aux_input_to_input_weights,
+ bw_aux_input_to_forget_weights, bw_aux_input_to_cell_weights,
+ bw_aux_input_to_output_weights, bw_input_gate_bias,
+ bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
+ bw_projection_weights, bw_projection_bias, params,
+ /*forward_sequence=*/false, bw_scratch_buffer, bw_activation_state,
+ bw_cell_state, bw_output);
+ TF_LITE_ENSURE_OK(context, bw_pass_status);
+ return kTfLiteOk;
+ }
+ case kTfLiteUInt8: {
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, kInputQuantized);
+ TfLiteTensor* aux_input_quantized =
+ GetTemporary(context, node, kAuxInputQuantized);
+ TfLiteTensor* fw_activation_state_quantized =
+ GetTemporary(context, node, kFwActivationStateQuantized);
+ TfLiteTensor* bw_activation_state_quantized =
+ GetTemporary(context, node, kBwActivationStateQuantized);
+ TfLiteTensor* fw_cell_state_quantized =
+ GetTemporary(context, node, kFwCellStateQuantized);
+ TfLiteTensor* bw_cell_state_quantized =
+ GetTemporary(context, node, kBwCellStateQuantized);
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, kScalingFactors);
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, kProductScalingFactors);
+ TfLiteTensor* recovered_cell_weights =
+ GetTemporary(context, node, kRecoveredCellWeights);
+
+ TfLiteStatus fw_pass_status = EvalHybrid(
+ input, fw_input_to_input_weights, fw_input_to_forget_weights,
+ fw_input_to_cell_weights, fw_input_to_output_weights,
+ fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
+ fw_recurrent_to_cell_weights, fw_recurrent_to_output_weights,
+ fw_cell_to_input_weights, fw_cell_to_forget_weights,
+ fw_cell_to_output_weights, aux_input, fw_aux_input_to_input_weights,
+ fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
+ fw_aux_input_to_output_weights, fw_input_gate_bias,
+ fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias,
+ fw_projection_weights, fw_projection_bias, params,
+ /*forward_sequence=*/true, fw_scratch_buffer, scaling_factors,
+ prod_scaling_factors, recovered_cell_weights, input_quantized,
+ aux_input_quantized, fw_activation_state_quantized,
+ fw_cell_state_quantized, fw_activation_state, fw_cell_state,
+ fw_output);
+ TF_LITE_ENSURE_OK(context, fw_pass_status);
+
+ TfLiteStatus bw_pass_status = EvalHybrid(
+ input, bw_input_to_input_weights, bw_input_to_forget_weights,
+ bw_input_to_cell_weights, bw_input_to_output_weights,
+ bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
+ bw_recurrent_to_cell_weights, bw_recurrent_to_output_weights,
+ bw_cell_to_input_weights, bw_cell_to_forget_weights,
+ bw_cell_to_output_weights, aux_input, fw_aux_input_to_input_weights,
+ fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
+ fw_aux_input_to_output_weights, bw_input_gate_bias,
+ bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
+ bw_projection_weights, bw_projection_bias, params,
+ /*forward_sequence=*/false, bw_scratch_buffer, scaling_factors,
+ prod_scaling_factors, recovered_cell_weights, input_quantized,
+ aux_input_quantized, bw_activation_state_quantized,
+ bw_cell_state_quantized, bw_activation_state, bw_cell_state,
+ bw_output);
+ TF_LITE_ENSURE_OK(context, bw_pass_status);
+ return kTfLiteOk;
+ }
+ default:
+ context->ReportError(context, "Type %d is not currently supported.",
+ fw_input_to_output_weights->type);
+ return kTfLiteError;
}
-
- // Backward step.
return kTfLiteOk;
}
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
index a18e1bce34..74ba8021c2 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
@@ -102,10 +102,6 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
fw_projection_bias_ = AddNullInput();
}
- fw_output_state_ = AddOutput(TensorType_FLOAT32);
- fw_cell_state_ = AddOutput(TensorType_FLOAT32);
- fw_output_ = AddOutput(TensorType_FLOAT32);
-
if (use_cifg) {
bw_input_to_input_weights_ = AddNullInput();
} else {
@@ -161,10 +157,36 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
bw_projection_bias_ = AddNullInput();
}
- bw_output_state_ = AddOutput(TensorType_FLOAT32);
- bw_cell_state_ = AddOutput(TensorType_FLOAT32);
+ // Adding the 2 input state tensors.
+ fw_input_activation_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_fw_output_ * n_batch_}},
+ /*is_variable=*/true);
+ fw_input_cell_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_fw_cell_ * n_batch_}},
+ /*is_variable=*/true);
+
+ // Adding the 2 input state tensors.
+ bw_input_activation_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_bw_output_ * n_batch_}},
+ /*is_variable=*/true);
+ bw_input_cell_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_bw_cell_ * n_batch_}},
+ /*is_variable=*/true);
+
+ fw_output_ = AddOutput(TensorType_FLOAT32);
+
bw_output_ = AddOutput(TensorType_FLOAT32);
+ aux_input_ = AddNullInput();
+ fw_aux_input_to_input_weights_ = AddNullInput();
+ fw_aux_input_to_forget_weights_ = AddNullInput();
+ fw_aux_input_to_cell_weights_ = AddNullInput();
+ fw_aux_input_to_output_weights_ = AddNullInput();
+ bw_aux_input_to_input_weights_ = AddNullInput();
+ bw_aux_input_to_forget_weights_ = AddNullInput();
+ bw_aux_input_to_cell_weights_ = AddNullInput();
+ bw_aux_input_to_output_weights_ = AddNullInput();
+
SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
BuiltinOptions_LSTMOptions,
CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
@@ -259,26 +281,6 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
PopulateTensor(bw_projection_bias_, f);
}
- void ResetFwOutputAndCellStates() {
- const int zero_buffer_size = n_fw_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(fw_output_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- PopulateTensor(fw_cell_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
- void ResetBwOutputAndCellStates() {
- const int zero_buffer_size = n_bw_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(bw_output_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- PopulateTensor(bw_cell_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
void SetInput(int offset, float* begin, float* end) {
PopulateTensor(input_, offset, begin, end);
}
@@ -340,13 +342,23 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
int bw_projection_weights_;
int bw_projection_bias_;
- int fw_output_;
- int fw_output_state_;
- int fw_cell_state_;
+ int fw_input_activation_state_;
+ int fw_input_cell_state_;
+ int bw_input_activation_state_;
+ int bw_input_cell_state_;
+ int fw_output_;
int bw_output_;
- int bw_output_state_;
- int bw_cell_state_;
+
+ int aux_input_;
+ int fw_aux_input_to_input_weights_;
+ int fw_aux_input_to_forget_weights_;
+ int fw_aux_input_to_cell_weights_;
+ int fw_aux_input_to_output_weights_;
+ int bw_aux_input_to_input_weights_;
+ int bw_aux_input_to_forget_weights_;
+ int bw_aux_input_to_cell_weights_;
+ int bw_aux_input_to_output_weights_;
int n_batch_;
int n_input_;
@@ -417,6 +429,22 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, sequence_length, 0}, // aux_input tensor
+ {n_cell, 0}, // aux_fw_input_to_input tensor
+ {n_cell, 0}, // aux_fw_input_to_forget tensor
+ {n_cell, 0}, // aux_fw_input_to_cell tensor
+ {n_cell, 0}, // aux_fw_input_to_output tensor
+ {n_cell, 0}, // aux_bw_input_to_input tensor
+ {n_cell, 0}, // aux_bw_input_to_forget tensor
+ {n_cell, 0}, // aux_bw_input_to_cell tensor
+ {n_cell, 0}, // aux_bw_input_to_output tensor
});
lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
@@ -474,10 +502,6 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
-0.0332076, 0.123838, 0.309777, -0.17621,
-0.0490733, 0.0739237, 0.067706, -0.0208124};
- // Resetting cell_state and output_state
- lstm.ResetFwOutputAndCellStates();
- lstm.ResetBwOutputAndCellStates();
-
float* batch0_start = lstm_input;
float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
@@ -500,34 +524,161 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
bw_expected.insert(bw_expected.end(), bw_golden_start, bw_golden_end);
EXPECT_THAT(lstm.GetBwOutput(),
ElementsAreArray(ArrayFloatNear(bw_expected)));
+}
+
+TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClippingReverse) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+ const int sequence_length = 3;
+
+ BidirectionalLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
+ /*use_peephole=*/false, /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {sequence_length, n_batch, n_input}, // input tensor
+
+ // Forward cell
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {0}, // cell_to_forget_weight tensor
+ {0}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+
+ // Backward cell
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {0}, // cell_to_forget_weight tensor
+ {0}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, sequence_length, 0}, // aux_input tensor
+ {n_cell, 0}, // aux_fw_input_to_input tensor
+ {n_cell, 0}, // aux_fw_input_to_forget tensor
+ {n_cell, 0}, // aux_fw_input_to_cell tensor
+ {n_cell, 0}, // aux_fw_input_to_output tensor
+ {n_cell, 0}, // aux_bw_input_to_input tensor
+ {n_cell, 0}, // aux_bw_input_to_forget tensor
+ {n_cell, 0}, // aux_bw_input_to_cell tensor
+ {n_cell, 0}, // aux_bw_input_to_output tensor
+ });
+
+ lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
+ -0.34550029, 0.04266912, -0.15680569,
+ -0.34856534, 0.43890524});
+
+ lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
+ -0.20583314, 0.44344562, 0.22077113,
+ -0.29909778});
+
+ lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
+ -0.31343272, -0.40032279, 0.44781327,
+ 0.01387155, -0.35593212});
+
+ lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
+ 0.40525138, 0.44272184, 0.03897077, -0.1556896,
+ 0.19487578});
+
+ lstm.SetInputGateBias({0., 0., 0., 0.});
+
+ lstm.SetCellBias({0., 0., 0., 0.});
+
+ lstm.SetForgetGateBias({1., 1., 1., 1.});
+
+ lstm.SetOutputGateBias({0., 0., 0., 0.});
+
+ lstm.SetRecurrentToInputWeights(
+ {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
+ -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
+ -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
+ lstm.SetRecurrentToCellWeights(
+ {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
+ -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
+ -0.46367589, 0.26016325, -0.03894562, -0.16368064});
+
+ lstm.SetRecurrentToForgetWeights(
+ {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
+ -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
+ 0.28053468, 0.01560611, -0.20127171, -0.01140004});
+
+ lstm.SetRecurrentToOutputWeights(
+ {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
+ 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
+ -0.51818722, -0.15390486, 0.0468148, 0.39922136});
+
+ // Input should have n_input * sequence_length many values.
// Check reversed inputs.
static float lstm_input_reversed[] = {1., 1., 3., 4., 2., 3.};
+ static float lstm_fw_golden_output[] = {
+ -0.02973187, 0.1229473, 0.20885126, -0.15358765,
+ -0.03716109, 0.12507336, 0.41193449, -0.20860538,
+ -0.15053082, 0.09120187, 0.24278517, -0.12222792};
+ static float lstm_bw_golden_output[] = {
+ -0.0806187, 0.139077, 0.400476, -0.197842, -0.0332076, 0.123838,
+ 0.309777, -0.17621, -0.0490733, 0.0739237, 0.067706, -0.0208124};
- // Resetting cell_state and output_state
- lstm.ResetFwOutputAndCellStates();
- lstm.ResetBwOutputAndCellStates();
-
- batch0_start = lstm_input_reversed;
- batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
+ float* batch0_start = lstm_input_reversed;
+ float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
lstm.SetInput(0, batch0_start, batch0_end);
lstm.Invoke();
- fw_expected.clear();
+ std::vector<float> fw_expected;
for (int s = 0; s < lstm.sequence_length(); s++) {
- fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs();
- fw_golden_end = fw_golden_start + lstm.num_fw_outputs();
+ float* fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs();
+ float* fw_golden_end = fw_golden_start + lstm.num_fw_outputs();
fw_expected.insert(fw_expected.begin(), fw_golden_start, fw_golden_end);
}
EXPECT_THAT(lstm.GetBwOutput(),
ElementsAreArray(ArrayFloatNear(fw_expected)));
- bw_expected.clear();
+ std::vector<float> bw_expected;
for (int s = 0; s < lstm.sequence_length(); s++) {
- bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs();
- bw_golden_end = bw_golden_start + lstm.num_bw_outputs();
+ float* bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs();
+ float* bw_golden_end = bw_golden_start + lstm.num_bw_outputs();
bw_expected.insert(bw_expected.begin(), bw_golden_start, bw_golden_end);
}
EXPECT_THAT(lstm.GetFwOutput(),
@@ -592,6 +743,22 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, sequence_length, 0}, // aux_input tensor
+ {n_cell, 0}, // aux_fw_input_to_input tensor
+ {n_cell, 0}, // aux_fw_input_to_forget tensor
+ {n_cell, 0}, // aux_fw_input_to_cell tensor
+ {n_cell, 0}, // aux_fw_input_to_output tensor
+ {n_cell, 0}, // aux_bw_input_to_input tensor
+ {n_cell, 0}, // aux_bw_input_to_forget tensor
+ {n_cell, 0}, // aux_bw_input_to_cell tensor
+ {n_cell, 0}, // aux_bw_input_to_output tensor
});
lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
@@ -642,10 +809,6 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
-0.401685, -0.0232794, 0.288642, -0.123074, -0.42915, -0.00871577,
0.20912, -0.103567, -0.166398, -0.00486649, 0.0697471, -0.0537578};
- // Resetting cell_state and output_state
- lstm.ResetFwOutputAndCellStates();
- lstm.ResetBwOutputAndCellStates();
-
float* batch0_start = lstm_input;
float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
@@ -668,34 +831,153 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
bw_expected.insert(bw_expected.end(), bw_golden_start, bw_golden_end);
EXPECT_THAT(lstm.GetBwOutput(),
ElementsAreArray(ArrayFloatNear(bw_expected)));
+}
- // Check reversed inputs.
- static float lstm_input_reversed[] = {1., 1., 3., 4., 2., 3.};
+TEST(LSTMOpTest,
+ BlackBoxTestWithCifgWithPeepholeNoProjectionNoClippingReversed) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+ const int sequence_length = 3;
+
+ BidirectionalLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true,
+ /*use_peephole=*/true, /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ {
+ {sequence_length, n_batch, n_input}, // input tensor
+
+ {0, 0}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {0, 0}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {0}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+
+ {0, 0}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {0, 0}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
- // Resetting cell_state and output_state
- lstm.ResetFwOutputAndCellStates();
- lstm.ResetBwOutputAndCellStates();
+ {0}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, sequence_length, 0}, // aux_input tensor
+ {n_cell, 0}, // aux_fw_input_to_input tensor
+ {n_cell, 0}, // aux_fw_input_to_forget tensor
+ {n_cell, 0}, // aux_fw_input_to_cell tensor
+ {n_cell, 0}, // aux_fw_input_to_output tensor
+ {n_cell, 0}, // aux_bw_input_to_input tensor
+ {n_cell, 0}, // aux_bw_input_to_forget tensor
+ {n_cell, 0}, // aux_bw_input_to_cell tensor
+ {n_cell, 0}, // aux_bw_input_to_output tensor
+ });
+
+ lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781,
+ 0.04717243, 0.48944736, -0.38535351,
+ -0.17212132});
+
+ lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988,
+ -0.3633365, -0.22755712, 0.28253698, 0.24407166,
+ 0.33826375});
+
+ lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593,
+ -0.09426838, -0.44257352, 0.54939759,
+ 0.01533556, 0.42751634});
+
+ lstm.SetCellBias({0., 0., 0., 0.});
+
+ lstm.SetForgetGateBias({1., 1., 1., 1.});
+
+ lstm.SetOutputGateBias({0., 0., 0., 0.});
+
+ lstm.SetRecurrentToCellWeights(
+ {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711,
+ 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004,
+ 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288,
+ 0.21193194});
+
+ lstm.SetRecurrentToForgetWeights(
+ {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827,
+ 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795,
+ -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349});
+
+ lstm.SetRecurrentToOutputWeights(
+ {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908,
+ -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835,
+ 0.50248802, 0.26114327, -0.43736315, 0.33149987});
- batch0_start = lstm_input_reversed;
- batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
+ lstm.SetCellToForgetWeights(
+ {0.47485286, -0.51955009, -0.24458408, 0.31544167});
+ lstm.SetCellToOutputWeights(
+ {-0.17135078, 0.82760304, 0.85573703, -0.77109635});
+
+ static float lstm_input_reversed[] = {1., 1., 3., 4., 2., 3.};
+ static float lstm_fw_golden_output[] = {
+ -0.36444446, -0.00352185, 0.12886585, -0.05163646,
+ -0.42312205, -0.01218222, 0.24201041, -0.08124574,
+ -0.358325, -0.04621704, 0.21641694, -0.06471302};
+ static float lstm_bw_golden_output[] = {
+ -0.401685, -0.0232794, 0.288642, -0.123074, -0.42915, -0.00871577,
+ 0.20912, -0.103567, -0.166398, -0.00486649, 0.0697471, -0.0537578};
+
+ float* batch0_start = lstm_input_reversed;
+ float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
lstm.SetInput(0, batch0_start, batch0_end);
lstm.Invoke();
- fw_expected.clear();
+ std::vector<float> fw_expected;
for (int s = 0; s < lstm.sequence_length(); s++) {
- fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs();
- fw_golden_end = fw_golden_start + lstm.num_fw_outputs();
+ float* fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs();
+ float* fw_golden_end = fw_golden_start + lstm.num_fw_outputs();
fw_expected.insert(fw_expected.begin(), fw_golden_start, fw_golden_end);
}
EXPECT_THAT(lstm.GetBwOutput(),
ElementsAreArray(ArrayFloatNear(fw_expected)));
- bw_expected.clear();
+ std::vector<float> bw_expected;
for (int s = 0; s < lstm.sequence_length(); s++) {
- bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs();
- bw_golden_end = bw_golden_start + lstm.num_bw_outputs();
+ float* bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs();
+ float* bw_golden_end = bw_golden_start + lstm.num_bw_outputs();
bw_expected.insert(bw_expected.begin(), bw_golden_start, bw_golden_end);
}
EXPECT_THAT(lstm.GetFwOutput(),
@@ -759,6 +1041,22 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, sequence_length, 0}, // aux_input tensor
+ {n_cell, 0}, // aux_fw_input_to_input tensor
+ {n_cell, 0}, // aux_fw_input_to_forget tensor
+ {n_cell, 0}, // aux_fw_input_to_cell tensor
+ {n_cell, 0}, // aux_fw_input_to_output tensor
+ {n_cell, 0}, // aux_bw_input_to_input tensor
+ {n_cell, 0}, // aux_bw_input_to_forget tensor
+ {n_cell, 0}, // aux_bw_input_to_cell tensor
+ {n_cell, 0}, // aux_bw_input_to_output tensor
});
lstm.SetInputToInputWeights(
@@ -1343,10 +1641,6 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
0.065133, 0.024321, 0.038473, 0.062438
}};
- // Resetting cell_state and output_state
- lstm.ResetFwOutputAndCellStates();
- lstm.ResetBwOutputAndCellStates();
-
for (int i = 0; i < lstm.sequence_length(); i++) {
float* batch0_start = lstm_input[0] + i * lstm.num_inputs();
float* batch0_end = batch0_start + lstm.num_inputs();
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
index 517309a226..2f896c5289 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
@@ -19,10 +19,11 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
@@ -35,34 +36,79 @@ constexpr int kInputTensor = 0;
constexpr int kFwWeightsTensor = 1;
constexpr int kFwRecurrentWeightsTensor = 2;
constexpr int kFwBiasTensor = 3;
-constexpr int kBwWeightsTensor = 4;
-constexpr int kBwRecurrentWeightsTensor = 5;
-constexpr int kBwBiasTensor = 6;
-// State and output tensors.
-constexpr int kFwHiddenStateTensor = 0;
-constexpr int kFwOutputTensor = 1;
-constexpr int kBwHiddenStateTensor = 2;
-constexpr int kBwOutputTensor = 3;
+constexpr int kFwHiddenStateTensor = 4;
+constexpr int kBwWeightsTensor = 5;
+constexpr int kBwRecurrentWeightsTensor = 6;
+constexpr int kBwBiasTensor = 7;
+constexpr int kBwHiddenStateTensor = 8;
+// Auxiliary inputs.
+constexpr int kAuxInputTensor = 9; // Optional.
+constexpr int kFwAuxWeightsTensor = 10; // Optional.
+constexpr int kBwAuxWeightsTensor = 11; // Optional.
+// Output tensors.
+constexpr int kFwOutputTensor = 0;
+constexpr int kBwOutputTensor = 1;
+
+// Temporary tensors.
+enum TemporaryTensor {
+ kInputQuantized = 0,
+ kFwHiddenStateQuantized = 1,
+ kBwHiddenStateQuantized = 2,
+ kScalingFactors = 3,
+ kAuxInputQuantized = 4,
+ kNumTemporaryTensors = 5
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* scratch_tensor_index = new int;
+ context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
+ return scratch_tensor_index;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<int*>(buffer);
+}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 7);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 4);
-
- TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
- TfLiteTensor* fw_input_weights =
- &context->tensors[node->inputs->data[kFwWeightsTensor]];
- TfLiteTensor* fw_recurrent_weights =
- &context->tensors[node->inputs->data[kFwRecurrentWeightsTensor]];
- TfLiteTensor* fw_bias = &context->tensors[node->inputs->data[kFwBiasTensor]];
- TfLiteTensor* bw_input_weights =
- &context->tensors[node->inputs->data[kBwWeightsTensor]];
- TfLiteTensor* bw_recurrent_weights =
- &context->tensors[node->inputs->data[kBwRecurrentWeightsTensor]];
- TfLiteTensor* bw_bias = &context->tensors[node->inputs->data[kBwBiasTensor]];
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 12);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* fw_input_weights =
+ GetInput(context, node, kFwWeightsTensor);
+ const TfLiteTensor* fw_recurrent_weights =
+ GetInput(context, node, kFwRecurrentWeightsTensor);
+ const TfLiteTensor* fw_bias = GetInput(context, node, kFwBiasTensor);
+ const TfLiteTensor* fw_hidden_state =
+ GetInput(context, node, kFwHiddenStateTensor);
+ const TfLiteTensor* bw_input_weights =
+ GetInput(context, node, kBwWeightsTensor);
+ const TfLiteTensor* bw_recurrent_weights =
+ GetInput(context, node, kBwRecurrentWeightsTensor);
+ const TfLiteTensor* bw_bias = GetInput(context, node, kBwBiasTensor);
+ const TfLiteTensor* bw_hidden_state =
+ GetInput(context, node, kBwHiddenStateTensor);
+
+ const TfLiteTensor* aux_input =
+ GetOptionalInputTensor(context, node, kAuxInputTensor);
+ const TfLiteTensor* fw_aux_input_weights =
+ GetOptionalInputTensor(context, node, kFwAuxWeightsTensor);
+ const TfLiteTensor* bw_aux_input_weights =
+ GetOptionalInputTensor(context, node, kBwAuxWeightsTensor);
+
+ const bool aux_inputs_all_or_none =
+ ((aux_input != nullptr) && (fw_aux_input_weights != nullptr) &&
+ (bw_aux_input_weights != nullptr)) ||
+ ((aux_input == nullptr) && (fw_aux_input_weights == nullptr) &&
+ (bw_aux_input_weights == nullptr));
+ TF_LITE_ENSURE(context, aux_inputs_all_or_none);
+ const bool has_aux_input = (aux_input != nullptr);
// Check all the parameters of tensor match within themselves and match the
// input configuration.
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+
const int batch_size = input->dims->data[0];
const int max_time = input->dims->data[1];
const int fw_num_units = fw_input_weights->dims->data[0];
@@ -75,32 +121,116 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
fw_bias->dims->data[0]);
TF_LITE_ASSERT_EQ(bw_recurrent_weights->dims->data[1],
bw_bias->dims->data[0]);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(fw_hidden_state), 2);
+ TF_LITE_ENSURE_EQ(context, fw_hidden_state->dims->data[0], batch_size);
+ TF_LITE_ENSURE_EQ(context, fw_hidden_state->dims->data[1], fw_num_units);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(bw_hidden_state), 2);
+ TF_LITE_ENSURE_EQ(context, bw_hidden_state->dims->data[0], batch_size);
+ TF_LITE_ENSURE_EQ(context, bw_hidden_state->dims->data[1], bw_num_units);
- TfLiteTensor* fw_output =
- &context->tensors[node->outputs->data[kFwOutputTensor]];
- TfLiteTensor* bw_output =
- &context->tensors[node->outputs->data[kBwOutputTensor]];
+ if (has_aux_input) {
+ // Check that aux_input has the same dimensions (except last) as the input.
+ TF_LITE_ASSERT_EQ(aux_input->dims->data[0], input->dims->data[0]);
+ TF_LITE_ASSERT_EQ(aux_input->dims->data[1], input->dims->data[1]);
+ // Check that aux_input_weights has the same dimensions (except last) as
+ // the input_weights.
+ TF_LITE_ASSERT_EQ(fw_aux_input_weights->dims->data[0], fw_num_units);
+ TF_LITE_ASSERT_EQ(bw_aux_input_weights->dims->data[0], bw_num_units);
+ TF_LITE_ASSERT_EQ(aux_input->dims->data[2],
+ fw_aux_input_weights->dims->data[1]);
+ TF_LITE_ASSERT_EQ(aux_input->dims->data[2],
+ bw_aux_input_weights->dims->data[1]);
+ }
- // Resize hidden states.
- TfLiteIntArray* fw_hidden_state_size_array = TfLiteIntArrayCreate(2);
- fw_hidden_state_size_array->data[0] = batch_size;
- fw_hidden_state_size_array->data[1] = fw_num_units;
- TfLiteTensor* fw_hidden_state =
- &context->tensors[node->outputs->data[kFwHiddenStateTensor]];
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_hidden_state,
- fw_hidden_state_size_array));
+ TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
+ TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
- TfLiteIntArray* bw_hidden_state_size_array = TfLiteIntArrayCreate(2);
- bw_hidden_state_size_array->data[0] = batch_size;
- bw_hidden_state_size_array->data[1] = fw_num_units;
- TfLiteTensor* bw_hidden_state =
- &context->tensors[node->outputs->data[kBwHiddenStateTensor]];
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_hidden_state,
- bw_hidden_state_size_array));
+ const bool is_hybrid_op =
+ (fw_input_weights->type == kTfLiteUInt8 && input->type == kTfLiteFloat32);
+
+ if (is_hybrid_op) {
+ int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+
+ TfLiteIntArrayFree(node->temporaries);
+ if (has_aux_input) {
+ node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
+ } else {
+ // No need to create a temporary tensor for the non-existent aux_input.
+ node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors - 1);
+ }
+
+ node->temporaries->data[kInputQuantized] =
+ *scratch_tensor_index + kInputQuantized;
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, kInputQuantized);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
+
+ node->temporaries->data[kFwHiddenStateQuantized] =
+ *scratch_tensor_index + kFwHiddenStateQuantized;
+ TfLiteTensor* fw_hidden_state_quantized =
+ GetTemporary(context, node, kFwHiddenStateQuantized);
+ fw_hidden_state_quantized->type = kTfLiteUInt8;
+ fw_hidden_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(fw_hidden_state_quantized->dims,
+ fw_hidden_state->dims)) {
+ TfLiteIntArray* fw_hidden_state_quantized_size =
+ TfLiteIntArrayCopy(fw_hidden_state->dims);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, fw_hidden_state_quantized,
+ fw_hidden_state_quantized_size));
+ }
+
+ node->temporaries->data[kBwHiddenStateQuantized] =
+ *scratch_tensor_index + kBwHiddenStateQuantized;
+ TfLiteTensor* bw_hidden_state_quantized =
+ GetTemporary(context, node, kBwHiddenStateQuantized);
+ bw_hidden_state_quantized->type = kTfLiteUInt8;
+ bw_hidden_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(bw_hidden_state_quantized->dims,
+ bw_hidden_state->dims)) {
+ TfLiteIntArray* bw_hidden_state_quantized_size =
+ TfLiteIntArrayCopy(bw_hidden_state->dims);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, bw_hidden_state_quantized,
+ bw_hidden_state_quantized_size));
+ }
- // Mark hidden states as a persistent tensor.
- fw_hidden_state->allocation_type = kTfLiteArenaRwPersistent;
- bw_hidden_state->allocation_type = kTfLiteArenaRwPersistent;
+ // Allocate temporary tensors to store scaling factors of quantization.
+ node->temporaries->data[kScalingFactors] =
+ *scratch_tensor_index + kScalingFactors;
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, kScalingFactors);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = batch_size;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
+
+ if (has_aux_input) {
+ node->temporaries->data[kAuxInputQuantized] =
+ *scratch_tensor_index + kAuxInputQuantized;
+ TfLiteTensor* aux_input_quantized =
+ GetTemporary(context, node, kAuxInputQuantized);
+ aux_input_quantized->type = kTfLiteUInt8;
+ aux_input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(aux_input_quantized->dims, aux_input->dims)) {
+ TfLiteIntArray* aux_input_quantized_size =
+ TfLiteIntArrayCopy(aux_input->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, aux_input_quantized,
+ aux_input_quantized_size));
+ }
+ }
+ }
// Resize outputs.
TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3);
@@ -119,33 +249,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
-
- TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]];
- TfLiteTensor* fw_input_weights =
- &context->tensors[node->inputs->data[kFwWeightsTensor]];
- TfLiteTensor* fw_recurrent_weights =
- &context->tensors[node->inputs->data[kFwRecurrentWeightsTensor]];
- TfLiteTensor* fw_bias = &context->tensors[node->inputs->data[kFwBiasTensor]];
- TfLiteTensor* fw_hidden_state =
- &context->tensors[node->outputs->data[kFwHiddenStateTensor]];
- TfLiteTensor* fw_output =
- &context->tensors[node->outputs->data[kFwOutputTensor]];
-
- TfLiteTensor* bw_input_weights =
- &context->tensors[node->inputs->data[kBwWeightsTensor]];
- TfLiteTensor* bw_recurrent_weights =
- &context->tensors[node->inputs->data[kBwRecurrentWeightsTensor]];
- TfLiteTensor* bw_bias = &context->tensors[node->inputs->data[kBwBiasTensor]];
- TfLiteTensor* bw_hidden_state =
- &context->tensors[node->outputs->data[kBwHiddenStateTensor]];
- TfLiteTensor* bw_output =
- &context->tensors[node->outputs->data[kBwOutputTensor]];
-
+TfLiteStatus EvalFloat(
+ const TfLiteTensor* input, const TfLiteTensor* fw_input_weights,
+ const TfLiteTensor* fw_recurrent_weights, const TfLiteTensor* fw_bias,
+ const TfLiteTensor* bw_input_weights,
+ const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias,
+ const TfLiteTensor* aux_input, const TfLiteTensor* fw_aux_input_weights,
+ const TfLiteTensor* bw_aux_input_weights,
+ const TfLiteSequenceRNNParams* params, TfLiteTensor* fw_hidden_state,
+ TfLiteTensor* fw_output, TfLiteTensor* bw_hidden_state,
+ TfLiteTensor* bw_output) {
const int batch_size = input->dims->data[0];
const int max_time = input->dims->data[1];
const int input_size = input->dims->data[2];
+ const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
const int fw_num_units = fw_input_weights->dims->data[0];
const float* fw_bias_ptr = fw_bias->data.f;
@@ -157,6 +274,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const float* bw_input_weights_ptr = bw_input_weights->data.f;
const float* bw_recurrent_weights_ptr = bw_recurrent_weights->data.f;
+ const float* fw_aux_input_weights_ptr = (fw_aux_input_weights != nullptr)
+ ? fw_aux_input_weights->data.f
+ : nullptr;
+ const float* bw_aux_input_weights_ptr = (bw_aux_input_weights != nullptr)
+ ? bw_aux_input_weights->data.f
+ : nullptr;
+
for (int b = 0; b < batch_size; b++) {
// Forward cell.
float* fw_hidden_state_ptr_batch =
@@ -164,12 +288,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
for (int s = 0; s < max_time; s++) {
const float* input_ptr_batch =
input->data.f + b * input_size * max_time + s * input_size;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + b * input_size * max_time + s * input_size
+ : nullptr;
float* output_ptr_batch =
fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units;
kernel_utils::RnnBatchStep(
- input_ptr_batch, fw_input_weights_ptr, fw_recurrent_weights_ptr,
- fw_bias_ptr, input_size, fw_num_units, /*batch_size=*/1,
+ input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch,
+ fw_aux_input_weights_ptr, fw_recurrent_weights_ptr, fw_bias_ptr,
+ input_size, aux_input_size, fw_num_units, /*batch_size=*/1,
params->activation, fw_hidden_state_ptr_batch, output_ptr_batch);
}
// Backward cell.
@@ -178,24 +307,208 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
for (int s = max_time - 1; s >= 0; s--) {
const float* input_ptr_batch =
input->data.f + b * input_size * max_time + s * input_size;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + b * input_size * max_time + s * input_size
+ : nullptr;
float* output_ptr_batch =
bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units;
kernel_utils::RnnBatchStep(
- input_ptr_batch, bw_input_weights_ptr, bw_recurrent_weights_ptr,
- bw_bias_ptr, input_size, bw_num_units, /*batch_size=*/1,
+ input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch,
+ bw_aux_input_weights_ptr, bw_recurrent_weights_ptr, bw_bias_ptr,
+ input_size, aux_input_size, bw_num_units, /*batch_size=*/1,
params->activation, bw_hidden_state_ptr_batch, output_ptr_batch);
}
}
return kTfLiteOk;
}
+TfLiteStatus EvalHybrid(
+ const TfLiteTensor* input, const TfLiteTensor* fw_input_weights,
+ const TfLiteTensor* fw_recurrent_weights, const TfLiteTensor* fw_bias,
+ const TfLiteTensor* bw_input_weights,
+ const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias,
+ const TfLiteTensor* aux_input, const TfLiteTensor* aux_fw_input_weights,
+ const TfLiteTensor* aux_bw_input_weights,
+ const TfLiteSequenceRNNParams* params, TfLiteTensor* scaling_factors,
+ TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
+ TfLiteTensor* fw_hidden_state_quantized, TfLiteTensor* fw_hidden_state,
+ TfLiteTensor* fw_output, TfLiteTensor* bw_hidden_state_quantized,
+ TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) {
+ const int batch_size = input->dims->data[0];
+ const int max_time = input->dims->data[1];
+ const int input_size = input->dims->data[2];
+ const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
+
+ const int fw_num_units = fw_input_weights->dims->data[0];
+ const float* fw_bias_ptr = fw_bias->data.f;
+ const int8_t* fw_input_weights_ptr =
+ reinterpret_cast<const int8_t*>(fw_input_weights->data.uint8);
+ float fw_input_weights_scale = fw_input_weights->params.scale;
+ const int8_t* fw_recurrent_weights_ptr =
+ reinterpret_cast<const int8_t*>(fw_recurrent_weights->data.uint8);
+ float fw_recurrent_weights_scale = fw_recurrent_weights->params.scale;
+
+ const int bw_num_units = bw_input_weights->dims->data[0];
+ const float* bw_bias_ptr = bw_bias->data.f;
+ const int8_t* bw_input_weights_ptr =
+ reinterpret_cast<const int8_t*>(bw_input_weights->data.uint8);
+ float bw_input_weights_scale = bw_input_weights->params.scale;
+ const int8_t* bw_recurrent_weights_ptr =
+ reinterpret_cast<const int8_t*>(bw_recurrent_weights->data.uint8);
+ float bw_recurrent_weights_scale = bw_recurrent_weights->params.scale;
+
+ // Set the auxiliary pointers and scales if needed.
+ int8_t* aux_fw_input_weights_ptr = nullptr;
+ float aux_fw_input_weights_scale = 0.0f;
+ int8_t* aux_bw_input_weights_ptr = nullptr;
+ float aux_bw_input_weights_scale = 0.0f;
+ int8_t* aux_quantized_input_ptr = nullptr;
+ if (aux_input_size > 0) {
+ aux_fw_input_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_fw_input_weights->data.uint8);
+ aux_fw_input_weights_scale = aux_fw_input_weights->params.scale;
+ aux_bw_input_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_bw_input_weights->data.uint8);
+ aux_bw_input_weights_scale = aux_bw_input_weights->params.scale;
+ aux_quantized_input_ptr = reinterpret_cast<int8_t*>(aux_input_quantized);
+ }
+
+ // Initialize temporary storage for quantized values.
+ int8_t* quantized_input_ptr =
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+ int8_t* fw_quantized_hidden_state_ptr =
+ reinterpret_cast<int8_t*>(fw_hidden_state_quantized->data.uint8);
+ int8_t* bw_quantized_hidden_state_ptr =
+ reinterpret_cast<int8_t*>(bw_hidden_state_quantized->data.uint8);
+ float* scaling_factors_ptr = scaling_factors->data.f;
+
+ for (int b = 0; b < batch_size; b++) {
+ // Forward cell.
+ float* fw_hidden_state_ptr_batch =
+ fw_hidden_state->data.f + b * fw_num_units;
+ for (int s = 0; s < max_time; s++) {
+ const float* input_ptr_batch =
+ input->data.f + b * input_size * max_time + s * input_size;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + b * input_size * max_time + s * input_size
+ : nullptr;
+ float* output_ptr_batch =
+ fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units;
+
+ kernel_utils::RnnBatchStep(
+ input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale,
+ aux_input_ptr_batch, aux_fw_input_weights_ptr,
+ aux_fw_input_weights_scale, fw_recurrent_weights_ptr,
+ fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size,
+ fw_num_units, /*batch_size=*/1, params->activation,
+ quantized_input_ptr, aux_quantized_input_ptr,
+ fw_quantized_hidden_state_ptr, scaling_factors_ptr,
+ fw_hidden_state_ptr_batch, output_ptr_batch);
+ }
+ // Backward cell.
+ float* bw_hidden_state_ptr_batch =
+ bw_hidden_state->data.f + b * bw_num_units;
+ for (int s = max_time - 1; s >= 0; s--) {
+ const float* input_ptr_batch =
+ input->data.f + b * input_size * max_time + s * input_size;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + b * input_size * max_time + s * input_size
+ : nullptr;
+ float* output_ptr_batch =
+ bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units;
+
+ kernel_utils::RnnBatchStep(
+ input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale,
+ aux_input_ptr_batch, aux_bw_input_weights_ptr,
+ aux_bw_input_weights_scale, bw_recurrent_weights_ptr,
+ bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size,
+ bw_num_units, /*batch_size=*/1, params->activation,
+ quantized_input_ptr, aux_quantized_input_ptr,
+ bw_quantized_hidden_state_ptr, scaling_factors_ptr,
+ bw_hidden_state_ptr_batch, output_ptr_batch);
+ }
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const auto* params =
+ reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* fw_input_weights =
+ GetInput(context, node, kFwWeightsTensor);
+ const TfLiteTensor* fw_recurrent_weights =
+ GetInput(context, node, kFwRecurrentWeightsTensor);
+ const TfLiteTensor* fw_bias = GetInput(context, node, kFwBiasTensor);
+ const TfLiteTensor* bw_input_weights =
+ GetInput(context, node, kBwWeightsTensor);
+ const TfLiteTensor* bw_recurrent_weights =
+ GetInput(context, node, kBwRecurrentWeightsTensor);
+ const TfLiteTensor* bw_bias = GetInput(context, node, kBwBiasTensor);
+
+ // Get auxiliary inputs.
+ const TfLiteTensor* aux_input =
+ GetOptionalInputTensor(context, node, kAuxInputTensor);
+ const TfLiteTensor* fw_aux_input_weights =
+ GetOptionalInputTensor(context, node, kFwAuxWeightsTensor);
+ const TfLiteTensor* bw_aux_input_weights =
+ GetOptionalInputTensor(context, node, kBwAuxWeightsTensor);
+
+ TfLiteTensor* fw_hidden_state =
+ GetVariableInput(context, node, kFwHiddenStateTensor);
+ TfLiteTensor* bw_hidden_state =
+ GetVariableInput(context, node, kBwHiddenStateTensor);
+
+ TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
+ TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
+
+ switch (fw_input_weights->type) {
+ case kTfLiteFloat32:
+ return EvalFloat(input, fw_input_weights, fw_recurrent_weights, fw_bias,
+ bw_input_weights, bw_recurrent_weights, bw_bias,
+ aux_input, fw_aux_input_weights, bw_aux_input_weights,
+ params, fw_hidden_state, fw_output, bw_hidden_state,
+ bw_output);
+ case kTfLiteUInt8: {
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, kInputQuantized);
+ TfLiteTensor* fw_hidden_state_quantized =
+ GetTemporary(context, node, kFwHiddenStateQuantized);
+ TfLiteTensor* bw_hidden_state_quantized =
+ GetTemporary(context, node, kBwHiddenStateQuantized);
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, kScalingFactors);
+ TfLiteTensor* aux_input_quantized =
+ (aux_input != nullptr)
+ ? GetTemporary(context, node, kAuxInputQuantized)
+ : nullptr;
+
+ return EvalHybrid(input, fw_input_weights, fw_recurrent_weights, fw_bias,
+ bw_input_weights, bw_recurrent_weights, bw_bias,
+ aux_input, fw_aux_input_weights, bw_aux_input_weights,
+ params, scaling_factors, input_quantized,
+ aux_input_quantized, fw_hidden_state_quantized,
+ fw_hidden_state, fw_output, bw_hidden_state_quantized,
+ bw_hidden_state, bw_output);
+ }
+ default:
+ context->ReportError(context, "Type not currently supported.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
} // namespace bidirectional_sequence_rnn
TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN() {
- static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
- bidirectional_sequence_rnn::Prepare,
- bidirectional_sequence_rnn::Eval};
+ static TfLiteRegistration r = {
+ bidirectional_sequence_rnn::Init, bidirectional_sequence_rnn::Free,
+ bidirectional_sequence_rnn::Prepare, bidirectional_sequence_rnn::Eval};
return &r;
}
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
index 911b108eaa..3e34ba6196 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
@@ -664,13 +664,19 @@ class BidirectionalRNNOpModel : public SingleOpModel {
fw_weights_ = AddInput(TensorType_FLOAT32);
fw_recurrent_weights_ = AddInput(TensorType_FLOAT32);
fw_bias_ = AddInput(TensorType_FLOAT32);
- fw_hidden_state_ = AddOutput(TensorType_FLOAT32);
- fw_output_ = AddOutput(TensorType_FLOAT32);
+ fw_hidden_state_ = AddInput(TensorType_FLOAT32, true);
bw_weights_ = AddInput(TensorType_FLOAT32);
bw_recurrent_weights_ = AddInput(TensorType_FLOAT32);
bw_bias_ = AddInput(TensorType_FLOAT32);
- bw_hidden_state_ = AddOutput(TensorType_FLOAT32);
+ bw_hidden_state_ = AddInput(TensorType_FLOAT32, true);
+
+ aux_input_ = AddNullInput();
+ aux_fw_weights_ = AddNullInput();
+ aux_bw_weights_ = AddNullInput();
+
+ fw_output_ = AddOutput(TensorType_FLOAT32);
bw_output_ = AddOutput(TensorType_FLOAT32);
+
SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
BuiltinOptions_SequenceRNNOptions,
CreateSequenceRNNOptions(builder_, /*time_major=*/false,
@@ -681,9 +687,14 @@ class BidirectionalRNNOpModel : public SingleOpModel {
{fw_units_, input_size_}, // fw_weights
{fw_units_, fw_units_}, // fw_recurrent_weights
{fw_units_}, // fw_bias
+ {batches_, fw_units_}, // fw_hidden_state
{bw_units_, input_size_}, // bw_weights
{bw_units_, bw_units_}, // bw_recurrent_weights
- {bw_units_} // bw_bias
+ {bw_units_}, // bw_bias
+ {batches_, bw_units_}, // bw_hidden_state
+ {batches_, sequence_len_, 0}, // aux_input
+ {fw_units_, 0}, // aux_fw_weights
+ {bw_units_, 0}, // aux_bw_weights
});
}
@@ -719,19 +730,6 @@ class BidirectionalRNNOpModel : public SingleOpModel {
PopulateTensor(input_, offset, begin, end);
}
- void ResetHiddenStates() {
- const int fw_zero_buffer_size = fw_units_ * batches_;
- std::unique_ptr<float[]> fw_zero_buffer(new float[fw_zero_buffer_size]);
- memset(fw_zero_buffer.get(), 0, fw_zero_buffer_size * sizeof(float));
- PopulateTensor(fw_hidden_state_, 0, fw_zero_buffer.get(),
- fw_zero_buffer.get() + fw_zero_buffer_size);
- const int bw_zero_buffer_size = bw_units_ * batches_;
- std::unique_ptr<float[]> bw_zero_buffer(new float[bw_zero_buffer_size]);
- memset(bw_zero_buffer.get(), 0, bw_zero_buffer_size * sizeof(float));
- PopulateTensor(bw_hidden_state_, 0, bw_zero_buffer.get(),
- bw_zero_buffer.get() + bw_zero_buffer_size);
- }
-
std::vector<float> GetFwOutput() { return ExtractVector<float>(fw_output_); }
std::vector<float> GetBwOutput() { return ExtractVector<float>(bw_output_); }
@@ -753,6 +751,9 @@ class BidirectionalRNNOpModel : public SingleOpModel {
int bw_bias_;
int bw_hidden_state_;
int bw_output_;
+ int aux_input_;
+ int aux_fw_weights_;
+ int aux_bw_weights_;
int batches_;
int sequence_len_;
@@ -774,7 +775,6 @@ TEST(BidirectionalRNNOpTest, BlackBoxTest) {
rnn.SetFwRecurrentWeights(recurrent_weights);
rnn.SetBwRecurrentWeights(recurrent_weights);
- rnn.ResetHiddenStates();
const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
float* batch_start = rnn_input;
float* batch_end = batch_start + input_sequence_size;
@@ -813,8 +813,6 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) {
rnn.SetFwRecurrentWeights(recurrent_weights);
rnn.SetBwRecurrentWeights(recurrent_weights);
- rnn.ResetHiddenStates();
-
// Reverse inputs in each batch: in_1, in_2,..., in_k is inserted in the
// following order: [in_k,..., in_2, in_1, in_k,...,in_2, in_1].
for (int i = 0; i < rnn.sequence_len(); i++) {
@@ -880,8 +878,6 @@ TEST(BidirectionalRNNOpTest, EndToEndTest) {
rnn.SetFwRecurrentWeights(recurrent_weights);
rnn.SetBwRecurrentWeights(recurrent_weights);
- rnn.ResetHiddenStates();
-
const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
const int output_sequence_size = output_size * rnn.sequence_len();
const int num_examples = 64;
diff --git a/tensorflow/contrib/lite/kernels/cast.cc b/tensorflow/contrib/lite/kernels/cast.cc
index 8dd48af57f..a7972140ac 100644
--- a/tensorflow/contrib/lite/kernels/cast.cc
+++ b/tensorflow/contrib/lite/kernels/cast.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include <string.h>
#include <algorithm>
#include <complex>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc
index 8b4d778332..4cd96348a2 100644
--- a/tensorflow/contrib/lite/kernels/comparisons.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc
index 605a20ac3e..25ea556d5a 100644
--- a/tensorflow/contrib/lite/kernels/concatenation.cc
+++ b/tensorflow/contrib/lite/kernels/concatenation.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index 50fe5c2e04..ab6bdaecaa 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/eigen_support.h"
#include "tensorflow/contrib/lite/kernels/gemm_support.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h"
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
#include "tensorflow/contrib/lite/kernels/padding.h"
@@ -60,6 +61,8 @@ struct OpData {
// memory buffers.
int im2col_id = kTensorNotAllocated;
int hwcn_weights_id = kTensorNotAllocated;
+ int input_quantized_id = kTensorNotAllocated;
+ int scaling_factors_id = kTensorNotAllocated;
TfLitePaddingValues padding;
// The scaling factor from input to output (aka the 'real multiplier') can
@@ -74,6 +77,8 @@ struct OpData {
// of the allocated temporaries.
int32_t im2col_index;
int32_t hwcn_weights_index;
+ int32_t input_quantized_index;
+ int32_t scaling_factors_index;
bool need_hwcn_weights;
bool have_weights_been_transposed;
bool need_im2col;
@@ -125,6 +130,9 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
TfLiteTensor* filter = &context->tensors[node->inputs->data[1]];
+ const bool is_hybrid =
+ (input->type == kTfLiteFloat32 && filter->type == kTfLiteUInt8);
+
int filter_width = filter->dims->data[2];
int filter_height = filter->dims->data[1];
@@ -145,8 +153,8 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
// buffer to store the results.
// This path is only used for float processing, so only create the buffer if
// we're running with that data type.
- data->need_hwcn_weights =
- (input->type == kTfLiteFloat32 && data->run_multithreaded_kernel);
+ data->need_hwcn_weights = (input->type == kTfLiteFloat32 &&
+ data->run_multithreaded_kernel && !is_hybrid);
int temporaries_count = 0;
if (data->need_im2col) {
@@ -164,6 +172,25 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
++temporaries_count;
}
+ if (is_hybrid) {
+ // Allocate tensor to store the on-the-fly quantized inputs.
+ data->input_quantized_index = temporaries_count;
+ if (data->input_quantized_id == kTensorNotAllocated) {
+ TF_LITE_ENSURE_OK(
+ context, context->AddTensors(context, 1, &data->input_quantized_id));
+ }
+ ++temporaries_count;
+
+ // Allocate tensor to store the quantization params computed during
+ // on-the-fly input quantization.
+ data->scaling_factors_index = temporaries_count;
+ if (data->scaling_factors_id == kTensorNotAllocated) {
+ TF_LITE_ENSURE_OK(
+ context, context->AddTensors(context, 1, &data->scaling_factors_id));
+ }
+ ++temporaries_count;
+ }
+
TfLiteIntArrayFree(node->temporaries);
node->temporaries = TfLiteIntArrayCreate(temporaries_count);
@@ -174,10 +201,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
- data->run_multithreaded_kernel = context->recommended_num_threads != 1;
-
- TF_LITE_ENSURE_STATUS(AllocateTemporaryTensorsIfRequired(context, node));
-
bool has_bias = node->inputs->size == 3;
// Check number of inputs/outputs
TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2);
@@ -193,11 +216,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, input->dims->data[3], filter->dims->data[3]);
// Check types. (We assume that UINT8 refers to quantized tensors)
- TfLiteType data_type = input->type;
+ TfLiteType input_type = input->type;
TF_LITE_ENSURE(context,
- data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8);
- TF_LITE_ENSURE_EQ(context, output->type, data_type);
- TF_LITE_ENSURE_EQ(context, filter->type, data_type);
+ input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8);
+ TF_LITE_ENSURE_EQ(context, output->type, input_type);
TfLiteTensor* bias = nullptr;
@@ -207,15 +229,27 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
if (has_bias) {
bias = &context->tensors[node->inputs->data[2]];
- if (data_type == kTfLiteUInt8) {
+ if (input_type == kTfLiteUInt8) {
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0);
} else {
- TF_LITE_ENSURE_EQ(context, bias->type, data_type);
+ TF_LITE_ENSURE_EQ(context, bias->type, input_type);
}
TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 0));
}
+ const bool is_hybrid =
+ (input->type == kTfLiteFloat32 && filter->type == kTfLiteUInt8);
+
+ data->run_multithreaded_kernel = context->recommended_num_threads != 1;
+ // Hybrid kernels don't support multithreading yet.
+ if (is_hybrid) {
+ data->run_multithreaded_kernel = false;
+ }
+
+ TF_LITE_ENSURE_STATUS(AllocateTemporaryTensorsIfRequired(context, node));
+
+ int channels_in = filter->dims->data[3];
int channels_out = filter->dims->data[0];
int width = input->dims->data[2];
int height = input->dims->data[1];
@@ -250,9 +284,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, has_bias);
- // Note that quantized inference requires that all tensors have their
+ // Note that full fixed-point inference requires that all tensors have their
// parameters set. This is usually done during quantized training.
- if (data_type != kTfLiteFloat32) {
+ if (input_type != kTfLiteFloat32) {
double real_multiplier = 0.0;
TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
context, input, filter, bias, output, &real_multiplier));
@@ -287,7 +321,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* im2col =
&context->tensors[node->temporaries->data[data->im2col_index]];
- im2col->type = data_type;
+ im2col->type = input->type;
+ if (is_hybrid) {
+ im2col->type = kTfLiteUInt8;
+ }
im2col->allocation_type = kTfLiteArenaRw;
auto im2col_status = context->ResizeTensor(context, im2col, im2col_size);
if (im2col_status != kTfLiteOk) return im2col_status;
@@ -307,7 +344,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* hwcn_weights =
&context->tensors[node->temporaries->data[data->hwcn_weights_index]];
- hwcn_weights->type = data_type;
+ hwcn_weights->type = input_type;
hwcn_weights->allocation_type = kTfLiteArenaRwPersistent;
auto hwcn_weights_status =
@@ -319,6 +356,36 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
data->have_weights_been_transposed = false;
}
+ if (is_hybrid) {
+ node->temporaries->data[data->input_quantized_index] =
+ data->input_quantized_id;
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, data->input_quantized_index);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
+
+ node->temporaries->data[data->scaling_factors_index] =
+ data->scaling_factors_id;
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, data->scaling_factors_index);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ // Only one scale factor per batch is typically necessary. See optimized
+ // implementation for why we need to allocate for the height of the inputs
+ // flattened to 2D.
+ scaling_factors_size->data[0] = NumElements(input) / channels_in;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
+ }
+
return kTfLiteOk;
}
@@ -456,6 +523,60 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
}
template <KernelType kernel_type>
+void EvalHybrid(TfLiteContext* context, TfLiteNode* node,
+ TfLiteConvParams* params, OpData* data, TfLiteTensor* input,
+ TfLiteTensor* filter, TfLiteTensor* bias, TfLiteTensor* im2col,
+ TfLiteTensor* hwcn_weights, TfLiteTensor* output) {
+ float output_activation_min, output_activation_max;
+ CalculateActivationRange(params->activation, &output_activation_min,
+ &output_activation_max);
+
+ const int input_size = NumElements(input) / SizeOfDimension(input, 0);
+ const int batch_size = SizeOfDimension(input, 0);
+
+ const TfLiteTensor* input_quantized =
+ GetTemporary(context, node, data->input_quantized_index);
+ int8_t* quantized_input_ptr_batch =
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+ float* scaling_factors_ptr =
+ GetTemporary(context, node, data->scaling_factors_index)->data.f;
+
+ // Per-batch input quantization for higher accuracy.
+ for (int b = 0; b < batch_size; ++b) {
+ float unused_min, unused_max;
+ const int offset = b * input_size;
+ tensor_utils::SymmetricQuantizeFloats(
+ input->data.f + offset, input_size, quantized_input_ptr_batch + offset,
+ &unused_min, &unused_max, &scaling_factors_ptr[b]);
+ scaling_factors_ptr[b] *= filter->params.scale;
+ }
+
+ int8_t* im2col_ptr = nullptr;
+ if (im2col != nullptr) {
+ im2col_ptr = reinterpret_cast<int8_t*>(im2col->data.uint8);
+ }
+ int8_t* filter_ptr = reinterpret_cast<int8_t*>(filter->data.uint8);
+
+ switch (kernel_type) {
+ case kReference:
+ case kGenericOptimized:
+ case kMultithreadOptimized:
+ case kCblasOptimized:
+ // There is only one implementation for hybrid kernel. Note
+ // this does not make use of gemmlowp nor supports multithreading.
+ optimized_ops::HybridConv(
+ quantized_input_ptr_batch, GetTensorDims(input), filter_ptr,
+ GetTensorDims(filter), GetTensorData<float>(bias),
+ GetTensorDims(bias), params->stride_width, params->stride_height,
+ data->padding.width, data->padding.height, scaling_factors_ptr,
+ output_activation_min, output_activation_max,
+ GetTensorData<float>(output), GetTensorDims(output), im2col_ptr,
+ GetTensorDims(im2col));
+ break;
+ }
+}
+
+template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
@@ -484,7 +605,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// separate ops to avoid dispatch overhead here.
switch (input->type) { // Already know in/outtypes are same.
case kTfLiteFloat32:
- if (data->run_multithreaded_kernel) {
+ if (filter->type == kTfLiteUInt8) {
+ EvalHybrid<kernel_type>(context, node, params, data, input, filter,
+ bias, im2col, hwcn_weights, output);
+ } else if (data->run_multithreaded_kernel) {
EvalFloat<kernel_type>(context, node, params, data, input, filter, bias,
im2col, hwcn_weights, output);
} else {
diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc
index 98152043c9..411615aa62 100644
--- a/tensorflow/contrib/lite/kernels/conv_test.cc
+++ b/tensorflow/contrib/lite/kernels/conv_test.cc
@@ -142,6 +142,104 @@ TEST_P(ConvolutionOpTest, SimpleTestFloat32) {
}));
}
+// This test's output is equivalent to the SimpleTestFloat32
+// because we break each input into two channels, each with half of the value,
+// while keeping the filters for each channel equivalent.
+//
+// 2 * (A/2) * B = A * B, where the left side is this new test.
+TEST_P(ConvolutionOpTest, SimpleTestFloat32WithChannels) {
+ ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
+ {TensorType_FLOAT32, {3, 2, 2, 2}},
+ {TensorType_FLOAT32, {}});
+
+ m.SetInput({
+ // First batch
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1
+ 1, 1, 1, 1, 1, 1, 1, 1, // row = 2
+ // Second batch
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2
+ });
+ m.SetFilter({
+ 1, 1, 2, 2, 3, 3, 4, 4, // first 2x2 filter
+ -1, -1, 1, 1, -1, -1, 1, 1, // second 2x2 filter
+ -1, -1, -1, -1, 1, 1, 1, 1 // third 2x2 filter
+ });
+ m.SetBias({1, 2, 3});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 18, 2, 5, // first batch, left
+ 18, 2, 5, // first batch, right
+ 17, 4, 3, // second batch, left
+ 37, 4, 3, // second batch, right
+ }));
+}
+
+TEST_P(ConvolutionOpTest, PointwiseFloat32) {
+ ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
+ {TensorType_FLOAT32, {1, 1, 1, 2}},
+ {TensorType_FLOAT32, {}}, 1, 1);
+
+ m.SetInput({
+ // First batch
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1
+ 1, 1, 1, 1, 1, 1, 1, 1, // row = 2
+ // Second batch
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2
+ });
+
+ m.SetFilter({
+ 1, 2, // first filter
+ });
+ m.SetBias({0});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ // First batch
+ 1.5, 1.5, 1.5, 1.5, // row = 1
+ 3., 3., 3., 3., // row = 2
+ // Second batch
+ 1.5, 3., 4.5, 6., // row = 1
+ 1.5, 3., 4.5, 6., // row = 2
+ }));
+}
+
+// TODO(alanchiao): this passes locally, but fails on continuous build system.
+// Re-enable when root cause found.
+TEST_P(ConvolutionOpTest, DISABLED_PointwiseMultifilterFloat32) {
+ ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
+ {TensorType_FLOAT32, {2, 1, 1, 2}},
+ {TensorType_FLOAT32, {}}, 1, 1);
+
+ m.SetInput({
+ // First batch
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1
+ 1, 1, 1, 1, 1, 1, 1, 1, // row = 2
+ // Second batch
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2
+ });
+
+ m.SetFilter({
+ 1, 2, // first filter
+ 2, 3, // second filter
+ });
+ m.SetBias({0});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({
+ 1.5, 2.5, 1.5, 2.5, 1.5, 2.5, 1.5, 2.5, 3., 5., 3.,
+ 5., 3., 5., 3., 5., 1.5, 2.5, 3., 5., 4.5, 7.5,
+ 6., 10., 1.5, 2.5, 3., 5., 4.5, 7.5, 6., 10.,
+ }));
+}
+
TEST_P(ConvolutionOpTest, SimpleTestFloat32WithAnisotropicStrides) {
ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 3, 6, 1}},
{TensorType_FLOAT32, {1, 2, 2, 1}},
@@ -624,6 +722,192 @@ TEST_P(ConvolutionOpTest, SimpleTestQuantizedWithDilation) {
ElementsAreArray({5, 5, 5, 5, 5, 5, 5, 5, 5}));
}
+class HybridConvolutionOpModel : public BaseConvolutionOpModel {
+ public:
+ using BaseConvolutionOpModel::BaseConvolutionOpModel;
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+
+ void SetFilter(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(filter_, f);
+ }
+
+ void SetBias(std::initializer_list<float> data) {
+ PopulateTensor(bias_, data);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+TEST_P(ConvolutionOpTest, SimpleTestHybrid) {
+ HybridConvolutionOpModel m(
+ GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}},
+ {TensorType_UINT8, {3, 2, 2, 1}}, {TensorType_FLOAT32, {}});
+
+ m.SetInput({
+ // First batch
+ 1, 1, 1, 1, // row = 1
+ 2, 2, 2, 2, // row = 2
+ // Second batch
+ 1, 2, 3, 4, // row = 1
+ 1, 2, 3, 4, // row = 2
+ });
+ m.SetFilter({
+ 1, 2, 3, 4, // first 2x2 filter
+ -1, 1, -1, 1, // second 2x2 filter
+ -1, -1, 1, 1, // third 2x2 filter
+ });
+ m.SetBias({1, 2, 3});
+
+ m.Invoke();
+
+ // Example: we get 17.1577 instead of 17.
+ //
+ // Second batch:
+ // 1 2 3 4 -> 32 64 95 127 with scale factor 127/4.
+ // 1 2 3 4 32 64 95 127
+ //
+ // First filter:
+ // 1 2 -> 32 64 with scale factor of 127/4.
+ // 3 4 95 127
+ //
+ // The left half of the input gives us 16288. Multiply by (4/127)^2 for
+ // dequantization and adding 1 for the bias gives us the result. and adding
+ // the bias gives us the result.
+ //
+ // The optimized kernel converts the input into this matrix via Im2Col
+ //
+ // 1 1 2 2
+ // 1 1 2 2
+ // 1 2 1 2
+ // 3 4 3 4
+ //
+ // and multiplies it with the filter directly.
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
+ {
+ 18, 2, 5, // first batch, left
+ 18, 2, 5, // first batch, right
+ 17, 4, 3, // second batch, left
+ 37, 4, 3, // second batch, right
+ },
+ 0.16)));
+}
+
+// This test's output is equivalent to the SimpleTestHybrid
+// because we break each input into two channels, each with half of the value,
+// while keeping the filters for each channel equivalent.
+//
+// 2 * (A/2) * B = A * B, where the left side is this new test.
+TEST_P(ConvolutionOpTest, SimpleTestHybridWithChannels) {
+ HybridConvolutionOpModel m(
+ GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
+ {TensorType_UINT8, {3, 2, 2, 2}}, {TensorType_FLOAT32, {}});
+
+ m.SetInput({
+ // First batch
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1
+ 1, 1, 1, 1, 1, 1, 1, 1, // row = 2
+ // Second batch
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2
+ });
+ m.SetFilter({
+ 1, 1, 2, 2, 3, 3, 4, 4, // first 2x2 filter
+ -1, -1, 1, 1, -1, -1, 1, 1, // second 2x2 filter
+ -1, -1, -1, -1, 1, 1, 1, 1 // third 2x2 filter
+ });
+ m.SetBias({1, 2, 3});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
+ {
+ 18, 2, 5, // first batch, left
+ 18, 2, 5, // first batch, right
+ 17, 4, 3, // second batch, left
+ 37, 4, 3, // second batch, right
+ },
+ 0.16)));
+}
+
+TEST_P(ConvolutionOpTest, PointwiseHybrid) {
+ HybridConvolutionOpModel m(
+ GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
+ {TensorType_UINT8, {1, 1, 1, 2}}, {TensorType_FLOAT32, {}}, 1, 1);
+
+ m.SetInput({
+ // First batch
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1
+ 1, 1, 1, 1, 1, 1, 1, 1, // row = 2
+ // Second batch
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2
+ });
+
+ m.SetFilter({
+ 1, 2, // first filter
+ });
+ m.SetBias({0});
+
+ m.Invoke();
+
+ // Example: we get 3.03156 instead of 3.
+ //
+ // Second batch:
+ // 0.5 0.5 1 1 1.5 1.5 2 2 -> 32 32 64 64 95 95 127 127 with scale factor
+ // 127/2. We care about the two 64's.
+ //
+ // Filter:
+ // 64 127 with scale factor of 127/2.
+ //
+ // (64 * 64 + 64 * 127) * (2/127)^2 gives us the expected result.
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 1.5, 1.5, 1.5, 1.5, // first batch, row = 1
+ 3., 3., 3., 3., // first batch, row = 2
+ 1.5, 3., 4.5, 6., // second batch, row = 1
+ 1.5, 3., 4.5, 6., // second batch, row = 2
+ },
+ 0.0316)));
+}
+
+// TODO(alanchiao): this passes locally, but fails on continuous build system.
+// Re-enable when root cause found.
+TEST_P(ConvolutionOpTest, DISABLED_PointwiseMultifilterHybrid) {
+ HybridConvolutionOpModel m(
+ GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 2}},
+ {TensorType_UINT8, {2, 1, 1, 2}}, {TensorType_FLOAT32, {}}, 1, 1);
+
+ m.SetInput({
+ // First batch
+ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, // row = 1
+ 1, 1, 1, 1, 1, 1, 1, 1, // row = 2
+ // Second batch
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2, // row = 1
+ 0.5, 0.5, 1, 1, 1.5, 1.5, 2, 2 // row = 2
+ });
+
+ m.SetFilter({
+ 1, 2, // first filter
+ 2, 3, // second filter
+ });
+ m.SetBias({0});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ 1.5, 2.5, 1.5, 2.5, 1.5, 2.5, 1.5, 2.5, 3., 5., 3.,
+ 5., 3., 5., 3., 5., 1.5, 2.5, 3., 5., 4.5, 7.5,
+ 6., 10., 1.5, 2.5, 3., 5., 4.5, 7.5, 6., 10.,
+ },
+ 0.0474)));
+}
+
INSTANTIATE_TEST_CASE_P(
ConvolutionOpTest, ConvolutionOpTest,
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
index 21518156b8..347515f289 100644
--- a/tensorflow/contrib/lite/kernels/depthwise_conv.cc
+++ b/tensorflow/contrib/lite/kernels/depthwise_conv.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
diff --git a/tensorflow/contrib/lite/kernels/dequantize.cc b/tensorflow/contrib/lite/kernels/dequantize.cc
index 2b0f04489a..3a08f48b00 100644
--- a/tensorflow/contrib/lite/kernels/dequantize.cc
+++ b/tensorflow/contrib/lite/kernels/dequantize.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess.cc b/tensorflow/contrib/lite/kernels/detection_postprocess.cc
index d7bde0ff79..d2906632d7 100644
--- a/tensorflow/contrib/lite/kernels/detection_postprocess.cc
+++ b/tensorflow/contrib/lite/kernels/detection_postprocess.cc
@@ -15,9 +15,9 @@ limitations under the License.
#include <string.h>
#include <numeric>
#include <vector>
-#include "flatbuffers/flexbuffers.h"
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc b/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc
index 4e0f8484a3..94c91a6bd6 100644
--- a/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc
+++ b/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
-#include "flatbuffers/flexbuffers.h"
+#include "flatbuffers/flexbuffers.h" // flatbuffers
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/div.cc b/tensorflow/contrib/lite/kernels/div.cc
index d7420ddd8e..7945c095b1 100644
--- a/tensorflow/contrib/lite/kernels/div.cc
+++ b/tensorflow/contrib/lite/kernels/div.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/eigen_support.h b/tensorflow/contrib/lite/kernels/eigen_support.h
index ec77856b10..feb1543f7b 100644
--- a/tensorflow/contrib/lite/kernels/eigen_support.h
+++ b/tensorflow/contrib/lite/kernels/eigen_support.h
@@ -15,10 +15,10 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_EIGEN_SUPPORT_H_
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace EigenForTFLite {
-class ThreadPoolDevice;
+struct ThreadPoolDevice;
}
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/elementwise.cc b/tensorflow/contrib/lite/kernels/elementwise.cc
index e19779ea59..04995d70dd 100644
--- a/tensorflow/contrib/lite/kernels/elementwise.cc
+++ b/tensorflow/contrib/lite/kernels/elementwise.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include <cmath>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup.cc b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
index b2dff87e62..fe33f98eb0 100644
--- a/tensorflow/contrib/lite/kernels/embedding_lookup.cc
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup.cc
@@ -37,8 +37,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc
index d3be36993c..aa75b03990 100644
--- a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc
+++ b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse.cc
@@ -65,8 +65,8 @@ limitations under the License.
#include <algorithm>
#include <cmath>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/exp.cc b/tensorflow/contrib/lite/kernels/exp.cc
index ce03cdfe26..673e7be90a 100644
--- a/tensorflow/contrib/lite/kernels/exp.cc
+++ b/tensorflow/contrib/lite/kernels/exp.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/expand_dims.cc b/tensorflow/contrib/lite/kernels/expand_dims.cc
index ed33012864..fa1140b19c 100644
--- a/tensorflow/contrib/lite/kernels/expand_dims.cc
+++ b/tensorflow/contrib/lite/kernels/expand_dims.cc
@@ -15,8 +15,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/expand_dims_test.cc b/tensorflow/contrib/lite/kernels/expand_dims_test.cc
index 50dc860e5a..a3bc1813db 100644
--- a/tensorflow/contrib/lite/kernels/expand_dims_test.cc
+++ b/tensorflow/contrib/lite/kernels/expand_dims_test.cc
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/fake_quant.cc b/tensorflow/contrib/lite/kernels/fake_quant.cc
index 0ef1a50b30..f9bc3747cb 100644
--- a/tensorflow/contrib/lite/kernels/fake_quant.cc
+++ b/tensorflow/contrib/lite/kernels/fake_quant.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/floor.cc b/tensorflow/contrib/lite/kernels/floor.cc
index 697b777693..59ff77f35b 100644
--- a/tensorflow/contrib/lite/kernels/floor.cc
+++ b/tensorflow/contrib/lite/kernels/floor.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -41,8 +41,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- optimized_ops::Floor(GetTensorData<float>(input), GetTensorDims(input),
- GetTensorData<float>(output), GetTensorDims(output));
+ optimized_ops::Floor(GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(output), GetTensorData<float>(output));
+
return kTfLiteOk;
}
} // namespace floor
diff --git a/tensorflow/contrib/lite/kernels/floor_div.cc b/tensorflow/contrib/lite/kernels/floor_div.cc
new file mode 100644
index 0000000000..5d62cd2755
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/floor_div.cc
@@ -0,0 +1,146 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace floor_div {
+namespace {
+
+// Input/output tensor index.
+constexpr int kInputTensor1 = 0;
+constexpr int kInputTensor2 = 1;
+constexpr int kOutputTensor = 0;
+
+// Op data for floor_div op.
+struct OpData {
+ bool requires_broadcast;
+};
+
+template <typename T>
+T FloorDiv(T input1, T input2) {
+ return std::floor(std::divides<double>()(static_cast<double>(input1),
+ static_cast<double>(input2)));
+}
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* data = new OpData;
+ data->requires_broadcast = false;
+ return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ // Reinterprete the opaque data provided by user.
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
+
+ const TfLiteType type = input1->type;
+ if (type != kTfLiteInt32) {
+ context->ReportError(context, "Currently floor_div only supports int32.");
+ return kTfLiteError;
+ }
+ output->type = type;
+
+ data->requires_broadcast = !HaveSameShapes(input1, input2);
+
+ TfLiteIntArray* output_size = nullptr;
+ if (data->requires_broadcast) {
+ TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
+ context, input1, input2, &output_size));
+ } else {
+ output_size = TfLiteIntArrayCopy(input1->dims);
+ }
+
+ return context->ResizeTensor(context, output, output_size);
+}
+
+template <typename T>
+TfLiteStatus EvalImpl(TfLiteContext* context, bool requires_broadcast,
+ const TfLiteTensor* input1, const TfLiteTensor* input2,
+ TfLiteTensor* output) {
+ const T* denominator_data = GetTensorData<T>(input2);
+
+ // Validate the denominator.
+ for (int i = 0; i < NumElements(input2); ++i) {
+ if (std::equal_to<T>()(denominator_data[i], 0)) {
+ context->ReportError(context, "Division by 0");
+ return kTfLiteError;
+ }
+ }
+ if (requires_broadcast) {
+ reference_ops::BroadcastBinaryFunction4DSlow<T, T, T>(
+ GetTensorShape(input1), GetTensorData<T>(input1),
+ GetTensorShape(input2), denominator_data, GetTensorShape(output),
+ GetTensorData<T>(output), FloorDiv<T>);
+ } else {
+ reference_ops::BinaryFunction<T, T, T>(
+ GetTensorShape(input1), GetTensorData<T>(input1),
+ GetTensorShape(input2), GetTensorData<T>(input2),
+ GetTensorShape(output), GetTensorData<T>(output), FloorDiv<T>);
+ }
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ OpData* data = reinterpret_cast<OpData*>(node->user_data);
+
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ switch (input1->type) {
+ case kTfLiteInt32: {
+ return EvalImpl<int32_t>(context, data->requires_broadcast, input1,
+ input2, output);
+ }
+ default: {
+ context->ReportError(context, "Currently floor_div only supports int32.");
+ return kTfLiteError;
+ }
+ }
+}
+
+} // namespace
+} // namespace floor_div
+
+TfLiteRegistration* Register_FLOOR_DIV() {
+ // Init, Free, Prepare, Eval are satisfying the Interface required by
+ // TfLiteRegistration.
+ static TfLiteRegistration r = {floor_div::Init, floor_div::Free,
+ floor_div::Prepare, floor_div::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/floor_div_test.cc b/tensorflow/contrib/lite/kernels/floor_div_test.cc
new file mode 100644
index 0000000000..eea69b61ac
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/floor_div_test.cc
@@ -0,0 +1,90 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+
+template <typename T>
+class FloorDivModel : public SingleOpModel {
+ public:
+ FloorDivModel(const TensorData& input1, const TensorData& input2,
+ const TensorData& output) {
+ input1_ = AddInput(input1);
+ input2_ = AddInput(input2);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_FLOOR_DIV, BuiltinOptions_FloorDivOptions,
+ CreateFloorDivOptions(builder_).Union());
+ BuildInterpreter({GetShape(input1_), GetShape(input2_)});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+
+ std::vector<T> GetOutput() { return ExtractVector<T>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input1_;
+ int input2_;
+ int output_;
+};
+
+TEST(PowOpModel, Simple) {
+ FloorDivModel<int32_t> model({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {}});
+ model.PopulateTensor<int32_t>(model.input1(), {10, 9, 11, 3});
+ model.PopulateTensor<int32_t>(model.input2(), {2, 2, 3, 4});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(5, 4, 3, 0));
+}
+
+TEST(PowOpModel, NegativeValue) {
+ FloorDivModel<int32_t> model({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {}});
+ model.PopulateTensor<int32_t>(model.input1(), {10, -9, -11, 7});
+ model.PopulateTensor<int32_t>(model.input2(), {2, 2, -3, -4});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(5, -5, 3, -2));
+}
+
+TEST(PowOpModel, BroadcastFloorDiv) {
+ FloorDivModel<int32_t> model({TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1}}, {TensorType_INT32, {}});
+ model.PopulateTensor<int32_t>(model.input1(), {10, -9, -11, 7});
+ model.PopulateTensor<int32_t>(model.input2(), {-3});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutputShape(), ElementsAre(1, 2, 2, 1));
+ EXPECT_THAT(model.GetOutput(), ElementsAre(-4, 3, 3, -3));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc
index eaf5a67d67..7a71fcc219 100644
--- a/tensorflow/contrib/lite/kernels/fully_connected.cc
+++ b/tensorflow/contrib/lite/kernels/fully_connected.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/gemm_support.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/gather.cc b/tensorflow/contrib/lite/kernels/gather.cc
index 2b2a9e6620..badd2de11a 100644
--- a/tensorflow/contrib/lite/kernels/gather.cc
+++ b/tensorflow/contrib/lite/kernels/gather.cc
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/gather_test.cc b/tensorflow/contrib/lite/kernels/gather_test.cc
index 1d4292955c..1b48884e09 100644
--- a/tensorflow/contrib/lite/kernels/gather_test.cc
+++ b/tensorflow/contrib/lite/kernels/gather_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/gemm_support.h b/tensorflow/contrib/lite/kernels/gemm_support.h
index 37af772c68..43cd2b3055 100644
--- a/tensorflow/contrib/lite/kernels/gemm_support.h
+++ b/tensorflow/contrib/lite/kernels/gemm_support.h
@@ -16,7 +16,7 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_
#include "public/gemmlowp.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
namespace gemm_support {
diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
index f37c66acb3..c0b3c3c0c5 100644
--- a/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
+++ b/tensorflow/contrib/lite/kernels/hashtable_lookup.cc
@@ -39,8 +39,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
#include "tensorflow/contrib/lite/string_util.h"
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
index 96798c900e..a6fd4ac2dd 100644
--- a/tensorflow/contrib/lite/kernels/internal/BUILD
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -160,9 +160,10 @@ cc_library(
":types",
":reference_base",
":round",
+ ":tensor_utils",
"//third_party/eigen3",
"@gemmlowp",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
] + select({
":haswell": tflite_deps_intel,
":ios_x86_64": tflite_deps_intel,
@@ -191,12 +192,13 @@ cc_library(
deps = [
":quantization_util",
":strided_slice_logic",
+ ":tensor_utils",
":types",
":legacy_reference_base",
":round",
"//third_party/eigen3",
"@gemmlowp",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
] + select({
":haswell": tflite_deps_intel,
":ios_x86_64": tflite_deps_intel,
@@ -218,13 +220,15 @@ cc_library(
"optimized/eigen_spatial_convolutions.h",
"optimized/eigen_tensor_reduced_instantiations_oss.h",
"optimized/multithreaded_conv.h",
+ # FIXME(petewarden) - This should be removed, since it's a header from the
+ # :tensor dependency below.
"tensor.h",
],
deps = [
":optimized_base",
+ ":tensor",
":types",
- "//tensorflow/contrib/lite:builtin_op_data",
- "//tensorflow/contrib/lite:context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//third_party/eigen3",
],
)
@@ -234,7 +238,7 @@ cc_test(
srcs = ["tensor_test.cc"],
tags = ["no_oss"],
deps = [
- ":reference",
+ ":tensor",
"@com_google_googletest//:gtest",
],
)
@@ -294,7 +298,7 @@ cc_library(
":strided_slice_logic",
":types",
"@gemmlowp",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
] + select({
":haswell": tflite_deps_intel,
":ios_x86_64": tflite_deps_intel,
@@ -324,7 +328,7 @@ cc_library(
":strided_slice_logic",
":types",
"@gemmlowp",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
] + select({
":haswell": tflite_deps_intel,
":ios_x86_64": tflite_deps_intel,
@@ -339,11 +343,27 @@ cc_library(
)
cc_library(
+ name = "tensor",
+ hdrs = [
+ "tensor.h",
+ "tensor_ctypes.h",
+ ],
+ deps = [
+ ":types",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ ],
+)
+
+# Deprecated version of :tensor, kept for backwards compatibility.
+cc_library(
name = "reference",
- hdrs = ["tensor.h"],
+ hdrs = [
+ "tensor.h",
+ "tensor_ctypes.h",
+ ],
deps = [
":types",
- "//tensorflow/contrib/lite:context",
+ "//tensorflow/contrib/lite/c:c_api_internal",
],
)
@@ -357,7 +377,7 @@ cc_library(
],
deps = [
":round",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:activation_functor",
"//tensorflow/contrib/lite/kernels:op_macros",
],
@@ -382,7 +402,7 @@ cc_library(
":cpu_check",
":round",
":types",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:activation_functor",
"//tensorflow/contrib/lite/kernels:op_macros",
"@arm_neon_2_x86_sse",
@@ -396,7 +416,7 @@ cc_library(
hdrs = ["kernel_utils.h"],
deps = [
":tensor_utils",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
],
)
@@ -439,7 +459,7 @@ cc_library(
copts = NEON_FLAGS_IF_APPLICABLE,
deps = [
"//tensorflow/contrib/lite/kernels:activation_functor",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"@arm_neon_2_x86_sse",
"@gemmlowp",
] + select({
@@ -515,7 +535,7 @@ cc_test(
],
deps = [
":tensor_utils",
- "//tensorflow/contrib/lite:builtin_op_data",
+ "//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:test_util",
"@com_google_googletest//:gtest_main",
],
diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h
index eb4d0108bd..e67fee11b8 100644
--- a/tensorflow/contrib/lite/kernels/internal/common.h
+++ b/tensorflow/contrib/lite/kernels/internal/common.h
@@ -45,7 +45,7 @@ limitations under the License.
#endif
#endif
-#include "public/gemmlowp.h"
+#include "fixedpoint/fixedpoint.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
index 200f2f1515..56e9367878 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
@@ -14,8 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
-#include <algorithm>
-
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
namespace tflite {
@@ -26,6 +24,21 @@ void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
int input_size, int num_units, int batch_size,
TfLiteFusedActivation activation,
float* hidden_state_ptr_batch, float* output_ptr_batch) {
+ RnnBatchStep(input_ptr_batch, input_weights_ptr,
+ /*aux_input_ptr_batch=*/nullptr,
+ /*aux_input_weights_ptr=*/nullptr, recurrent_weights_ptr,
+ bias_ptr, input_size, /*aux_input_size=*/0, num_units,
+ batch_size, activation, hidden_state_ptr_batch,
+ output_ptr_batch);
+}
+
+void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
+ const float* aux_input_ptr_batch,
+ const float* aux_input_weights_ptr,
+ const float* recurrent_weights_ptr, const float* bias_ptr,
+ int input_size, int aux_input_size, int num_units,
+ int batch_size, TfLiteFusedActivation activation,
+ float* hidden_state_ptr_batch, float* output_ptr_batch) {
// Output = bias
tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size,
output_ptr_batch);
@@ -33,6 +46,12 @@ void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
input_weights_ptr, num_units, input_size, input_ptr_batch, batch_size,
output_ptr_batch, /*result_stride=*/1);
+ // Output += aux_input * aux_input_weights (if they are not empty).
+ if (aux_input_size > 0) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_weights_ptr, num_units, aux_input_size, aux_input_ptr_batch,
+ batch_size, output_ptr_batch, /*result_stride=*/1);
+ }
// Output += recurrent_weights * hidden_state
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
recurrent_weights_ptr, num_units, num_units, hidden_state_ptr_batch,
@@ -54,6 +73,28 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
int8_t* quantized_hidden_state_ptr_batch,
float* scaling_factors, float* hidden_state_ptr_batch,
float* output_ptr_batch) {
+ RnnBatchStep(input_ptr_batch, input_weights_ptr, input_weights_scale,
+ /*aux_input_ptr_batch=*/nullptr,
+ /*aux_input_weights_ptr=*/nullptr,
+ /*aux_input_weights_scale=*/0.0f, recurrent_weights_ptr,
+ recurrent_weights_scale, bias_ptr, input_size,
+ /*aux_input_size=*/0, num_units, batch_size, activation,
+ quantized_input_ptr_batch,
+ /*aux_quantized_input_ptr_batch=*/nullptr,
+ quantized_hidden_state_ptr_batch, scaling_factors,
+ hidden_state_ptr_batch, output_ptr_batch);
+}
+
+void RnnBatchStep(
+ const float* input_ptr_batch, const int8_t* input_weights_ptr,
+ float input_weights_scale, const float* aux_input_ptr_batch,
+ const int8_t* aux_input_weights_ptr, float aux_input_weights_scale,
+ const int8_t* recurrent_weights_ptr, float recurrent_weights_scale,
+ const float* bias_ptr, int input_size, int aux_input_size, int num_units,
+ int batch_size, TfLiteFusedActivation activation,
+ int8_t* quantized_input_ptr_batch, int8_t* aux_quantized_input_ptr_batch,
+ int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
+ float* hidden_state_ptr_batch, float* output_ptr_batch) {
// Output = bias
tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size,
output_ptr_batch);
@@ -80,6 +121,26 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
scaling_factors, batch_size, output_ptr_batch, /*result_stride=*/1);
}
+ if (aux_input_ptr_batch &&
+ !tensor_utils::IsZeroVector(aux_input_ptr_batch,
+ batch_size * aux_input_size)) {
+ float unused_min, unused_max;
+ for (int b = 0; b < batch_size; ++b) {
+ const int offset = b * aux_input_size;
+ tensor_utils::SymmetricQuantizeFloats(
+ aux_input_ptr_batch + offset, aux_input_size,
+ aux_quantized_input_ptr_batch + offset, &unused_min, &unused_max,
+ &scaling_factors[b]);
+ scaling_factors[b] *= aux_input_weights_scale;
+ }
+
+ // Output += aux_input * aux_input_weights
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_weights_ptr, num_units, aux_input_size,
+ aux_quantized_input_ptr_batch, scaling_factors, batch_size,
+ output_ptr_batch, /*result_stride=*/1);
+ }
+
// Save quantization and matmul computation for all zero input.
if (!tensor_utils::IsZeroVector(hidden_state_ptr_batch,
batch_size * num_units)) {
@@ -127,6 +188,47 @@ void LstmStep(
float* cell_state_ptr, float* input_gate_scratch,
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
float* output_ptr_batch) {
+ LstmStepWithAuxInput(
+ input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr,
+ input_to_cell_weights_ptr, input_to_output_weights_ptr,
+ /*aux_input_ptr_batch=*/nullptr,
+ /*aux_input_to_input_weights_ptr=*/nullptr,
+ /*aux_input_to_forget_weights_ptr=*/nullptr,
+ /*aux_input_to_cell_weights_ptr=*/nullptr,
+ /*aux_input_to_output_weights_ptr=*/nullptr,
+ recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
+ recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
+ cell_to_input_weights_ptr, cell_to_forget_weights_ptr,
+ cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr,
+ cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
+ projection_bias_ptr, params, n_batch, n_cell, n_input, /*n_aux_input=*/0,
+ n_output, output_state_ptr, cell_state_ptr, input_gate_scratch,
+ forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr_batch);
+}
+
+void LstmStepWithAuxInput(
+ const float* input_ptr_batch, const float* input_to_input_weights_ptr,
+ const float* input_to_forget_weights_ptr,
+ const float* input_to_cell_weights_ptr,
+ const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch,
+ const float* aux_input_to_input_weights_ptr,
+ const float* aux_input_to_forget_weights_ptr,
+ const float* aux_input_to_cell_weights_ptr,
+ const float* aux_input_to_output_weights_ptr,
+ const float* recurrent_to_input_weights_ptr,
+ const float* recurrent_to_forget_weights_ptr,
+ const float* recurrent_to_cell_weights_ptr,
+ const float* recurrent_to_output_weights_ptr,
+ const float* cell_to_input_weights_ptr,
+ const float* cell_to_forget_weights_ptr,
+ const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
+ const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+ const float* output_gate_bias_ptr, const float* projection_weights_ptr,
+ const float* projection_bias_ptr, const TfLiteLSTMParams* params,
+ int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
+ float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
+ float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
+ float* output_ptr_batch) {
// Since we have already checked that weights are all there or none, we can
// check the existense of only one to the get the condition.
const bool use_cifg = (input_to_input_weights_ptr == nullptr);
@@ -160,6 +262,26 @@ void LstmStep(
input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
output_gate_scratch, /*result_stride=*/1);
+ // If auxiliary input is available then compute aux_input_weight * aux_input
+ if (aux_input_ptr_batch != nullptr) {
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_input_weights_ptr, n_cell, n_aux_input,
+ aux_input_ptr_batch, n_batch, input_gate_scratch,
+ /*result_stride=*/1);
+ }
+
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
+ aux_input_ptr_batch, n_batch, forget_gate_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr_batch,
+ n_batch, cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_output_weights_ptr, n_cell, n_aux_input,
+ aux_input_ptr_batch, n_batch, output_gate_scratch, /*result_stride=*/1);
+ }
+
// For each batch and cell: compute recurrent_weight * output_state.
if (!use_cifg) {
tensor_utils::MatrixBatchVectorMultiplyAccumulate(
@@ -286,227 +408,364 @@ void LstmStep(
int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr,
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
float* cell_state_ptr, float* output_ptr_batch) {
- // Since we have already checked that weights are all there or none, we can
- // check the existense of only one to the get the condition.
- const bool use_cifg = (input_to_input_weights_ptr == nullptr);
- const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
- // Initialize scratch buffers with bias.
- if (!use_cifg) {
- tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch,
- input_gate_scratch);
- }
- tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
- forget_gate_scratch);
- tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
- cell_scratch);
- tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
- output_gate_scratch);
-
- if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
- // Save quantization and matmul computation for all zero input.
- float unused_min, unused_max;
- for (int b = 0; b < n_batch; ++b) {
- const int offset = b * n_input;
- tensor_utils::SymmetricQuantizeFloats(
- input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset,
- &unused_min, &unused_max, &scaling_factors[b]);
- }
- // For each batch and cell: compute input_weight * input.
- if (!use_cifg) {
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_input_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_input_weights_ptr, n_cell, n_input,
- quantized_input_ptr_batch, product_scaling_factors, n_batch,
- input_gate_scratch, /*result_stride=*/1);
+ LstmStepWithAuxInput(
+ input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale,
+ input_to_forget_weights_ptr, input_to_forget_weights_scale,
+ input_to_cell_weights_ptr, input_to_cell_weights_scale,
+ input_to_output_weights_ptr, input_to_output_weights_scale,
+ /*aux_input_ptr_batch=*/nullptr,
+ /*aux_input_to_input_weights_ptr=*/nullptr,
+ /*aux_input_to_input_weights_scale=*/0.0f,
+ /*aux_input_to_forget_weights_ptr=*/nullptr,
+ /*aux_input_to_forget_weights_scale=*/0.0f,
+ /*aux_input_to_cell_weights_ptr=*/nullptr,
+ /*aux_input_to_cell_weights_scale=*/0.0f,
+ /*aux_input_to_output_weights_ptr=*/nullptr,
+ /*aux_input_to_output_weights_scale=*/0.0f,
+ recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale,
+ recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale,
+ recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale,
+ recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale,
+ cell_to_input_weights_ptr, cell_to_input_weights_scale,
+ cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
+ cell_to_output_weights_ptr, cell_to_output_weights_scale,
+ input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
+ output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale,
+ projection_bias_ptr, params, n_batch, n_cell, n_input,
+ /*n_aux_input=*/0, n_output, input_gate_scratch, forget_gate_scratch,
+ cell_scratch, output_gate_scratch, scaling_factors,
+ product_scaling_factors, recovered_cell_weights,
+ quantized_input_ptr_batch,
+ /*quantized_aux_input_ptr_batch=*/nullptr, quantized_output_state_ptr,
+ quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
+ output_ptr_batch);
}
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_forget_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
- product_scaling_factors, n_batch, forget_gate_scratch,
- /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_cell_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
- product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_output_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
- product_scaling_factors, n_batch, output_gate_scratch,
- /*result_stride=*/1);
- }
-
- if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
- // Save quantization and matmul computation for all zero input.
- float unused_min, unused_max;
- for (int b = 0; b < n_batch; ++b) {
- const int offset = b * n_output;
- tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output,
- quantized_output_state_ptr + offset,
- &unused_min, &unused_max,
- &scaling_factors[b]);
- }
- // For each batch and cell: compute recurrent_weight * output_state.
- if (!use_cifg) {
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_input_weights_scale;
+ void LstmStepWithAuxInput(
+ const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
+ float input_to_input_weights_scale,
+ const int8_t* input_to_forget_weights_ptr,
+ float input_to_forget_weights_scale,
+ const int8_t* input_to_cell_weights_ptr,
+ float input_to_cell_weights_scale,
+ const int8_t* input_to_output_weights_ptr,
+ float input_to_output_weights_scale, const float* aux_input_ptr_batch,
+ const int8_t* aux_input_to_input_weights_ptr,
+ float aux_input_to_input_weights_scale,
+ const int8_t* aux_input_to_forget_weights_ptr,
+ float aux_input_to_forget_weights_scale,
+ const int8_t* aux_input_to_cell_weights_ptr,
+ float aux_input_to_cell_weights_scale,
+ const int8_t* aux_input_to_output_weights_ptr,
+ float aux_input_to_output_weights_scale,
+ const int8_t* recurrent_to_input_weights_ptr,
+ float recurrent_to_input_weights_scale,
+ const int8_t* recurrent_to_forget_weights_ptr,
+ float recurrent_to_forget_weights_scale,
+ const int8_t* recurrent_to_cell_weights_ptr,
+ float recurrent_to_cell_weights_scale,
+ const int8_t* recurrent_to_output_weights_ptr,
+ float recurrent_to_output_weights_scale,
+ const int8_t* cell_to_input_weights_ptr,
+ float cell_to_input_weights_scale,
+ const int8_t* cell_to_forget_weights_ptr,
+ float cell_to_forget_weights_scale,
+ const int8_t* cell_to_output_weights_ptr,
+ float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
+ const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+ const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
+ float projection_weights_scale, const float* projection_bias_ptr,
+ const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
+ int n_aux_input, int n_output, float* input_gate_scratch,
+ float* forget_gate_scratch, float* cell_scratch,
+ float* output_gate_scratch, float* scaling_factors,
+ float* product_scaling_factors, float* recovered_cell_weights,
+ int8_t* quantized_input_ptr_batch,
+ int8_t* quantized_aux_input_ptr_batch,
+ int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr,
+ float* output_state_ptr, float* cell_state_ptr,
+ float* output_ptr_batch) {
+ // Since we have already checked that weights are all there or none, we
+ // can check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights_ptr == nullptr);
+ const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
+ // Initialize scratch buffers with bias.
+ if (!use_cifg) {
+ tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell,
+ n_batch, input_gate_scratch);
+ }
+ tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell,
+ n_batch, forget_gate_scratch);
+ tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
+ cell_scratch);
+ tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell,
+ n_batch, output_gate_scratch);
+
+ if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_input;
+ tensor_utils::SymmetricQuantizeFloats(
+ input_ptr_batch + offset, n_input,
+ quantized_input_ptr_batch + offset, &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_input_weights_ptr, n_cell, n_input,
+ quantized_input_ptr_batch, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_forget_weights_ptr, n_cell, n_input,
+ quantized_input_ptr_batch, product_scaling_factors, n_batch,
+ forget_gate_scratch,
+ /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_cell_weights_ptr, n_cell, n_input,
+ quantized_input_ptr_batch, product_scaling_factors, n_batch,
+ cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_output_weights_ptr, n_cell, n_input,
+ quantized_input_ptr_batch, product_scaling_factors, n_batch,
+ output_gate_scratch,
+ /*result_stride=*/1);
}
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_input_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- input_gate_scratch, /*result_stride=*/1);
- }
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_forget_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_forget_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- forget_gate_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_cell_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_cell_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- cell_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_output_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_output_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- output_gate_scratch, /*result_stride=*/1);
- }
-
- // Save quantization and matmul computation for all zero input.
- bool is_cell_state_all_zeros =
- tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
- // For each batch and cell: update input gate.
- if (!use_cifg) {
- if (use_peephole && !is_cell_state_all_zeros) {
- tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell,
- cell_to_input_weights_scale,
- recovered_cell_weights);
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
- input_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
- input_gate_scratch);
- }
+ if (aux_input_ptr_batch != nullptr &&
+ !tensor_utils::IsZeroVector(aux_input_ptr_batch, n_batch * n_input)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_input;
+ tensor_utils::SymmetricQuantizeFloats(
+ aux_input_ptr_batch + offset, n_input,
+ quantized_aux_input_ptr_batch + offset, &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * aux_input_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_input_weights_ptr, n_cell, n_input,
+ quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * aux_input_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_forget_weights_ptr, n_cell, n_input,
+ quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
+ forget_gate_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * aux_input_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_cell_weights_ptr, n_cell, n_input,
+ quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
+ cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * aux_input_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_output_weights_ptr, n_cell, n_input,
+ quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
+ output_gate_scratch, /*result_stride=*/1);
+ }
- // For each batch and cell: update forget gate.
- if (use_peephole && !is_cell_state_all_zeros) {
- tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell,
- cell_to_forget_weights_scale,
- recovered_cell_weights);
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
- forget_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
- forget_gate_scratch);
+ if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_output;
+ tensor_utils::SymmetricQuantizeFloats(
+ output_state_ptr + offset, n_output,
+ quantized_output_state_ptr + offset, &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ // For each batch and cell: compute recurrent_weight * output_state.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_input_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_forget_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ forget_gate_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_cell_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_output_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ output_gate_scratch, /*result_stride=*/1);
+ }
- // For each batch and cell: update the cell.
- tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
- n_batch * n_cell, cell_state_ptr);
- tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
- params->activation, cell_scratch);
- if (use_cifg) {
- tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
- forget_gate_scratch);
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
- } else {
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
- }
- if (params->cell_clip > 0.0) {
- tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
- params->cell_clip, cell_state_ptr);
- }
+ // Save quantization and matmul computation for all zero input.
+ bool is_cell_state_all_zeros =
+ tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
+
+ // For each batch and cell: update input gate.
+ if (!use_cifg) {
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell,
+ cell_to_input_weights_scale,
+ recovered_cell_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
+ input_gate_scratch);
+ }
- is_cell_state_all_zeros =
- tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
- // For each batch and cell: update the output gate.
- if (use_peephole && !is_cell_state_all_zeros) {
- tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell,
- cell_to_output_weights_scale,
- recovered_cell_weights);
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
- output_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
- output_gate_scratch);
- tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
- params->activation, cell_scratch);
- tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
- n_batch * n_cell, output_gate_scratch);
+ // For each batch and cell: update forget gate.
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell,
+ cell_to_forget_weights_scale,
+ recovered_cell_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
+ forget_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
+ forget_gate_scratch);
+
+ // For each batch and cell: update the cell.
+ tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch,
+ cell_state_ptr, n_batch * n_cell,
+ cell_state_ptr);
+ tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
+ params->activation, cell_scratch);
+ if (use_cifg) {
+ tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
+ forget_gate_scratch);
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, forget_gate_scratch, n_batch * n_cell,
+ cell_state_ptr);
+ } else {
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ }
+ if (params->cell_clip > 0.0) {
+ tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
+ params->cell_clip, cell_state_ptr);
+ }
- // For each batch: update the projection and output_state.
- const bool use_projection_weight = (projection_weights_ptr != nullptr);
- const bool use_projection_bias = (projection_bias_ptr != nullptr);
- if (use_projection_weight) {
- if (use_projection_bias) {
- tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
- n_batch, output_ptr_batch);
- } else {
- tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
- }
- if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) {
- // Save quantization and matmul computation for all zero input.
- float unused_min, unused_max;
- for (int b = 0; b < n_batch; ++b) {
- const int offset = b * n_cell;
- tensor_utils::SymmetricQuantizeFloats(
- output_gate_scratch + offset, n_cell,
- quantized_cell_state_ptr + offset, &unused_min, &unused_max,
- &scaling_factors[b]);
+ is_cell_state_all_zeros =
+ tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
+ // For each batch and cell: update the output gate.
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell,
+ cell_to_output_weights_scale,
+ recovered_cell_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
+ output_gate_scratch);
}
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * projection_weights_scale;
+ tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
+ output_gate_scratch);
+ tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
+ params->activation, cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
+ n_batch * n_cell,
+ output_gate_scratch);
+
+ // For each batch: update the projection and output_state.
+ const bool use_projection_weight = (projection_weights_ptr != nullptr);
+ const bool use_projection_bias = (projection_bias_ptr != nullptr);
+ if (use_projection_weight) {
+ if (use_projection_bias) {
+ tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
+ n_batch, output_ptr_batch);
+ } else {
+ tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
+ }
+ if (!tensor_utils::IsZeroVector(output_gate_scratch,
+ n_batch * n_cell)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_cell;
+ tensor_utils::SymmetricQuantizeFloats(
+ output_gate_scratch + offset, n_cell,
+ quantized_cell_state_ptr + offset, &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * projection_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ projection_weights_ptr, n_output, n_cell,
+ quantized_cell_state_ptr, product_scaling_factors, n_batch,
+ output_ptr_batch,
+ /*result_stride=*/1);
+ }
+ if (params->proj_clip > 0.0) {
+ tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
+ params->proj_clip, output_ptr_batch);
+ }
+ } else {
+ tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
+ output_ptr_batch);
}
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr,
- product_scaling_factors, n_batch, output_ptr_batch,
- /*result_stride=*/1);
+ tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
+ output_state_ptr);
}
- if (params->proj_clip > 0.0) {
- tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
- params->proj_clip, output_ptr_batch);
- }
- } else {
- tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
- output_ptr_batch);
- }
- tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
- output_state_ptr);
-}
} // namespace kernel_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
index 2a11b37a60..b5558cce55 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
namespace tflite {
namespace kernel_utils {
@@ -35,6 +35,15 @@ void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
TfLiteFusedActivation activation,
float* hidden_state_ptr_batch, float* output_ptr_batch);
+// Same as above but includes an auxiliary input with the corresponding weights.
+void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr,
+ const float* aux_input_ptr_batch,
+ const float* aux_input_weights_ptr,
+ const float* recurrent_weights_ptr, const float* bias_ptr,
+ int input_size, int aux_input_size, int num_units,
+ int batch_size, TfLiteFusedActivation activation,
+ float* hidden_state_ptr_batch, float* output_ptr_batch);
+
// Performs a quantized RNN batch inference step. Same as above, but for
// quantization purposes, we also pass in quantized_hidden_state_ptr_batch and
// quantized_input_ptr_batch pointers for temporary storage of the quantized
@@ -56,6 +65,17 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
float* scaling_factors, float* hidden_state_ptr_batch,
float* output_ptr_batch);
+void RnnBatchStep(
+ const float* input_ptr_batch, const int8_t* input_weights_ptr,
+ float input_weights_scale, const float* aux_input_ptr_batch,
+ const int8_t* aux_input_weights_ptr, float aux_input_weights_scale,
+ const int8_t* recurrent_weights_ptr, float recurrent_weights_scale,
+ const float* bias_ptr, int input_size, int aux_input_size, int num_units,
+ int batch_size, TfLiteFusedActivation activation,
+ int8_t* quantized_input_ptr_batch, int8_t* aux_quantized_input_ptr_batch,
+ int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
+ float* hidden_state_ptr_batch, float* output_ptr_batch);
+
// Performs an LSTM batch inference step for input specified by input_ptr_batch.
// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
// biases (*_bias_ptr), and buffers (*_scratch), along with additional
@@ -66,8 +86,7 @@ void RnnBatchStep(const float* input_ptr_batch, const int8_t* input_weights_ptr,
// - n_input: the input size,
// - n_output: the output size.
//
-// The pointers to the cell and output state and the output are updated. Unless
-// projection is specified output and output state contain the same data.
+// The pointers to the cell and output state and the output are updated.
//
// The pointers with the suffix "_batch" point to data aligned in batch_major
// order, and each step processes batch_size many inputs from input_ptr_batch,
@@ -92,6 +111,31 @@ void LstmStep(
float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
float* output_ptr_batch);
+// Same as above but includes an auxiliary input with the corresponding weights.
+void LstmStepWithAuxInput(
+ const float* input_ptr_batch, const float* input_to_input_weights_ptr,
+ const float* input_to_forget_weights_ptr,
+ const float* input_to_cell_weights_ptr,
+ const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch,
+ const float* aux_input_to_input_weights_ptr,
+ const float* aux_input_to_forget_weights_ptr,
+ const float* aux_input_to_cell_weights_ptr,
+ const float* aux_input_to_output_weights_ptr,
+ const float* recurrent_to_input_weights_ptr,
+ const float* recurrent_to_forget_weights_ptr,
+ const float* recurrent_to_cell_weights_ptr,
+ const float* recurrent_to_output_weights_ptr,
+ const float* cell_to_input_weights_ptr,
+ const float* cell_to_forget_weights_ptr,
+ const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
+ const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+ const float* output_gate_bias_ptr, const float* projection_weights_ptr,
+ const float* projection_bias_ptr, const TfLiteLSTMParams* params,
+ int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
+ float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
+ float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
+ float* output_ptr_batch);
+
// Same as above but with quantized weight matrices. In detail:
// Input of size 'n_batch * n_input':
// input_ptr_batch
@@ -175,6 +219,47 @@ void LstmStep(
int8_t* quantized_cell_state_ptr, float* output_state_ptr,
float* cell_state_ptr, float* output_ptr_batch);
+void LstmStepWithAuxInput(
+ const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
+ float input_to_input_weights_scale,
+ const int8_t* input_to_forget_weights_ptr,
+ float input_to_forget_weights_scale,
+ const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
+ const int8_t* input_to_output_weights_ptr,
+ float input_to_output_weights_scale, const float* aux_input_ptr_batch,
+ const int8_t* aux_input_to_input_weights_ptr,
+ float aux_input_to_input_weights_scale,
+ const int8_t* aux_input_to_forget_weights_ptr,
+ float aux_input_to_forget_weights_scale,
+ const int8_t* aux_input_to_cell_weights_ptr,
+ float aux_input_to_cell_weights_scale,
+ const int8_t* aux_input_to_output_weights_ptr,
+ float aux_input_to_output_weights_scale,
+ const int8_t* recurrent_to_input_weights_ptr,
+ float recurrent_to_input_weights_scale,
+ const int8_t* recurrent_to_forget_weights_ptr,
+ float recurrent_to_forget_weights_scale,
+ const int8_t* recurrent_to_cell_weights_ptr,
+ float recurrent_to_cell_weights_scale,
+ const int8_t* recurrent_to_output_weights_ptr,
+ float recurrent_to_output_weights_scale,
+ const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
+ const int8_t* cell_to_forget_weights_ptr,
+ float cell_to_forget_weights_scale,
+ const int8_t* cell_to_output_weights_ptr,
+ float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
+ const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+ const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
+ float projection_weights_scale, const float* projection_bias_ptr,
+ const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
+ int n_aux_input, int n_output, float* input_gate_scratch,
+ float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
+ float* scaling_factors, float* product_scaling_factors,
+ float* recovered_cell_weights, int8_t* quantized_input_ptr_batch,
+ int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr,
+ int8_t* quantized_cell_state_ptr, float* output_state_ptr,
+ float* cell_state_ptr, float* output_ptr_batch);
+
} // namespace kernel_utils
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
index df4d871466..b6151c40b3 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/legacy_optimized_ops.h
@@ -27,8 +27,33 @@ namespace tflite {
namespace optimized_ops {
// Unoptimized reference ops:
+using reference_ops::ArgMax;
using reference_ops::Relu1;
using reference_ops::Relu6;
+using reference_ops::SpaceToBatchND;
+
+template <FusedActivationFunctionType Ac>
+void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ static_assert(Ac == FusedActivationFunctionType::kNone, "");
+ tflite::L2NormalizationParams op_params;
+ // No params need to be set for float, but reserved in signature for future
+ // activations.
+
+ L2Normalization(op_params, input_shape, input_data, output_shape,
+ output_data);
+}
+
+inline void L2Normalization(const uint8* input_data,
+ const RuntimeShape& input_shape,
+ int32 input_zero_point, uint8* output_data,
+ const RuntimeShape& output_shape) {
+ tflite::L2NormalizationParams op_params;
+ op_params.input_zero_point = input_zero_point;
+
+ L2Normalization(op_params, input_shape, input_data, output_shape,
+ output_data);
+}
template <FusedActivationFunctionType Ac>
void L2Normalization(const float* input_data, const Dims<4>& input_dims,
@@ -296,13 +321,17 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
int output_shift, int32 output_activation_min,
int32 output_activation_max, uint8* output_data,
const Dims<4>& output_dims) {
- BroadcastMul4DSlow(
- input1_data, input1_dims, input1_offset, input2_data, input2_dims,
- input2_offset, output_offset, output_multiplier,
- // This legacy version switches the sign of the output shift.
- kReverseShift * output_shift,
- // (Break to highlight preceding line.)
- output_activation_min, output_activation_max, output_data, output_dims);
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+ op_params.input1_offset = input1_offset;
+ op_params.input2_offset = input2_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = kReverseShift * output_shift;
+
+ BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
}
// legacy, for compatibility with old checked-in code
@@ -621,6 +650,294 @@ inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims));
}
+template <typename T>
+inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
+ int block_size, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DepthToSpaceParams op_params;
+ op_params.block_size = block_size;
+
+ DepthToSpace(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
+ int block_size, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::SpaceToDepthParams op_params;
+ op_params.block_size = block_size;
+
+ SpaceToDepth(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void Mul(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void Mul(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
+ output_activation_max, output_data, output_dims);
+}
+
+inline void Mul(const int32* input1_data, const Dims<4>& input1_dims,
+ const int32* input2_data, const Dims<4>& input2_dims,
+ int32 output_activation_min, int32 output_activation_max,
+ int32* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac>
+void Mul(const int32* input1_data, const Dims<4>& input1_dims,
+ const int32* input2_data, const Dims<4>& input2_dims,
+ int32* output_data, const Dims<4>& output_dims) {
+ TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+ tflite::ArithmeticParams op_params;
+ // No parameters needed.
+
+ MulNoActivation(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
+ const int16* input2_data, const Dims<4>& input2_dims,
+ int16* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ // No parameters needed.
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
+ const int16* input2_data, const Dims<4>& input2_dims,
+ int32 output_offset, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ op_params.output_offset = output_offset;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// For compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+inline void BroadcastMul(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ float float_activation_min;
+ float float_activation_max;
+ GetActivationMinMax(Ac, &float_activation_min, &float_activation_max);
+ SetActivationParams(float_activation_min, float_activation_max, &op_params);
+
+ BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// Legacy Dims<4>.
+inline void LocalResponseNormalization(const float* input_data,
+ const Dims<4>& input_dims, int range,
+ float bias, float alpha, float beta,
+ float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::LocalResponseNormalizationParams op_params;
+ op_params.range = range;
+ op_params.bias = bias;
+ op_params.alpha = alpha;
+ op_params.beta = beta;
+
+ LocalResponseNormalization(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// Legacy Dims<4> version.
+template <typename SrcT, typename DstT>
+void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data,
+ const Dims<4>& output_dims) {
+ Cast(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// Legacy Dims<4> version.
+inline void Floor(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Floor(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// Legacy Dims<4>
+inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, float* output_data,
+ const Dims<4>& output_dims, bool align_corners) {
+ tflite::ResizeBilinearParams op_params;
+ op_params.align_corners = align_corners;
+ ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_size_dims), output_size_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// Legacy Dims<4>
+inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, uint8* output_data,
+ const Dims<4>& output_dims, bool align_corners) {
+ tflite::ResizeBilinearParams op_params;
+ op_params.align_corners = align_corners;
+ ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_size_dims), output_size_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, float* output_data,
+ const Dims<4>& output_dims) {
+ ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
+ output_data, output_dims, /*align_corners=*/false);
+}
+
+// legacy, for compatibility with old checked-in code
+inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, uint8* output_data,
+ const Dims<4>& output_dims) {
+ ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
+ output_data, output_dims, /*align_corners=*/false);
+}
+
+// Legacy Dims<4>.
+template <typename T>
+inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
+ const int32* block_shape_data,
+ const Dims<4>& block_shape_dims,
+ const int32* crops_data, const Dims<4>& crops_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ BatchToSpaceND(DimsToShape(input_dims), input_data,
+ DimsToShape(block_shape_dims), block_shape_data,
+ DimsToShape(crops_dims), crops_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// Legacy signature, function covered both Pad and PadV2.
+template <typename T>
+inline void PadV2(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims, const T pad_value) {
+ TFLITE_DCHECK_EQ(left_paddings.size(), 4);
+ TFLITE_DCHECK_EQ(right_paddings.size(), 4);
+ tflite::PadParams op_params;
+ op_params.left_padding_count = 4;
+ op_params.right_padding_count = 4;
+ for (int i = 0; i < 4; ++i) {
+ op_params.left_padding[i] = left_paddings[3 - i];
+ op_params.right_padding[i] = right_paddings[3 - i];
+ }
+ const T pad_value_copy = pad_value;
+
+ Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy,
+ DimsToShape(output_dims), output_data);
+}
+
+// Old Pad that calls legacy PadV2.
+template <typename T>
+inline void Pad(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims, const int32_t pad_value) {
+ const T converted_pad_value = static_cast<T>(pad_value);
+ PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
+ output_dims, converted_pad_value);
+}
+
+// Old Pad that only padded with 0.
+template <typename T>
+inline void Pad(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims) {
+ const T pad_value = static_cast<T>(0);
+ PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
+ output_dims, pad_value);
+}
+
+template <typename T>
+inline void Slice(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& begin, const std::vector<int>& size,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::SliceParams op_params;
+ op_params.begin_count = 4;
+ op_params.size_count = 4;
+ for (int i = 0; i < 4; ++i) {
+ op_params.begin[i] = begin[3 - i];
+ op_params.size[i] = size[3 - i];
+ }
+
+ Slice(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ Minimum(DimsToShape(input1_dims), input1_data, input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ Maximum(DimsToShape(input1_dims), input1_data, input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
} // namespace optimized_ops
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_LEGACY_OPTIMIZED_OPS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
index 921aae1303..5fb31889fe 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h
@@ -26,7 +26,7 @@ limitations under the License.
#include <tuple>
#include <type_traits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
index 420bc68b43..27418178fd 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include <stdlib.h>
#include <string.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/common.h"
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
@@ -236,6 +236,35 @@ void NeonVectorVectorCwiseProductAccumulate(const float* vector1,
}
}
+void NeonVectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector, int n_batch,
+ float* result) {
+ // If v_size is not divisible by kWeightsPerNeonLane, we cannot use the main
+ // vectorized loop, and we need to process sequentially. postamble_start shows
+ // the start index where this should happen.
+ const int postamble_start =
+ v_size - (v_size & (kFloatWeightsPerNeonLane - 1));
+
+ for (int b = 0; b < n_batch; b++) {
+ for (int v = 0; v < postamble_start; v += kFloatWeightsPerNeonLane) {
+ // Load from memory to vectors.
+ float32x4_t batch_vector_f32x4 = vld1q_f32(batch_vector + v);
+ float32x4_t vector_f32x4 = vld1q_f32(vector + v);
+ // Multiply.
+ float32x4_t result_f32x4 = vmulq_f32(batch_vector_f32x4, vector_f32x4);
+ // Store.
+ vst1q_f32(result + v, result_f32x4);
+ }
+ // Postamble loop
+ for (int v = postamble_start; v < v_size; v++) {
+ result[v] = vector[v] * batch_vector[v];
+ }
+ // Update the pointers.
+ result += v_size;
+ batch_vector += v_size;
+ }
+}
+
void NeonVectorBatchVectorCwiseProductAccumulate(const float* vector,
int v_size,
const float* batch_vector,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
index 63c89d1eee..630a6bbf29 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h
@@ -17,7 +17,7 @@ limitations under the License.
// TODO(ghodrat): Remove this header file and the dependency to internal data
// structure.
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h"
@@ -52,6 +52,13 @@ void VectorVectorCwiseProductAccumulate(const float* vector1,
result);
}
+void VectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector, int n_batch,
+ float* result) {
+ NEON_OR_PORTABLE(VectorBatchVectorCwiseProduct, vector, v_size, batch_vector,
+ n_batch, result);
+}
+
void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size,
const float* batch_vector,
int n_batch, float* result) {
@@ -72,6 +79,11 @@ void BatchVectorBatchVectorDotProduct(const float* vector1,
n_batch, result, result_stride);
}
+void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+ float* batch_vector) {
+ PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector);
+}
+
void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch,
float* batch_vector) {
PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector);
@@ -131,6 +143,13 @@ void ReductionSumVector(const float* input_vector, float* output_vector,
reduction_size);
}
+void MeanStddevNormalization(const float* input_vector, float* output_vector,
+ int v_size, int n_batch,
+ float normalization_epsilon) {
+ PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch,
+ normalization_epsilon);
+}
+
} // namespace tensor_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index 51a9aa5a42..2c8e8f90e3 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/round.h"
#include "tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
namespace tflite {
@@ -42,6 +43,14 @@ namespace optimized_ops {
// Unoptimized reference ops:
using reference_ops::ArgMax;
using reference_ops::ArgMinMax;
+using reference_ops::Broadcast4DSlowGreater;
+using reference_ops::Broadcast4DSlowGreaterEqual;
+using reference_ops::Broadcast4DSlowGreaterEqualWithScaling;
+using reference_ops::Broadcast4DSlowGreaterWithScaling;
+using reference_ops::Broadcast4DSlowLess;
+using reference_ops::Broadcast4DSlowLessEqual;
+using reference_ops::Broadcast4DSlowLessEqualWithScaling;
+using reference_ops::Broadcast4DSlowLessWithScaling;
using reference_ops::BroadcastAdd4DSlow;
using reference_ops::BroadcastGreater;
using reference_ops::BroadcastGreaterEqual;
@@ -57,8 +66,12 @@ using reference_ops::FakeQuant;
using reference_ops::Gather;
using reference_ops::Greater;
using reference_ops::GreaterEqual;
+using reference_ops::GreaterEqualWithScaling;
+using reference_ops::GreaterWithScaling;
using reference_ops::Less;
using reference_ops::LessEqual;
+using reference_ops::LessEqualWithScaling;
+using reference_ops::LessWithScaling;
using reference_ops::Mean;
using reference_ops::RankOneSelect;
using reference_ops::Relu1;
@@ -66,6 +79,7 @@ using reference_ops::Relu6;
using reference_ops::ReluX;
using reference_ops::Select;
using reference_ops::SpaceToBatchND;
+using reference_ops::Split;
using reference_ops::StridedSlice;
using reference_ops::Transpose;
@@ -319,6 +333,7 @@ inline void AddBiasAndEvalActivationFunction(const float* bias_data,
#endif
}
+// Note: This to be converted to RuntimeShapes along with Conv.
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac>
void AddBiasAndEvalActivationFunction(const float* bias_data,
@@ -1934,6 +1949,85 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims,
output_activation_max);
}
+inline void HybridConv(const int8_t* input_data, const Dims<4>& input_dims,
+ const int8_t* filter_data, const Dims<4>& filter_dims,
+ const float* bias_data, const Dims<4>& bias_dims,
+ int stride_width, int stride_height, int pad_width,
+ int pad_height, float* scaling_factors_ptr,
+ float output_activation_min, float output_activation_max,
+ float* output_data, const Dims<4>& output_dims,
+ int8_t* im2col_data, const Dims<4>& im2col_dims) {
+ const int batch_size = input_dims.sizes[3];
+ const int filter_width = ArraySize(filter_dims, 1);
+ const int filter_height = ArraySize(filter_dims, 2);
+
+ const int8_t* gemm_input_data = nullptr;
+ int num_input;
+ const bool need_im2col = stride_width != 1 || stride_height != 1 ||
+ filter_width != 1 || filter_height != 1;
+
+ if (need_im2col) {
+ TFLITE_DCHECK(im2col_data);
+ // symmetric quantization assumes zero point of 0.
+ const int input_zero_point = 0;
+ Im2col(input_data, input_dims, stride_width, stride_height, pad_width,
+ pad_height, filter_height, filter_width, input_zero_point,
+ im2col_data, im2col_dims);
+ gemm_input_data = im2col_data;
+ num_input = im2col_dims.sizes[0] * im2col_dims.sizes[1] *
+ im2col_dims.sizes[2] * im2col_dims.sizes[3];
+ } else {
+ TFLITE_DCHECK(!im2col_data);
+ gemm_input_data = input_data;
+ num_input = input_dims.sizes[0] * input_dims.sizes[1] *
+ input_dims.sizes[2] * input_dims.sizes[3];
+ }
+
+ // Flatten 4D matrices into 2D matrices for matrix multiplication.
+
+ // Flatten so that each filter has its own row.
+ const int filter_rows = filter_dims.sizes[3];
+ const int filter_cols =
+ filter_dims.sizes[0] * filter_dims.sizes[1] * filter_dims.sizes[2];
+
+ // In MatrixBatchVectorMultiplyAccumulate, each output value is the
+ // dot product of one row of the first matrix with one row of the second
+ // matrix. Therefore, the number of cols in each matrix are equivalent.
+ //
+ // After Im2Col, each input patch becomes a row.
+ const int gemm_input_cols = filter_cols;
+ const int gemm_input_rows = num_input / gemm_input_cols;
+
+ const int output_cols = output_dims.sizes[0];
+ const int output_rows =
+ output_dims.sizes[1] * output_dims.sizes[2] * output_dims.sizes[3];
+ TFLITE_DCHECK_EQ(output_cols, filter_rows);
+ TFLITE_DCHECK_EQ(output_rows, gemm_input_rows);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[0], output_cols);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[1], 1);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[2], 1);
+ TFLITE_DCHECK_EQ(bias_dims.sizes[3], 1);
+
+ // MatrixBatchVectorMultiplyAccumulate assumes that each row of the second
+ // input matrix has its own scale factor. This code duplicates the scale
+ // factors for each row in the same batch.
+ const int rows_per_batch = gemm_input_rows / batch_size;
+ for (int i = gemm_input_rows - 1; i >= 0; --i) {
+ scaling_factors_ptr[i] = scaling_factors_ptr[i / rows_per_batch];
+ }
+
+ tensor_utils::ZeroVector(output_data, output_rows * output_cols);
+
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ filter_data, filter_rows, filter_cols, gemm_input_data,
+ scaling_factors_ptr, /*n_batch=*/gemm_input_rows, output_data,
+ /*result_stride=*/1);
+
+ AddBiasAndEvalActivationFunction(bias_data, bias_dims, output_data,
+ output_dims, output_activation_min,
+ output_activation_max);
+}
+
template <FusedActivationFunctionType Ac>
void Conv(const float* input_data, const Dims<4>& input_dims,
const float* filter_data, const Dims<4>& filter_dims,
@@ -2142,38 +2236,6 @@ void Conv(const uint8* input_data, const Dims<4>& input_dims,
im2col_data, im2col_dims, gemm_context);
}
-template <typename T>
-inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
- int block_size, T* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("DepthToSpace");
-
- const int input_depth = ArraySize(input_dims, 0);
- const int input_width = ArraySize(input_dims, 1);
- const int input_height = ArraySize(input_dims, 2);
-
- const int output_depth = ArraySize(output_dims, 0);
- const int batch_size = ArraySize(output_dims, 3);
-
- // Number of continuous values that we can copy in one interation.
- const int stride = block_size * output_depth;
-
- for (int batch = 0; batch < batch_size; ++batch) {
- for (int in_h = 0; in_h < input_height; ++in_h) {
- const T* input_ptr = input_data + Offset(input_dims, 0, 0, in_h, batch);
- for (int offset_h = 0; offset_h < block_size; ++offset_h) {
- const T* src = input_ptr;
- for (int in_w = 0; in_w < input_width; ++in_w) {
- memcpy(output_data, src, stride * sizeof(T));
- output_data += stride;
- src += input_depth;
- }
- input_ptr += stride;
- }
- }
- }
-}
-
// legacy, for compatibility with old checked-in code
template <FusedActivationFunctionType Ac, typename T>
void Im2col(const T* input_data, const Dims<4>& input_dims, int stride,
@@ -2249,25 +2311,75 @@ void ConvAsGemm(const uint8* input_data, const Dims<4>& input_dims,
}
template <typename T>
-inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
- int block_size, T* output_data,
- const Dims<4>& output_dims) {
+inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ gemmlowp::ScopedProfilingLabel label("DepthToSpace");
+
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int input_depth = input_shape.Dims(3);
+ const int input_width = input_shape.Dims(2);
+ const int input_height = input_shape.Dims(1);
+
+ const int output_depth = output_shape.Dims(3);
+ const int batch_size = output_shape.Dims(0);
+
+ // Number of continuous values that we can copy in one interation.
+ const int stride = op_params.block_size * output_depth;
+
+ for (int batch = 0; batch < batch_size; ++batch) {
+ for (int in_h = 0; in_h < input_height; ++in_h) {
+ const T* input_ptr = input_data + Offset(input_shape, batch, in_h, 0, 0);
+ for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) {
+ const T* src = input_ptr;
+ for (int in_w = 0; in_w < input_width; ++in_w) {
+ memcpy(output_data, src, stride * sizeof(T));
+ output_data += stride;
+ src += input_depth;
+ }
+ input_ptr += stride;
+ }
+ }
+ }
+}
+
+template <typename T>
+inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
gemmlowp::ScopedProfilingLabel label("SpaceToDepth");
- const int output_depth = ArraySize(output_dims, 0);
- const int output_width = ArraySize(output_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
- const int input_depth = ArraySize(input_dims, 0);
- const int batch_size = ArraySize(input_dims, 3);
+ const int output_depth = output_shape.Dims(3);
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+
+ const int input_depth = input_shape.Dims(3);
+ const int batch_size = input_shape.Dims(0);
// Number of continuous values that we can copy in one interation.
- const int stride = block_size * input_depth;
+ const int stride = op_params.block_size * input_depth;
for (int batch = 0; batch < batch_size; ++batch) {
for (int out_h = 0; out_h < output_height; ++out_h) {
- T* output_ptr = output_data + Offset(output_dims, 0, 0, out_h, batch);
- for (int offset_h = 0; offset_h < block_size; ++offset_h) {
+ T* output_ptr = output_data + Offset(output_shape, batch, out_h, 0, 0);
+ for (int offset_h = 0; offset_h < op_params.block_size; ++offset_h) {
T* dst = output_ptr;
for (int out_w = 0; out_w < output_width; ++out_w) {
memcpy(dst, input_data, stride * sizeof(T));
@@ -2280,53 +2392,6 @@ inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
}
}
-template <FusedActivationFunctionType Ac>
-void NonGlobalBatchNormalization(
- const float* input_data, const Dims<4>& input_dims, const float* mean_data,
- const Dims<4>& mean_dims, const float* multiplier_data,
- const Dims<4>& multiplier_dims, const float* offset_data,
- const Dims<4>& offset_dims, float* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("NonGlobalBatchNormalization");
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int inner_size = MatchingFlatSizeSkipDim(
- input_dims, 3, mean_dims, multiplier_dims, offset_dims, output_dims);
-
- for (int b = 0; b < batches; ++b) {
- for (int i = 0; i < inner_size; ++i) {
- *output_data = ActivationFunction<Ac>(
- (*input_data - mean_data[i]) * multiplier_data[i] + offset_data[i]);
- ++output_data;
- ++input_data;
- }
- }
-}
-
-template <FusedActivationFunctionType Ac>
-void GlobalBatchNormalization(const float* input_data,
- const Dims<4>& input_dims, const float* mean_data,
- const Dims<4>& mean_dims,
- const float* multiplier_data,
- const Dims<4>& multiplier_dims,
- const float* offset_data,
- const Dims<4>& offset_dims, float* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("GlobalBatchNormalization");
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth =
- MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0,
- offset_dims, 0, output_dims, 0);
-
- for (int i = 0; i < outer_size; ++i) {
- for (int c = 0; c < depth; ++c) {
- *output_data = ActivationFunction<Ac>(
- (*input_data - mean_data[c]) * multiplier_data[c] + offset_data[c]);
- ++output_data;
- ++input_data;
- }
- }
-}
-
inline void Relu(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Relu (not fused)");
@@ -2336,11 +2401,12 @@ inline void Relu(const RuntimeShape& input_shape, const float* input_data,
output = input.cwiseMax(0.0f);
}
-template <FusedActivationFunctionType Ac>
-void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
+inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
+ const RuntimeShape& input_shape,
+ const float* input_data,
+ const RuntimeShape& output_shape,
+ float* output_data) {
gemmlowp::ScopedProfilingLabel label("L2Normalization");
- static_assert(Ac == FusedActivationFunctionType::kNone, "");
const int trailing_dim = input_shape.DimensionsCount() - 1;
const int outer_size =
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
@@ -2409,16 +2475,18 @@ inline void GetInvSqrtQuantizedMultiplierExp(int32 input,
*output_shift *= kReverseShift;
}
-inline void L2Normalization(const uint8* input_data,
+inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
const RuntimeShape& input_shape,
- int32 input_zero_point, uint8* output_data,
- const RuntimeShape& output_shape) {
+ const uint8* input_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("L2Normalization/8bit");
const int trailing_dim = input_shape.DimensionsCount() - 1;
const int depth =
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
const int outer_size =
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int32 input_zero_point = op_params.input_zero_point;
for (int i = 0; i < outer_size; ++i) {
int32 square_l2_norm = 0;
for (int c = 0; c < depth; c++) {
@@ -2725,17 +2793,16 @@ inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params,
}
}
-inline void Mul(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float output_activation_min, float output_activation_max,
- float* output_data, const Dims<4>& output_dims) {
+inline void Mul(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const float* input1_data,
+ const RuntimeShape& input2_shape, const float* input2_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("Mul");
- TFLITE_DCHECK(IsPackedWithoutStrides(input1_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(input2_dims));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ const float output_activation_min = params.float_activation_min;
+ const float output_activation_max = params.float_activation_max;
int i = 0;
- const int size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+ const int size = MatchingFlatSize(input1_shape, input2_shape, output_shape);
#ifdef USE_NEON
const auto activation_min = vdupq_n_f32(output_activation_min);
const auto activation_max = vdupq_n_f32(output_activation_max);
@@ -2786,25 +2853,16 @@ inline void Mul(const float* input1_data, const Dims<4>& input1_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Mul(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float* output_data, const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
-inline void Mul(const int32* input1_data, const Dims<4>& input1_dims,
- const int32* input2_data, const Dims<4>& input2_dims,
- int32 output_activation_min, int32 output_activation_max,
- int32* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("Mul/int32");
+inline void Mul(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const int32* input1_data,
+ const RuntimeShape& input2_shape, const int32* input2_data,
+ const RuntimeShape& output_shape, int32* output_data) {
+ gemmlowp::ScopedProfilingLabel label("Mul/int32/activation");
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] * input2_data[i], output_activation_min,
@@ -2812,22 +2870,24 @@ inline void Mul(const int32* input1_data, const Dims<4>& input1_dims,
}
}
-template <FusedActivationFunctionType Ac>
-void Mul(const int32* input1_data, const Dims<4>& input1_dims,
- const int32* input2_data, const Dims<4>& input2_dims,
- int32* output_data, const Dims<4>& output_dims) {
+inline void MulNoActivation(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape,
+ const int32* input1_data,
+ const RuntimeShape& input2_shape,
+ const int32* input2_data,
+ const RuntimeShape& output_shape,
+ int32* output_data) {
gemmlowp::ScopedProfilingLabel label("Mul/int32");
- TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
- auto input1_map = MapAsVector(input1_data, input1_dims);
- auto input2_map = MapAsVector(input2_data, input2_dims);
- auto output_map = MapAsVector(output_data, output_dims);
- if (AreSameDims(input1_dims, input2_dims)) {
+ auto input1_map = MapAsVector(input1_data, input1_shape);
+ auto input2_map = MapAsVector(input2_data, input2_shape);
+ auto output_map = MapAsVector(output_data, output_shape);
+ if (input1_shape == input2_shape) {
output_map.array() = input1_map.array() * input2_map.array();
- } else if (FlatSize(input2_dims) == 1) {
+ } else if (input2_shape.FlatSize() == 1) {
auto scalar = input2_data[0];
output_map.array() = input1_map.array() * scalar;
- } else if (FlatSize(input1_dims) == 1) {
+ } else if (input1_shape.FlatSize() == 1) {
auto scalar = input1_data[0];
output_map.array() = scalar * input2_map.array();
} else {
@@ -2836,14 +2896,16 @@ void Mul(const int32* input1_data, const Dims<4>& input1_dims,
}
}
-inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
- const int16* input2_data, const Dims<4>& input2_dims,
- int16* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("Mul/Int16");
+inline void Mul(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const int16* input1_data,
+ const RuntimeShape& input2_shape, const int16* input2_data,
+ const RuntimeShape& output_shape, int16* output_data) {
+ gemmlowp::ScopedProfilingLabel label("Mul/Int16/NoActivation");
// This is a copy of the reference implementation. We do not currently have a
// properly optimized version.
- const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
// F0 uses 0 integer bits, range [-1, 1].
@@ -2855,17 +2917,20 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
}
}
-inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
- const int16* input2_data, const Dims<4>& input2_dims,
- int32 output_offset, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
+inline void Mul(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const int16* input1_data,
+ const RuntimeShape& input2_shape, const int16* input2_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("Mul/Int16Uint8");
// This is a copy of the reference implementation. We do not currently have a
// properly optimized version.
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+ const int32 output_offset = params.output_offset;
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
// F0 uses 0 integer bits, range [-1, 1].
@@ -2883,64 +2948,6 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
}
}
-// TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary
-// dimensionality if the runtime code does a single loop over one dimension
-// that handles broadcasting as the base case. The code generator would then
-// generate max(D1, D2) nested for loops.
-// TODO(benoitjacob): BroadcastMul is intentionally duplicated from
-// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
-// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
-// reference_ops.h.
-template <typename T>
-void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastMul");
-
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
-
- // In Tensorflow, the dimensions are canonically named (batch_number, row,
- // col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest stride,
- // typically 1 element).
- //
- // In generated C code, we store arrays with the dimensions reversed. The
- // first dimension has smallest stride.
- //
- // We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for the
- // best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- ActivationFunctionWithMinMax(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] *
- input2_data[SubscriptToIndex(desc2, c, x, y, b)],
- output_activation_min, output_activation_max);
- }
- }
- }
- }
-}
-
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac, typename T>
-void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims) {
- T output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- BroadcastMul(input1_data, input1_dims, input2_data, input2_dims,
- output_activation_min, output_activation_max, output_data,
- output_dims);
-}
-
// Element-wise mul that can often be used for inner loop of broadcast Mul as
// well as the non-broadcast Mul.
inline void MulElementwise(int size, const ArithmeticParams& params,
@@ -3169,15 +3176,28 @@ inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
// reference_ops.h.
template <typename T>
-void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastDiv");
+void BroadcastDiv4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& unextended_input1_shape,
+ const T* input1_data,
+ const RuntimeShape& unextended_input2_shape,
+ const T* input2_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastDiv4DSlow");
+ T output_activation_min;
+ T output_activation_max;
+ GetActivationParams(params, &output_activation_min, &output_activation_max);
+
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
@@ -3190,14 +3210,14 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
// We name our variables by their Tensorflow convention, but generate C code
// nesting loops such that the innermost loop has the smallest stride for the
// best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
+ output_data[Offset(output_shape, b, y, x, c)] =
ActivationFunctionWithMinMax(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] /
- input2_data[SubscriptToIndex(desc2, c, x, y, b)],
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] /
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)],
output_activation_min, output_activation_max);
}
}
@@ -3205,6 +3225,21 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+template <typename T>
+void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ BroadcastDiv4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
// TODO(aselle): This is not actually optimized yet.
inline void SubNonBroadcast(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
@@ -4034,29 +4069,28 @@ inline void L2Pool(const PoolParams& params, const RuntimeShape& input_shape,
}
}
-inline void LocalResponseNormalization(const float* input_data,
- const Dims<4>& input_dims, int range,
- float bias, float alpha, float beta,
- float* output_data,
- const Dims<4>& output_dims) {
+inline void LocalResponseNormalization(
+ const tflite::LocalResponseNormalizationParams& op_params,
+ const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
gemmlowp::ScopedProfilingLabel label("LocalResponseNormalization");
- MatchingFlatSize(input_dims, output_dims);
+ MatchingFlatSize(input_shape, output_shape);
- const auto data_in = MapAsMatrixWithFirstDimAsRows(input_data, input_dims);
- auto data_out = MapAsMatrixWithFirstDimAsRows(output_data, output_dims);
+ const auto data_in = MapAsMatrixWithLastDimAsRows(input_data, input_shape);
+ auto data_out = MapAsMatrixWithLastDimAsRows(output_data, output_shape);
// Carry out local response normalization, vector by vector.
// Since the data are stored column major, making row-wise operation
// probably not memory efficient anyway, we do an explicit for loop over
// the columns.
- const int double_range = range * 2;
+ const int double_range = op_params.range * 2;
Eigen::VectorXf padded_square(data_in.rows() + double_range);
padded_square.setZero();
for (int r = 0; r < data_in.cols(); ++r) {
// Do local response normalization for data_in(:, r)
// first, compute the square and store them in buffer for repeated use
- padded_square.block(range, 0, data_in.rows(), 1) =
- data_in.col(r).cwiseProduct(data_in.col(r)) * alpha;
+ padded_square.block(op_params.range, 0, data_in.rows(), 1) =
+ data_in.col(r).cwiseProduct(data_in.col(r)) * op_params.alpha;
// Then, compute the scale and writes them to data_out
float accumulated_scale = 0;
for (int i = 0; i < double_range; ++i) {
@@ -4064,18 +4098,18 @@ inline void LocalResponseNormalization(const float* input_data,
}
for (int i = 0; i < data_in.rows(); ++i) {
accumulated_scale += padded_square(i + double_range);
- data_out(i, r) = bias + accumulated_scale;
+ data_out(i, r) = op_params.bias + accumulated_scale;
accumulated_scale -= padded_square(i);
}
}
// In a few cases, the pow computation could benefit from speedups.
- if (beta == 1) {
+ if (op_params.beta == 1) {
data_out.array() = data_in.array() * data_out.array().inverse();
- } else if (beta == 0.5) {
+ } else if (op_params.beta == 0.5) {
data_out.array() = data_in.array() * data_out.array().sqrt().inverse();
} else {
- data_out.array() = data_in.array() * data_out.array().pow(-beta);
+ data_out.array() = data_in.array() * data_out.array().pow(-op_params.beta);
}
}
@@ -5012,11 +5046,11 @@ inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
}
template <typename SrcT, typename DstT>
-inline void Cast(const SrcT* input_data, const Dims<4>& input_dims,
- DstT* output_data, const Dims<4>& output_dims) {
+inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
+ const RuntimeShape& output_shape, DstT* output_data) {
gemmlowp::ScopedProfilingLabel label("Cast");
- auto input_map = MapAsVector(input_data, input_dims);
- auto output_map = MapAsVector(output_data, output_dims);
+ auto input_map = MapAsVector(input_data, input_shape);
+ auto output_map = MapAsVector(output_data, output_shape);
output_map.array() = input_map.array().template cast<DstT>();
}
@@ -5028,13 +5062,6 @@ inline void Floor(const RuntimeShape& input_shape, const float* input_data,
output_map.array() = Eigen::floor(input_map.array());
}
-// Legacy Dims<4> version.
-inline void Floor(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- Floor(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
- output_data);
-}
-
#ifdef USE_NEON
inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
float scale, float* output_ptr) {
@@ -5134,12 +5161,14 @@ inline void ResizeBilinearKernel(const float* input_ptr, int32 depth,
inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
int32 x, int32 y, int32 depth, int32 batch,
+ const RuntimeShape& input_shape,
const float* input_data,
- const Dims<4>& input_dims,
- float* output_data,
- const Dims<4>& output_dims) {
- const int32 input_width = ArraySize(input_dims, 1);
- const int32 output_width = ArraySize(output_dims, 1);
+ const RuntimeShape& output_shape,
+ float* output_data) {
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+ const int32 input_width = input_shape.Dims(2);
+ const int32 output_width = output_shape.Dims(2);
const int32 input_x_offset = (x1 - x0) * depth;
const int32 input_y_offset = (y1 - y0) * depth * input_width;
@@ -5147,7 +5176,6 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
const int32 output_y_offset = depth * output_width;
#ifdef USE_NEON
- TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
TFLITE_DCHECK(x1 >= x0);
TFLITE_DCHECK(y1 >= y0);
@@ -5157,7 +5185,7 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
const float* input_ptr = nullptr;
float32x4x2_t x0y0;
- input_ptr = &input_data[Offset(input_dims, ic, x0, y0, batch)];
+ input_ptr = &input_data[Offset(input_shape, batch, y0, x0, ic)];
x0y0.val[0] = vld1q_f32(input_ptr);
x0y0.val[1] = vld1q_f32(input_ptr + 4);
@@ -5177,7 +5205,7 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
x1y1.val[1] = vld1q_f32(input_ptr + 4);
// Top left corner.
- float* output_ptr = &output_data[Offset(output_dims, ic, x, y, batch)];
+ float* output_ptr = &output_data[Offset(output_shape, batch, y, x, ic)];
vst1q_f32(output_ptr, x0y0.val[0]);
vst1q_f32(output_ptr + 4, x0y0.val[1]);
@@ -5216,14 +5244,15 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
}
// Handle 4 input channels at a time.
for (; ic <= depth - 4; ic += 4) {
- const float* input_ptr = &input_data[Offset(input_dims, ic, x0, y0, batch)];
+ const float* input_ptr =
+ &input_data[Offset(input_shape, batch, y0, x0, ic)];
float32x4_t x0y0 = vld1q_f32(input_ptr);
float32x4_t x1y0 = vld1q_f32(input_ptr + input_x_offset);
float32x4_t x0y1 = vld1q_f32(input_ptr + input_y_offset);
float32x4_t x1y1 = vld1q_f32(input_ptr + input_x_offset + input_y_offset);
// Top left corner.
- float* output_ptr = &output_data[Offset(output_dims, ic, x, y, batch)];
+ float* output_ptr = &output_data[Offset(output_shape, batch, y, x, ic)];
vst1q_f32(output_ptr, x0y0);
// Top right corner.
@@ -5247,7 +5276,7 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
}
// Handle one input channel at a time.
for (; ic < depth; ic++) {
- const int32 input_offset = Offset(input_dims, ic, x0, y0, batch);
+ const int32 input_offset = Offset(input_shape, batch, y0, x0, ic);
float x0y0 = input_data[input_offset];
float x1y0 = input_data[input_offset + input_x_offset];
@@ -5255,7 +5284,7 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
float x1y1 = input_data[input_offset + input_x_offset + input_y_offset];
// Top left corner.
- const int32 output_offset = Offset(output_dims, ic, x, y, batch);
+ const int32 output_offset = Offset(output_shape, batch, y, x, ic);
output_data[output_offset] = x0y0;
// Top right corner.
@@ -5271,7 +5300,7 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
}
#else
for (int ch = 0; ch < depth; ch++) {
- const int32 input_offset = Offset(input_dims, ch, x0, y0, batch);
+ const int32 input_offset = Offset(input_shape, batch, y0, x0, ch);
float x0y0 = input_data[input_offset];
float x1y0 = input_data[input_offset + input_x_offset];
@@ -5279,7 +5308,7 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
float x1y1 = input_data[input_offset + input_x_offset + input_y_offset];
// Top left corner.
- const int32 output_offset = Offset(output_dims, ch, x, y, batch);
+ const int32 output_offset = Offset(output_shape, batch, y, x, ch);
output_data[output_offset] = x0y0;
// Top right corner.
@@ -5296,31 +5325,30 @@ inline void ResizeBilinearKernel2x2(int32 x0, int32 x1, int32 y0, int32 y1,
#endif
}
-inline void ResizeBilinear2x2(const float* input_data,
- const Dims<4>& input_dims, float* output_data,
- const Dims<4>& output_dims, int32 batches,
- int32 input_height, int32 input_width,
- int32 depth, int32 output_height,
- int32 output_width) {
+inline void ResizeBilinear2x2(int32 batches, int32 input_height,
+ int32 input_width, int32 depth,
+ int32 output_height, int32 output_width,
+ const RuntimeShape& input_shape,
+ const float* input_data,
+ const RuntimeShape& output_shape,
+ float* output_data) {
for (int b = 0; b < batches; b++) {
for (int y0 = 0, y = 0; y <= output_height - 2; y += 2, y0++) {
for (int x0 = 0, x = 0; x <= output_width - 2; x += 2, x0++) {
int32 x1 = std::min(x0 + 1, input_width - 1);
int32 y1 = std::min(y0 + 1, input_height - 1);
- ResizeBilinearKernel2x2(x0, x1, y0, y1, x, y, depth, b, input_data,
- input_dims, output_data, output_dims);
+ ResizeBilinearKernel2x2(x0, x1, y0, y1, x, y, depth, b, input_shape,
+ input_data, output_shape, output_data);
}
}
}
}
-inline void ResizeBilinearGeneric(const float* input_data,
- const Dims<4>& input_dims, float* output_data,
- const Dims<4>& output_dims, int32 batches,
- int32 input_height, int32 input_width,
- int32 depth, int32 output_height,
- int32 output_width, float height_scale,
- float width_scale) {
+inline void ResizeBilinearGeneric(
+ int32 batches, int32 input_height, int32 input_width, int32 depth,
+ int32 output_height, int32 output_width, float height_scale,
+ float width_scale, const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
memset(output_data, 0,
batches * output_height * output_width * depth * sizeof(float));
@@ -5337,22 +5365,22 @@ inline void ResizeBilinearGeneric(const float* input_data,
float* output_ptr = &output_data[output_offset];
// Run kernel on the 4 corners of the bilinear resize algorithm.
- int32 input_offset = Offset(input_dims, 0, x0, y0, b);
+ int32 input_offset = Offset(input_shape, b, y0, x0, 0);
float scale = (1 - (input_y - y0)) * (1 - (input_x - x0));
const float* input_ptr = &input_data[input_offset];
ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
- input_offset = Offset(input_dims, 0, x1, y0, b);
+ input_offset = Offset(input_shape, b, y0, x1, 0);
scale = (1 - (input_y - y0)) * (input_x - x0);
input_ptr = &input_data[input_offset];
ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
- input_offset = Offset(input_dims, 0, x0, y1, b);
+ input_offset = Offset(input_shape, b, y1, x0, 0);
scale = (input_y - y0) * (1 - (input_x - x0));
input_ptr = &input_data[input_offset];
ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
- input_offset = Offset(input_dims, 0, x1, y1, b);
+ input_offset = Offset(input_shape, b, y1, x1, 0);
scale = (input_y - y0) * (input_x - x0);
input_ptr = &input_data[input_offset];
ResizeBilinearKernel(input_ptr, depth, scale, output_ptr);
@@ -5365,10 +5393,10 @@ inline void ResizeBilinearGeneric(const float* input_data,
template <typename T>
inline void ResizeBilinearGenericSmallChannel(
- const T* input_data, const Dims<4>& input_dims, T* output_data,
- const Dims<4>& output_dims, int32 batches, int32 input_height,
- int32 input_width, int32 depth, int32 output_height, int32 output_width,
- float height_scale, float width_scale) {
+ int32 batches, int32 input_height, int32 input_width, int32 depth,
+ int32 output_height, int32 output_width, float height_scale,
+ float width_scale, const RuntimeShape& input_shape, const T* input_data,
+ const RuntimeShape& output_shape, T* output_data) {
memset(output_data, 0,
batches * output_height * output_width * depth * sizeof(T));
@@ -5383,9 +5411,10 @@ inline void ResizeBilinearGenericSmallChannel(
int32 x0 = static_cast<int32>(input_x);
int32 x1 = std::min(x0 + 1, input_width - 1);
- int32 input_offset[4] = {
- Offset(input_dims, 0, x0, y0, b), Offset(input_dims, 0, x1, y0, b),
- Offset(input_dims, 0, x0, y1, b), Offset(input_dims, 0, x1, y1, b)};
+ int32 input_offset[4] = {Offset(input_shape, b, y0, x0, 0),
+ Offset(input_shape, b, y0, x1, 0),
+ Offset(input_shape, b, y1, x0, 0),
+ Offset(input_shape, b, y1, x1, 0)};
float scale[4] = {(1 - (input_y - y0)) * (1 - (input_x - x0)),
(1 - (input_y - y0)) * (input_x - x0),
(input_y - y0) * (1 - (input_x - x0)),
@@ -5403,97 +5432,93 @@ inline void ResizeBilinearGenericSmallChannel(
}
}
-inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
+inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const float* input_data,
+ const RuntimeShape& output_size_shape,
const int32* output_size_data,
- const Dims<4>& output_size_dims, float* output_data,
- const Dims<4>& output_dims, bool align_corners) {
+ const RuntimeShape& unextended_output_shape,
+ float* output_data) {
gemmlowp::ScopedProfilingLabel label("ResizeBilinear");
- int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- int32 input_height = ArraySize(input_dims, 2);
- int32 input_width = ArraySize(input_dims, 1);
- int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0);
-
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1);
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1);
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1);
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2);
- int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)];
- int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)];
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
+ int32 input_height = input_shape.Dims(1);
+ int32 input_width = input_shape.Dims(2);
+ int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
+
+ TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
+ int32 output_height = output_size_data[0];
+ int32 output_width = output_size_data[1];
// Specialize for 2x2 upsample.
- if (!align_corners && output_height == 2 * input_height &&
+ if (!op_params.align_corners && output_height == 2 * input_height &&
output_width == 2 * input_width) {
- ResizeBilinear2x2(input_data, input_dims, output_data, output_dims, batches,
- input_height, input_width, depth, output_height,
- output_width);
+ ResizeBilinear2x2(batches, input_height, input_width, depth, output_height,
+ output_width, input_shape, input_data, output_shape,
+ output_data);
} else {
float height_scale = static_cast<float>(input_height) / output_height;
float width_scale = static_cast<float>(input_width) / output_width;
- if (align_corners && output_height > 1) {
+ if (op_params.align_corners && output_height > 1) {
height_scale = static_cast<float>(input_height - 1) / (output_height - 1);
}
- if (align_corners && output_width > 1) {
+ if (op_params.align_corners && output_width > 1) {
width_scale = static_cast<float>(input_width - 1) / (output_width - 1);
}
- ResizeBilinearGeneric(input_data, input_dims, output_data, output_dims,
- batches, input_height, input_width, depth,
+ ResizeBilinearGeneric(batches, input_height, input_width, depth,
output_height, output_width, height_scale,
- width_scale);
+ width_scale, input_shape, input_data, output_shape,
+ output_data);
}
}
// TODO(prabhumk): This is not a real quantized bilinear. It does not use int8
// or int16 arithmetic.
-inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
+inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const uint8* input_data,
+ const RuntimeShape& output_size_shape,
const int32* output_size_data,
- const Dims<4>& output_size_dims, uint8* output_data,
- const Dims<4>& output_dims, bool align_corners) {
+ const RuntimeShape& unextended_output_shape,
+ uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("ResizeBilinear");
- int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- int32 input_height = ArraySize(input_dims, 2);
- int32 input_width = ArraySize(input_dims, 1);
- int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0);
-
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1);
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1);
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1);
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2);
- int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)];
- int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)];
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
+ int32 input_height = input_shape.Dims(1);
+ int32 input_width = input_shape.Dims(2);
+ int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
+
+ TFLITE_DCHECK_EQ(output_size_shape.FlatSize(), 2);
+ int32 output_height = output_size_data[0];
+ int32 output_width = output_size_data[1];
float height_scale =
- (align_corners && output_height > 1)
+ (op_params.align_corners && output_height > 1)
? (static_cast<float>(input_height - 1) / (output_height - 1))
: (static_cast<float>(input_height) / output_height);
float width_scale =
- (align_corners && output_width > 1)
+ (op_params.align_corners && output_width > 1)
? (static_cast<float>(input_width - 1) / (output_width - 1))
: (static_cast<float>(input_width) / output_width);
ResizeBilinearGenericSmallChannel<uint8>(
- input_data, input_dims, output_data, output_dims, batches, input_height,
- input_width, depth, output_height, output_width, height_scale,
- width_scale);
-}
-
-// legacy, for compatibility with old checked-in code
-inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
- const int32* output_size_data,
- const Dims<4>& output_size_dims, float* output_data,
- const Dims<4>& output_dims) {
- ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
- output_data, output_dims, /*align_corners=*/false);
-}
-
-// legacy, for compatibility with old checked-in code
-inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
- const int32* output_size_data,
- const Dims<4>& output_size_dims, uint8* output_data,
- const Dims<4>& output_dims) {
- ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
- output_data, output_dims, /*align_corners=*/false);
+ batches, input_height, input_width, depth, output_height, output_width,
+ height_scale, width_scale, input_shape, input_data, output_shape,
+ output_data);
}
// Helper methods for BatchToSpaceND.
@@ -5518,20 +5543,29 @@ inline void GetIndexRange(int spatial_index_dim, int block_shape_dim,
}
template <typename T>
-inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
- const int32* block_shape_data,
- const Dims<4>& block_shape_dims,
- const int32* crops_data, const Dims<4>& crops_dims,
- T* output_data, const Dims<4>& output_dims) {
+inline void BatchToSpaceND(
+ const RuntimeShape& unextended_input1_shape, const T* input1_data,
+ const RuntimeShape& unextended_input2_shape, const int32* block_shape_data,
+ const RuntimeShape& unextended_input3_shape, const int32* crops_data,
+ const RuntimeShape& unextended_output_shape, T* output_data) {
gemmlowp::ScopedProfilingLabel label("BatchToSpaceND");
- const int output_batch_size = ArraySize(output_dims, 3);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int input_batch_size = ArraySize(input_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int depth = ArraySize(input_dims, 0);
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input1_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input1_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_batch_size = output_shape.Dims(0);
+
+ const int depth = input1_shape.Dims(3);
+ const int input_width = input1_shape.Dims(2);
+ const int input_height = input1_shape.Dims(1);
+ const int input_batch_size = input1_shape.Dims(0);
+
const int block_shape_width = block_shape_data[1];
const int block_shape_height = block_shape_data[0];
const int crops_top = crops_data[0];
@@ -5566,8 +5600,9 @@ inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
spatial_offset % block_shape_width - crops_left;
TFLITE_DCHECK_GE(out_w, 0);
TFLITE_DCHECK_LT(out_w, output_width);
- T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_batch);
- const T* in = input_data + Offset(input_dims, 0, in_w, in_h, in_batch);
+ T* out = output_data + Offset(output_shape, out_batch, out_h, out_w, 0);
+ const T* in =
+ input1_data + Offset(input1_shape, in_batch, in_h, in_w, 0);
memcpy(out, in, depth * sizeof(T));
}
}
@@ -5731,49 +5766,6 @@ inline void Pad(const tflite::PadParams& op_params,
output_data);
}
-// Legacy signature, function covered both Pad and PadV2.
-template <typename T>
-inline void PadV2(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims, const T pad_value) {
- TFLITE_DCHECK_EQ(left_paddings.size(), 4);
- TFLITE_DCHECK_EQ(right_paddings.size(), 4);
- tflite::PadParams op_params;
- op_params.left_padding_count = 4;
- op_params.right_padding_count = 4;
- for (int i = 0; i < 4; ++i) {
- op_params.left_padding[i] = left_paddings[3 - i];
- op_params.right_padding[i] = right_paddings[3 - i];
- }
- const T pad_value_copy = pad_value;
-
- Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy,
- DimsToShape(output_dims), output_data);
-}
-
-// Old Pad that calls legacy PadV2.
-template <typename T>
-inline void Pad(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims, const int32_t pad_value) {
- const T converted_pad_value = static_cast<T>(pad_value);
- PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
- output_dims, converted_pad_value);
-}
-
-// Old Pad that only padded with 0.
-template <typename T>
-inline void Pad(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims) {
- const T pad_value = static_cast<T>(0);
- PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
- output_dims, pad_value);
-}
-
template <typename T>
inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
@@ -5818,22 +5810,6 @@ inline void Slice(const tflite::SliceParams& op_params,
}
template <typename T>
-inline void Slice(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& begin, const std::vector<int>& size,
- T* output_data, const Dims<4>& output_dims) {
- tflite::SliceParams op_params;
- op_params.begin_count = 4;
- op_params.size_count = 4;
- for (int i = 0; i < 4; ++i) {
- op_params.begin[i] = begin[3 - i];
- op_params.size[i] = size[3 - i];
- }
-
- Slice(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(output_dims), output_data);
-}
-
-template <typename T>
void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
const T* input2_data, const RuntimeShape& output_shape,
T* output_data) {
@@ -5856,22 +5832,6 @@ void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
}
template <typename T>
-void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, T* output_data,
- const Dims<4>& output_dims) {
- Minimum(DimsToShape(input1_dims), input1_data, input2_data,
- DimsToShape(output_dims), output_data);
-}
-
-template <typename T>
-void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, T* output_data,
- const Dims<4>& output_dims) {
- Maximum(DimsToShape(input1_dims), input1_data, input2_data,
- DimsToShape(output_dims), output_data);
-}
-
-template <typename T>
void TransposeIm2col(const T* input_data, const Dims<4>& input_dims,
const Dims<4>& filter_dims, int stride_width,
int stride_height, int pad_width, int pad_height,
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
index 010b40b901..f87760a6c3 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
@@ -17,7 +17,7 @@ limitations under the License.
// TODO(ghodrat): Remove this header file and the dependency to internal data
// structure.
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#if defined(_MSC_VER)
#define __restrict__ __restrict
@@ -86,6 +86,14 @@ void NeonBatchVectorBatchVectorDotProduct(const float* vector1,
int n_batch, float* result,
int result_stride);
+// Cwise product of a vector and a batch-vector.
+void PortableVectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector,
+ int n_batch, float* result);
+void NeonVectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector, int n_batch,
+ float* result);
+
// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC
// operation, the assumption here is that result array is initialized to valid
// values.
@@ -109,6 +117,10 @@ void PortableClipVector(const float* vector, int v_size, float abs_limit,
void NeonClipVector(const float* vector, int v_size, float abs_limit,
float* result);
+// Add another vector for each batch in the batch vector.
+void PortableVectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+ float* batch_vector);
+
// Batch vector initialization with another vector.
void PortableVectorBatchVectorAssign(const float* vector, int v_size,
int n_batch, float* batch_vector);
@@ -164,6 +176,10 @@ void PortableReductionSumVector(const float* input_vector, float* output_vector,
void NeonReductionSumVector(const float* input_vector, float* output_vector,
int output_size, int reduction_size);
+void PortableMeanStddevNormalization(const float* input_vector,
+ float* output_vector, int v_size,
+ int n_batch, float normalization_epsilon);
+
} // namespace tensor_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc
index f882f9910e..544ef16ce1 100644
--- a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc
@@ -23,6 +23,32 @@ limitations under the License.
namespace tflite {
+namespace {
+// These constants are used to manipulate the binary representation of doubles.
+// Double-precision binary64 floating point format is:
+// Bit | 63 | 62-52 | 51-0 |
+// | Sign | Exponent | Fraction |
+// To avoid 64-bit integers as much as possible, I break this into high and
+// low 32-bit chunks. High is:
+// Bit | 31 | 30-20 | 19-0 |
+// | Sign | Exponent | High Fraction |
+// Low is:
+// Bit | 31-0 |
+// | Low Fraction |
+// We then access the components through logical bit-wise operations to
+// extract the parts needed, with the positions and masks derived from the
+// layout shown above.
+constexpr uint64_t kSignMask = 0x8000000000000000LL;
+constexpr uint64_t kExponentMask = 0x7ff0000000000000LL;
+constexpr int32_t kExponentShift = 52;
+constexpr int32_t kExponentBias = 1023;
+constexpr uint32_t kExponentIsBadNum = 0x7ff;
+constexpr uint64_t kFractionMask = 0x000fffffffc00000LL;
+constexpr uint32_t kFractionShift = 22;
+constexpr uint32_t kFractionRoundingMask = 0x003fffff;
+constexpr uint32_t kFractionRoundingThreshold = 0x00200000;
+} // namespace
+
void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier,
int* shift) {
if (double_multiplier == 0.) {
@@ -30,8 +56,16 @@ void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier,
*shift = 0;
return;
}
+#ifdef TFLITE_EMULATE_FLOAT
+ // If we're trying to avoid the use of floating-point instructions (for
+ // example on microcontrollers) then use an alternative implementation
+ // that only requires integer and bitwise operations. To enable this, you
+ // need to set the define during the build process for your platform.
+ int64_t q_fixed = IntegerFrExp(double_multiplier, shift);
+#else // TFLITE_EMULATE_FLOAT
const double q = std::frexp(double_multiplier, shift);
auto q_fixed = static_cast<int64_t>(TfLiteRound(q * (1ll << 31)));
+#endif // TFLITE_EMULATE_FLOAT
TFLITE_CHECK(q_fixed <= (1ll << 31));
if (q_fixed == (1ll << 31)) {
q_fixed /= 2;
@@ -60,6 +94,163 @@ void QuantizeMultiplierSmallerThanOneExp(double double_multiplier,
*left_shift = shift;
}
+int64_t IntegerFrExp(double input, int* shift) {
+ // Make sure our assumptions about the double layout hold.
+ TFLITE_CHECK_EQ(8, sizeof(double));
+
+ // We want to access the bits of the input double value directly, which is
+ // tricky to do safely, so use a union to handle the casting.
+ union {
+ double double_value;
+ uint64_t double_as_uint;
+ } cast_union;
+ cast_union.double_value = input;
+ const uint64_t u = cast_union.double_as_uint;
+
+ // If the bitfield is all zeros apart from the sign bit, this is a normalized
+ // zero value, so return standard values for this special case.
+ if ((u & ~kSignMask) == 0) {
+ *shift = 0;
+ return 0;
+ }
+
+ // Deal with NaNs and Infs, which are always indicated with a fixed pattern in
+ // the exponent, and distinguished by whether the fractions are zero or
+ // non-zero.
+ const uint32_t exponent_part = ((u & kExponentMask) >> kExponentShift);
+ if (exponent_part == kExponentIsBadNum) {
+ *shift = std::numeric_limits<int>::max();
+ if (u & kFractionMask) {
+ // NaN, so just return zero (with the exponent set to INT_MAX).
+ return 0;
+ } else {
+ // Infinity, so return +/- INT_MAX.
+ if (u & kSignMask) {
+ return std::numeric_limits<int64_t>::min();
+ } else {
+ return std::numeric_limits<int64_t>::max();
+ }
+ }
+ }
+
+ // The shift is fairly easy to extract from the high bits of the double value,
+ // just by masking it out and applying a bias. The std::frexp() implementation
+ // always returns values between 0.5 and 1.0 though, whereas the exponent
+ // assumes 1.0 to 2.0 is the standard range, so I add on one to match that
+ // interface.
+ *shift = (exponent_part - kExponentBias) + 1;
+
+ // There's an implicit high bit in the double format definition, so make sure
+ // we include that at the top, and then reconstruct the rest of the fractional
+ // value from the remaining fragments.
+ int64_t fraction = 0x40000000 + ((u & kFractionMask) >> kFractionShift);
+
+ // We're cutting off some bits at the bottom, so to exactly match the standard
+ // frexp implementation here we'll apply rounding by adding one to the least
+ // significant bit of the result if the discarded portion is over half of the
+ // maximum.
+ if ((u & kFractionRoundingMask) > kFractionRoundingThreshold) {
+ fraction += 1;
+ }
+ // Negate the fraction if the sign bit was set.
+ if (u & kSignMask) {
+ fraction *= -1;
+ }
+
+ return fraction;
+}
+
+double DoubleFromFractionAndShift(int64_t fraction, int shift) {
+ union {
+ double double_value;
+ uint64_t double_as_uint;
+ } result;
+
+ // Detect NaNs and infinities.
+ if (shift == std::numeric_limits<int>::max()) {
+ if (fraction == 0) {
+ return NAN;
+ } else if (fraction > 0) {
+ return INFINITY;
+ } else {
+ return -INFINITY;
+ }
+ }
+
+ // Return a normalized zero for a zero fraction.
+ if (fraction == 0) {
+ result.double_as_uint = 0;
+ return result.double_value;
+ }
+
+ bool is_negative = (fraction < 0);
+ int64_t encoded_fraction = is_negative ? -fraction : fraction;
+ int64_t encoded_shift = (shift - 1);
+ while (encoded_fraction < 0x40000000) {
+ encoded_fraction *= 2;
+ encoded_shift -= 1;
+ }
+ while (encoded_fraction > 0x80000000) {
+ encoded_fraction /= 2;
+ encoded_shift += 1;
+ }
+ encoded_fraction -= 0x40000000;
+ if (encoded_shift < -1022) {
+ encoded_shift = -1023;
+ } else if (encoded_shift > 1022) {
+ encoded_shift = 1023;
+ }
+ encoded_shift += kExponentBias;
+ uint64_t encoded_sign = is_negative ? kSignMask : 0;
+ result.double_as_uint = encoded_sign | (encoded_shift << kExponentShift) |
+ (encoded_fraction << kFractionShift);
+ return result.double_value;
+}
+
+double IntegerDoubleMultiply(double a, double b) {
+ int a_shift;
+ const int64_t a_fraction = IntegerFrExp(a, &a_shift);
+ int b_shift;
+ const int64_t b_fraction = IntegerFrExp(b, &b_shift);
+ // Detect NaNs and infinities.
+ if (a_shift == std::numeric_limits<int>::max() ||
+ (b_shift == std::numeric_limits<int>::max())) {
+ return NAN;
+ }
+ const int result_shift = a_shift + b_shift + 1;
+ const int64_t result_fraction = (a_fraction * b_fraction) >> 32;
+ return DoubleFromFractionAndShift(result_fraction, result_shift);
+}
+
+int IntegerDoubleCompare(double a, double b) {
+ int a_shift;
+ const int64_t a_fraction = IntegerFrExp(a, &a_shift);
+ int b_shift;
+ const int64_t b_fraction = IntegerFrExp(b, &b_shift);
+
+ // Detect NaNs and infinities.
+ if (a_shift == std::numeric_limits<int>::max() ||
+ (b_shift == std::numeric_limits<int>::max())) {
+ return 1;
+ }
+
+ if ((a_fraction == 0) && (b_fraction < 0)) {
+ return 1;
+ } else if ((a_fraction < 0) && (b_fraction == 0)) {
+ return -1;
+ } else if (a_shift < b_shift) {
+ return -1;
+ } else if (a_shift > b_shift) {
+ return 1;
+ } else if (a_fraction < b_fraction) {
+ return -1;
+ } else if (a_fraction > b_fraction) {
+ return 1;
+ } else {
+ return 0;
+ }
+}
+
void PreprocessSoftmaxScaling(double beta, double input_scale,
int input_integer_bits,
int32_t* quantized_multiplier, int* left_shift) {
@@ -72,8 +263,20 @@ void PreprocessSoftmaxScaling(double beta, double input_scale,
// result is double equivalent of Q0.31 (actually with more precision). Thus
// this generates a Q(input_integer_bits).(31-input_integer_bits)
// representation.
+#ifdef TFLITE_EMULATE_FLOAT
+ const double input_beta = IntegerDoubleMultiply(beta, input_scale);
+ int shift;
+ int64_t fraction = IntegerFrExp(input_beta, &shift);
+ shift += (31 - input_integer_bits);
+ double input_beta_real_multiplier =
+ DoubleFromFractionAndShift(fraction, shift);
+ if (IntegerDoubleCompare(input_beta_real_multiplier, (1ll << 31) - 1.0) > 0) {
+ input_beta_real_multiplier = (1ll << 31) - 1.0;
+ }
+#else // TFLITE_EMULATE_FLOAT
const double input_beta_real_multiplier = std::min(
beta * input_scale * (1 << (31 - input_integer_bits)), (1ll << 31) - 1.0);
+#endif // TFLITE_EMULATE_FLOAT
QuantizeMultiplierGreaterThanOne(input_beta_real_multiplier,
quantized_multiplier, left_shift);
@@ -97,6 +300,12 @@ void PreprocessLogSoftmaxScalingExp(double beta, double input_scale,
}
int CalculateInputRadius(int input_integer_bits, int input_left_shift) {
+#ifdef TFLITE_EMULATE_FLOAT
+ int64_t result = (1 << input_integer_bits) - 1;
+ result <<= (31 - input_integer_bits);
+ result >>= input_left_shift;
+ return result;
+#else // TFLITE_EMULATE_FLOAT
const double max_input_rescaled = 1.0 * ((1 << input_integer_bits) - 1) *
(1ll << (31 - input_integer_bits)) /
(1ll << input_left_shift);
@@ -104,6 +313,7 @@ int CalculateInputRadius(int input_integer_bits, int input_left_shift) {
// After scaling the difference, the result would be at the maximum. Thus we
// must ensure that our value has lower magnitude.
return static_cast<int>(std::floor(max_input_rescaled));
+#endif // TFLITE_EMULATE_FLOAT
}
void NudgeQuantizationRange(const float min, const float max,
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h
index 9ee4a47fbb..d74a1bac97 100644
--- a/tensorflow/contrib/lite/kernels/internal/quantization_util.h
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h
@@ -195,6 +195,44 @@ void QuantizeMultiplierGreaterThanOne(double double_multiplier,
void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier,
int* shift);
+// Splits a double input value into a returned fraction, and a shift value from
+// the exponent, using only bitwise and integer operations to support
+// microcontrollers and other environments without floating-point support.
+//
+// This is designed to be a replacement for how std::frexp() is used within the
+// QuantizeMultiplier() function, and so has a different signature than the
+// standard version, returning a 64-bit integer rather than a double. This
+// result has a maximum value of 1<<31, with the fraction expressed as a
+// proportion of that maximum.
+//
+// std::frexp() returns NaNs and infinities unmodified, but since we're
+// returning integers that can't represent those values, instead we return
+// a shift of std::numeric_limits<int>::max() for all bad numbers, with an int64
+// result of 0 for NaNs, std:numeric_limits<int64_t>::max() for +INFINITY, and
+// std::numeric_limits<int64_t>::min() for -INFINITY. Denormalized inputs will
+// result in return values that end up truncating some bits at the end,
+// reflecting the loss of precision inherent in denormalization.
+int64_t IntegerFrExp(double input, int* shift);
+
+// Converts an integer fraction in the format produced by IntegerFrExp (where
+// 0x40000000 is 1.0) and an exponent shift (between -1022 and +1022) into an
+// IEEE binary64 double format result. The implementation uses only integer and
+// bitwise operators, so no floating point hardware support or emulation is
+// needed. This is here so quantized operations can run non-time-critical
+// preparation calculations on microcontrollers and other platforms without
+// float support.
+double DoubleFromFractionAndShift(int64_t fraction, int shift);
+
+// Performs a multiplication of two numbers in double format, using only integer
+// and bitwise instructions. This is aimed at supporting housekeeping functions
+// for quantized operations on microcontrollers without floating-point hardware.
+double IntegerDoubleMultiply(double a, double b);
+
+// Returns -1 if a is less than b, 0 if a and b are equal, and +1 if a is
+// greater than b. It is implemented using only integer and logical instructions
+// so that it can be easily run on microcontrollers for quantized operations.
+int IntegerDoubleCompare(double a, double b);
+
// This first creates a multiplier in a double equivalent of
// Q(input_integer_bits).(31-input_integer_bits) representation, with extra
// precision in the double's fractional bits. It then splits the result into
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
index 00fc3e91dc..14281f25c6 100644
--- a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
@@ -191,6 +191,139 @@ TEST(QuantizationUtilTest, ChooseQuantizationParamsZeroPointOnMaxBoundary) {
EXPECT_EQ(qp.zero_point, 255);
}
+TEST(QuantizationUtilTest, IntegerFrExp) {
+ int shift;
+ int64_t result = IntegerFrExp(0.0, &shift);
+ EXPECT_EQ(0, result);
+ EXPECT_EQ(0, shift);
+
+ result = IntegerFrExp(1.0, &shift);
+ EXPECT_NEAR(0x40000000, result, 1);
+ EXPECT_EQ(1, shift);
+
+ result = IntegerFrExp(0.25, &shift);
+ EXPECT_NEAR(0x40000000, result, 1);
+ EXPECT_EQ(-1, shift);
+
+ result = IntegerFrExp(-1.0, &shift);
+ EXPECT_NEAR(-(1 << 30), result, 1);
+ EXPECT_EQ(1, shift);
+
+ result = IntegerFrExp(123.45, &shift);
+ EXPECT_NEAR(2071147315, result, 1);
+ EXPECT_EQ(7, shift);
+
+ result = IntegerFrExp(NAN, &shift);
+ EXPECT_NEAR(0, result, 1);
+ EXPECT_EQ(0x7fffffff, shift);
+
+ result = IntegerFrExp(INFINITY, &shift);
+ EXPECT_NEAR(std::numeric_limits<int64_t>::max(), result, 1);
+ EXPECT_EQ(0x7fffffff, shift);
+
+ result = IntegerFrExp(-INFINITY, &shift);
+ EXPECT_NEAR(std::numeric_limits<int64_t>::min(), result, 1);
+ EXPECT_EQ(0x7fffffff, shift);
+}
+
+TEST(QuantizationUtilTest, IntegerFrExpVersusDouble) {
+ int shift;
+ int32_t result = IntegerFrExp(0.0, &shift);
+ EXPECT_EQ(result, 0);
+ EXPECT_EQ(shift, 0);
+
+ int double_shift;
+ double double_result = std::frexp(0.0, &double_shift);
+ EXPECT_EQ(double_result, 0);
+ EXPECT_EQ(double_shift, 0);
+
+ result = IntegerFrExp(1.0, &shift);
+ EXPECT_NEAR(result, 0x40000000, 1);
+ EXPECT_EQ(shift, 1);
+ double_result = std::frexp(1.0, &double_shift);
+ EXPECT_NEAR(double_result, 0.5, 1e-5);
+ EXPECT_EQ(double_shift, 1);
+
+ result = IntegerFrExp(0.25, &shift);
+ EXPECT_NEAR(result, 0x40000000, 1);
+ EXPECT_EQ(shift, -1);
+ double_result = std::frexp(0.25, &double_shift);
+ EXPECT_NEAR(double_result, 0.5, 1e-5);
+ EXPECT_EQ(double_shift, -1);
+
+ result = IntegerFrExp(-1.0, &shift);
+ EXPECT_NEAR(result, -(1 << 30), 1);
+ EXPECT_EQ(shift, 1);
+ double_result = std::frexp(-1.0, &double_shift);
+ EXPECT_NEAR(double_result, -0.5, 1e-5);
+ EXPECT_EQ(double_shift, 1);
+
+ result = IntegerFrExp(123.45, &shift);
+ EXPECT_NEAR(result, (0.964453 * (1L << 31)), 1000);
+ EXPECT_EQ(shift, 7);
+ double_result = std::frexp(123.45, &double_shift);
+ EXPECT_NEAR(double_result, 0.964453, 1e-5);
+ EXPECT_EQ(double_shift, 7);
+}
+
+TEST(QuantizationUtilTest, DoubleFromFractionAndShift) {
+ double result = DoubleFromFractionAndShift(0, 0);
+ EXPECT_EQ(0, result);
+
+ result = DoubleFromFractionAndShift(0x40000000, 1);
+ EXPECT_NEAR(1.0, result, 1e-5);
+
+ result = DoubleFromFractionAndShift(0x40000000, 2);
+ EXPECT_NEAR(2.0, result, 1e-5);
+
+ int shift;
+ int64_t fraction = IntegerFrExp(3.0, &shift);
+ result = DoubleFromFractionAndShift(fraction, shift);
+ EXPECT_NEAR(3.0, result, 1e-5);
+
+ fraction = IntegerFrExp(123.45, &shift);
+ result = DoubleFromFractionAndShift(fraction, shift);
+ EXPECT_NEAR(123.45, result, 1e-5);
+
+ fraction = IntegerFrExp(-23.232323, &shift);
+ result = DoubleFromFractionAndShift(fraction, shift);
+ EXPECT_NEAR(-23.232323, result, 1e-5);
+
+ fraction = IntegerFrExp(NAN, &shift);
+ result = DoubleFromFractionAndShift(fraction, shift);
+ EXPECT_TRUE(std::isnan(result));
+
+ fraction = IntegerFrExp(INFINITY, &shift);
+ result = DoubleFromFractionAndShift(fraction, shift);
+ EXPECT_FALSE(std::isfinite(result));
+}
+
+TEST(QuantizationUtilTest, IntegerDoubleMultiply) {
+ EXPECT_NEAR(1.0, IntegerDoubleMultiply(1.0, 1.0), 1e-5);
+ EXPECT_NEAR(2.0, IntegerDoubleMultiply(1.0, 2.0), 1e-5);
+ EXPECT_NEAR(2.0, IntegerDoubleMultiply(2.0, 1.0), 1e-5);
+ EXPECT_NEAR(4.0, IntegerDoubleMultiply(2.0, 2.0), 1e-5);
+ EXPECT_NEAR(0.5, IntegerDoubleMultiply(1.0, 0.5), 1e-5);
+ EXPECT_NEAR(0.25, IntegerDoubleMultiply(0.5, 0.5), 1e-5);
+ EXPECT_NEAR(-1.0, IntegerDoubleMultiply(1.0, -1.0), 1e-5);
+ EXPECT_NEAR(-1.0, IntegerDoubleMultiply(-1.0, 1.0), 1e-5);
+ EXPECT_NEAR(1.0, IntegerDoubleMultiply(-1.0, -1.0), 1e-5);
+ EXPECT_NEAR(15000000.0, IntegerDoubleMultiply(3000.0, 5000.0), 1e-5);
+ EXPECT_TRUE(std::isnan(IntegerDoubleMultiply(NAN, 5000.0)));
+ EXPECT_TRUE(std::isnan(IntegerDoubleMultiply(3000.0, NAN)));
+}
+
+TEST(QuantizationUtilTest, IntegerDoubleCompare) {
+ EXPECT_EQ(-1, IntegerDoubleCompare(0.0, 1.0));
+ EXPECT_EQ(1, IntegerDoubleCompare(1.0, 0.0));
+ EXPECT_EQ(0, IntegerDoubleCompare(1.0, 1.0));
+ EXPECT_EQ(0, IntegerDoubleCompare(0.0, 0.0));
+ EXPECT_EQ(-1, IntegerDoubleCompare(-10.0, 10.0));
+ EXPECT_EQ(1, IntegerDoubleCompare(123.45, 10.0));
+ EXPECT_EQ(1, IntegerDoubleCompare(NAN, INFINITY));
+ EXPECT_EQ(1, IntegerDoubleCompare(INFINITY, NAN));
+}
+
#ifdef GTEST_HAS_DEATH_TEST
TEST(QuantizationUtilTest, ChooseQuantizationParamsInvalidRange) {
EXPECT_DEATH(ChooseQuantizationParams<uint8>(10.0, -30.0), "");
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
index 71ae74f34c..683ccdc74d 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
@@ -27,6 +27,28 @@ namespace tflite {
namespace reference_ops {
template <FusedActivationFunctionType Ac>
+void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
+ float* output_data, const RuntimeShape& output_shape) {
+ static_assert(Ac == FusedActivationFunctionType::kNone, "");
+ tflite::L2NormalizationParams op_params;
+ // No params need to be set for float.
+
+ L2Normalization(op_params, input_shape, input_data, output_shape,
+ output_data);
+}
+
+inline void L2Normalization(const uint8* input_data,
+ const RuntimeShape& input_shape,
+ int32 input_zero_point, uint8* output_data,
+ const RuntimeShape& output_shape) {
+ tflite::L2NormalizationParams op_params;
+ op_params.input_zero_point = input_zero_point;
+
+ L2Normalization(op_params, input_shape, input_data, output_shape,
+ output_data);
+}
+
+template <FusedActivationFunctionType Ac>
void L2Normalization(const float* input_data, const Dims<4>& input_dims,
float* output_data, const Dims<4>& output_dims) {
L2Normalization<Ac>(input_data, DimsToShape(input_dims), output_data,
@@ -58,6 +80,15 @@ inline void Relu6(const float* input_data, const Dims<4>& input_dims,
output_data);
}
+inline void ReluX(uint8 min_value, uint8 max_value, const uint8* input_data,
+ const RuntimeShape& input_shape, uint8* output_data,
+ const RuntimeShape& output_shape) {
+ tflite::ActivationParams params;
+ params.quantized_activation_max = max_value;
+ params.quantized_activation_min = min_value;
+ ReluX(params, input_shape, input_data, output_shape, output_data);
+}
+
template <FusedActivationFunctionType Ac>
inline void Add(int left_shift, const uint8* input1_data,
const Dims<4>& input1_dims, int32 input1_offset,
@@ -311,6 +342,30 @@ inline void AveragePool(const float* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims), output_data);
}
+// Legacy.
+// Transitional version that will be moved shortly to legacy_reference_ops, as
+// part of RuntimeShape revisions.
+inline void BroadcastMul4DSlow(const uint8* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ const uint8* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 output_offset, int32 output_multiplier,
+ int output_shift, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+ op_params.input1_offset = input1_offset;
+ op_params.input2_offset = input2_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = output_multiplier;
+ op_params.output_shift = output_shift;
+
+ BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims,
int32 input1_offset, const uint8* input2_data,
const Dims<4>& input2_dims, int32 input2_offset,
@@ -624,6 +679,377 @@ inline void Tanh(const int16* input_data, const Dims<4>& input_dims,
DimsToShape(output_dims));
}
+template <typename T>
+inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
+ int block_size, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DepthToSpaceParams op_params;
+ op_params.block_size = block_size;
+
+ DepthToSpace(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
+ int block_size, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::SpaceToDepthParams op_params;
+ op_params.block_size = block_size;
+
+ SpaceToDepth(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void Mul(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac>
+void Mul(const float* input1_data, const Dims<4>& input1_dims,
+ const float* input2_data, const Dims<4>& input2_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ float output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+template <FusedActivationFunctionType Ac, typename T>
+void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ T output_activation_min, output_activation_max;
+ GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
+
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ BroadcastMul4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
+ const int16* input2_data, const Dims<4>& input2_dims,
+ int16* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ // No params in this version.
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
+ const int16* input2_data, const Dims<4>& input2_dims,
+ int32 output_offset, int32 output_activation_min,
+ int32 output_activation_max, uint8* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ op_params.quantized_activation_min = output_activation_min;
+ op_params.quantized_activation_max = output_activation_max;
+ op_params.output_offset = output_offset;
+
+ Mul(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void LocalResponseNormalization(const float* input_data,
+ const Dims<4>& input_dims, int range,
+ float bias, float alpha, float beta,
+ float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::LocalResponseNormalizationParams op_params;
+ op_params.range = range;
+ op_params.bias = bias;
+ op_params.alpha = alpha;
+ op_params.beta = beta;
+
+ LocalResponseNormalization(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename SrcT, typename DstT>
+void Cast(const SrcT* input_data, const Dims<4>& input_dims, DstT* output_data,
+ const Dims<4>& output_dims) {
+ Cast(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+inline void Floor(const float* input_data, const Dims<4>& input_dims,
+ float* output_data, const Dims<4>& output_dims) {
+ Floor(DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T>
+inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, T* output_data,
+ const Dims<4>& output_dims, bool align_corners) {
+ tflite::ResizeBilinearParams op_params;
+ op_params.align_corners = align_corners;
+ ResizeBilinear(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_size_dims), output_size_data,
+ DimsToShape(output_dims), output_data);
+}
+
+// legacy, for compatibility with old checked-in code
+inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, float* output_data,
+ const Dims<4>& output_dims) {
+ ResizeBilinear<float>(input_data, input_dims, output_size_data,
+ output_size_dims, output_data, output_dims,
+ /*align_corners=*/false);
+}
+
+inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
+ const int32* output_size_data,
+ const Dims<4>& output_size_dims, uint8* output_data,
+ const Dims<4>& output_dims) {
+ ResizeBilinear<uint8>(input_data, input_dims, output_size_data,
+ output_size_dims, output_data, output_dims,
+ /*align_corners=*/false);
+}
+
+template <typename T>
+inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
+ const int32* block_shape_data,
+ const Dims<4>& block_shape_dims,
+ const int32* paddings_data,
+ const Dims<4>& paddings_dims, T* output_data,
+ const Dims<4>& output_dims,
+ const int32_t pad_value) {
+ tflite::SpaceToBatchParams op_params;
+ op_params.output_offset = pad_value;
+
+ SpaceToBatchND(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(block_shape_dims), block_shape_data,
+ DimsToShape(paddings_dims), paddings_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
+ const int32* block_shape_data,
+ const Dims<4>& block_shape_dims,
+ const int32* paddings_data,
+ const Dims<4>& paddings_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::SpaceToBatchParams op_params;
+ op_params.output_offset = 0;
+
+ SpaceToBatchND(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(block_shape_dims), block_shape_data,
+ DimsToShape(paddings_dims), paddings_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
+ const int32* block_shape_data,
+ const Dims<4>& block_shape_dims,
+ const int32* crops_data, const Dims<4>& crops_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ BatchToSpaceND(DimsToShape(input_dims), input_data,
+ DimsToShape(block_shape_dims), block_shape_data,
+ DimsToShape(crops_dims), crops_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// Legacy signature, function covered both Pad and PadV2.
+template <typename T>
+inline void PadV2(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims, const T pad_value) {
+ TFLITE_DCHECK_EQ(left_paddings.size(), 4);
+ TFLITE_DCHECK_EQ(right_paddings.size(), 4);
+ tflite::PadParams op_params;
+ op_params.left_padding_count = 4;
+ op_params.right_padding_count = 4;
+ for (int i = 0; i < 4; ++i) {
+ op_params.left_padding[i] = left_paddings[3 - i];
+ op_params.right_padding[i] = right_paddings[3 - i];
+ }
+ // SetFloatOrInt(pad_value, &op_params.pad_value);
+ const T pad_value_copy = pad_value;
+
+ Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy,
+ DimsToShape(output_dims), output_data);
+}
+
+// Old Pad that calls legacy PadV2.
+template <typename T>
+inline void Pad(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims, const int32_t pad_value) {
+ const T converted_pad_value = static_cast<T>(pad_value);
+ PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
+ output_dims, converted_pad_value);
+}
+
+// Old Pad that only padded with 0.
+template <typename T>
+inline void Pad(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& left_paddings,
+ const std::vector<int>& right_paddings, T* output_data,
+ const Dims<4>& output_dims) {
+ const T pad_value = static_cast<T>(0);
+ PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
+ output_dims, pad_value);
+}
+
+template <typename T>
+void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ Minimum(DimsToShape(input1_dims), input1_data, input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, T* output_data,
+ const Dims<4>& output_dims) {
+ Maximum(DimsToShape(input1_dims), input1_data, input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T, typename Op>
+void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims,
+ Op op) {
+ MaximumMinimumBroadcast4DSlow(DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data, op);
+}
+
+template <typename T1, typename T2, typename T3>
+void ArgMax(const T3* axis, const T1* input_data,
+ const tflite::Dims<4>& input_dims, T2* output_data,
+ const tflite::Dims<4>& output_dims) {
+ ArgMinMax(DimsToShape(input_dims), input_data, axis, DimsToShape(output_dims),
+ output_data, std::greater<T1>());
+}
+
+template <typename T1, typename T2, typename T3, typename Cmp>
+void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
+ T2* output_data, const Dims<4>& output_dims, const Cmp& cmp) {
+ ArgMinMax(axis, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data, cmp);
+}
+
+template <typename T>
+inline void Pow(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ Pow(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims),
+ input2_data, DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T* output_data, const Dims<4>& output_dims) {
+ BroadcastPow4DSlow(DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void Logical(const bool* input1_data, const Dims<4>& input1_dims,
+ const bool* input2_data, const Dims<4>& input2_dims,
+ bool* output_data, const Dims<4>& output_dims,
+ const std::function<bool(bool, bool)>& func) {
+ Logical(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims),
+ input2_data, DimsToShape(output_dims), output_data, func);
+}
+
+inline void BroadcastLogical(const bool* input1_data,
+ const Dims<4>& input1_dims,
+ const bool* input2_data,
+ const Dims<4>& input2_dims, bool* output_data,
+ const Dims<4>& output_dims,
+ const std::function<bool(bool, bool)>& func) {
+ BroadcastLogical4DSlow(DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data, func);
+}
+
+// R: Result type. T1: Input 1 type. T2: Input 2 type.
+template <typename R, typename T1, typename T2>
+inline void BroadcastBinaryFunction(const T1* input1_data,
+ const Dims<4>& input1_dims,
+ const T2* input2_data,
+ const Dims<4>& input2_dims, R* output_data,
+ const Dims<4>& output_dims,
+ R (*func)(T1, T2)) {
+ BroadcastBinaryFunction(DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data, func);
+}
+
+// R: Result type. T1: Input 1 type. T2: Input 2 type.
+template <typename R, typename T1, typename T2>
+inline void BinaryFunction(const T1* input1_data, const Dims<4>& input1_dims,
+ const T2* input2_data, const Dims<4>& input2_dims,
+ R* output_data, const Dims<4>& output_dims,
+ R (*func)(T1, T2)) {
+ BinaryFunction(DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data, func);
+}
+
+template <typename T>
+inline void Slice(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& begin, const std::vector<int>& size,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::SliceParams op_params;
+ op_params.begin_count = 4;
+ op_params.size_count = 4;
+ for (int i = 0; i < 4; ++i) {
+ op_params.begin[i] = begin[3 - i];
+ op_params.size[i] = size[3 - i];
+ }
+
+ Slice(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
} // namespace reference_ops
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_LEGACY_REFERENCE_OPS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
index aa93e857d7..77e60adc18 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include <string.h>
#include <algorithm>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/round.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
@@ -151,6 +151,16 @@ void PortableVectorVectorCwiseProductAccumulate(const float* vector1,
}
}
+void PortableVectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector,
+ int n_batch, float* result) {
+ for (int b = 0; b < n_batch; b++) {
+ for (int v = 0; v < v_size; v++) {
+ *result++ = vector[v] * *batch_vector++;
+ }
+ }
+}
+
void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector,
int v_size,
const float* batch_vector,
@@ -163,6 +173,16 @@ void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector,
}
}
+void PortableVectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+ float* batch_vector) {
+ for (int b = 0; b < n_batch; b++) {
+ for (int i = 0; i < v_size; ++i) {
+ batch_vector[i] += vector[i];
+ }
+ batch_vector += v_size;
+ }
+}
+
void PortableVectorBatchVectorAssign(const float* vector, int v_size,
int n_batch, float* batch_vector) {
for (int b = 0; b < n_batch; b++) {
@@ -233,5 +253,31 @@ void PortableReductionSumVector(const float* input_vector, float* output_vector,
}
}
+void PortableMeanStddevNormalization(const float* input_vector,
+ float* output_vector, int v_size,
+ int n_batch, float normalization_epsilon) {
+ for (int batch = 0; batch < n_batch; ++batch) {
+ float sum = 0.0f;
+ float sum_sq = 0.0f;
+ for (int i = 0; i < v_size; ++i) {
+ sum += input_vector[i];
+ sum_sq += input_vector[i] * input_vector[i];
+ }
+ const float mean = sum / v_size;
+ float stddev_inv = 0.0f;
+ const float variance = sum_sq / v_size - mean * mean;
+ if (variance == 0) {
+ stddev_inv = 1.0f / sqrt(normalization_epsilon);
+ } else {
+ stddev_inv = 1.0f / sqrt(variance);
+ }
+ for (int i = 0; i < v_size; ++i) {
+ output_vector[i] = (input_vector[i] - mean) * stddev_inv;
+ }
+ input_vector += v_size;
+ output_vector += v_size;
+ }
+}
+
} // namespace tensor_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
index a375aaffa6..714b1164ee 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h
@@ -17,7 +17,7 @@ limitations under the License.
// TODO(ghodrat): Remove this header file and the dependency to internal data
// structure.
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#if defined(_MSC_VER)
#define __restrict__ __restrict
@@ -69,6 +69,11 @@ void PortableBatchVectorBatchVectorDotProduct(const float* vector1,
int n_batch, float* result,
int result_stride);
+// Cwise product of a vector and a batch-vector.
+void PortableVectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector,
+ int n_batch, float* result);
+
// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC
// operation, the assumption here is that result array is initialized to valid
// values.
@@ -82,6 +87,10 @@ void PortableVectorBatchVectorCwiseProductAccumulate(const float* vector,
void PortableVectorBatchVectorAssign(const float* vector, int v_size,
int n_batch, float* batch_vector);
+// Add another vector for each batch in the batch vector.
+void PortableVectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+ float* batch_vector);
+
// Apply sigmoid to elements of a vector.
void PortableApplySigmoidToVector(const float* vector, int v_size,
float* result);
@@ -120,6 +129,12 @@ void PortableVectorShiftLeft(float* vector, int v_size, float shift_value);
void PortableReductionSumVector(const float* input_vector, float* output_vector,
int output_size, int reduction_size);
+// Layer norm for each batch.
+// normalization_epsilon is added to avoid divergence.
+void PortableMeanStddevNormalization(const float* input_vector,
+ float* output_vector, int v_size,
+ int n_batch, float normalization_epsilon);
+
float Clip(float f, float abs_limit) { return PortableClip(f, abs_limit); }
bool IsZeroVector(const float* vector, int v_size) {
@@ -161,6 +176,13 @@ void VectorVectorCwiseProductAccumulate(const float* vector1,
PortableVectorVectorCwiseProductAccumulate(vector1, vector2, v_size, result);
}
+void VectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector, int n_batch,
+ float* result) {
+ PortableVectorBatchVectorCwiseProduct(vector, v_size, batch_vector, n_batch,
+ result);
+}
+
void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size,
const float* batch_vector,
int n_batch, float* result) {
@@ -181,6 +203,11 @@ void BatchVectorBatchVectorDotProduct(const float* vector1,
result, result_stride);
}
+void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+ float* batch_vector) {
+ PortableVectorBatchVectorAdd(vector, v_size, n_batch, batch_vector);
+}
+
void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch,
float* batch_vector) {
PortableVectorBatchVectorAssign(vector, v_size, n_batch, batch_vector);
@@ -228,6 +255,13 @@ void ReductionSumVector(const float* input_vector, float* output_vector,
reduction_size);
}
+void MeanStddevNormalization(const float* input_vector, float* output_vector,
+ int v_size, int n_batch,
+ float normalization_epsilon) {
+ PortableMeanStddevNormalization(input_vector, output_vector, v_size, n_batch,
+ normalization_epsilon);
+}
+
} // namespace tensor_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index b241ecbcf5..0abacf85e1 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -110,6 +110,11 @@ inline RuntimeShape DimsToShape(const tflite::Dims<4>& dims) {
{dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
}
+inline void ShapeFromDims(const tflite::Dims<4>& dims, RuntimeShape* shape) {
+ shape->BuildFrom(
+ {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
+}
+
template <typename T>
int CountLeadingZeros(T integer_input) {
static_assert(std::is_unsigned<T>::value,
@@ -407,18 +412,29 @@ void Conv(const uint8* input_data, const Dims<4>& input_dims,
}
template <typename T>
-inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
- int block_size, T* output_data,
- const Dims<4>& output_dims) {
- const int input_depth = ArraySize(input_dims, 0);
- const int input_width = ArraySize(input_dims, 1);
- const int input_height = ArraySize(input_dims, 2);
- const int input_batch = ArraySize(input_dims, 3);
+inline void DepthToSpace(const tflite::DepthToSpaceParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int input_depth = input_shape.Dims(3);
+ const int input_width = input_shape.Dims(2);
+ const int input_height = input_shape.Dims(1);
+ const int input_batch = input_shape.Dims(0);
- const int output_depth = ArraySize(output_dims, 0);
- const int output_width = ArraySize(output_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_batch = ArraySize(output_dims, 3);
+ const int output_depth = output_shape.Dims(3);
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_batch = output_shape.Dims(0);
+
+ const int32 block_size = op_params.block_size;
TFLITE_DCHECK_EQ(input_width * block_size, output_width);
TFLITE_DCHECK_EQ(input_height * block_size, output_height);
@@ -437,9 +453,9 @@ inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
const int in_h = out_h / block_size;
const int in_b = out_b;
+ const int input_index = Offset(input_shape, in_b, in_h, in_w, in_d);
const int output_index =
- Offset(output_dims, out_d, out_w, out_h, out_b);
- const int input_index = Offset(input_dims, in_d, in_w, in_h, in_b);
+ Offset(output_shape, out_b, out_h, out_w, out_d);
output_data[output_index] = input_data[input_index];
}
@@ -449,18 +465,29 @@ inline void DepthToSpace(const T* input_data, const Dims<4>& input_dims,
}
template <typename T>
-inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
- int block_size, T* output_data,
- const Dims<4>& output_dims) {
- const int input_depth = ArraySize(input_dims, 0);
- const int input_width = ArraySize(input_dims, 1);
- const int input_height = ArraySize(input_dims, 2);
- const int input_batch = ArraySize(input_dims, 3);
+inline void SpaceToDepth(const tflite::SpaceToDepthParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int input_depth = input_shape.Dims(3);
+ const int input_width = input_shape.Dims(2);
+ const int input_height = input_shape.Dims(1);
+ const int input_batch = input_shape.Dims(0);
- const int output_depth = ArraySize(output_dims, 0);
- const int output_width = ArraySize(output_dims, 1);
- const int output_height = ArraySize(output_dims, 2);
- const int output_batch = ArraySize(output_dims, 3);
+ const int output_depth = output_shape.Dims(3);
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_batch = output_shape.Dims(0);
+
+ const int32 block_size = op_params.block_size;
TFLITE_DCHECK_EQ(input_width, output_width * block_size);
TFLITE_DCHECK_EQ(input_height, output_height * block_size);
@@ -478,9 +505,9 @@ inline void SpaceToDepth(const T* input_data, const Dims<4>& input_dims,
const int out_h = in_h / block_size;
const int out_b = in_b;
+ const int input_index = Offset(input_shape, in_b, in_h, in_w, in_d);
const int output_index =
- Offset(output_dims, out_d, out_w, out_h, out_b);
- const int input_index = Offset(input_dims, in_d, in_w, in_h, in_b);
+ Offset(output_shape, out_b, out_h, out_w, out_d);
output_data[output_index] = input_data[input_index];
}
@@ -803,49 +830,6 @@ void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
output_activation_max, output_data, output_dims, gemm_context);
}
-template <FusedActivationFunctionType Ac>
-void NonGlobalBatchNormalization(
- const float* input_data, const Dims<4>& input_dims, const float* mean_data,
- const Dims<4>& mean_dims, const float* multiplier_data,
- const Dims<4>& multiplier_dims, const float* offset_data,
- const Dims<4>& offset_dims, float* output_data,
- const Dims<4>& output_dims) {
- const int batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- const int inner_size = MatchingFlatSizeSkipDim(
- input_dims, 3, mean_dims, multiplier_dims, offset_dims, output_dims);
-
- for (int b = 0; b < batches; ++b) {
- for (int i = 0; i < inner_size; ++i) {
- output_data[b * inner_size + i] = ActivationFunction<Ac>(
- (input_data[b * inner_size + i] - mean_data[i]) * multiplier_data[i] +
- offset_data[i]);
- }
- }
-}
-
-template <FusedActivationFunctionType Ac>
-void GlobalBatchNormalization(const float* input_data,
- const Dims<4>& input_dims, const float* mean_data,
- const Dims<4>& mean_dims,
- const float* multiplier_data,
- const Dims<4>& multiplier_dims,
- const float* offset_data,
- const Dims<4>& offset_dims, float* output_data,
- const Dims<4>& output_dims) {
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth =
- MatchingArraySize(input_dims, 0, mean_dims, 0, multiplier_dims, 0,
- offset_dims, 0, output_dims, 0);
-
- for (int i = 0; i < outer_size; ++i) {
- for (int c = 0; c < depth; ++c) {
- output_data[depth * i + c] = ActivationFunction<Ac>(
- (input_data[depth * i + c] - mean_data[c]) * multiplier_data[c] +
- offset_data[c]);
- }
- }
-}
-
inline void Relu(const RuntimeShape& input_shape, const float* input_data,
const RuntimeShape& output_shape, float* output_data) {
const int flat_size = MatchingFlatSize(input_shape, output_shape);
@@ -883,11 +867,13 @@ inline void Relu6(const RuntimeShape& input_shape, const float* input_data,
}
}
-inline void ReluX(uint8 min_value, uint8 max_value, const uint8* input_data,
- const RuntimeShape& input_shape, uint8* output_data,
- const RuntimeShape& output_shape) {
+inline void ReluX(const tflite::ActivationParams& params,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("Quantized ReluX (not fused)");
const int flat_size = MatchingFlatSize(input_shape, output_shape);
+ const uint8 max_value = params.quantized_activation_max;
+ const uint8 min_value = params.quantized_activation_min;
for (int i = 0; i < flat_size; ++i) {
const uint8 val = input_data[i];
const uint8 clamped =
@@ -896,10 +882,11 @@ inline void ReluX(uint8 min_value, uint8 max_value, const uint8* input_data,
}
}
-template <FusedActivationFunctionType Ac>
-void L2Normalization(const float* input_data, const RuntimeShape& input_shape,
- float* output_data, const RuntimeShape& output_shape) {
- static_assert(Ac == FusedActivationFunctionType::kNone, "");
+inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
+ const RuntimeShape& input_shape,
+ const float* input_data,
+ const RuntimeShape& output_shape,
+ float* output_data) {
const int trailing_dim = input_shape.DimensionsCount() - 1;
const int outer_size =
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
@@ -966,15 +953,17 @@ inline void GetInvSqrtQuantizedMultiplierExp(int32 input,
*output_shift *= kReverseShift;
}
-inline void L2Normalization(const uint8* input_data,
+inline void L2Normalization(const tflite::L2NormalizationParams& op_params,
const RuntimeShape& input_shape,
- int32 input_zero_point, uint8* output_data,
- const RuntimeShape& output_shape) {
+ const uint8* input_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
const int trailing_dim = input_shape.DimensionsCount() - 1;
const int depth =
MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
const int outer_size =
MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int32 input_zero_point = op_params.input_zero_point;
for (int i = 0; i < outer_size; ++i) {
int32 square_l2_norm = 0;
for (int c = 0; c < depth; c++) {
@@ -1320,11 +1309,16 @@ inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params,
}
template <typename T>
-inline void Mul(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+inline void Mul(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape& input2_shape, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ T output_activation_min;
+ T output_activation_max;
+ GetActivationParams(params, &output_activation_min, &output_activation_max);
+
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] * input2_data[i], output_activation_min,
@@ -1332,52 +1326,57 @@ inline void Mul(const T* input1_data, const Dims<4>& input1_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac>
-void Mul(const float* input1_data, const Dims<4>& input1_dims,
- const float* input2_data, const Dims<4>& input2_dims,
- float* output_data, const Dims<4>& output_dims) {
- float output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- Mul(input1_data, input1_dims, input2_data, input2_dims, output_activation_min,
- output_activation_max, output_data, output_dims);
-}
-
// TODO(jiawen): We can implement BroadcastMul on buffers of arbitrary
// dimensionality if the runtime code does a single loop over one dimension
// that handles broadcasting as the base case. The code generator would then
// generate max(D1, D2) nested for loops.
+// TODO(benoitjacob): BroadcastMul is intentionally duplicated from
+// reference_ops.h. Once an optimized version is implemented and NdArrayDesc<T>
+// is no longer referenced in this file, move NdArrayDesc<T> from types.h to
+// reference_ops.h.
template <typename T>
-void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastMul");
+void BroadcastMul4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& unextended_input1_shape,
+ const T* input1_data,
+ const RuntimeShape& unextended_input2_shape,
+ const T* input2_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastMul4DSlow");
+ T output_activation_min;
+ T output_activation_max;
+ GetActivationParams(params, &output_activation_min, &output_activation_max);
+
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest
- // stride, typically 1 element).
+ // trailing dimension changing most rapidly (channels has the smallest stride,
+ // typically 1 element).
//
// In generated C code, we store arrays with the dimensions reversed. The
// first dimension has smallest stride.
//
// We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for
- // the best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
+ // nesting loops such that the innermost loop has the smallest stride for the
+ // best cache behavior.
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
+ output_data[Offset(output_shape, b, y, x, c)] =
ActivationFunctionWithMinMax(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] *
- input2_data[SubscriptToIndex(desc2, c, x, y, b)],
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] *
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)],
output_activation_min, output_activation_max);
}
}
@@ -1385,19 +1384,6 @@ void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
}
}
-// legacy, for compatibility with old checked-in code
-template <FusedActivationFunctionType Ac, typename T>
-void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims) {
- T output_activation_min, output_activation_max;
- GetActivationMinMax(Ac, &output_activation_min, &output_activation_max);
-
- BroadcastMul(input1_data, input1_dims, input2_data, input2_dims,
- output_activation_min, output_activation_max, output_data,
- output_dims);
-}
-
// Element-wise mul that can often be used for inner loop of broadcast Mul as
// well as the non-broadcast Mul.
inline void MulElementwise(int size, const ArithmeticParams& params,
@@ -1526,62 +1512,14 @@ inline void BroadcastMul4DSlow(const ArithmeticParams& params,
}
}
-// Transitional version that will be moved shortly to legacy_reference_ops, as
-// part of RuntimeShape revisions.
-inline void BroadcastMul4DSlow(const uint8* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- const uint8* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset,
- int32 output_offset, int32 output_multiplier,
- int output_shift, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastMul/8bit");
-
- NdArrayDesc<4> desc1;
- NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
-
- // In Tensorflow, the dimensions are canonically named (batch_number, row,
- // col, channel), with extents (batches, height, width, depth), with the
- // trailing dimension changing most rapidly (channels has the smallest
- // stride, typically 1 element).
- //
- // In generated C code, we store arrays with the dimensions reversed. The
- // first dimension has smallest stride.
- //
- // We name our variables by their Tensorflow convention, but generate C code
- // nesting loops such that the innermost loop has the smallest stride for
- // the best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- const int32 input1_val =
- input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
- const int32 input2_val =
- input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
- const int32 unclamped_result =
- output_offset +
- MultiplyByQuantizedMultiplierSmallerThanOneExp(
- input1_val * input2_val, output_multiplier, output_shift);
- const int32 clamped_output =
- std::min(output_activation_max,
- std::max(output_activation_min, unclamped_result));
- output_data[Offset(output_dims, c, x, y, b)] =
- static_cast<uint8>(clamped_output);
- }
- }
- }
- }
-}
-
-inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
- const int16* input2_data, const Dims<4>& input2_dims,
- int16* output_data, const Dims<4>& output_dims) {
+inline void Mul(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const int16* input1_data,
+ const RuntimeShape& input2_shape, const int16* input2_data,
+ const RuntimeShape& output_shape, int16* output_data) {
gemmlowp::ScopedProfilingLabel label("Mul/Int16");
- const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
// F0 uses 0 integer bits, range [-1, 1].
@@ -1593,15 +1531,18 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
}
}
-inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
- const int16* input2_data, const Dims<4>& input2_dims,
- int32 output_offset, int32 output_activation_min,
- int32 output_activation_max, uint8* output_data,
- const Dims<4>& output_dims) {
+inline void Mul(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const int16* input1_data,
+ const RuntimeShape& input2_shape, const int16* input2_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
gemmlowp::ScopedProfilingLabel label("Mul/Int16Uint8");
+ int32 output_offset = params.output_offset;
+ int32 output_activation_min = params.quantized_activation_min;
+ int32 output_activation_max = params.quantized_activation_max;
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- const int flat_size = MatchingFlatSize(output_dims, input1_dims, input2_dims);
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
// F0 uses 0 integer bits, range [-1, 1].
@@ -1624,15 +1565,27 @@ inline void Mul(const int16* input1_data, const Dims<4>& input1_dims,
// that handles broadcasting as the base case. The code generator would then
// generate max(D1, D2) nested for loops.
template <typename T>
-void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- gemmlowp::ScopedProfilingLabel label("BroadcastDiv");
+void BroadcastDiv4DSlow(const ArithmeticParams& params,
+ const RuntimeShape& unextended_input1_shape,
+ const T* input1_data,
+ const RuntimeShape& unextended_input2_shape,
+ const T* input2_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ T output_activation_min;
+ T output_activation_max;
+ GetActivationParams(params, &output_activation_min, &output_activation_max);
+
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
@@ -1645,14 +1598,14 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
// We name our variables by their Tensorflow convention, but generate C code
// nesting loops such that the innermost loop has the smallest stride for
// the best cache behavior.
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
+ output_data[Offset(output_shape, b, y, x, c)] =
ActivationFunctionWithMinMax(
- input1_data[SubscriptToIndex(desc1, c, x, y, b)] /
- input2_data[SubscriptToIndex(desc2, c, x, y, b)],
+ input1_data[SubscriptToIndex(desc1, b, y, x, c)] /
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)],
output_activation_min, output_activation_max);
}
}
@@ -1660,12 +1613,32 @@ void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
template <typename T>
-inline void Div(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T output_activation_min, T output_activation_max,
- T* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(input1_dims, input2_dims, output_dims);
+void BroadcastDiv(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ BroadcastDiv4DSlow(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename T>
+inline void Div(const ArithmeticParams& params,
+ const RuntimeShape& input1_shape, const T* input1_data,
+ const RuntimeShape& input2_shape, const T* input2_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ T output_activation_min;
+ T output_activation_max;
+ GetActivationParams(params, &output_activation_min, &output_activation_max);
+
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] / input2_data[i], output_activation_min,
@@ -1673,6 +1646,21 @@ inline void Div(const T* input1_data, const Dims<4>& input1_dims,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+template <typename T>
+inline void Div(const T* input1_data, const Dims<4>& input1_dims,
+ const T* input2_data, const Dims<4>& input2_dims,
+ T output_activation_min, T output_activation_max,
+ T* output_data, const Dims<4>& output_dims) {
+ tflite::ArithmeticParams op_params;
+ SetActivationParams(output_activation_min, output_activation_max, &op_params);
+
+ Div(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
inline void SubNonBroadcast(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
const float* input1_data,
@@ -1968,32 +1956,43 @@ inline void SubWithActivation(const ArithmeticParams& params,
}
}
-template <FusedActivationFunctionType Ac, typename Scalar>
-void Concatenation(int concat_dim, const Scalar* const* input_data,
- const Dims<4>* const* input_dims, int inputs_count,
- Scalar* output_data, const Dims<4>& output_dims) {
- int concat_size = 0;
+template <typename Scalar>
+inline void Concatenation(const ConcatenationParams& params,
+ const RuntimeShape* const* input_shapes,
+ const Scalar* const* input_data,
+ const RuntimeShape& output_shape,
+ Scalar* output_data) {
+ int axis = params.axis;
+ int inputs_count = params.inputs_count;
+ const int concat_dimensions = output_shape.DimensionsCount();
+ TFLITE_DCHECK_LT(axis, concat_dimensions);
+
+ int64_t concat_size = 0;
for (int i = 0; i < inputs_count; i++) {
- for (int j = 0; j < 4; j++) {
- if (j != concat_dim) {
- MatchingArraySize(*input_dims[i], j, output_dims, j);
+ TFLITE_DCHECK_EQ(input_shapes[i]->DimensionsCount(), concat_dimensions);
+ for (int j = 0; j < concat_dimensions; j++) {
+ if (j != axis) {
+ MatchingDim(*input_shapes[i], j, output_shape, j);
}
}
- concat_size += ArraySize(*input_dims[i], concat_dim);
+ concat_size += input_shapes[i]->Dims(axis);
}
- TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim));
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
- // For now we don't have a model with a Concatenation with fused activation.
- TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
- int outer_size = 1;
- for (int i = concat_dim + 1; i < 4; i++) {
- outer_size *= output_dims.sizes[i];
+ TFLITE_DCHECK_EQ(concat_size, output_shape.Dims(axis));
+ int64_t outer_size = 1;
+ for (int i = 0; i < axis; ++i) {
+ outer_size *= output_shape.Dims(i);
+ }
+ // For all input arrays,
+ // FlatSize() = outer_size * Dims(axis) * base_inner_size;
+ int64_t base_inner_size = 1;
+ for (int i = axis + 1; i < concat_dimensions; ++i) {
+ base_inner_size *= output_shape.Dims(i);
}
+
Scalar* output_ptr = output_data;
for (int k = 0; k < outer_size; k++) {
for (int i = 0; i < inputs_count; ++i) {
- const int copy_size =
- input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim];
+ const int copy_size = input_shapes[i]->Dims(axis) * base_inner_size;
memcpy(output_ptr, input_data[i] + k * copy_size,
copy_size * sizeof(Scalar));
output_ptr += copy_size;
@@ -2001,61 +2000,78 @@ void Concatenation(int concat_dim, const Scalar* const* input_data,
}
}
-template <typename Scalar>
-void Pack(int dim, const Scalar* const* input_data,
- const Dims<4>* const* input_dims, int inputs_count,
- Scalar* output_data, const Dims<4>& output_dims) {
- TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
- int outer_size = 1;
- for (int i = dim + 1; i < 4; i++) {
- outer_size *= output_dims.sizes[i];
- }
- Scalar* output_ptr = output_data;
- const int copy_size = FlatSize(**input_dims) / outer_size;
- for (int k = 0; k < outer_size; k++) {
- for (int i = 0; i < inputs_count; ++i) {
- memcpy(output_ptr, input_data[i] + k * copy_size,
- copy_size * sizeof(Scalar));
- output_ptr += copy_size;
- }
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+template <FusedActivationFunctionType Ac, typename Scalar>
+inline void Concatenation(int concat_dim, const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, int inputs_count,
+ Scalar* output_data, const Dims<4>& output_dims) {
+ // For now we don't have a model with a Concatenation with fused activation.
+ TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
+
+ std::vector<RuntimeShape> input_shapes(inputs_count);
+ std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
+ for (int i = 0; i < inputs_count; ++i) {
+ ShapeFromDims(*input_dims[i], &input_shapes[i]);
+ input_shapes_indirect[i] = &input_shapes[i];
}
+ tflite::ConcatenationParams op_params;
+ op_params.axis = 3 - concat_dim;
+ op_params.inputs_count = inputs_count;
+
+ Concatenation(op_params, input_shapes_indirect.data(), input_data,
+ DimsToShape(output_dims), output_data);
}
// TODO(prabhumk): This is the same as the optimized implementation.
// TODO(prabhumk): The quantized implementation of concatentation isn't fully
// quantized as it takes scale as a floating point value. This should be fixed
// when optimizng this routine further.
-inline void Concatenation(int concat_dim, const uint8* const* input_data,
- const Dims<4>* const* input_dims,
- const int32* input_zeropoint,
- const float* input_scale, int inputs_count,
- uint8* output_data, const Dims<4>& output_dims,
- const int32 output_zeropoint,
- const float output_scale) {
+
+// template <>
+inline void ConcatenationWithScaling(const ConcatenationParams& params,
+ const RuntimeShape* const* input_shapes,
+ const uint8* const* input_data,
+ const RuntimeShape& output_shape,
+ uint8* output_data) {
+ int axis = params.axis;
+ const int32* input_zeropoint = params.input_zeropoint;
+ const float* input_scale = params.input_scale;
+ int inputs_count = params.inputs_count;
+ const int32 output_zeropoint = params.output_zeropoint;
+ const float output_scale = params.output_scale;
+
// The arguments input_zeropoint and input_scale are expected to be an array
// that have the quantization parameters for all the inputs to the concat
// operator.
TFLITE_DCHECK_GT(inputs_count, 1);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
int64_t concat_size = 0;
for (int i = 0; i < inputs_count; i++) {
+ TFLITE_DCHECK_EQ(input_shapes[i]->DimensionsCount(), 4);
for (int j = 0; j < 4; j++) {
- if (j != concat_dim) {
- MatchingArraySize(*input_dims[i], j, output_dims, j);
+ if (j != axis) {
+ MatchingDim(*input_shapes[i], j, output_shape, j);
}
}
- concat_size += ArraySize(*input_dims[i], concat_dim);
+ concat_size += input_shapes[i]->Dims(axis);
}
- TFLITE_DCHECK_EQ(concat_size, ArraySize(output_dims, concat_dim));
+ TFLITE_DCHECK_EQ(concat_size, output_shape.Dims(axis));
int64_t outer_size = 1;
- for (int i = concat_dim + 1; i < 4; i++) {
- outer_size *= output_dims.sizes[i];
+ for (int i = 0; i < axis; ++i) {
+ outer_size *= output_shape.Dims(i);
+ }
+ // For all input arrays,
+ // FlatSize() = outer_size * Dims(axis) * base_inner_size;
+ int64_t base_inner_size = 1;
+ for (int i = axis + 1; i < 4; ++i) {
+ base_inner_size *= output_shape.Dims(i);
}
const float inverse_output_scale = 1.f / output_scale;
uint8* output_ptr = output_data;
for (int k = 0; k < outer_size; k++) {
for (int i = 0; i < inputs_count; ++i) {
- const int copy_size =
- input_dims[i]->sizes[concat_dim] * input_dims[i]->strides[concat_dim];
+ const int copy_size = input_shapes[i]->Dims(axis) * base_inner_size;
const uint8* input_ptr = input_data[i] + k * copy_size;
if (input_zeropoint[i] == output_zeropoint &&
input_scale[i] == output_scale) {
@@ -2076,6 +2092,72 @@ inline void Concatenation(int concat_dim, const uint8* const* input_data,
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+inline void Concatenation(int concat_dim, const uint8* const* input_data,
+ const Dims<4>* const* input_dims,
+ const int32* input_zeropoint,
+ const float* input_scale, int inputs_count,
+ uint8* output_data, const Dims<4>& output_dims,
+ const int32 output_zeropoint,
+ const float output_scale) {
+ std::vector<RuntimeShape> input_shapes(inputs_count);
+ std::vector<const RuntimeShape*> input_shapes_indirect(inputs_count);
+ for (int i = 0; i < inputs_count; ++i) {
+ ShapeFromDims(*input_dims[i], &input_shapes[i]);
+ input_shapes_indirect[i] = &input_shapes[i];
+ }
+ tflite::ConcatenationParams op_params;
+ op_params.axis = 3 - concat_dim;
+ op_params.input_zeropoint = input_zeropoint;
+ op_params.input_scale = input_scale;
+ op_params.inputs_count = inputs_count;
+ op_params.output_zeropoint = output_zeropoint;
+ op_params.output_scale = output_scale;
+
+ ConcatenationWithScaling(op_params, input_shapes_indirect.data(), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+template <typename Scalar>
+void Pack(int dim, const Scalar* const* input_data,
+ const Dims<4>* const* input_dims, int inputs_count,
+ Scalar* output_data, const Dims<4>& output_dims) {
+ TFLITE_DCHECK(IsPackedWithoutStrides(output_dims));
+ int outer_size = 1;
+ for (int i = dim + 1; i < 4; i++) {
+ outer_size *= output_dims.sizes[i];
+ }
+ Scalar* output_ptr = output_data;
+ const int copy_size = FlatSize(**input_dims) / outer_size;
+ for (int k = 0; k < outer_size; k++) {
+ for (int i = 0; i < inputs_count; ++i) {
+ memcpy(output_ptr, input_data[i] + k * copy_size,
+ copy_size * sizeof(Scalar));
+ output_ptr += copy_size;
+ }
+ }
+}
+
+template <typename Scalar>
+void Unpack(int axis, const Scalar* input_data, const Dims<4>& input_dims,
+ int dimensions, int outputs_count, Scalar* const* output_datas,
+ const Dims<4>& output_dims) {
+ int outer_size = 1;
+ for (int i = dimensions - axis; i < 4; i++) {
+ outer_size *= input_dims.sizes[i];
+ }
+
+ const int copy_size = FlatSize(input_dims) / outer_size / outputs_count;
+ for (int k = 0; k < outer_size; k++) {
+ for (int i = 0; i < outputs_count; ++i) {
+ Scalar* output_ptr = output_datas[i] + copy_size * k;
+ int loc = k * outputs_count * copy_size + i * copy_size;
+ memcpy(output_ptr, input_data + loc, copy_size * sizeof(Scalar));
+ }
+ }
+}
+
template <typename Scalar>
void Pack(int dim, const Scalar* const* input_data,
const Dims<4>* const* input_dims, const int32* input_zeropoint,
@@ -2442,32 +2524,69 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims,
}
template <typename Scalar>
+void Split(const SplitParams& params, const RuntimeShape& input_shape,
+ const Scalar* input_data, const RuntimeShape* const* output_shapes,
+ Scalar* const* output_data) {
+ const int concat_dimensions = input_shape.DimensionsCount();
+ int axis = params.axis < 0 ? params.axis + concat_dimensions : params.axis;
+ int outputs_count = params.num_split;
+ TFLITE_DCHECK_LT(axis, concat_dimensions);
+
+ int64_t concat_size = 0;
+ for (int i = 0; i < outputs_count; i++) {
+ TFLITE_DCHECK_EQ(output_shapes[i]->DimensionsCount(), concat_dimensions);
+ for (int j = 0; j < concat_dimensions; j++) {
+ if (j != axis) {
+ MatchingDim(*output_shapes[i], j, input_shape, j);
+ }
+ }
+ concat_size += output_shapes[i]->Dims(axis);
+ }
+ TFLITE_DCHECK_EQ(concat_size, input_shape.Dims(axis));
+ int64_t outer_size = 1;
+ for (int i = 0; i < axis; ++i) {
+ outer_size *= input_shape.Dims(i);
+ }
+ // For all output arrays,
+ // FlatSize() = outer_size * Dims(axis) * base_inner_size;
+ int64_t base_inner_size = 1;
+ for (int i = axis + 1; i < concat_dimensions; ++i) {
+ base_inner_size *= input_shape.Dims(i);
+ }
+
+ const Scalar* input_ptr = input_data;
+ for (int k = 0; k < outer_size; k++) {
+ for (int i = 0; i < outputs_count; ++i) {
+ const int copy_size = output_shapes[i]->Dims(axis) * base_inner_size;
+ memcpy(output_data[i] + k * copy_size, input_ptr,
+ copy_size * sizeof(Scalar));
+ input_ptr += copy_size;
+ }
+ }
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+template <typename Scalar>
void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
int axis, int outputs_count, Scalar* const* output_data,
const Dims<4>* const* output_dims) {
- const int batches = ArraySize(*output_dims[0], 3);
- const int height = ArraySize(*output_dims[0], 2);
- const int width = ArraySize(*output_dims[0], 1);
- const int depth = ArraySize(*output_dims[0], 0);
-
- const int slice_size = ArraySize(*output_dims[0], axis);
-
+ std::vector<RuntimeShape> output_shapes(outputs_count);
+ std::vector<const RuntimeShape*> output_shapes_indirect(outputs_count);
for (int i = 0; i < outputs_count; ++i) {
- int offset = i * slice_size * input_dims.strides[axis];
- for (int b = 0; b < batches; ++b) {
- for (int y = 0; y < height; ++y) {
- for (int x = 0; x < width; ++x) {
- for (int c = 0; c < depth; ++c) {
- auto out = Offset(*output_dims[i], c, x, y, b);
- auto in = Offset(input_dims, c, x, y, b);
- output_data[i][out] = input_data[offset + in];
- }
- }
- }
- }
+ ShapeFromDims(*output_dims[i], &output_shapes[i]);
+ output_shapes_indirect[i] = &output_shapes[i];
}
+ tflite::SplitParams op_params;
+ op_params.axis = 3 - axis;
+ op_params.num_split = outputs_count;
+
+ Split(op_params, DimsToShape(input_dims), input_data,
+ output_shapes_indirect.data(), output_data);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
template <FusedActivationFunctionType Ac, typename Scalar>
void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
int outputs_count, Scalar* const* output_data,
@@ -2478,9 +2597,8 @@ void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims,
/* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2);
/* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1);
}
- // for now we dont have a model with a TensorFlowSplit
- // with fused activation function.
- TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone);
+ // For now we don't have a model with a Split with fused activation.
+ TFLITE_DCHECK_EQ(Ac, FusedActivationFunctionType::kNone);
TensorFlowSplit(input_data, input_dims, /*axis=*/0, outputs_count,
output_data, output_dims);
@@ -2758,24 +2876,27 @@ inline void MaxPool(const PoolParams& params, const RuntimeShape& input_shape,
}
}
-inline void LocalResponseNormalization(const float* input_data,
- const Dims<4>& input_dims, int range,
- float bias, float alpha, float beta,
- float* output_data,
- const Dims<4>& output_dims) {
- const int outer_size = MatchingFlatSizeSkipDim(input_dims, 0, output_dims);
- const int depth = MatchingArraySize(input_dims, 0, output_dims, 0);
+inline void LocalResponseNormalization(
+ const tflite::LocalResponseNormalizationParams& op_params,
+ const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ const int trailing_dim = input_shape.DimensionsCount() - 1;
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
+ const int depth =
+ MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
for (int i = 0; i < outer_size; ++i) {
for (int c = 0; c < depth; ++c) {
- const int begin_input_c = std::max(0, c - range);
- const int end_input_c = std::min(depth, c + range);
+ const int begin_input_c = std::max(0, c - op_params.range);
+ const int end_input_c = std::min(depth, c + op_params.range);
float accum = 0.f;
for (int input_c = begin_input_c; input_c < end_input_c; ++input_c) {
const float input_val = input_data[i * depth + input_c];
accum += input_val * input_val;
}
- const float multiplier = std::pow(bias + alpha * accum, -beta);
+ const float multiplier =
+ std::pow(op_params.bias + op_params.alpha * accum, -op_params.beta);
output_data[i * depth + c] = input_data[i * depth + c] * multiplier;
}
}
@@ -3277,10 +3398,12 @@ inline void Tanh(const int16* input_data, const RuntimeShape& input_shape,
}
}
-inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
- int32 zero_point, double scale, float* output_data,
- const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+inline void Dequantize(const tflite::DequantizationParams& op_params,
+ const RuntimeShape& input_shape, const uint8* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ int32 zero_point = op_params.zero_point;
+ double scale = op_params.scale;
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
int32 val = input_data[i];
@@ -3289,9 +3412,25 @@ inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
}
}
-inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
- float rmin, float rmax, int num_bits, float* output_data,
- const Dims<4>& output_dims) {
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims,
+ int32 zero_point, double scale, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::DequantizationParams op_params;
+ op_params.zero_point = zero_point;
+ op_params.scale = scale;
+
+ Dequantize(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
+inline void FakeQuant(const tflite::FakeQuantParams& op_params,
+ const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ float rmin = op_params.minmax.min;
+ float rmax = op_params.minmax.max;
+ int num_bits = op_params.num_bits;
// 0 should always be a representable value. Let's assume that the initial
// min,max range contains 0.
TFLITE_DCHECK_LE(rmin, 0.0f);
@@ -3304,15 +3443,29 @@ inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
float nudged_min, nudged_max, nudged_scale;
NudgeQuantizationRange(rmin, rmax, quant_min, quant_max, &nudged_min,
&nudged_max, &nudged_scale);
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
FakeQuantizeArray(nudged_scale, nudged_min, nudged_max, input_data,
output_data, flat_size);
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+inline void FakeQuant(const float* input_data, const Dims<4>& input_dims,
+ float rmin, float rmax, int num_bits, float* output_data,
+ const Dims<4>& output_dims) {
+ tflite::FakeQuantParams op_params;
+ op_params.num_bits = num_bits;
+ op_params.minmax.min = rmin;
+ op_params.minmax.max = rmax;
+
+ FakeQuant(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
template <typename SrcT, typename DstT>
-inline void Cast(const SrcT* input_data, const Dims<4>& input_dims,
- DstT* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+inline void Cast(const RuntimeShape& input_shape, const SrcT* input_data,
+ const RuntimeShape& output_shape, DstT* output_data) {
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
int offset = i;
@@ -3320,9 +3473,9 @@ inline void Cast(const SrcT* input_data, const Dims<4>& input_dims,
}
}
-inline void Floor(const float* input_data, const Dims<4>& input_dims,
- float* output_data, const Dims<4>& output_dims) {
- const int flat_size = MatchingFlatSize(output_dims, input_dims);
+inline void Floor(const RuntimeShape& input_shape, const float* input_data,
+ const RuntimeShape& output_shape, float* output_data) {
+ const int flat_size = MatchingFlatSize(input_shape, output_shape);
for (int i = 0; i < flat_size; i++) {
int offset = i;
@@ -3331,45 +3484,90 @@ inline void Floor(const float* input_data, const Dims<4>& input_dims,
}
template <typename T>
-inline void Gather(const T* input_data, const Dims<4>& input_dims,
- int input_rank, const int32* coords_data,
- const Dims<4>& coords_dims, T* output_data,
- const Dims<4>& output_dims) {
- TFLITE_DCHECK(coords_dims.sizes[0] == output_dims.sizes[input_rank - 1]);
- int stride = input_dims.strides[input_rank - 1];
+inline void Gather(const tflite::GatherParams& op_params,
+ const RuntimeShape& input_shape, const T* input_data,
+ const RuntimeShape& coords_shape, const int32* coords_data,
+ const RuntimeShape& output_shape, T* output_data) {
+ // Enable these checks when moving legacy ops to legacy_reference_ops.
+ //
+ // TFLITE_DCHECK_EQ(coords_shape.DimensionsCount(), 1);
+ const int input_rank = op_params.input_rank;
+ const int gather_dimensions = output_shape.DimensionsCount();
+ TFLITE_DCHECK_LE(input_shape.DimensionsCount(), gather_dimensions);
+ const int axis = gather_dimensions - input_rank;
+ TFLITE_DCHECK_LT(axis, gather_dimensions);
+ TFLITE_DCHECK_GE(axis, 0);
+ const int coords_count = coords_shape.FlatSize();
+ TFLITE_DCHECK_EQ(coords_count, output_shape.Dims(axis));
+
+ int64_t stride = 1;
+ for (int i = axis + 1; i < gather_dimensions; ++i) {
+ stride *= input_shape.Dims(i);
+ }
T* out = output_data;
- for (int i = 0; i < coords_dims.sizes[0]; i++) {
+ for (int i = 0; i < coords_count; ++i) {
TFLITE_DCHECK_GE(coords_data[i], 0);
- TFLITE_DCHECK_LT(coords_data[i], input_dims.sizes[input_rank - 1]);
+ TFLITE_DCHECK_LT(coords_data[i], input_shape.Dims(axis));
const T* in = input_data + coords_data[i] * stride;
memcpy(out, in, sizeof(T) * stride);
out += stride;
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4> version.
+// When moving legacy ops to legacy_reference_ops, replace content with looser
+// implementation.
+template <typename T>
+inline void Gather(const T* input_data, const Dims<4>& input_dims,
+ int input_rank, const int32* coords_data,
+ const Dims<4>& coords_dims, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::GatherParams op_params;
+ op_params.input_rank = input_rank;
+
+ Gather(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(coords_dims), coords_data, DimsToShape(output_dims),
+ output_data);
+}
+
template <typename T>
-inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims,
+inline void ResizeBilinear(const tflite::ResizeBilinearParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_size_shape,
const int32* output_size_data,
- const Dims<4>& output_size_dims, T* output_data,
- const Dims<4>& output_dims, bool align_corners) {
- int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3);
- int32 input_height = ArraySize(input_dims, 2);
- int32 input_width = ArraySize(input_dims, 1);
- int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0);
-
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1);
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1);
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1);
- TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2);
- int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)];
- int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)];
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_size_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_size_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_size_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ int32 batches = MatchingDim(input_shape, 0, output_shape, 0);
+ int32 input_height = input_shape.Dims(1);
+ int32 input_width = input_shape.Dims(2);
+ int32 depth = MatchingDim(input_shape, 3, output_shape, 3);
+
+ TFLITE_DCHECK_EQ(output_size_shape.Dims(0), 1);
+ TFLITE_DCHECK_EQ(output_size_shape.Dims(1), 1);
+ TFLITE_DCHECK_EQ(output_size_shape.Dims(2), 1);
+ TFLITE_DCHECK_EQ(output_size_shape.Dims(3), 2);
+ int32 output_height = output_size_data[Offset(output_size_shape, 0, 0, 0, 0)];
+ int32 output_width = output_size_data[Offset(output_size_shape, 0, 0, 0, 1)];
+
float height_scale = static_cast<float>(input_height) / output_height;
float width_scale = static_cast<float>(input_width) / output_width;
- if (align_corners && output_height > 1) {
+ if (op_params.align_corners && output_height > 1) {
height_scale = static_cast<float>(input_height - 1) / (output_height - 1);
}
- if (align_corners && output_width > 1) {
+ if (op_params.align_corners && output_width > 1) {
width_scale = static_cast<float>(input_width - 1) / (output_width - 1);
}
@@ -3384,80 +3582,72 @@ inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims,
int32 x1 = std::min(x0 + 1, input_width - 1);
for (int c = 0; c < depth; ++c) {
T interpolation =
- static_cast<T>(input_data[Offset(input_dims, c, x0, y0, b)] *
+ static_cast<T>(input_data[Offset(input_shape, b, y0, x0, c)] *
(1 - (input_y - y0)) * (1 - (input_x - x0)) +
- input_data[Offset(input_dims, c, x0, y1, b)] *
+ input_data[Offset(input_shape, b, y1, x0, c)] *
(input_y - y0) * (1 - (input_x - x0)) +
- input_data[Offset(input_dims, c, x1, y0, b)] *
+ input_data[Offset(input_shape, b, y0, x1, c)] *
(1 - (input_y - y0)) * (input_x - x0) +
- input_data[Offset(input_dims, c, x1, y1, b)] *
+ input_data[Offset(input_shape, b, y1, x1, c)] *
(input_y - y0) * (input_x - x0));
- output_data[Offset(output_dims, c, x, y, b)] = interpolation;
+ output_data[Offset(output_shape, b, y, x, c)] = interpolation;
}
}
}
}
}
-// legacy, for compatibility with old checked-in code
-inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
- const int32* output_size_data,
- const Dims<4>& output_size_dims, float* output_data,
- const Dims<4>& output_dims) {
- ResizeBilinear<float>(input_data, input_dims, output_size_data,
- output_size_dims, output_data, output_dims,
- /*align_corners=*/false);
-}
+template <typename T>
+inline void SpaceToBatchND(
+ const SpaceToBatchParams& params,
+ const RuntimeShape& unextended_input1_shape, const T* input1_data,
+ const RuntimeShape& unextended_input2_shape, const int32* block_shape_data,
+ const RuntimeShape& unextended_input3_shape, const int32* paddings_data,
+ const RuntimeShape& unextended_output_shape, T* output_data) {
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input1_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input1_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int depth = input1_shape.Dims(3);
+ const int input_width = input1_shape.Dims(2);
+ const int input_height = input1_shape.Dims(1);
+ const int input_batch_size = input1_shape.Dims(0);
-inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
- const int32* output_size_data,
- const Dims<4>& output_size_dims, uint8* output_data,
- const Dims<4>& output_dims) {
- ResizeBilinear<uint8>(input_data, input_dims, output_size_data,
- output_size_dims, output_data, output_dims,
- /*align_corners=*/false);
-}
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_batch_size = output_shape.Dims(0);
-template <typename T>
-inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
- const int32* block_shape_data,
- const Dims<4>& block_shape_dims,
- const int32* paddings_data,
- const Dims<4>& paddings_dims, T* output_data,
- const Dims<4>& output_dims,
- const int32_t pad_value) {
- const int output_batch_size = ArraySize(output_dims, 3);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int input_batch_size = ArraySize(input_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int depth = ArraySize(input_dims, 0);
const int block_shape_height = block_shape_data[0];
const int block_shape_width = block_shape_data[1];
const int padding_top = paddings_data[0];
const int padding_left = paddings_data[2];
+ // For uint8 quantized, the correct padding "zero value" is the output offset.
+ const int32_t pad_value = params.output_offset;
+
for (int out_b = 0; out_b < output_batch_size; ++out_b) {
int input_batch = out_b % input_batch_size;
int shift_w = (out_b / input_batch_size) % block_shape_width;
int shift_h = (out_b / input_batch_size) / block_shape_width;
for (int out_h = 0; out_h < output_height; ++out_h) {
for (int out_w = 0; out_w < output_width; ++out_w) {
- T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b);
+ T* out = output_data + Offset(output_shape, out_b, out_h, out_w, 0);
if (out_h * block_shape_height + shift_h < padding_top ||
out_h * block_shape_height + shift_h >=
padding_top + input_height ||
out_w * block_shape_width + shift_w < padding_left ||
out_w * block_shape_width + shift_w >= padding_left + input_width) {
+ // This may not execute correctly when pad_value != 0 and T != uint8.
memset(out, pad_value, depth * sizeof(T));
} else {
const T* in =
- input_data +
- Offset(input_dims, 0,
- (out_w * block_shape_width + shift_w) - padding_left,
+ input1_data +
+ Offset(input1_shape, input_batch,
(out_h * block_shape_height + shift_h) - padding_top,
- input_batch);
+ (out_w * block_shape_width + shift_w) - padding_left, 0);
memcpy(out, in, depth * sizeof(T));
}
}
@@ -3466,29 +3656,27 @@ inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
}
template <typename T>
-inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
- const int32* block_shape_data,
- const Dims<4>& block_shape_dims,
- const int32* paddings_data,
- const Dims<4>& paddings_dims, T* output_data,
- const Dims<4>& output_dims) {
- SpaceToBatchND(input_data, input_dims, block_shape_data, block_shape_dims,
- paddings_data, paddings_dims, output_data, output_dims, 0);
-}
+inline void BatchToSpaceND(
+ const RuntimeShape& unextended_input1_shape, const T* input1_data,
+ const RuntimeShape& unextended_input2_shape, const int32* block_shape_data,
+ const RuntimeShape& unextended_input3_shape, const int32* crops_data,
+ const RuntimeShape& unextended_output_shape, T* output_data) {
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input1_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input1_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int output_width = output_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_batch_size = output_shape.Dims(0);
+
+ const int depth = input1_shape.Dims(3);
+ const int input_width = input1_shape.Dims(2);
+ const int input_height = input1_shape.Dims(1);
+ const int input_batch_size = input1_shape.Dims(0);
-template <typename T>
-inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
- const int32* block_shape_data,
- const Dims<4>& block_shape_dims,
- const int32* crops_data, const Dims<4>& crops_dims,
- T* output_data, const Dims<4>& output_dims) {
- const int output_batch_size = ArraySize(output_dims, 3);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int input_batch_size = ArraySize(input_dims, 3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
- const int depth = ArraySize(input_dims, 0);
const int block_shape_width = block_shape_data[1];
const int block_shape_height = block_shape_data[0];
const int crops_top = crops_data[0];
@@ -3510,8 +3698,9 @@ inline void BatchToSpaceND(const T* input_data, const Dims<4>& input_dims,
if (out_w < 0 || out_w >= output_width) {
continue;
}
- T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_batch);
- const T* in = input_data + Offset(input_dims, 0, in_w, in_h, in_batch);
+ T* out = output_data + Offset(output_shape, out_batch, out_h, out_w, 0);
+ const T* in =
+ input1_data + Offset(input1_shape, in_batch, in_h, in_w, 0);
memcpy(out, in, depth * sizeof(T));
}
}
@@ -3617,103 +3806,111 @@ inline void Pad(const tflite::PadParams& op_params,
output_data);
}
-// Legacy signature, function covered both Pad and PadV2.
-template <typename T>
-inline void PadV2(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims, const T pad_value) {
- TFLITE_DCHECK_EQ(left_paddings.size(), 4);
- TFLITE_DCHECK_EQ(right_paddings.size(), 4);
- tflite::PadParams op_params;
- op_params.left_padding_count = 4;
- op_params.right_padding_count = 4;
- for (int i = 0; i < 4; ++i) {
- op_params.left_padding[i] = left_paddings[3 - i];
- op_params.right_padding[i] = right_paddings[3 - i];
- }
- // SetFloatOrInt(pad_value, &op_params.pad_value);
- const T pad_value_copy = pad_value;
-
- Pad(op_params, DimsToShape(input_dims), input_data, &pad_value_copy,
- DimsToShape(output_dims), output_data);
-}
-
-// Old Pad that calls legacy PadV2.
template <typename T>
-inline void Pad(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims, const int32_t pad_value) {
- const T converted_pad_value = static_cast<T>(pad_value);
- PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
- output_dims, converted_pad_value);
-}
-
-// Old Pad that only padded with 0.
-template <typename T>
-inline void Pad(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& left_paddings,
- const std::vector<int>& right_paddings, T* output_data,
- const Dims<4>& output_dims) {
- const T pad_value = static_cast<T>(0);
- PadV2<T>(input_data, input_dims, left_paddings, right_paddings, output_data,
- output_dims, pad_value);
-}
-
-template <typename T>
-inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
- int begin_mask, int end_mask, int shrink_axis_mask,
- const std::vector<int>& start_indices,
- const std::vector<int>& stop_indices,
- const std::vector<int>& strides, T* output_data,
- const Dims<4>& output_dims) {
- // Note that the axis orders are reversed for runtime ops, so the indices,
- // strides and masks must be as well too.
- TFLITE_DCHECK_EQ(start_indices.size(), 4);
- TFLITE_DCHECK_EQ(stop_indices.size(), 4);
- TFLITE_DCHECK_EQ(strides.size(), 4);
- const int start_b = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 3);
+inline void StridedSlice(const tflite::StridedSliceParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_shape,
+ T* output_data) {
+ // Note that the output_shape is not used herein.
+ tflite::StridedSliceParams params_copy = op_params;
+
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ // Reverse and pad to 4 dimensions because that is what the runtime code
+ // requires (ie. all shapes must be 4D and are given backwards).
+ strided_slice::StridedSlicePadIndices(&params_copy, 4);
+
+ const int start_b = strided_slice::StartForAxis(params_copy, input_shape, 0);
const int stop_b =
- strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
- strides, input_dims.sizes, 3, start_b);
- const int start_h = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 2);
+ strided_slice::StopForAxis(params_copy, input_shape, 0, start_b);
+ const int start_h = strided_slice::StartForAxis(params_copy, input_shape, 1);
const int stop_h =
- strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
- strides, input_dims.sizes, 2, start_h);
- const int start_w = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 1);
+ strided_slice::StopForAxis(params_copy, input_shape, 1, start_h);
+ const int start_w = strided_slice::StartForAxis(params_copy, input_shape, 2);
const int stop_w =
- strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
- strides, input_dims.sizes, 1, start_w);
- const int start_d = strided_slice::StartForAxis(begin_mask, start_indices,
- strides, input_dims.sizes, 0);
+ strided_slice::StopForAxis(params_copy, input_shape, 2, start_w);
+ const int start_d = strided_slice::StartForAxis(params_copy, input_shape, 3);
const int stop_d =
- strided_slice::StopForAxis(end_mask, shrink_axis_mask, stop_indices,
- strides, input_dims.sizes, 0, start_d);
+ strided_slice::StopForAxis(params_copy, input_shape, 3, start_d);
T* out_ptr = output_data;
for (int in_b = start_b;
- !strided_slice::LoopCondition(in_b, stop_b, strides[3]);
- in_b += strides[3]) {
+ !strided_slice::LoopCondition(in_b, stop_b, params_copy.strides[0]);
+ in_b += params_copy.strides[0]) {
for (int in_h = start_h;
- !strided_slice::LoopCondition(in_h, stop_h, strides[2]);
- in_h += strides[2]) {
+ !strided_slice::LoopCondition(in_h, stop_h, params_copy.strides[1]);
+ in_h += params_copy.strides[1]) {
for (int in_w = start_w;
- !strided_slice::LoopCondition(in_w, stop_w, strides[1]);
- in_w += strides[1]) {
- for (int in_d = start_d;
- !strided_slice::LoopCondition(in_d, stop_d, strides[0]);
- in_d += strides[0]) {
- *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)];
+ !strided_slice::LoopCondition(in_w, stop_w, params_copy.strides[2]);
+ in_w += params_copy.strides[2]) {
+ for (int in_d = start_d; !strided_slice::LoopCondition(
+ in_d, stop_d, params_copy.strides[3]);
+ in_d += params_copy.strides[3]) {
+ *out_ptr++ = input_data[Offset(input_shape, in_b, in_h, in_w, in_d)];
}
}
}
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+inline uint32 LegacyReverseBits32(uint32 n) {
+ n = ((n >> 1) & 0x55555555) | ((n & 0x55555555) << 1);
+ n = ((n >> 2) & 0x33333333) | ((n & 0x33333333) << 2);
+ n = ((n >> 4) & 0x0F0F0F0F) | ((n & 0x0F0F0F0F) << 4);
+ return (((n & 0xFF) << 24) | ((n & 0xFF00) << 8) | ((n & 0xFF0000) >> 8) |
+ ((n & 0xFF000000) >> 24));
+}
+
+inline void StridedSliceReverseIndices(tflite::StridedSliceParams* p) {
+ TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
+ TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
+
+ std::reverse(p->start_indices, p->start_indices + p->start_indices_count);
+ std::reverse(p->stop_indices, p->stop_indices + p->stop_indices_count);
+ std::reverse(p->strides, p->strides + p->strides_count);
+
+ p->begin_mask = LegacyReverseBits32(static_cast<uint32>(p->begin_mask)) >>
+ (32 - p->start_indices_count);
+ p->ellipsis_mask =
+ LegacyReverseBits32(static_cast<uint32>(p->ellipsis_mask)) >>
+ (32 - p->start_indices_count);
+ p->end_mask = LegacyReverseBits32(static_cast<uint32>(p->end_mask)) >>
+ (32 - p->start_indices_count);
+ p->new_axis_mask =
+ LegacyReverseBits32(static_cast<uint32>(p->new_axis_mask)) >>
+ (32 - p->start_indices_count);
+ p->shrink_axis_mask =
+ LegacyReverseBits32(static_cast<uint32>(p->shrink_axis_mask)) >>
+ (32 - p->start_indices_count);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T>
+inline void StridedSlice(const T* input_data, const Dims<4>& input_dims,
+ int begin_mask, int end_mask, int shrink_axis_mask,
+ const std::vector<int>& start_indices,
+ const std::vector<int>& stop_indices,
+ const std::vector<int>& strides, T* output_data,
+ const Dims<4>& output_dims) {
+ TFLITE_DCHECK_EQ(start_indices.size(), 4);
+ auto op_params = strided_slice::BuildStridedSliceParams(
+ begin_mask, end_mask, shrink_axis_mask, start_indices, stop_indices,
+ strides);
+ StridedSliceReverseIndices(&op_params);
+
+ StridedSlice(op_params, DimsToShape(input_dims), input_data,
+ DimsToShape(output_dims), output_data);
+}
+
template <typename T>
inline void Slice(const tflite::SliceParams& op_params,
const RuntimeShape& input_shape, const T* input_data,
@@ -3755,22 +3952,6 @@ inline void Slice(const tflite::SliceParams& op_params,
}
template <typename T>
-inline void Slice(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& begin, const std::vector<int>& size,
- T* output_data, const Dims<4>& output_dims) {
- tflite::SliceParams op_params;
- op_params.begin_count = 4;
- op_params.size_count = 4;
- for (int i = 0; i < 4; ++i) {
- op_params.begin[i] = begin[3 - i];
- op_params.size[i] = size[3 - i];
- }
-
- Slice(op_params, DimsToShape(input_dims), input_data,
- DimsToShape(output_dims), output_data);
-}
-
-template <typename T>
inline void Exp(const T* input_data, const size_t num_elements,
T* output_data) {
for (size_t idx = 0; idx < num_elements; ++idx) {
@@ -3894,74 +4075,6 @@ inline bool ReduceGeneric(const T* input_data, const int* input_dims,
temp_index, reducer, output_data);
}
-// Computes the sum of elements across dimensions given in axis.
-template <typename T>
-inline bool Sum(const T* input_data, const int* input_dims,
- const int input_num_dims, T* output_data,
- const int* output_dims, const int output_num_dims,
- const int* axis, const int num_axis_dimensions, bool keep_dims,
- int* temp_index, int* resolved_axis) {
- T init_value = static_cast<T>(0);
-
- auto reducer = [](const T current, const T in) -> T { return current + in; };
- return ReduceGeneric<T>(input_data, input_dims, input_num_dims, output_data,
- output_dims, output_num_dims, axis,
- num_axis_dimensions, keep_dims, temp_index,
- resolved_axis, init_value, reducer);
-}
-
-// Computes the max of elements across dimensions given in axis.
-template <typename T>
-inline bool ReduceMax(const T* input_data, const int* input_dims,
- const int input_num_dims, T* output_data,
- const int* output_dims, const int output_num_dims,
- const int* axis, const int64_t num_axis_dimensions,
- bool keep_dims, int* temp_index, int* resolved_axis) {
- T init_value = std::numeric_limits<T>::lowest();
-
- auto reducer = [](const T current, const T in) -> T {
- return (in > current) ? in : current;
- };
- return ReduceGeneric<T>(input_data, input_dims, input_num_dims, output_data,
- output_dims, output_num_dims, axis,
- num_axis_dimensions, keep_dims, temp_index,
- resolved_axis, init_value, reducer);
-}
-
-// Computes the min of elements across dimensions given in axis.
-template <typename T>
-inline bool ReduceMin(const T* input_data, const int* input_dims,
- const int input_num_dims, T* output_data,
- const int* output_dims, const int output_num_dims,
- const int* axis, const int64_t num_axis_dimensions,
- bool keep_dims, int* temp_index, int* resolved_axis) {
- T init_value = std::numeric_limits<T>::max();
-
- auto reducer = [](const T current, const T in) -> T {
- return (in < current) ? in : current;
- };
- return ReduceGeneric<T>(input_data, input_dims, input_num_dims, output_data,
- output_dims, output_num_dims, axis,
- num_axis_dimensions, keep_dims, temp_index,
- resolved_axis, init_value, reducer);
-}
-
-// Computes the prod of elements across dimensions given in axis.
-template <typename T>
-inline bool ReduceProd(const T* input_data, const int* input_dims,
- const int input_num_dims, T* output_data,
- const int* output_dims, const int output_num_dims,
- const int* axis, const int64_t num_axis_dimensions,
- bool keep_dims, int* temp_index, int* resolved_axis) {
- T init_value = static_cast<T>(1);
-
- auto reducer = [](const T current, const T in) -> T { return in * current; };
- return ReduceGeneric<T>(input_data, input_dims, input_num_dims, output_data,
- output_dims, output_num_dims, axis,
- num_axis_dimensions, keep_dims, temp_index,
- resolved_axis, init_value, reducer);
-}
-
// Computes the mean of elements across dimensions given in axis.
// It does so in two stages, first calculates the sum of elements along the axis
// then divides it by the number of element in axis.
@@ -4020,22 +4133,32 @@ inline bool Mean(const T* input_data, const int* input_dims,
}
template <typename T>
-inline void Mean(const T* input_data, const Dims<4>& input_dims,
- const std::vector<int>& reduction_indices, T* output_data,
- const Dims<4>& output_dims) {
- const int output_batch = ArraySize(output_dims, 3);
- const int output_height = ArraySize(output_dims, 2);
- const int output_width = ArraySize(output_dims, 1);
- const int output_depth = ArraySize(output_dims, 0);
+inline void Mean(const tflite::MeanParams& op_params,
+ const RuntimeShape& unextended_input_shape,
+ const T* input_data,
+ const RuntimeShape& unextended_output_shape, T* output_data) {
+ gemmlowp::ScopedProfilingLabel label("Mean");
+
+ TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape input_shape =
+ RuntimeShape::ExtendedShape(4, unextended_input_shape);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
+ const int output_batch = output_shape.Dims(0);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ const int output_depth = output_shape.Dims(3);
- const int input_height = ArraySize(input_dims, 2);
- const int input_width = ArraySize(input_dims, 1);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
// The current implementation only supports simultaneous reduction over
// width and height.
- TFLITE_DCHECK_EQ(reduction_indices.size(), 2);
- TFLITE_DCHECK((reduction_indices[0] == 1 && reduction_indices[1] == 2) ||
- (reduction_indices[0] == 2 && reduction_indices[1] == 1));
+ TFLITE_DCHECK_EQ(op_params.axis_count, 2);
+ TFLITE_DCHECK((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
+ (op_params.axis[0] == 2 && op_params.axis[1] == 1));
TFLITE_DCHECK_EQ(output_height, 1);
TFLITE_DCHECK_EQ(output_width, 1);
@@ -4044,15 +4167,95 @@ inline void Mean(const T* input_data, const Dims<4>& input_dims,
float value = 0;
for (int in_h = 0; in_h < input_height; ++in_h) {
for (int in_w = 0; in_w < input_width; ++in_w) {
- value += input_data[Offset(input_dims, out_d, in_w, in_h, out_b)];
+ value += input_data[Offset(input_shape, out_b, in_h, in_w, out_d)];
}
}
- output_data[Offset(output_dims, out_d, 0, 0, out_b)] =
+ output_data[Offset(output_shape, out_b, 0, 0, out_d)] =
value / (input_width * input_height);
}
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy Dims<4>.
+template <typename T>
+inline void Mean(const T* input_data, const Dims<4>& input_dims,
+ const std::vector<int>& reduction_indices, T* output_data,
+ const Dims<4>& output_dims) {
+ tflite::MeanParams op_params;
+ op_params.axis_count = reduction_indices.size();
+ for (int i = 0; i < op_params.axis_count; ++i) {
+ op_params.axis[i] = reduction_indices[op_params.axis_count - 1 - i];
+ }
+
+ Mean(op_params, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
+ output_data);
+}
+
+// Computes the mean of elements across dimensions given in axis.
+// It does so in two stages, first calculates the sum of elements along the axis
+// then divides it by the number of element in axis for quantized values.
+template <typename T, typename U>
+inline bool Mean(const T* input_data, int32 input_zero_point, float input_scale,
+ const int* input_dims, const int input_num_dims,
+ T* output_data, int32 output_zero_point, float output_scale,
+ const int* output_dims, const int output_num_dims,
+ const int* axis, const int num_axis_dimensions, bool keep_dims,
+ int* temp_index, int* resolved_axis, U* temp_sum) {
+ // Reset output data.
+ size_t num_outputs = 1;
+ for (int idx = 0; idx < output_num_dims; ++idx) {
+ size_t current = static_cast<size_t>(output_dims[idx]);
+ // Overflow prevention.
+ if (num_outputs > std::numeric_limits<size_t>::max() / current) {
+ return false;
+ }
+ num_outputs *= current;
+ }
+ for (size_t idx = 0; idx < num_outputs; ++idx) {
+ output_data[idx] = T();
+ temp_sum[idx] = U();
+ }
+
+ // Resolve axis.
+ int num_resolved_axis = 0;
+ if (!ResolveAxis(input_num_dims, axis, num_axis_dimensions, resolved_axis,
+ &num_resolved_axis)) {
+ return false;
+ }
+
+ if (!ReduceSumImpl<T, U>(input_data, input_dims, output_dims, input_num_dims,
+ output_num_dims, resolved_axis, num_resolved_axis,
+ temp_index, temp_sum)) {
+ return false;
+ }
+
+ // Calculate mean by dividing output_data by num of aggregated element.
+ U num_elements_in_axis = 1;
+ for (int idx = 0; idx < num_resolved_axis; ++idx) {
+ size_t current = static_cast<size_t>(input_dims[resolved_axis[idx]]);
+ // Overflow prevention.
+ if (current > (std::numeric_limits<U>::max() / num_elements_in_axis)) {
+ return false;
+ }
+ num_elements_in_axis *= current;
+ }
+
+ if (num_elements_in_axis > 0) {
+ const float scale = input_scale / output_scale;
+ const float bias = -input_zero_point * scale;
+ for (size_t idx = 0; idx < num_outputs; ++idx) {
+ float float_mean = static_cast<float>(temp_sum[idx]) /
+ static_cast<float>(num_elements_in_axis);
+
+ // Convert to float value.
+ output_data[idx] =
+ static_cast<T>(round(float_mean * scale + bias)) + output_zero_point;
+ }
+ }
+ return true;
+}
+
template <typename T>
void Minimum(const RuntimeShape& input1_shape, const T* input1_data,
const T* input2_data, const RuntimeShape& output_shape,
@@ -4077,33 +4280,23 @@ void Maximum(const RuntimeShape& input1_shape, const T* input1_data,
}
}
-template <typename T>
-void TensorFlowMinimum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, T* output_data,
- const Dims<4>& output_dims) {
- Minimum(DimsToShape(input1_dims), input1_data, input2_data,
- DimsToShape(output_dims), output_data);
-}
-
-template <typename T>
-void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, T* output_data,
- const Dims<4>& output_dims) {
- Maximum(DimsToShape(input1_dims), input1_data, input2_data,
- DimsToShape(output_dims), output_data);
-}
-
template <typename T, typename Op>
-void MaximumMinimumBroadcast4DSlow(const RuntimeShape& input1_shape,
+void MaximumMinimumBroadcast4DSlow(const RuntimeShape& unextended_input1_shape,
const T* input1_data,
- const RuntimeShape& input2_shape,
+ const RuntimeShape& unextended_input2_shape,
const T* input2_data,
- const RuntimeShape& output_shape,
+ const RuntimeShape& unextended_output_shape,
T* output_data, Op op) {
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
- &desc2);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
for (int b = 0; b < output_shape.Dims(0); ++b) {
for (int y = 0; y < output_shape.Dims(1); ++y) {
@@ -4121,19 +4314,9 @@ void MaximumMinimumBroadcast4DSlow(const RuntimeShape& input1_shape,
}
}
-template <typename T, typename Op>
-void TensorFlowMaximumMinimum(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims,
- Op op) {
- MaximumMinimumBroadcast4DSlow(DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data, op);
-}
-
template <typename T1, typename T2, typename T3, typename Cmp>
-void ArgMinMax(const T3* axis, const RuntimeShape& input_shape,
- const T1* input_data, const RuntimeShape& output_shape,
+void ArgMinMax(const RuntimeShape& input1_shape, const T1* input1_data,
+ const T3* input2_data, const RuntimeShape& output_shape,
T2* output_data, const Cmp& cmp) {
// The current ArgMax implemention can only determine the index of the maximum
// value in the last dimension. So the axis argument is ignored.
@@ -4142,17 +4325,19 @@ void ArgMinMax(const T3* axis, const RuntimeShape& input_shape,
// 1). For the sake of simplicity, the output dimensions are equal to the
// input dimensions here. We enforce the constraint that the last dimension
// must always be 1.
- TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(output_shape.Dims(3), 1);
- const int outer_size = MatchingFlatSizeSkipDim(input_shape, 3, output_shape);
- const int depth = input_shape.Dims(3);
+ const int trailing_dim = output_shape.DimensionsCount() - 1;
+ TFLITE_DCHECK_EQ(input1_shape.DimensionsCount(),
+ output_shape.DimensionsCount());
+ TFLITE_DCHECK_EQ(output_shape.Dims(trailing_dim), 1);
+ const int outer_size =
+ MatchingFlatSizeSkipDim(input1_shape, trailing_dim, output_shape);
+ const int depth = input1_shape.Dims(trailing_dim);
for (int i = 0; i < outer_size; ++i) {
- auto min_max_value = input_data[i * depth];
+ auto min_max_value = input1_data[i * depth];
int min_max_index = 0;
for (int d = 1; d < depth; ++d) {
- const auto& curr_value = input_data[i * depth + d];
+ const auto& curr_value = input1_data[i * depth + d];
if (cmp(curr_value, min_max_value)) {
min_max_value = curr_value;
min_max_index = d;
@@ -4162,21 +4347,11 @@ void ArgMinMax(const T3* axis, const RuntimeShape& input_shape,
}
}
-// Legacy Dims<4> version.
-template <typename T1, typename T2, typename T3, typename Cmp>
-void ArgMinMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
- T2* output_data, const Dims<4>& output_dims, const Cmp& cmp) {
- ArgMinMax(axis, DimsToShape(input_dims), input_data, DimsToShape(output_dims),
- output_data, cmp);
-}
-
-// Legacy.
-// TODO(renjieliu): Remove this one.
template <typename T1, typename T2, typename T3>
-void ArgMax(const T3* axis, const T1* input_data,
- const tflite::Dims<4>& input_dims, T2* output_data,
- const tflite::Dims<4>& output_dims) {
- ArgMinMax(axis, input_data, input_dims, output_data, output_dims,
+void ArgMax(const RuntimeShape& input1_shape, const T1* input1_data,
+ const T3* input2_data, const RuntimeShape& output_shape,
+ T2* output_data) {
+ ArgMinMax(input1_shape, input1_data, input2_data, output_shape, output_data,
std::greater<T1>());
}
@@ -4303,9 +4478,10 @@ template <typename T>
using ComparisonFn = bool (*)(T, T);
template <typename T, ComparisonFn<T> F>
-inline void Comparison(const RuntimeShape& input1_shape, const T* input1_data,
- const RuntimeShape& input2_shape, const T* input2_data,
- const RuntimeShape& output_shape, bool* output_data) {
+inline void ComparisonImpl(
+ const ComparisonParams& op_params, const RuntimeShape& input1_shape,
+ const T* input1_data, const RuntimeShape& input2_shape,
+ const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
const int64_t flatsize =
MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int64_t i = 0; i < flatsize; ++i) {
@@ -4313,26 +4489,45 @@ inline void Comparison(const RuntimeShape& input1_shape, const T* input1_data,
}
}
-// Legacy Dims<4> version.
+template <ComparisonFn<float> F>
+inline void Comparison(const ComparisonParams& op_params,
+ const RuntimeShape& input1_shape,
+ const float* input1_data,
+ const RuntimeShape& input2_shape,
+ const float* input2_data,
+ const RuntimeShape& output_shape, bool* output_data) {
+ ComparisonImpl<float, F>(op_params, input1_shape, input1_data, input2_shape,
+ input2_data, output_shape, output_data);
+}
+
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
template <typename T, ComparisonFn<T> F>
inline void Comparison(const T* input1_data, const Dims<4>& input1_dims,
const T* input2_data, const Dims<4>& input2_dims,
bool* output_data, const Dims<4>& output_dims) {
- Comparison<T, F>(DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data);
+ ComparisonParams op_params;
+ // No parameters needed.
+ ComparisonImpl<T, F>(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
}
template <typename T, ComparisonFn<int32> F>
-inline void Comparison(int left_shift, const T* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
- const T* input2_data, const Dims<4>& input2_dims,
- int32 input2_offset, int32 input2_multiplier,
- int input2_shift, bool* output_data,
- const Dims<4>& output_dims) {
+inline void ComparisonWithScaling(
+ const ComparisonParams& op_params, const RuntimeShape& input1_shape,
+ const T* input1_data, const RuntimeShape& input2_shape,
+ const T* input2_data, const RuntimeShape& output_shape, bool* output_data) {
+ int left_shift = op_params.left_shift;
+ int32 input1_offset = op_params.input1_offset;
+ int32 input1_multiplier = op_params.input1_multiplier;
+ int input1_shift = op_params.input1_shift;
+ int32 input2_offset = op_params.input2_offset;
+ int32 input2_multiplier = op_params.input2_multiplier;
+ int input2_shift = op_params.input2_shift;
+
const int64_t flatsize =
- MatchingFlatSize(input1_dims, input2_dims, output_dims);
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
for (int64_t i = 0; i < flatsize; ++i) {
const int32 input1_val = input1_offset + input1_data[i];
const int32 input2_val = input2_offset + input2_data[i];
@@ -4340,68 +4535,140 @@ inline void Comparison(int left_shift, const T* input1_data,
const int32 shifted_input2_val = input2_val * (1 << left_shift);
const int32 scaled_input1_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input1_val, input1_multiplier,
- kReverseShift * input1_shift);
+ shifted_input1_val, input1_multiplier, input1_shift);
const int32 scaled_input2_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input2_val, input2_multiplier,
- kReverseShift * input2_shift);
+ shifted_input2_val, input2_multiplier, input2_shift);
output_data[i] = F(scaled_input1_val, scaled_input2_val);
}
}
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T, ComparisonFn<int32> F>
+inline void Comparison(int left_shift, const T* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const T* input2_data, const Dims<4>& input2_dims,
+ int32 input2_offset, int32 input2_multiplier,
+ int input2_shift, bool* output_data,
+ const Dims<4>& output_dims) {
+ tflite::ComparisonParams op_params;
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ op_params.input2_shift = kReverseShift * input2_shift;
+
+ ComparisonWithScaling<T, F>(op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data,
+ DimsToShape(output_dims), output_data);
+}
+
template <typename T, ComparisonFn<T> F>
-inline void BroadcastComparison(const T* input1_data,
- const Dims<4>& input1_dims,
- const T* input2_data,
- const Dims<4>& input2_dims, bool* output_data,
- const Dims<4>& output_dims) {
+inline void BroadcastComparison4DSlowImpl(
+ const ComparisonParams& op_params,
+ const RuntimeShape& unextended_input1_shape, const T* input1_data,
+ const RuntimeShape& unextended_input2_shape, const T* input2_data,
+ const RuntimeShape& unextended_output_shape, bool* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlow");
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
- output_data[Offset(output_dims, c, x, y, b)] =
- F(input1_data[SubscriptToIndex(desc1, c, x, y, b)],
- input2_data[SubscriptToIndex(desc2, c, x, y, b)]);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
+
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
+ output_data[Offset(output_shape, b, y, x, c)] =
+ F(input1_data[SubscriptToIndex(desc1, b, y, x, c)],
+ input2_data[SubscriptToIndex(desc2, b, y, x, c)]);
}
}
}
}
}
+template <ComparisonFn<float> F>
+inline void BroadcastComparison4DSlow(const ComparisonParams& op_params,
+ const RuntimeShape& input1_shape,
+ const float* input1_data,
+ const RuntimeShape& input2_shape,
+ const float* input2_data,
+ const RuntimeShape& output_shape,
+ bool* output_data) {
+ BroadcastComparison4DSlowImpl<float, F>(op_params, input1_shape, input1_data,
+ input2_shape, input2_data,
+ output_shape, output_data);
+}
-template <typename T, ComparisonFn<int32> F>
-inline void BroadcastComparison(int left_shift, const T* input1_data,
- const Dims<4>& input1_dims, int32 input1_offset,
- int32 input1_multiplier, int input1_shift,
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T, ComparisonFn<T> F>
+inline void BroadcastComparison(const T* input1_data,
+ const Dims<4>& input1_dims,
const T* input2_data,
- const Dims<4>& input2_dims, int32 input2_offset,
- int32 input2_multiplier, int input2_shift,
- bool* output_data, const Dims<4>& output_dims) {
+ const Dims<4>& input2_dims, bool* output_data,
+ const Dims<4>& output_dims) {
+ ComparisonParams op_params;
+ // No parameters needed.
+ BroadcastComparison4DSlowImpl<T, F>(op_params, DimsToShape(input1_dims),
+ input1_data, DimsToShape(input2_dims),
+ input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+template <typename T, ComparisonFn<int32> F>
+inline void BroadcastComparison4DSlowWithScaling(
+ const ComparisonParams& op_params,
+ const RuntimeShape& unextended_input1_shape, const T* input1_data,
+ const RuntimeShape& unextended_input2_shape, const T* input2_data,
+ const RuntimeShape& unextended_output_shape, bool* output_data) {
+ gemmlowp::ScopedProfilingLabel label("BroadcastComparison4DSlowWithScaling");
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
- for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
- for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
- for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
- for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
+
+ int left_shift = op_params.left_shift;
+ int32 input1_offset = op_params.input1_offset;
+ int32 input1_multiplier = op_params.input1_multiplier;
+ int input1_shift = op_params.input1_shift;
+ int32 input2_offset = op_params.input2_offset;
+ int32 input2_multiplier = op_params.input2_multiplier;
+ int input2_shift = op_params.input2_shift;
+
+ for (int b = 0; b < output_shape.Dims(0); ++b) {
+ for (int y = 0; y < output_shape.Dims(1); ++y) {
+ for (int x = 0; x < output_shape.Dims(2); ++x) {
+ for (int c = 0; c < output_shape.Dims(3); ++c) {
const int32 input1_val =
- input1_offset + input1_data[SubscriptToIndex(desc1, c, x, y, b)];
+ input1_offset + input1_data[SubscriptToIndex(desc1, b, y, x, c)];
const int32 input2_val =
- input2_offset + input2_data[SubscriptToIndex(desc2, c, x, y, b)];
+ input2_offset + input2_data[SubscriptToIndex(desc2, b, y, x, c)];
const int32 shifted_input1_val = input1_val * (1 << left_shift);
const int32 shifted_input2_val = input2_val * (1 << left_shift);
const int32 scaled_input1_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input1_val, input1_multiplier,
- kReverseShift * input1_shift);
+ shifted_input1_val, input1_multiplier, input1_shift);
const int32 scaled_input2_val =
MultiplyByQuantizedMultiplierSmallerThanOneExp(
- shifted_input2_val, input2_multiplier,
- kReverseShift * input2_shift);
- output_data[Offset(output_dims, c, x, y, b)] =
+ shifted_input2_val, input2_multiplier, input2_shift);
+ output_data[Offset(output_shape, b, y, x, c)] =
F(scaled_input1_val, scaled_input2_val);
}
}
@@ -4409,51 +4676,117 @@ inline void BroadcastComparison(int left_shift, const T* input1_data,
}
}
-#define TFLITE_COMPARISON_OP(name) \
- template <typename T> \
- inline void name(const T* input1_data, const Dims<4>& input1_dims, \
- const T* input2_data, const Dims<4>& input2_dims, \
- bool* output_data, const Dims<4>& output_dims) { \
- gemmlowp::ScopedProfilingLabel label(#name); \
- Comparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
- input2_dims, output_data, output_dims); \
- } \
- template <typename T> \
- inline void name( \
- int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
- int32 input1_offset, int32 input1_multiplier, int input1_shift, \
- const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
- int32 input2_multiplier, int input2_shift, bool* output_data, \
- const Dims<4>& output_dims) { \
- gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \
- Comparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
- input1_offset, input1_multiplier, input1_shift, \
- input2_data, input2_dims, input2_offset, \
- input2_multiplier, input2_shift, output_data, \
- output_dims); \
- } \
- template <typename T> \
- inline void Broadcast##name( \
- const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, \
- const Dims<4>& input2_dims, bool* output_data, \
- const Dims<4>& output_dims) { \
- gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \
- BroadcastComparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
- input2_dims, output_data, output_dims); \
- } \
- template <typename T> \
- inline void Broadcast##name( \
- int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
- int32 input1_offset, int32 input1_multiplier, int input1_shift, \
- const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
- int32 input2_multiplier, int input2_shift, bool* output_data, \
- const Dims<4>& output_dims) { \
- gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \
- BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
- input1_offset, input1_multiplier, \
- input1_shift, input2_data, input2_dims, \
- input2_offset, input2_multiplier, \
- input2_shift, output_data, output_dims); \
+// TODO(b/80418076): Move to legacy ops file, update invocations.
+// Legacy.
+template <typename T, ComparisonFn<int32> F>
+inline void BroadcastComparison(int left_shift, const T* input1_data,
+ const Dims<4>& input1_dims, int32 input1_offset,
+ int32 input1_multiplier, int input1_shift,
+ const T* input2_data,
+ const Dims<4>& input2_dims, int32 input2_offset,
+ int32 input2_multiplier, int input2_shift,
+ bool* output_data, const Dims<4>& output_dims) {
+ ComparisonParams op_params;
+
+ op_params.left_shift = left_shift;
+ op_params.input1_offset = input1_offset;
+ op_params.input1_multiplier = input1_multiplier;
+ op_params.input1_shift = kReverseShift * input1_shift;
+ op_params.input2_offset = input2_offset;
+ op_params.input2_multiplier = input2_multiplier;
+ op_params.input2_shift = kReverseShift * input2_shift;
+
+ BroadcastComparison4DSlowWithScaling<T, F>(
+ op_params, DimsToShape(input1_dims), input1_data,
+ DimsToShape(input2_dims), input2_data, DimsToShape(output_dims),
+ output_data);
+}
+
+#define TFLITE_COMPARISON_OP(name) \
+ template <typename T> \
+ inline void name(const T* input1_data, const Dims<4>& input1_dims, \
+ const T* input2_data, const Dims<4>& input2_dims, \
+ bool* output_data, const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label(#name); \
+ Comparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
+ input2_dims, output_data, output_dims); \
+ } \
+ template <typename T> \
+ inline void name( \
+ int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
+ int32 input1_offset, int32 input1_multiplier, int input1_shift, \
+ const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
+ int32 input2_multiplier, int input2_shift, bool* output_data, \
+ const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \
+ Comparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
+ input1_offset, input1_multiplier, input1_shift, \
+ input2_data, input2_dims, input2_offset, \
+ input2_multiplier, input2_shift, output_data, \
+ output_dims); \
+ } \
+ template <typename T> \
+ inline void Broadcast##name( \
+ const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, \
+ const Dims<4>& input2_dims, bool* output_data, \
+ const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \
+ BroadcastComparison<T, name##Fn>(input1_data, input1_dims, input2_data, \
+ input2_dims, output_data, output_dims); \
+ } \
+ template <typename T> \
+ inline void Broadcast##name( \
+ int left_shift, const T* input1_data, const Dims<4>& input1_dims, \
+ int32 input1_offset, int32 input1_multiplier, int input1_shift, \
+ const T* input2_data, const Dims<4>& input2_dims, int32 input2_offset, \
+ int32 input2_multiplier, int input2_shift, bool* output_data, \
+ const Dims<4>& output_dims) { \
+ gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \
+ BroadcastComparison<T, name##Fn>(left_shift, input1_data, input1_dims, \
+ input1_offset, input1_multiplier, \
+ input1_shift, input2_data, input2_dims, \
+ input2_offset, input2_multiplier, \
+ input2_shift, output_data, output_dims); \
+ } \
+ inline void name(const ComparisonParams& op_params, \
+ const RuntimeShape& input1_shape, const float* input1_data, \
+ const RuntimeShape& input2_shape, const float* input2_data, \
+ const RuntimeShape& output_shape, bool* output_data) { \
+ gemmlowp::ScopedProfilingLabel label(#name); \
+ Comparison<name##Fn>(op_params, input1_shape, input1_data, input2_shape, \
+ input2_data, output_shape, output_data); \
+ } \
+ template <typename T> \
+ inline void name##WithScaling( \
+ const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
+ const T* input1_data, const RuntimeShape& input2_shape, \
+ const T* input2_data, const RuntimeShape& output_shape, \
+ bool* output_data) { \
+ gemmlowp::ScopedProfilingLabel label(#name "/8bit"); \
+ ComparisonWithScaling<T, name##Fn>(op_params, input1_shape, input1_data, \
+ input2_shape, input2_data, \
+ output_shape, output_data); \
+ } \
+ inline void Broadcast4DSlow##name( \
+ const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
+ const float* input1_data, const RuntimeShape& input2_shape, \
+ const float* input2_data, const RuntimeShape& output_shape, \
+ bool* output_data) { \
+ gemmlowp::ScopedProfilingLabel label("Broadcast" #name); \
+ BroadcastComparison4DSlow<name##Fn>(op_params, input1_shape, input1_data, \
+ input2_shape, input2_data, \
+ output_shape, output_data); \
+ } \
+ template <typename T> \
+ inline void Broadcast4DSlow##name##WithScaling( \
+ const ComparisonParams& op_params, const RuntimeShape& input1_shape, \
+ const T* input1_data, const RuntimeShape& input2_shape, \
+ const T* input2_data, const RuntimeShape& output_shape, \
+ bool* output_data) { \
+ gemmlowp::ScopedProfilingLabel label("Broadcast" #name "/8bit"); \
+ BroadcastComparison4DSlowWithScaling<T, name##Fn>( \
+ op_params, input1_shape, input1_data, input2_shape, input2_data, \
+ output_shape, output_data); \
}
TFLITE_COMPARISON_OP(Equal);
TFLITE_COMPARISON_OP(NotEqual);
@@ -4543,15 +4876,6 @@ inline void Pow(const RuntimeShape& input1_shape, const T* input1_data,
}
}
-// Legacy Dims<4> version.
-template <typename T>
-inline void Pow(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims) {
- Pow(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims),
- input2_data, DimsToShape(output_dims), output_data);
-}
-
template <typename T>
inline void BroadcastPow4DSlow(const RuntimeShape& input1_shape,
const T* input1_data,
@@ -4580,16 +4904,6 @@ inline void BroadcastPow4DSlow(const RuntimeShape& input1_shape,
}
}
-// Legacy Dims<4> version.
-template <typename T>
-inline void BroadcastPow(const T* input1_data, const Dims<4>& input1_dims,
- const T* input2_data, const Dims<4>& input2_dims,
- T* output_data, const Dims<4>& output_dims) {
- BroadcastPow4DSlow(DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data);
-}
-
inline void Logical(const RuntimeShape& input1_shape, const bool* input1_data,
const RuntimeShape& input2_shape, const bool* input2_data,
const RuntimeShape& output_shape, bool* output_data,
@@ -4601,24 +4915,21 @@ inline void Logical(const RuntimeShape& input1_shape, const bool* input1_data,
}
}
-// Legacy Dims<4> version.
-inline void Logical(const bool* input1_data, const Dims<4>& input1_dims,
- const bool* input2_data, const Dims<4>& input2_dims,
- bool* output_data, const Dims<4>& output_dims,
- const std::function<bool(bool, bool)>& func) {
- Logical(DimsToShape(input1_dims), input1_data, DimsToShape(input2_dims),
- input2_data, DimsToShape(output_dims), output_data, func);
-}
-
inline void BroadcastLogical4DSlow(
- const RuntimeShape& input1_shape, const bool* input1_data,
- const RuntimeShape& input2_shape, const bool* input2_data,
- const RuntimeShape& output_shape, bool* output_data,
+ const RuntimeShape& unextended_input1_shape, const bool* input1_data,
+ const RuntimeShape& unextended_input2_shape, const bool* input2_data,
+ const RuntimeShape& unextended_output_shape, bool* output_data,
const std::function<bool(bool, bool)>& func) {
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
- &desc2);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
for (int b = 0; b < output_shape.Dims(0); ++b) {
for (int y = 0; y < output_shape.Dims(1); ++y) {
@@ -4636,18 +4947,6 @@ inline void BroadcastLogical4DSlow(
}
}
-// Legacy Dims<4> version.
-inline void BroadcastLogical(const bool* input1_data,
- const Dims<4>& input1_dims,
- const bool* input2_data,
- const Dims<4>& input2_dims, bool* output_data,
- const Dims<4>& output_dims,
- const std::function<bool(bool, bool)>& func) {
- BroadcastLogical4DSlow(DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data, func);
-}
-
// TODO(ycling): Refactoring. Remove BroadcastLogical and use the more
// generalized and efficient BroadcastBinaryFunction.
//
@@ -4655,16 +4954,21 @@ inline void BroadcastLogical(const bool* input1_data,
//
// R: Result type. T1: Input 1 type. T2: Input 2 type.
template <typename R, typename T1, typename T2>
-inline void BroadcastBinaryFunction4DSlow(const RuntimeShape& input1_shape,
- const T1* input1_data,
- const RuntimeShape& input2_shape,
- const T2* input2_data,
- const RuntimeShape& output_shape,
- R* output_data, R (*func)(T1, T2)) {
+inline void BroadcastBinaryFunction4DSlow(
+ const RuntimeShape& unextended_input1_shape, const T1* input1_data,
+ const RuntimeShape& unextended_input2_shape, const T2* input2_data,
+ const RuntimeShape& unextended_output_shape, R* output_data,
+ R (*func)(T1, T2)) {
+ TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
+ RuntimeShape output_shape =
+ RuntimeShape::ExtendedShape(4, unextended_output_shape);
+
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
- NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
- &desc2);
+ NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
+ unextended_input2_shape, &desc1, &desc2);
for (int b = 0; b < output_shape.Dims(0); ++b) {
for (int y = 0; y < output_shape.Dims(1); ++y) {
@@ -4682,19 +4986,20 @@ inline void BroadcastBinaryFunction4DSlow(const RuntimeShape& input1_shape,
}
}
-// Legacy Dims<4> version.
-//
// R: Result type. T1: Input 1 type. T2: Input 2 type.
+// TODO(renjieliu): Refactor other binary functions to use this one.
template <typename R, typename T1, typename T2>
-inline void BroadcastBinaryFunction(const T1* input1_data,
- const Dims<4>& input1_dims,
- const T2* input2_data,
- const Dims<4>& input2_dims, R* output_data,
- const Dims<4>& output_dims,
- R (*func)(T1, T2)) {
- BroadcastBinaryFunction4DSlow(DimsToShape(input1_dims), input1_data,
- DimsToShape(input2_dims), input2_data,
- DimsToShape(output_dims), output_data, func);
+inline void BinaryFunction(const RuntimeShape& input1_shape,
+ const T1* input1_data,
+ const RuntimeShape& input2_shape,
+ const T2* input2_data,
+ const RuntimeShape& output_shape, R* output_data,
+ R (*func)(T1, T2)) {
+ const int flat_size =
+ MatchingFlatSize(input1_shape, input2_shape, output_shape);
+ for (int i = 0; i < flat_size; ++i) {
+ output_data[i] = func(input1_data[i], input2_data[i]);
+ }
}
} // namespace reference_ops
diff --git a/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc
index 3d8765f11b..15df31f75a 100644
--- a/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/resize_bilinear_test.cc
@@ -28,14 +28,12 @@ template <typename T>
void TestOneResizeBilinear(int batch, int depth, int input_width,
int input_height, int output_width,
int output_height, float error_threshold) {
- Dims<4> input_dims_inference =
- MakeDimsForInference(depth, input_width, input_height, batch);
- Dims<4> output_dims_inference =
- MakeDimsForInference(depth, output_width, output_height, batch);
+ RuntimeShape input_dims_inference({batch, input_height, input_width, depth});
+ RuntimeShape output_dims_inference(
+ {batch, output_height, output_width, depth});
- const int input_buffer_size = RequiredBufferSizeForDims(input_dims_inference);
- const int output_buffer_size =
- RequiredBufferSizeForDims(output_dims_inference);
+ const int input_buffer_size = input_dims_inference.FlatSize();
+ const int output_buffer_size = output_dims_inference.FlatSize();
std::vector<T> input_data(input_buffer_size, 0);
std::vector<T> reference_output_data(output_buffer_size, 0);
@@ -47,15 +45,19 @@ void TestOneResizeBilinear(int batch, int depth, int input_width,
const T max_amplitude = static_cast<T>(255);
FillRandom(&input_data, min_amplitude, max_amplitude);
- Dims<4> output_size_dims = MakeDimsForInference(2, 1, 1, 1);
+ RuntimeShape output_size_dims({1, 1, 1, 2});
std::vector<int32> output_size_data = {output_height, output_width};
- reference_ops::ResizeBilinear(
- input_data.data(), input_dims_inference, output_size_data.data(),
- output_size_dims, reference_output_data.data(), output_dims_inference);
- optimized_ops::ResizeBilinear(input_data.data(), input_dims_inference,
- output_size_data.data(), output_size_dims,
- output_data.data(), output_dims_inference);
+ tflite::ResizeBilinearParams op_params;
+ op_params.align_corners = false;
+
+ reference_ops::ResizeBilinear(op_params, input_dims_inference,
+ input_data.data(), output_size_dims,
+ output_size_data.data(), output_dims_inference,
+ reference_output_data.data());
+ optimized_ops::ResizeBilinear(
+ op_params, input_dims_inference, input_data.data(), output_size_dims,
+ output_size_data.data(), output_dims_inference, output_data.data());
double sum_diff = 0;
float max_abs_val = 0;
diff --git a/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h b/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h
index 5994fad5c7..af5db1064c 100644
--- a/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h
+++ b/tensorflow/contrib/lite/kernels/internal/strided_slice_logic.h
@@ -19,9 +19,9 @@ limitations under the License.
#include <limits>
#include <vector>
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
namespace tflite {
-
namespace strided_slice {
// Use until std::clamp() is available from C++17.
@@ -32,15 +32,51 @@ inline int Clamp(const int v, const int lo, const int hi) {
return v;
}
+inline void StridedSlicePadIndices(tflite::StridedSliceParams* p,
+ int dim_count) {
+ // Add indices and mask bits to fully include extra dimensions
+ TFLITE_CHECK_LE(dim_count, 4);
+ TFLITE_CHECK_GE(dim_count, p->start_indices_count);
+ TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
+ TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
+
+ const int pad_count = dim_count - p->start_indices_count;
+
+ // Pad indices at start, so move arrays by pad_count.
+ for (int i = p->start_indices_count - 1; i > 0; --i) {
+ p->strides[i + pad_count] = p->strides[i];
+ p->start_indices[i + pad_count] = p->start_indices[i];
+ p->stop_indices[i + pad_count] = p->stop_indices[i];
+ }
+ for (int i = 0; i < pad_count; ++i) {
+ p->start_indices[i] = 0;
+ p->stop_indices[i] = 0;
+ p->strides[i] = 1;
+ }
+
+ // Pad masks with 0s or 1s as required.
+ p->shrink_axis_mask <<= pad_count;
+ p->ellipsis_mask <<= pad_count;
+ p->new_axis_mask <<= pad_count;
+ p->begin_mask <<= pad_count;
+ p->end_mask <<= pad_count;
+ p->begin_mask |= (1 << pad_count) - 1;
+ p->end_mask |= (1 << pad_count) - 1;
+
+ p->start_indices_count = dim_count;
+ p->stop_indices_count = dim_count;
+ p->strides_count = dim_count;
+}
+
// Return the index for the first element along that axis. This index will be a
// positive integer between [0, axis_size - 1] that can be used to index
// directly into the data.
-template <typename IntType>
-inline int StartForAxis(int begin_mask,
- std::vector<IntType> const& start_indices,
- std::vector<IntType> const& strides,
- int const* input_shape, int axis) {
- // Begin with the specified index
+inline int StartForAxis(const tflite::StridedSliceParams& params,
+ const RuntimeShape& input_shape, int axis) {
+ const auto begin_mask = params.begin_mask;
+ const auto* start_indices = params.start_indices;
+ const auto* strides = params.strides;
+ // Begin with the specified index.
int start = start_indices[axis];
// begin_mask override
@@ -57,7 +93,7 @@ inline int StartForAxis(int begin_mask,
}
// Handle negative indices
- int axis_size = input_shape[axis];
+ int axis_size = input_shape.Dims(axis);
if (start < 0) {
start += axis_size;
}
@@ -73,11 +109,14 @@ inline int StartForAxis(int begin_mask,
// element. ie. So if you were iterating through all elements of a 1D array of
// size 4, this function would return 4 as the stop, because it is one past the
// "real" indices of 0, 1, 2 & 3.
-template <typename IntType>
-inline int StopForAxis(int end_mask, int shrink_axis_mask,
- std::vector<IntType> const& stop_indices,
- std::vector<IntType> const& strides,
- int const* input_shape, int axis, int start_for_axis) {
+inline int StopForAxis(const tflite::StridedSliceParams& params,
+ const RuntimeShape& input_shape, int axis,
+ int start_for_axis) {
+ const auto end_mask = params.end_mask;
+ const auto shrink_axis_mask = params.shrink_axis_mask;
+ const auto* stop_indices = params.stop_indices;
+ const auto* strides = params.strides;
+
// Begin with the specified index
const bool shrink_axis = shrink_axis_mask & (1 << axis);
int stop = stop_indices[axis];
@@ -103,7 +142,7 @@ inline int StopForAxis(int end_mask, int shrink_axis_mask,
}
// Handle negative indices
- const int axis_size = input_shape[axis];
+ const int axis_size = input_shape.Dims(axis);
if (stop < 0) {
stop += axis_size;
}
@@ -127,6 +166,31 @@ inline bool LoopCondition(int index, int stop, int stride) {
return stride > 0 ? index >= stop : index <= stop;
}
+inline tflite::StridedSliceParams BuildStridedSliceParams(
+ int begin_mask, int end_mask, int shrink_axis_mask,
+ const std::vector<int>& start_indices, const std::vector<int>& stop_indices,
+ const std::vector<int>& strides) {
+ tflite::StridedSliceParams op_params;
+ const int dims_count = start_indices.size();
+
+ op_params.start_indices_count = dims_count;
+ op_params.stop_indices_count = dims_count;
+ op_params.strides_count = dims_count;
+ for (int i = 0; i < dims_count; ++i) {
+ op_params.start_indices[i] = start_indices[i];
+ op_params.stop_indices[i] = stop_indices[i];
+ op_params.strides[i] = strides[i];
+ }
+
+ op_params.begin_mask = begin_mask;
+ op_params.ellipsis_mask = 0;
+ op_params.end_mask = end_mask;
+ op_params.new_axis_mask = 0;
+ op_params.shrink_axis_mask = shrink_axis_mask;
+
+ return op_params;
+}
+
} // namespace strided_slice
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h
index ee2af5b460..13106456df 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor.h
@@ -17,44 +17,12 @@ limitations under the License.
#include <complex>
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/contrib/lite/kernels/internal/types.h"
namespace tflite {
-template <typename T>
-inline T* GetTensorData(TfLiteTensor* tensor);
-
-template <>
-inline float* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.f : nullptr;
-}
-
-template <>
-inline uint8_t* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.uint8 : nullptr;
-}
-
-template <>
-inline int16_t* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i16 : nullptr;
-}
-
-template <>
-inline int32_t* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i32 : nullptr;
-}
-
-template <>
-inline int64_t* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i64 : nullptr;
-}
-
-template <>
-inline bool* GetTensorData(TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.b : nullptr;
-}
-
template <>
inline std::complex<float>* GetTensorData(TfLiteTensor* tensor) {
return tensor != nullptr
@@ -62,39 +30,6 @@ inline std::complex<float>* GetTensorData(TfLiteTensor* tensor) {
: nullptr;
}
-template <typename T>
-inline const T* GetTensorData(const TfLiteTensor* tensor);
-
-template <>
-inline const float* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.f : nullptr;
-}
-
-template <>
-inline const uint8_t* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.uint8 : nullptr;
-}
-
-template <>
-inline const int16_t* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i16 : nullptr;
-}
-
-template <>
-inline const int32_t* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i32 : nullptr;
-}
-
-template <>
-inline const int64_t* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.i64 : nullptr;
-}
-
-template <>
-inline const bool* GetTensorData(const TfLiteTensor* tensor) {
- return tensor != nullptr ? tensor->data.b : nullptr;
-}
-
template <>
inline const std::complex<float>* GetTensorData(const TfLiteTensor* tensor) {
return tensor != nullptr
@@ -102,56 +37,14 @@ inline const std::complex<float>* GetTensorData(const TfLiteTensor* tensor) {
: nullptr;
}
-inline int RemapDim(int max_dimensions, int d) {
- return max_dimensions - d - 1;
-}
-
-// TODO(ahentz): the implementations in kernels/internal/ take a Dims<4> object
-// even if the original tensors were not 4D. We should consider rewriting them
-// to take a more generic 'shape' object.
-inline Dims<4> GetTensorDims(const int data[], const int size) {
- Dims<4> d;
- for (int i = 0; i < 4; ++i) {
- int src = size - i - 1;
- if (src >= 0) {
- d.sizes[i] = data[src];
- } else {
- d.sizes[i] = 1;
- }
- }
- d.strides[0] = 1;
- for (int i = 1; i < 4; i++) {
- d.strides[i] = d.strides[i - 1] * d.sizes[i - 1];
- }
- return d;
-}
-
inline Dims<4> GetTensorDims(std::vector<int32_t> data) {
return GetTensorDims(data.data(), data.size());
}
-inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) {
- if (tensor == nullptr) {
- return Dims<4>();
- }
-
- auto* dims = tensor->dims;
- return GetTensorDims(dims->data, dims->size);
-}
-
inline RuntimeShape GetTensorShape(std::vector<int32_t> data) {
return RuntimeShape(data.size(), data.data());
}
-inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) {
- if (tensor == nullptr) {
- return RuntimeShape();
- }
-
- auto* dims = tensor->dims;
- return RuntimeShape(dims->size, dims->data);
-}
-
// A list of tensors in a format that can be used by kernels like split and
// concatenation.
template <typename T>
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h
new file mode 100644
index 0000000000..77e22a08b4
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h
@@ -0,0 +1,135 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
+
+namespace tflite {
+
+template <typename T>
+inline T* GetTensorData(TfLiteTensor* tensor);
+
+template <>
+inline float* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.f : nullptr;
+}
+
+template <>
+inline uint8_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.uint8 : nullptr;
+}
+
+template <>
+inline int16_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i16 : nullptr;
+}
+
+template <>
+inline int32_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i32 : nullptr;
+}
+
+template <>
+inline int64_t* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i64 : nullptr;
+}
+
+template <>
+inline bool* GetTensorData(TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.b : nullptr;
+}
+
+template <typename T>
+inline const T* GetTensorData(const TfLiteTensor* tensor);
+
+template <>
+inline const float* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.f : nullptr;
+}
+
+template <>
+inline const uint8_t* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.uint8 : nullptr;
+}
+
+template <>
+inline const int16_t* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i16 : nullptr;
+}
+
+template <>
+inline const int32_t* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i32 : nullptr;
+}
+
+template <>
+inline const int64_t* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.i64 : nullptr;
+}
+
+template <>
+inline const bool* GetTensorData(const TfLiteTensor* tensor) {
+ return tensor != nullptr ? tensor->data.b : nullptr;
+}
+
+inline int RemapDim(int max_dimensions, int d) {
+ return max_dimensions - d - 1;
+}
+
+// TODO(ahentz): the implementations in kernels/internal/ take a Dims<4> object
+// even if the original tensors were not 4D. We should consider rewriting them
+// to take a more generic 'shape' object.
+inline Dims<4> GetTensorDims(const int data[], const int size) {
+ Dims<4> d;
+ for (int i = 0; i < 4; ++i) {
+ int src = size - i - 1;
+ if (src >= 0) {
+ d.sizes[i] = data[src];
+ } else {
+ d.sizes[i] = 1;
+ }
+ }
+ d.strides[0] = 1;
+ for (int i = 1; i < 4; i++) {
+ d.strides[i] = d.strides[i - 1] * d.sizes[i - 1];
+ }
+ return d;
+}
+
+inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) {
+ if (tensor == nullptr) {
+ return Dims<4>();
+ }
+
+ auto* dims = tensor->dims;
+ return GetTensorDims(dims->data, dims->size);
+}
+
+inline RuntimeShape GetTensorShape(const TfLiteTensor* tensor) {
+ if (tensor == nullptr) {
+ return RuntimeShape();
+ }
+
+ TfLiteIntArray* dims = tensor->dims;
+ const int dims_size = dims->size;
+ const int32_t* dims_data = dims->data;
+ return RuntimeShape(dims_size, dims_data);
+}
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_CTYPES_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
index 1ff8cfe39c..b0fe5adf65 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#if defined(_MSC_VER)
#define __restrict__ __restrict
@@ -101,6 +101,11 @@ void BatchVectorBatchVectorDotProduct(const float* vector1,
int n_batch, float* result,
int result_stride);
+// Cwise product of a vector and a batch-vector.
+void VectorBatchVectorCwiseProduct(const float* vector, int v_size,
+ const float* batch_vector, int n_batch,
+ float* result);
+
// Cwise product and accumulate of a vector and a batch-vector. Since it's a MAC
// operation, the assumption here is that result array is initialized to valid
// values.
@@ -108,6 +113,10 @@ void VectorBatchVectorCwiseProductAccumulate(const float* vector, int v_size,
const float* batch_vector,
int n_batch, float* result);
+// Add another vector for each batch in the batch vector.
+void VectorBatchVectorAdd(const float* vector, int v_size, int n_batch,
+ float* batch_vector);
+
// Batch vector initialization with another vector.
void VectorBatchVectorAssign(const float* vector, int v_size, int n_batch,
float* batch_vector);
@@ -147,6 +156,12 @@ void VectorShiftLeft(float* vector, int v_size, float shift_value);
// added to get one element of output.
void ReductionSumVector(const float* input_vector, float* output_vector,
int output_size, int reduction_size);
+
+// Layer norm for each batch.
+// normalization_epsilon is added to avoid divergence.
+void MeanStddevNormalization(const float* input_vector, float* output_vector,
+ int v_size, int n_batch,
+ float normalization_epsilon);
} // namespace tensor_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
index e8343f1223..6458af714b 100644
--- a/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils_test.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include <gmock/gmock.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
namespace tflite {
@@ -496,6 +496,16 @@ TEST(uKernels, VectorVectorCwiseProductAccumulateTest) {
{1.0, 1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4, 1.45})));
}
+TEST(uKernels, VectorBatchVectorAddTest) {
+ constexpr int kVectorSize = 3;
+ constexpr int kBatchSize = 2;
+ static float input[kVectorSize] = {0.0, -0.5, 1.0};
+ std::vector<float> output = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
+ VectorBatchVectorAdd(input, kVectorSize, kBatchSize, output.data());
+ EXPECT_THAT(output,
+ testing::ElementsAreArray({1.0, 1.5, 4.0, 4.0, 4.5, 7.0}));
+}
+
TEST(uKernels, VectorBatchVectorAssignTest) {
constexpr int kVectorSize = 5;
constexpr int kBatchSize = 3;
@@ -555,6 +565,120 @@ TEST(uKernels, ZeroVectorTest) {
ElementsAreArray(ArrayFloatNear({0.0, 0.0, 0.0, 0.0, 0.0})));
}
+TEST(uKernels, VectorBatchVectorCwiseProductAccumulate) {
+ constexpr int kVectorSize = 29;
+ constexpr int kBatchSize = 4;
+ static float input[kVectorSize] = {
+ 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1,
+ 11.11, 12.12, 13.13, 14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.2,
+ 21.21, 22.22, 23.23, 24.24, 25.25, 26.26, 27.27, 28.28, 0};
+ std::vector<float> output = {
+ /* batch 0 */
+ 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
+ 14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.2, 21.21, 22.22, 23.23,
+ 24.24, 25.25, 26.26, 27.27, 28.28, 0,
+ /* batch 1 */
+ -1.1, -2.2, -3.3, -4.4, -5.5, -6.6, -7.7, -8.8, -9.9, -10.1, -11.11,
+ -12.12, -13.13, -14.14, -15.15, -16.16, -17.17, -18.18, -19.19, -20.2,
+ -21.21, -22.22, -23.23, -24.24, -25.25, -26.26, -27.27, -28.28, 0,
+ /* batch 2 */
+ 1.1, -2.2, 3.3, -4.4, 5.5, -6.6, 7.7, -8.8, 9.9, -10.1, 11.11, -12.12,
+ 13.13, -14.14, 15.15, -16.16, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22,
+ 23.23, -24.24, 25.25, -26.26, 27.27, -28.28, 0,
+ /* batch 3 */
+ -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
+ -13.13, 14.14, -15.15, 16.16, -17.17, 18.18, -19.19, 20.2, -21.21, 22.22,
+ -23.23, 24.24, -25.25, 26.26, -27.27, 28.28, 0};
+ VectorBatchVectorCwiseProductAccumulate(input, kVectorSize, output.data(),
+ kBatchSize, output.data());
+
+ // Expect output = input * output + output.
+ const std::vector<float> expected_output = {
+ /* batch 0 */
+ 2.310000, 7.040000, 14.190000, 23.760000, 35.750000, 50.159996, 66.989998,
+ 86.240005, 107.909996, 112.110008, 134.542084, 159.014389, 185.526901,
+ 214.079605, 244.672485, 277.305603, 311.978912, 348.692413, 387.446136,
+ 428.240051, 471.074066, 515.948364, 562.862854, 611.817566, 662.812500,
+ 715.847595, 770.922974, 828.038452, 0.000000,
+ /* batch 1 */
+ -2.310000, -7.040000, -14.190000, -23.760000, -35.750000, -50.159996,
+ -66.989998, -86.240005, -107.909996, -112.110008, -134.542084,
+ -159.014389, -185.526901, -214.079605, -244.672485, -277.305603,
+ -311.978912, -348.692413, -387.446136, -428.240051, -471.074066,
+ -515.948364, -562.862854, -611.817566, -662.812500, -715.847595,
+ -770.922974, -828.038452, 0.000000,
+ /* batch 2 */
+ 2.310000, -7.040000, 14.190000, -23.760000, 35.750000, -50.159996,
+ 66.989998, -86.240005, 107.909996, -112.110008, 134.542084, -159.014389,
+ 185.526901, -214.079605, 244.672485, -277.305603, 311.978912, -348.692413,
+ 387.446136, -428.240051, 471.074066, -515.948364, 562.862854, -611.817566,
+ 662.812500, -715.847595, 770.922974, -828.038452, 0.000000,
+ /* batch 3 */
+ -2.310000, 7.040000, -14.190000, 23.760000, -35.750000, 50.159996,
+ -66.989998, 86.240005, -107.909996, 112.110008, -134.542084, 159.014389,
+ -185.526901, 214.079605, -244.672485, 277.305603, -311.978912, 348.692413,
+ -387.446136, 428.240051, -471.074066, 515.948364, -562.862854, 611.817566,
+ -662.812500, 715.847595, -770.922974, 828.038452, 0.000000};
+ EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
+TEST(uKernels, VectorBatchVectorCwiseProductNoAccumulate) {
+ constexpr int kVectorSize = 29;
+ constexpr int kBatchSize = 4;
+ static float input[kVectorSize] = {
+ 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1,
+ 11.11, 12.12, 13.13, 14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.2,
+ 21.21, 22.22, 23.23, 24.24, 25.25, 26.26, 27.27, 28.28, 0};
+ std::vector<float> output = {
+ /* batch 0 */
+ 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9, 10.1, 11.11, 12.12, 13.13,
+ 14.14, 15.15, 16.16, 17.17, 18.18, 19.19, 20.2, 21.21, 22.22, 23.23,
+ 24.24, 25.25, 26.26, 27.27, 28.28, 0,
+ /* batch 1 */
+ -1.1, -2.2, -3.3, -4.4, -5.5, -6.6, -7.7, -8.8, -9.9, -10.1, -11.11,
+ -12.12, -13.13, -14.14, -15.15, -16.16, -17.17, -18.18, -19.19, -20.2,
+ -21.21, -22.22, -23.23, -24.24, -25.25, -26.26, -27.27, -28.28, 0,
+ /* batch 2 */
+ 1.1, -2.2, 3.3, -4.4, 5.5, -6.6, 7.7, -8.8, 9.9, -10.1, 11.11, -12.12,
+ 13.13, -14.14, 15.15, -16.16, 17.17, -18.18, 19.19, -20.2, 21.21, -22.22,
+ 23.23, -24.24, 25.25, -26.26, 27.27, -28.28, 0,
+ /* batch 3 */
+ -1.1, 2.2, -3.3, 4.4, -5.5, 6.6, -7.7, 8.8, -9.9, 10.1, -11.11, 12.12,
+ -13.13, 14.14, -15.15, 16.16, -17.17, 18.18, -19.19, 20.2, -21.21, 22.22,
+ -23.23, 24.24, -25.25, 26.26, -27.27, 28.28, 0};
+ VectorBatchVectorCwiseProduct(input, kVectorSize, output.data(), kBatchSize,
+ output.data());
+
+ // Expect output = input * output + output.
+ const std::vector<float> expected_output = {
+ /* batch 0 */
+ 1.210000, 4.840000, 10.889999, 19.360001, 30.250000, 43.559998, 59.289997,
+ 77.440002, 98.009995, 102.010010, 123.432091, 146.894394, 172.396896,
+ 199.939606, 229.522491, 261.145599, 294.808899, 330.512421, 368.256134,
+ 408.040039, 449.864075, 493.728363, 539.632874, 587.577576, 637.562500,
+ 689.587585, 743.652954, 799.758423, 0.000000,
+ /* batch 1 */
+ -1.210000, -4.840000, -10.889999, -19.360001, -30.250000, -43.559998,
+ -59.289997, -77.440002, -98.009995, -102.010010, -123.432091, -146.894394,
+ -172.396896, -199.939606, -229.522491, -261.145599, -294.808899,
+ -330.512421, -368.256134, -408.040039, -449.864075, -493.728363,
+ -539.632874, -587.577576, -637.562500, -689.587585, -743.652954,
+ -799.758423, 0.000000,
+ /* batch 2 */
+ 1.210000, -4.840000, 10.889999, -19.360001, 30.250000, -43.559998,
+ 59.289997, -77.440002, 98.009995, -102.010010, 123.432091, -146.894394,
+ 172.396896, -199.939606, 229.522491, -261.145599, 294.808899, -330.512421,
+ 368.256134, -408.040039, 449.864075, -493.728363, 539.632874, -587.577576,
+ 637.562500, -689.587585, 743.652954, -799.758423, 0.000000,
+ /* batch 3 */
+ -1.210000, 4.840000, -10.889999, 19.360001, -30.250000, 43.559998,
+ -59.289997, 77.440002, -98.009995, 102.010010, -123.432091, 146.894394,
+ -172.396896, 199.939606, -229.522491, 261.145599, -294.808899, 330.512421,
+ -368.256134, 408.040039, -449.864075, 493.728363, -539.632874, 587.577576,
+ -637.562500, 689.587585, -743.652954, 799.758423, 0.000000};
+ EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
TEST(uKernels, BatchVectorBatchVectorDotProductTest) {
constexpr int kVectorSize = 5;
constexpr int kBatch = 2;
@@ -598,5 +722,85 @@ TEST(uKernels, ReductionSumVectorTest) {
EXPECT_THAT(result2, ElementsAreArray(ArrayFloatNear({1.0, 3.5})));
}
+TEST(uKernels, MeanStddevNormalizationNoneZeroInput) {
+ constexpr int kVectorSize = 4;
+ constexpr int kBatchSize = 2;
+ constexpr float kNormalizationEpsilon = 1e-8;
+
+ // None-zero input.
+ static float input[kVectorSize * kBatchSize] = {
+ 0.1, 0.2, 0.3, 0.4, // batch 0
+ 0.9, 1.0, 1.1, 1.2, // batch 1
+ };
+ std::vector<float> output(kVectorSize * kBatchSize);
+ MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize,
+ kNormalizationEpsilon);
+ const std::vector<float> expected_output = {
+ -1.34164071, -0.447213531, 0.44721365, 1.34164071, // batch 0
+ -1.34163153, -0.447210163, 0.447211236, 1.3416326, // batch 1
+ };
+ EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
+TEST(uKernels, MeanStddevNormalizationAllZeroInput) {
+ constexpr int kVectorSize = 4;
+ constexpr int kBatchSize = 2;
+ constexpr float kNormalizationEpsilon = 1e-8;
+
+ // Zero input.
+ static float input[kVectorSize * kBatchSize] = {
+ 0.0, 0.0, 0.0, 0.0, // batch 0
+ 0.0, 0.0, 0.0, 0.0, // batch 1
+ };
+ std::vector<float> output(kVectorSize * kBatchSize);
+ MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize,
+ kNormalizationEpsilon);
+ const std::vector<float> expected_output = {
+ 0.0, 0.0, 0.0, 0.0, // batch 0
+ 0.0, 0.0, 0.0, 0.0, // batch 1
+ };
+ EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
+TEST(uKernels, MeanStddevNormalizationMixed) {
+ constexpr int kVectorSize = 4;
+ constexpr int kBatchSize = 2;
+ constexpr float kNormalizationEpsilon = 1e-8;
+
+ // Mix of zero and non-zero input.
+ static float input[kVectorSize * kBatchSize] = {
+ 0.0, 0.0, 0.0, 0.0, // batch 0
+ 0.1, 0.2, 0.3, 0.4, // batch 1
+ };
+ std::vector<float> output(kVectorSize * kBatchSize);
+ MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize,
+ kNormalizationEpsilon);
+ const std::vector<float> expected_output = {
+ 0.0, 0.0, 0.0, 0.0, // batch 0
+ -1.34164071, -0.447213531, 0.44721365, 1.34164071, // batch 1
+ };
+ EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
+TEST(uKernels, MeanStddevNormalizationSmallValue) {
+ constexpr int kVectorSize = 4;
+ constexpr int kBatchSize = 2;
+ constexpr float kNormalizationEpsilon = 1e-8;
+
+ // Mix of zero and non-zero input.
+ static float input[kVectorSize * kBatchSize] = {
+ 3e-5, -7e-6, -9e-5, 1e-6, // batch 0
+ 4e-5, 9e-6, 2e-4, 0.0, // batch 1
+ };
+ std::vector<float> output(kVectorSize * kBatchSize);
+ MeanStddevNormalization(input, output.data(), kVectorSize, kBatchSize,
+ kNormalizationEpsilon);
+ const std::vector<float> expected_output = {
+ 1.04231524, 0.212946132, -1.64753067, 0.392269224, // batch 0
+ -0.275023013, -0.658201098, 1.70267045, -0.769446373, // batch 1
+ };
+ EXPECT_THAT(output, testing::ElementsAreArray(expected_output));
+}
+
} // namespace tensor_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index 204df9ab19..c4c7cf3842 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -668,9 +668,9 @@ static_assert(sizeof(MinMax) == 8, "");
struct ActivationParams {
FusedActivationFunctionType activation_type;
- // Quantized inference params.
- int32 activation_min;
- int32 activation_max;
+ // uint8, etc, activation params.
+ int32 quantized_activation_min;
+ int32 quantized_activation_max;
};
// For Add, Sub, Mul ops.
@@ -710,17 +710,22 @@ struct ArithmeticParams {
struct ConcatenationParams {
int8 axis;
+ const int32* input_zeropoint;
+ const float* input_scale;
+ uint16 inputs_count;
+ int32 output_zeropoint;
+ float output_scale;
};
struct ComparisonParams {
// uint8 inference params.
int left_shift;
- int32 input0_offset;
- int32 input0_multiplier;
- int input0_shift;
int32 input1_offset;
int32 input1_multiplier;
int input1_shift;
+ int32 input2_offset;
+ int32 input2_multiplier;
+ int input2_shift;
// Shape dependent / common to inference types.
bool is_broadcast;
};
@@ -745,7 +750,7 @@ struct ConvParams {
};
struct DepthToSpaceParams {
- int16 block_size;
+ int32 block_size;
};
struct DepthwiseParams {
@@ -764,6 +769,11 @@ struct DepthwiseParams {
int32 output_activation_max;
};
+struct DequantizationParams {
+ double scale;
+ int32 zero_point;
+};
+
struct FakeQuantParams {
MinMax minmax;
int32 num_bits;
@@ -871,14 +881,20 @@ struct SoftmaxParams {
int diff_min;
};
+struct SpaceToBatchParams {
+ // "Zero" padding for uint8 means padding with the output offset.
+ int32 output_offset;
+};
+
struct SpaceToDepthParams {
- int16 block_size;
+ int32 block_size;
};
struct SplitParams {
// Graphs that split into, say, 2000 nodes are encountered. The indices in
// OperatorEdges are of type uint16.
uint16 num_split;
+ int16 axis;
};
struct SqueezeParams {
@@ -908,23 +924,30 @@ struct TanhParams {
int input_left_shift;
};
-template <typename T>
-inline void SetActivationParams(T min, T max, ArithmeticParams* params);
-
-template <>
-inline void SetActivationParams(float min, float max,
- ArithmeticParams* params) {
+template <typename P>
+inline void SetActivationParams(float min, float max, P* params) {
params->float_activation_min = min;
params->float_activation_max = max;
}
-template <>
-inline void SetActivationParams(int32 min, int32 max,
- ArithmeticParams* params) {
+template <typename P>
+inline void SetActivationParams(int32 min, int32 max, P* params) {
params->quantized_activation_min = min;
params->quantized_activation_max = max;
}
+template <typename P>
+inline void GetActivationParams(const P& params, int32* min, int32* max) {
+ *min = params.quantized_activation_min;
+ *max = params.quantized_activation_max;
+}
+
+template <typename P>
+inline void GetActivationParams(const P& params, float* min, float* max) {
+ *min = params.float_activation_min;
+ *max = params.float_activation_max;
+}
+
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_
diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h
index c8ce3c917d..e9a5fd7a40 100644
--- a/tensorflow/contrib/lite/kernels/kernel_util.h
+++ b/tensorflow/contrib/lite/kernels/kernel_util.h
@@ -16,9 +16,10 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_
#include <algorithm>
+#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
@@ -30,6 +31,11 @@ inline const TfLiteTensor* GetInput(TfLiteContext* context, TfLiteNode* node,
int index) {
return &context->tensors[node->inputs->data[index]];
}
+inline TfLiteTensor* GetVariableInput(TfLiteContext* context, TfLiteNode* node,
+ int index) {
+ TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]];
+ return (tensor->is_variable) ? tensor : nullptr;
+}
inline TfLiteTensor* GetOutput(TfLiteContext* context, TfLiteNode* node,
int index) {
return &context->tensors[node->outputs->data[index]];
diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc
index a7b54c6b84..e02d7df9ef 100644
--- a/tensorflow/contrib/lite/kernels/l2norm.cc
+++ b/tensorflow/contrib/lite/kernels/l2norm.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -68,10 +68,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (output->type == kTfLiteFloat32) {
-#define TF_LITE_L2NORM(type) \
- type::L2Normalization<FusedActivationFunctionType::kNone>( \
- GetTensorData<float>(input), GetTensorShape(input), \
- GetTensorData<float>(output), GetTensorShape(output))
+#define TF_LITE_L2NORM(type) \
+ tflite::L2NormalizationParams op_params; \
+ op_params.input_zero_point = 0; \
+ type::L2Normalization(op_params, GetTensorShape(input), \
+ GetTensorData<float>(input), GetTensorShape(output), \
+ GetTensorData<float>(output))
if (kernel_type == kReference) {
TF_LITE_L2NORM(reference_ops);
@@ -81,10 +83,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
#undef TF_LITE_L2NORM
} else if (output->type == kTfLiteUInt8) {
-#define TF_LITE_L2NORM(type) \
- type::L2Normalization(GetTensorData<uint8>(input), GetTensorShape(input), \
- input->params.zero_point, \
- GetTensorData<uint8>(output), GetTensorShape(output))
+#define TF_LITE_L2NORM(type) \
+ tflite::L2NormalizationParams op_params; \
+ op_params.input_zero_point = input->params.zero_point; \
+ type::L2Normalization(op_params, GetTensorShape(input), \
+ GetTensorData<uint8>(input), GetTensorShape(output), \
+ GetTensorData<uint8>(output))
if (kernel_type == kReference) {
TF_LITE_L2NORM(reference_ops);
diff --git a/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc b/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc
new file mode 100644
index 0000000000..1bbea67b93
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/layer_norm_lstm.cc
@@ -0,0 +1,1316 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Layer Normalization LSTM op that applies normalization by mean and standard
+// deviation to the activation of the LSTM layers. Please see
+// https://arxiv.org/abs/1607.06450 for details.
+#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace layer_norm_lstm {
+
+// Struct to hold Layer Norm LSTM option data.
+struct OpData {
+ TfLiteFusedActivation activation;
+ float cell_clip;
+ float proj_clip;
+ int scratch_tensor_index;
+};
+
+// Input Tensors of size {n_batch, n_input}
+constexpr int kInputTensor = 0;
+
+// Input weight tensors of size: {n_cell, n_input}
+constexpr int kInputToInputWeightsTensor = 1; // Optional
+constexpr int kInputToForgetWeightsTensor = 2;
+constexpr int kInputToCellWeightsTensor = 3;
+constexpr int kInputToOutputWeightsTensor = 4;
+
+// Recurrent weight tensors of size {n_cell, n_output}
+constexpr int kRecurrentToInputWeightsTensor = 5; // Optional
+constexpr int kRecurrentToForgetWeightsTensor = 6;
+constexpr int kRecurrentToCellWeightsTensor = 7;
+constexpr int kRecurrentToOutputWeightsTensor = 8;
+
+// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
+constexpr int kCellToInputWeightsTensor = 9; // Optional
+constexpr int kCellToForgetWeightsTensor = 10; // Optional
+constexpr int kCellToOutputWeightsTensor = 11; // Optional
+
+// Layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
+constexpr int kInputLayerNormWeightsTensor = 12;
+constexpr int kForgetLayerNormWeightsTensor = 13;
+constexpr int kCellLayerNormWeightsTensor = 14;
+constexpr int kOutputLayerNormWeightsTensor = 15;
+
+// Gates bias tensors of size {n_cell}
+constexpr int kInputGateBiasTensor = 16; // Optional
+constexpr int kForgetGateBiasTensor = 17;
+constexpr int kCellGateBiasTensor = 18;
+constexpr int kOutputGateBiasTensor = 19;
+
+// Projection weight tensor of size {n_output, n_cell}
+constexpr int kProjectionWeightsTensor = 20; // Optional
+// Projection bias tensor of size {n_output}
+constexpr int kProjectionBiasTensor = 21; // Optional
+
+// State tensors.
+constexpr int kInputActivationStateTensor = 22;
+constexpr int kInputCellStateTensor = 23;
+
+// Output tensor.
+constexpr int kOutputTensor = 0;
+
+// Total number of scratch tensors for hybrid Op.
+constexpr int kTensorsToAdd = 7;
+
+// Small float to avoid divergence during calculation of deviation.
+const float kLayerNormEpsilon = 1e-8;
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* data = new OpData;
+
+ // Turn custom option data into flexbuffer map format.
+ const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
+ const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
+
+ // Get activation function, cell_clip and proj_clip from the flexbuffer.
+ // TODO(b/113824099): make activation more generic.
+ assert(m["fused_activation_function"].ToString() == "TANH");
+ data->activation = kTfLiteActTanh;
+ data->cell_clip = m["cell_clip"].AsFloat();
+ data->proj_clip = m["proj_clip"].AsFloat();
+
+ // Populate scratch_tensor_index.
+ context->AddTensors(context, /*tensors_to_add=*/kTensorsToAdd,
+ &data->scratch_tensor_index);
+ return data;
+}
+
+// Check that input tensor dimensions matches with each other.
+TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
+ TfLiteNode* node, int n_input,
+ int n_output, int n_cell) {
+ const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+
+ // Making sure clipping parameters have valid values.
+ // == 0 means no clipping
+ // > 0 means clipping
+ TF_LITE_ENSURE(context, op_data->cell_clip >= 0);
+ TF_LITE_ENSURE(context, op_data->proj_clip >= 0);
+
+ const TfLiteTensor* input_to_input_weights =
+ GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+ if (input_to_input_weights != nullptr) {
+ TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
+ }
+
+ const TfLiteTensor* input_to_forget_weights =
+ GetInput(context, node, kInputToForgetWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
+
+ const TfLiteTensor* input_to_cell_weights =
+ GetInput(context, node, kInputToCellWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
+
+ const TfLiteTensor* recurrent_to_input_weights =
+ GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
+ if (recurrent_to_input_weights != nullptr) {
+ TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
+ n_cell);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
+ n_output);
+ }
+
+ const TfLiteTensor* recurrent_to_forget_weights =
+ GetInput(context, node, kRecurrentToForgetWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
+ n_cell);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
+ n_output);
+
+ const TfLiteTensor* recurrent_to_cell_weights =
+ GetInput(context, node, kRecurrentToCellWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
+ n_output);
+
+ // We make sure the input-gate's parameters are either both present (regular
+ // LSTM) or not at all (CIFG-LSTM).
+ const bool cifg_weights_all_or_none =
+ ((input_to_input_weights != nullptr) &&
+ (recurrent_to_input_weights != nullptr)) ||
+ ((input_to_input_weights == nullptr) &&
+ (recurrent_to_input_weights == nullptr));
+ TF_LITE_ENSURE(context, cifg_weights_all_or_none == true);
+
+ const TfLiteTensor* cell_to_input_weights =
+ GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
+ if (cell_to_input_weights) {
+ TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
+ }
+
+ const TfLiteTensor* cell_to_forget_weights =
+ GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
+ if (cell_to_forget_weights) {
+ TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
+ }
+
+ const TfLiteTensor* cell_to_output_weights =
+ GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
+ if (cell_to_output_weights) {
+ TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
+ }
+
+ // Making sure the peephole weights are there all or none.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool peephole_weights_all_or_none =
+ ((cell_to_input_weights != nullptr || use_cifg) &&
+ (cell_to_forget_weights != nullptr) &&
+ (cell_to_output_weights != nullptr)) ||
+ ((cell_to_input_weights == nullptr) &&
+ (cell_to_forget_weights == nullptr) &&
+ (cell_to_output_weights == nullptr));
+ TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
+
+ // Making sure layer norm weights are not null and have the right dimension.
+ const TfLiteTensor* input_layer_norm_weights =
+ GetInput(context, node, kInputLayerNormWeightsTensor);
+ TF_LITE_ENSURE(context, input_layer_norm_weights != nullptr);
+ TF_LITE_ENSURE_EQ(context, input_layer_norm_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, input_layer_norm_weights->dims->data[0], n_cell);
+
+ const TfLiteTensor* forget_layer_norm_weights =
+ GetInput(context, node, kForgetLayerNormWeightsTensor);
+ TF_LITE_ENSURE(context, forget_layer_norm_weights != nullptr);
+ TF_LITE_ENSURE_EQ(context, forget_layer_norm_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, forget_layer_norm_weights->dims->data[0], n_cell);
+
+ const TfLiteTensor* cell_layer_norm_weights =
+ GetInput(context, node, kCellLayerNormWeightsTensor);
+ TF_LITE_ENSURE(context, cell_layer_norm_weights != nullptr);
+ TF_LITE_ENSURE_EQ(context, cell_layer_norm_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_layer_norm_weights->dims->data[0], n_cell);
+
+ const TfLiteTensor* output_layer_norm_weights =
+ GetInput(context, node, kOutputLayerNormWeightsTensor);
+ TF_LITE_ENSURE(context, output_layer_norm_weights != nullptr);
+ TF_LITE_ENSURE_EQ(context, output_layer_norm_weights->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, output_layer_norm_weights->dims->data[0], n_cell);
+
+ // Make sure the input gate bias is present only when not a CIFG-LSTM.
+ const TfLiteTensor* input_gate_bias =
+ GetOptionalInputTensor(context, node, kInputGateBiasTensor);
+ if (use_cifg) {
+ TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
+ } else {
+ TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
+ }
+
+ const TfLiteTensor* forget_gate_bias =
+ GetInput(context, node, kForgetGateBiasTensor);
+ TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
+
+ const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
+ TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell);
+
+ const TfLiteTensor* output_gate_bias =
+ GetInput(context, node, kOutputGateBiasTensor);
+ TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
+
+ const TfLiteTensor* projection_weights =
+ GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
+ if (projection_weights != nullptr) {
+ TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
+ TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
+ }
+
+ const TfLiteTensor* projection_bias =
+ GetOptionalInputTensor(context, node, kProjectionBiasTensor);
+ if (projection_bias != nullptr) {
+ TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
+ TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
+ }
+
+ // Making sure the projection tensors are consistent:
+ // 1) If projection weight is not present, then projection bias should not be
+ // present.
+ // 2) If projection weight is present, then projection bias is optional.
+ const bool projection_tensors_consistent =
+ ((projection_weights != nullptr) || (projection_bias == nullptr));
+ TF_LITE_ENSURE(context, projection_tensors_consistent == true);
+
+ return kTfLiteOk;
+}
+
+// Resize the output, state tensors based on the sizes of the input tensors.
+// Allocate a temporary scratch tensor. Also check that the sizes of the input
+// tensors match each other.
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 24);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+
+ // Inferring batch size, number of outputs and number of cells from the
+ // input tensors.
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ TF_LITE_ENSURE(context, input->dims->size > 1);
+ const int n_batch = input->dims->data[0];
+ const int n_input = input->dims->data[1];
+
+ const TfLiteTensor* input_to_output_weights =
+ GetInput(context, node, kInputToOutputWeightsTensor);
+ const int n_cell = input_to_output_weights->dims->data[0];
+ TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
+
+ const TfLiteTensor* recurrent_to_output_weights =
+ GetInput(context, node, kRecurrentToOutputWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
+ n_cell);
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Check that input tensor dimensions matches with each other.
+ TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input,
+ n_output, n_cell));
+
+ // Get the pointer to output, activation_state and cell_state tensors.
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ const TfLiteTensor* activation_state =
+ GetInput(context, node, kInputActivationStateTensor);
+ const TfLiteTensor* cell_state =
+ GetInput(context, node, kInputCellStateTensor);
+
+ // Check the shape of input state tensors.
+ // These tensor may be 1D or 2D. It's fine as long as the total size is
+ // correct.
+ TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output);
+ TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
+ // Resize the output tensors.
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
+ output_size->data[0] = n_batch;
+ output_size->data[1] = n_output;
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, output, output_size));
+
+ // The weights are of consistent type, so it suffices to check one.
+ const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 &&
+ input->type == kTfLiteFloat32);
+
+ TfLiteIntArrayFree(node->temporaries);
+ if (is_hybrid_op) {
+ node->temporaries = TfLiteIntArrayCreate(7);
+ } else {
+ node->temporaries = TfLiteIntArrayCreate(1);
+ }
+ node->temporaries->data[0] = op_data->scratch_tensor_index;
+
+ // Create a scratch buffer tensor.
+ TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
+ scratch_buffer->type = input->type;
+ scratch_buffer->allocation_type = kTfLiteArenaRw;
+
+ const TfLiteTensor* input_to_input_weights =
+ GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
+ scratch_buffer_size->data[0] = n_batch;
+ if (use_cifg) {
+ // Reserving space for Cell, Forget, Output gates
+ scratch_buffer_size->data[1] = n_cell * 3;
+ } else {
+ // Reserving space for Input, Cell, Forget, Output gates
+ scratch_buffer_size->data[1] = n_cell * 4;
+ }
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
+ scratch_buffer_size));
+
+ if (is_hybrid_op) {
+ // Allocate temporary tensors to store quantized values of input,
+ // activation_state and cell_state tensors.
+ node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
+ TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
+ node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
+ TfLiteTensor* activation_state_quantized =
+ GetTemporary(context, node, /*index=*/2);
+ activation_state_quantized->type = kTfLiteUInt8;
+ activation_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(activation_state_quantized->dims,
+ activation_state->dims)) {
+ TfLiteIntArray* activation_state_quantized_size =
+ TfLiteIntArrayCopy(activation_state->dims);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, activation_state_quantized,
+ activation_state_quantized_size));
+ }
+ node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
+ TfLiteTensor* cell_state_quantized =
+ GetTemporary(context, node, /*index=*/3);
+ cell_state_quantized->type = kTfLiteUInt8;
+ cell_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
+ TfLiteIntArray* cell_state_quantized_size =
+ TfLiteIntArrayCopy(cell_state->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, cell_state_quantized,
+ cell_state_quantized_size));
+ }
+
+ // Allocate temporary tensors to store scaling factors and product scaling
+ // factors. The latter is a convenience storage which allows to quantize
+ // a vector once (which produces the scaling factors) and multiply it with
+ // different matrices (which requires multiplying the scaling factors with
+ // the scaling factor of the matrix).
+ node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
+ node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, /*index=*/5);
+ prod_scaling_factors->type = kTfLiteFloat32;
+ prod_scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
+ prod_scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(prod_scaling_factors->dims,
+ prod_scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, prod_scaling_factors,
+ prod_scaling_factors_size));
+ }
+
+ // Allocate a temporary tensor to store the recovered weights. Since
+ // this is used for diagonal matrices, only need to store n_cell values.
+ node->temporaries->data[6] = op_data->scratch_tensor_index + 6;
+ TfLiteTensor* recovered_weights = GetTemporary(context, node, /*index=*/6);
+ recovered_weights->type = kTfLiteFloat32;
+ recovered_weights->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* recovered_weights_size = TfLiteIntArrayCreate(1);
+ recovered_weights_size->data[0] = n_cell;
+ if (!TfLiteIntArrayEqual(recovered_weights->dims, recovered_weights_size)) {
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, recovered_weights,
+ recovered_weights_size));
+ }
+ }
+ return kTfLiteOk;
+}
+
+void LayerNormLstmStep(
+ const float* input_ptr_batch, const float* input_to_input_weights_ptr,
+ const float* input_to_forget_weights_ptr,
+ const float* input_to_cell_weights_ptr,
+ const float* input_to_output_weights_ptr,
+ const float* recurrent_to_input_weights_ptr,
+ const float* recurrent_to_forget_weights_ptr,
+ const float* recurrent_to_cell_weights_ptr,
+ const float* recurrent_to_output_weights_ptr,
+ const float* cell_to_input_weights_ptr,
+ const float* cell_to_forget_weights_ptr,
+ const float* cell_to_output_weights_ptr,
+ const float* input_layer_norm_weight_ptr,
+ const float* forget_layer_norm_weight_ptr,
+ const float* cell_layer_norm_weight_ptr,
+ const float* output_layer_norm_weight_ptr, const float* input_gate_bias_ptr,
+ const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+ const float* output_gate_bias_ptr, const float* projection_weights_ptr,
+ const float* projection_bias_ptr, float cell_clip, float proj_clip,
+ const TfLiteFusedActivation& activation, int n_batch, int n_cell,
+ int n_input, int n_output, float* output_state_ptr, float* cell_state_ptr,
+ float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch,
+ float* output_gate_scratch, float* output_ptr_batch) {
+ // Since we have already checked that weights are all there or none, we can
+ // check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights_ptr == nullptr);
+ const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
+
+ // Initialize scratch buffers with 0.
+ if (!use_cifg) {
+ tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
+ }
+ tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
+ tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch);
+ tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch);
+
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ forget_gate_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ output_gate_scratch, /*result_stride=*/1);
+
+ // For each batch and cell: compute recurrent_weight * output_state.
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, input_gate_scratch, /*result_stride=*/1);
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, forget_gate_scratch,
+ /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, output_gate_scratch,
+ /*result_stride=*/1);
+
+ // For each batch and cell: update input gate.
+ if (!use_cifg) {
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::MeanStddevNormalization(input_gate_scratch,
+ input_gate_scratch, n_cell, n_batch,
+ kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weight_ptr,
+ n_cell, input_gate_scratch,
+ n_batch, input_gate_scratch);
+ tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
+ input_gate_scratch);
+ tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
+ input_gate_scratch);
+ }
+
+ // For each batch and cell: update forget gate.
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
+ forget_gate_scratch);
+ }
+ tensor_utils::MeanStddevNormalization(forget_gate_scratch,
+ forget_gate_scratch, n_cell, n_batch,
+ kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weight_ptr,
+ n_cell, forget_gate_scratch,
+ n_batch, forget_gate_scratch);
+ tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
+ forget_gate_scratch);
+ tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
+ forget_gate_scratch);
+
+ // For each batch and cell: update the cell.
+ tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
+ n_batch, kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(
+ cell_layer_norm_weight_ptr, n_cell, cell_scratch, n_batch, cell_scratch);
+ tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
+ cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
+ n_batch * n_cell, cell_state_ptr);
+ tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
+ activation, cell_scratch);
+ if (use_cifg) {
+ tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
+ forget_gate_scratch);
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ } else {
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ }
+ if (cell_clip > 0.0) {
+ tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, cell_clip,
+ cell_state_ptr);
+ }
+
+ // For each batch and cell: update the output gate.
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
+ output_gate_scratch);
+ }
+ tensor_utils::MeanStddevNormalization(output_gate_scratch,
+ output_gate_scratch, n_cell, n_batch,
+ kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weight_ptr,
+ n_cell, output_gate_scratch,
+ n_batch, output_gate_scratch);
+ tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
+ output_gate_scratch);
+ tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
+ output_gate_scratch);
+ tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
+ activation, cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
+ n_batch * n_cell, output_gate_scratch);
+
+ // For each batch: update the projection and output_state.
+ const bool use_projection_weight = (projection_weights_ptr != nullptr);
+ const bool use_projection_bias = (projection_bias_ptr != nullptr);
+ if (use_projection_weight) {
+ if (use_projection_bias) {
+ tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
+ n_batch, output_ptr_batch);
+ } else {
+ tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch,
+ output_ptr_batch, /*result_stride=*/1);
+ if (proj_clip > 0.0) {
+ tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, proj_clip,
+ output_ptr_batch);
+ }
+ } else {
+ tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
+ output_ptr_batch);
+ }
+ tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
+ output_state_ptr);
+}
+
+void LayerNormLstmStep(
+ const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
+ float input_to_input_weights_scale,
+ const int8_t* input_to_forget_weights_ptr,
+ float input_to_forget_weights_scale,
+ const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
+ const int8_t* input_to_output_weights_ptr,
+ float input_to_output_weights_scale,
+ const int8_t* recurrent_to_input_weights_ptr,
+ float recurrent_to_input_weights_scale,
+ const int8_t* recurrent_to_forget_weights_ptr,
+ float recurrent_to_forget_weights_scale,
+ const int8_t* recurrent_to_cell_weights_ptr,
+ float recurrent_to_cell_weights_scale,
+ const int8_t* recurrent_to_output_weights_ptr,
+ float recurrent_to_output_weights_scale,
+ const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
+ const int8_t* cell_to_forget_weights_ptr,
+ float cell_to_forget_weights_scale,
+ const int8_t* cell_to_output_weights_ptr,
+ float cell_to_output_weights_scale,
+ const float* input_layer_norm_weight_ptr,
+ const float* forget_layer_norm_weight_ptr,
+ const float* cell_layer_norm_weight_ptr,
+ const float* output_layer_norm_weight_ptr, const float* input_gate_bias_ptr,
+ const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+ const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
+ float projection_weights_scale, const float* projection_bias_ptr,
+ float cell_clip, float proj_clip, const TfLiteFusedActivation& activation,
+ int n_batch, int n_cell, int n_input, int n_output,
+ float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch,
+ float* output_gate_scratch, float* scaling_factors,
+ float* product_scaling_factors, float* recovered_weights,
+ int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr,
+ int8_t* quantized_cell_state_ptr, float* output_state_ptr,
+ float* cell_state_ptr, float* output_ptr_batch) {
+ // Since we have already checked that weights are all there or none, we can
+ // check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights_ptr == nullptr);
+ const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
+
+ // Initialize scratch buffers with 0.
+ if (!use_cifg) {
+ tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
+ }
+ tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
+ tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch);
+ tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch);
+
+ if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_input;
+ tensor_utils::SymmetricQuantizeFloats(
+ input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset,
+ &unused_min, &unused_max, &scaling_factors[b]);
+ }
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_input_weights_ptr, n_cell, n_input,
+ quantized_input_ptr_batch, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+ product_scaling_factors, n_batch, forget_gate_scratch,
+ /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+ product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+ product_scaling_factors, n_batch, output_gate_scratch,
+ /*result_stride=*/1);
+ }
+
+ if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_output;
+ tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output,
+ quantized_output_state_ptr + offset,
+ &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ // For each batch and cell: compute recurrent_weight * output_state.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_input_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_forget_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ forget_gate_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_cell_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_output_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ output_gate_scratch, /*result_stride=*/1);
+ }
+
+ // Save quantization and matmul computation for all zero input.
+ bool is_cell_state_all_zeros =
+ tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
+
+ // For each batch and cell: update input gate.
+ if (!use_cifg) {
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell,
+ cell_to_input_weights_scale,
+ recovered_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_weights, n_cell, cell_state_ptr, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::MeanStddevNormalization(input_gate_scratch,
+ input_gate_scratch, n_cell, n_batch,
+ kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weight_ptr,
+ n_cell, input_gate_scratch,
+ n_batch, input_gate_scratch);
+ tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
+ input_gate_scratch);
+ tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
+ input_gate_scratch);
+ }
+
+ // For each batch and cell: update forget gate.
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell,
+ cell_to_forget_weights_scale,
+ recovered_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_weights, n_cell, cell_state_ptr, n_batch,
+ forget_gate_scratch);
+ }
+ tensor_utils::MeanStddevNormalization(forget_gate_scratch,
+ forget_gate_scratch, n_cell, n_batch,
+ kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weight_ptr,
+ n_cell, forget_gate_scratch,
+ n_batch, forget_gate_scratch);
+ tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
+ forget_gate_scratch);
+ tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
+ forget_gate_scratch);
+
+ // For each batch and cell: update the cell.
+ tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
+ n_batch, kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(
+ cell_layer_norm_weight_ptr, n_cell, cell_scratch, n_batch, cell_scratch);
+ tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
+ cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
+ n_batch * n_cell, cell_state_ptr);
+ tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
+ activation, cell_scratch);
+ if (use_cifg) {
+ tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
+ forget_gate_scratch);
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ } else {
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ }
+ if (cell_clip > 0.0) {
+ tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, cell_clip,
+ cell_state_ptr);
+ }
+
+ is_cell_state_all_zeros =
+ tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
+ // For each batch and cell: update the output gate.
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell,
+ cell_to_output_weights_scale,
+ recovered_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_weights, n_cell, cell_state_ptr, n_batch,
+ output_gate_scratch);
+ }
+ tensor_utils::MeanStddevNormalization(output_gate_scratch,
+ output_gate_scratch, n_cell, n_batch,
+ kLayerNormEpsilon);
+ tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weight_ptr,
+ n_cell, output_gate_scratch,
+ n_batch, output_gate_scratch);
+ tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
+ output_gate_scratch);
+ tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
+ output_gate_scratch);
+ tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
+ activation, cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
+ n_batch * n_cell, output_gate_scratch);
+
+ // For each batch: update the projection and output_state.
+ const bool use_projection_weight = (projection_weights_ptr != nullptr);
+ const bool use_projection_bias = (projection_bias_ptr != nullptr);
+ if (use_projection_weight) {
+ if (use_projection_bias) {
+ tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
+ n_batch, output_ptr_batch);
+ } else {
+ tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
+ }
+ if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_cell;
+ tensor_utils::SymmetricQuantizeFloats(
+ output_gate_scratch + offset, n_cell,
+ quantized_cell_state_ptr + offset, &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * projection_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr,
+ product_scaling_factors, n_batch, output_ptr_batch,
+ /*result_stride=*/1);
+ }
+ if (proj_clip > 0.0) {
+ tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, proj_clip,
+ output_ptr_batch);
+ }
+ } else {
+ tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
+ output_ptr_batch);
+ }
+ tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
+ output_state_ptr);
+}
+
+// The LayerNormLSTM Op engine.
+TfLiteStatus EvalFloat(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights,
+ const TfLiteTensor* input_layer_norm_weights,
+ const TfLiteTensor* forget_layer_norm_weights,
+ const TfLiteTensor* cell_layer_norm_weights,
+ const TfLiteTensor* output_layer_norm_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ float cell_clip, float proj_clip, const TfLiteFusedActivation& activation,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
+ TfLiteTensor* cell_state, TfLiteTensor* output) {
+ const int n_batch = input->dims->data[0];
+ const int n_input = input->dims->data[1];
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existence of only one to get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Check optional tensors, the respective pointers can be null.
+ const float* input_to_input_weights_ptr =
+ (use_cifg) ? nullptr : input_to_input_weights->data.f;
+ const float* recurrent_to_input_weights_ptr =
+ (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
+ const float* input_gate_bias_ptr =
+ (use_cifg) ? nullptr : input_gate_bias->data.f;
+ const float* cell_to_input_weights_ptr =
+ (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
+ const float* cell_to_forget_weights_ptr =
+ (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
+ const float* cell_to_output_weights_ptr =
+ (use_peephole) ? cell_to_output_weights->data.f : nullptr;
+ const float* projection_weights_ptr =
+ (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
+ const float* projection_bias_ptr =
+ (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+ // Required tensors, pointers are non-null.
+ const float* input_ptr_batch = input->data.f;
+ const float* input_to_forget_weights_ptr = input_to_forget_weights->data.f;
+ const float* input_to_cell_weights_ptr = input_to_cell_weights->data.f;
+ const float* input_to_output_weights_ptr = input_to_output_weights->data.f;
+ const float* recurrent_to_forget_weights_ptr =
+ recurrent_to_forget_weights->data.f;
+ const float* recurrent_to_cell_weights_ptr =
+ recurrent_to_cell_weights->data.f;
+ const float* recurrent_to_output_weights_ptr =
+ recurrent_to_output_weights->data.f;
+ const float* input_layer_norm_weight_ptr = input_layer_norm_weights->data.f;
+ const float* forget_layer_norm_weight_ptr = forget_layer_norm_weights->data.f;
+ const float* cell_layer_norm_weight_ptr = cell_layer_norm_weights->data.f;
+ const float* output_layer_norm_weight_ptr = output_layer_norm_weights->data.f;
+ const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
+ const float* cell_bias_ptr = cell_bias->data.f;
+ const float* output_gate_bias_ptr = output_gate_bias->data.f;
+
+ float* activation_state_ptr = activation_state->data.f;
+ float* cell_state_ptr = cell_state->data.f;
+ float* output_ptr_batch = output->data.f;
+
+ LayerNormLstmStep(
+ input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr,
+ input_to_cell_weights_ptr, input_to_output_weights_ptr,
+ recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
+ recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
+ cell_to_input_weights_ptr, cell_to_forget_weights_ptr,
+ cell_to_output_weights_ptr, input_layer_norm_weight_ptr,
+ forget_layer_norm_weight_ptr, cell_layer_norm_weight_ptr,
+ output_layer_norm_weight_ptr, input_gate_bias_ptr, forget_gate_bias_ptr,
+ cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
+ projection_bias_ptr, cell_clip, proj_clip, activation, n_batch, n_cell,
+ n_input, n_output, activation_state_ptr, cell_state_ptr,
+ input_gate_scratch, forget_gate_scratch, cell_scratch,
+ output_gate_scratch, output_ptr_batch);
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalHybrid(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights,
+ const TfLiteTensor* input_layer_norm_weights,
+ const TfLiteTensor* forget_layer_norm_weights,
+ const TfLiteTensor* cell_layer_norm_weights,
+ const TfLiteTensor* output_layer_norm_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ float cell_clip, float proj_clip, const TfLiteFusedActivation& activation,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
+ TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_weights,
+ TfLiteTensor* input_quantized, TfLiteTensor* activation_state_quantized,
+ TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state,
+ TfLiteTensor* cell_state, TfLiteTensor* output) {
+ const int n_batch = input->dims->data[0];
+ const int n_input = input->dims->data[1];
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existence of only one to get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Check optional tensors, the respective pointers can be null.
+ int8_t* input_to_input_weights_ptr = nullptr;
+ float input_to_input_weights_scale = 1.0f;
+ int8_t* recurrent_to_input_weights_ptr = nullptr;
+ float recurrent_to_input_weights_scale = 1.0f;
+ float* input_gate_bias_ptr = nullptr;
+ if (!use_cifg) {
+ input_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
+ recurrent_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
+ input_gate_bias_ptr = input_gate_bias->data.f;
+ input_to_input_weights_scale = input_to_input_weights->params.scale;
+ recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
+ }
+
+ int8_t* cell_to_input_weights_ptr = nullptr;
+ int8_t* cell_to_forget_weights_ptr = nullptr;
+ int8_t* cell_to_output_weights_ptr = nullptr;
+ float cell_to_input_weights_scale = 1.0f;
+ float cell_to_forget_weights_scale = 1.0f;
+ float cell_to_output_weights_scale = 1.0f;
+ if (use_peephole) {
+ if (!use_cifg) {
+ cell_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
+ cell_to_input_weights_scale = cell_to_input_weights->params.scale;
+ }
+ cell_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
+ cell_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
+ cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
+ cell_to_output_weights_scale = cell_to_output_weights->params.scale;
+ }
+
+ const int8_t* projection_weights_ptr =
+ (projection_weights == nullptr)
+ ? nullptr
+ : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
+ const float projection_weights_scale =
+ (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
+ const float* projection_bias_ptr =
+ (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+ // Required tensors, pointers are non-null.
+ const float* input_ptr_batch = input->data.f;
+ const int8_t* input_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
+ const float input_to_forget_weights_scale =
+ input_to_forget_weights->params.scale;
+ const int8_t* input_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
+ const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
+ const int8_t* input_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
+ const float input_to_output_weights_scale =
+ input_to_output_weights->params.scale;
+ const int8_t* recurrent_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
+ const float recurrent_to_forget_weights_scale =
+ recurrent_to_forget_weights->params.scale;
+ const int8_t* recurrent_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
+ const float recurrent_to_cell_weights_scale =
+ recurrent_to_cell_weights->params.scale;
+ const int8_t* recurrent_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
+ const float recurrent_to_output_weights_scale =
+ recurrent_to_output_weights->params.scale;
+ const float* input_layer_norm_weight_ptr = input_layer_norm_weights->data.f;
+ const float* forget_layer_norm_weight_ptr = forget_layer_norm_weights->data.f;
+ const float* cell_layer_norm_weight_ptr = cell_layer_norm_weights->data.f;
+ const float* output_layer_norm_weight_ptr = output_layer_norm_weights->data.f;
+ const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
+ const float* cell_bias_ptr = cell_bias->data.f;
+ const float* output_gate_bias_ptr = output_gate_bias->data.f;
+
+ float* activation_state_ptr = activation_state->data.f;
+ float* cell_state_ptr = cell_state->data.f;
+ float* output_ptr_batch = output->data.f;
+
+ // Temporary storage for quantized values and scaling factors.
+ int8_t* quantized_input_ptr =
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+ int8_t* quantized_activation_state_ptr =
+ reinterpret_cast<int8_t*>(activation_state_quantized->data.uint8);
+ int8_t* quantized_cell_state_ptr =
+ reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
+ float* scaling_factors_ptr = scaling_factors->data.f;
+ float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
+ float* recovered_weights_ptr = recovered_weights->data.f;
+
+ LayerNormLstmStep(
+ input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale,
+ input_to_forget_weights_ptr, input_to_forget_weights_scale,
+ input_to_cell_weights_ptr, input_to_cell_weights_scale,
+ input_to_output_weights_ptr, input_to_output_weights_scale,
+ recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale,
+ recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale,
+ recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale,
+ recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale,
+ cell_to_input_weights_ptr, cell_to_input_weights_scale,
+ cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
+ cell_to_output_weights_ptr, cell_to_output_weights_scale,
+ input_layer_norm_weight_ptr, forget_layer_norm_weight_ptr,
+ cell_layer_norm_weight_ptr, output_layer_norm_weight_ptr,
+ input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
+ output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale,
+ projection_bias_ptr, cell_clip, proj_clip, activation, n_batch, n_cell,
+ n_input, n_output, input_gate_scratch, forget_gate_scratch, cell_scratch,
+ output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
+ recovered_weights_ptr, quantized_input_ptr,
+ quantized_activation_state_ptr, quantized_cell_state_ptr,
+ activation_state_ptr, cell_state_ptr, output_ptr_batch);
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+
+ const TfLiteTensor* input_to_input_weights =
+ GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
+ const TfLiteTensor* input_to_forget_weights =
+ GetInput(context, node, kInputToForgetWeightsTensor);
+ const TfLiteTensor* input_to_cell_weights =
+ GetInput(context, node, kInputToCellWeightsTensor);
+ const TfLiteTensor* input_to_output_weights =
+ GetInput(context, node, kInputToOutputWeightsTensor);
+
+ const TfLiteTensor* recurrent_to_input_weights =
+ GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
+ const TfLiteTensor* recurrent_to_forget_weights =
+ GetInput(context, node, kRecurrentToForgetWeightsTensor);
+ const TfLiteTensor* recurrent_to_cell_weights =
+ GetInput(context, node, kRecurrentToCellWeightsTensor);
+ const TfLiteTensor* recurrent_to_output_weights =
+ GetInput(context, node, kRecurrentToOutputWeightsTensor);
+
+ const TfLiteTensor* cell_to_input_weights =
+ GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
+ const TfLiteTensor* cell_to_forget_weights =
+ GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
+ const TfLiteTensor* cell_to_output_weights =
+ GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
+
+ const TfLiteTensor* input_layer_norm_weights =
+ GetInput(context, node, kInputLayerNormWeightsTensor);
+ const TfLiteTensor* forget_layer_norm_weights =
+ GetInput(context, node, kForgetLayerNormWeightsTensor);
+ const TfLiteTensor* cell_layer_norm_weights =
+ GetInput(context, node, kCellLayerNormWeightsTensor);
+ const TfLiteTensor* output_layer_norm_weights =
+ GetInput(context, node, kOutputLayerNormWeightsTensor);
+
+ const TfLiteTensor* input_gate_bias =
+ GetOptionalInputTensor(context, node, kInputGateBiasTensor);
+ const TfLiteTensor* forget_gate_bias =
+ GetInput(context, node, kForgetGateBiasTensor);
+ const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
+ const TfLiteTensor* output_gate_bias =
+ GetInput(context, node, kOutputGateBiasTensor);
+
+ const TfLiteTensor* projection_weights =
+ GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
+ const TfLiteTensor* projection_bias =
+ GetOptionalInputTensor(context, node, kProjectionBiasTensor);
+
+ // Index the scratch buffers pointers to the global scratch buffer.
+ TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
+
+ TfLiteTensor* activation_state =
+ &context->tensors[node->inputs->data[kInputActivationStateTensor]];
+ TfLiteTensor* cell_state =
+ &context->tensors[node->inputs->data[kInputCellStateTensor]];
+
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ switch (input_to_output_weights->type) {
+ case kTfLiteFloat32: {
+ return EvalFloat(input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights,
+ recurrent_to_input_weights, recurrent_to_forget_weights,
+ recurrent_to_cell_weights, recurrent_to_output_weights,
+ cell_to_input_weights, cell_to_forget_weights,
+ cell_to_output_weights, input_layer_norm_weights,
+ forget_layer_norm_weights, cell_layer_norm_weights,
+ output_layer_norm_weights, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias,
+ projection_weights, projection_bias, op_data->cell_clip,
+ op_data->proj_clip, op_data->activation, scratch_buffer,
+ activation_state, cell_state, output);
+ }
+ case kTfLiteUInt8: {
+ TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
+ TfLiteTensor* activation_state_quantized =
+ GetTemporary(context, node, /*index=*/2);
+ TfLiteTensor* cell_state_quantized =
+ GetTemporary(context, node, /*index=*/3);
+ TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4);
+ TfLiteTensor* prod_scaling_factors =
+ GetTemporary(context, node, /*index=*/5);
+ TfLiteTensor* recovered_weights =
+ GetTemporary(context, node, /*index=*/6);
+ return EvalHybrid(
+ input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights,
+ recurrent_to_input_weights, recurrent_to_forget_weights,
+ recurrent_to_cell_weights, recurrent_to_output_weights,
+ cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
+ input_layer_norm_weights, forget_layer_norm_weights,
+ cell_layer_norm_weights, output_layer_norm_weights, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
+ projection_bias, op_data->cell_clip, op_data->proj_clip,
+ op_data->activation, scratch_buffer, scaling_factors,
+ prod_scaling_factors, recovered_weights, input_quantized,
+ activation_state_quantized, cell_state_quantized, activation_state,
+ cell_state, output);
+ }
+ default:
+ context->ReportError(context, "Type %d is not currently supported.",
+ input_to_output_weights->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+} // namespace layer_norm_lstm
+
+TfLiteRegistration* Register_LAYER_NORM_LSTM() {
+ static TfLiteRegistration r = {layer_norm_lstm::Init, layer_norm_lstm::Free,
+ layer_norm_lstm::Prepare,
+ layer_norm_lstm::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc b/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc
new file mode 100644
index 0000000000..abc229f85a
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/layer_norm_lstm_test.cc
@@ -0,0 +1,664 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite Layer Norm LSTM op.
+
+#include <memory>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_LAYER_NORM_LSTM();
+
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class LayerNormLSTMOpModel : public SingleOpModel {
+ public:
+ LayerNormLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
+ bool use_cifg, bool use_peephole,
+ bool use_projection_weights, bool use_projection_bias,
+ float cell_clip, float proj_clip,
+ const std::vector<std::vector<int>>& input_shapes,
+ const TensorType& weight_type = TensorType_FLOAT32)
+ : n_batch_(n_batch),
+ n_input_(n_input),
+ n_cell_(n_cell),
+ n_output_(n_output) {
+ input_ = AddInput(TensorType_FLOAT32);
+
+ if (use_cifg) {
+ input_to_input_weights_ = AddNullInput();
+ } else {
+ input_to_input_weights_ = AddInput(weight_type);
+ }
+
+ input_to_forget_weights_ = AddInput(weight_type);
+ input_to_cell_weights_ = AddInput(weight_type);
+ input_to_output_weights_ = AddInput(weight_type);
+
+ if (use_cifg) {
+ recurrent_to_input_weights_ = AddNullInput();
+ } else {
+ recurrent_to_input_weights_ = AddInput(weight_type);
+ }
+
+ recurrent_to_forget_weights_ = AddInput(weight_type);
+ recurrent_to_cell_weights_ = AddInput(weight_type);
+ recurrent_to_output_weights_ = AddInput(weight_type);
+
+ if (use_peephole) {
+ if (use_cifg) {
+ cell_to_input_weights_ = AddNullInput();
+ } else {
+ cell_to_input_weights_ = AddInput(weight_type);
+ }
+ cell_to_forget_weights_ = AddInput(weight_type);
+ cell_to_output_weights_ = AddInput(weight_type);
+ } else {
+ cell_to_input_weights_ = AddNullInput();
+ cell_to_forget_weights_ = AddNullInput();
+ cell_to_output_weights_ = AddNullInput();
+ }
+
+ input_layer_norm_weights_ = AddInput(TensorType_FLOAT32);
+ forget_layer_norm_weights_ = AddInput(TensorType_FLOAT32);
+ cell_layer_norm_weights_ = AddInput(TensorType_FLOAT32);
+ output_layer_norm_weights_ = AddInput(TensorType_FLOAT32);
+
+ if (use_cifg) {
+ input_gate_bias_ = AddNullInput();
+ } else {
+ input_gate_bias_ = AddInput(TensorType_FLOAT32);
+ }
+ forget_gate_bias_ = AddInput(TensorType_FLOAT32);
+ cell_bias_ = AddInput(TensorType_FLOAT32);
+ output_gate_bias_ = AddInput(TensorType_FLOAT32);
+
+ if (use_projection_weights) {
+ projection_weights_ = AddInput(weight_type);
+ if (use_projection_bias) {
+ projection_bias_ = AddInput(TensorType_FLOAT32);
+ } else {
+ projection_bias_ = AddNullInput();
+ }
+ } else {
+ projection_weights_ = AddNullInput();
+ projection_bias_ = AddNullInput();
+ }
+
+ // Adding the 2 state tensors.
+ output_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true);
+ cell_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true);
+
+ output_ = AddOutput(TensorType_FLOAT32);
+
+ // Set up and pass in custom options using flexbuffer.
+ flexbuffers::Builder fbb;
+ fbb.Map([&]() {
+ fbb.Int("cell_clip", cell_clip);
+ fbb.Int("proj_clip", proj_clip);
+ fbb.String("fused_activation_function", "TANH");
+ });
+ fbb.Finish();
+ SetCustomOp("LAYER_NORM_LSTM", fbb.GetBuffer(), Register_LAYER_NORM_LSTM);
+ BuildInterpreter(input_shapes);
+ }
+
+ void SetInputToInputWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_input_weights_, f);
+ }
+
+ void SetInputToForgetWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_forget_weights_, f);
+ }
+
+ void SetInputToCellWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_cell_weights_, f);
+ }
+
+ void SetInputToOutputWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_to_output_weights_, f);
+ }
+
+ void SetRecurrentToInputWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_input_weights_, f);
+ }
+
+ void SetRecurrentToForgetWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_forget_weights_, f);
+ }
+
+ void SetRecurrentToCellWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_cell_weights_, f);
+ }
+
+ void SetRecurrentToOutputWeights(std::initializer_list<float> f) {
+ PopulateTensor(recurrent_to_output_weights_, f);
+ }
+
+ void SetCellToInputWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_to_input_weights_, f);
+ }
+
+ void SetCellToForgetWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_to_forget_weights_, f);
+ }
+
+ void SetCellToOutputWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_to_output_weights_, f);
+ }
+
+ void SetInputLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_layer_norm_weights_, f);
+ }
+
+ void SetForgetLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(forget_layer_norm_weights_, f);
+ }
+
+ void SetCellLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_layer_norm_weights_, f);
+ }
+
+ void SetOutputLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(output_layer_norm_weights_, f);
+ }
+
+ void SetInputGateBias(std::initializer_list<float> f) {
+ PopulateTensor(input_gate_bias_, f);
+ }
+
+ void SetForgetGateBias(std::initializer_list<float> f) {
+ PopulateTensor(forget_gate_bias_, f);
+ }
+
+ void SetCellBias(std::initializer_list<float> f) {
+ PopulateTensor(cell_bias_, f);
+ }
+
+ void SetOutputGateBias(std::initializer_list<float> f) {
+ PopulateTensor(output_gate_bias_, f);
+ }
+
+ void SetProjectionWeights(std::initializer_list<float> f) {
+ PopulateTensor(projection_weights_, f);
+ }
+
+ void SetProjectionBias(std::initializer_list<float> f) {
+ PopulateTensor(projection_bias_, f);
+ }
+
+ void SetInput(int offset, const float* begin, const float* end) {
+ PopulateTensor(input_, offset, const_cast<float*>(begin),
+ const_cast<float*>(end));
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ int num_inputs() { return n_input_; }
+ int num_outputs() { return n_output_; }
+ int num_cells() { return n_cell_; }
+ int num_batches() { return n_batch_; }
+
+ protected:
+ int input_;
+ int input_to_input_weights_;
+ int input_to_forget_weights_;
+ int input_to_cell_weights_;
+ int input_to_output_weights_;
+
+ int recurrent_to_input_weights_;
+ int recurrent_to_forget_weights_;
+ int recurrent_to_cell_weights_;
+ int recurrent_to_output_weights_;
+
+ int cell_to_input_weights_;
+ int cell_to_forget_weights_;
+ int cell_to_output_weights_;
+
+ int input_layer_norm_weights_;
+ int forget_layer_norm_weights_;
+ int cell_layer_norm_weights_;
+ int output_layer_norm_weights_;
+
+ int input_gate_bias_;
+ int forget_gate_bias_;
+ int cell_bias_;
+ int output_gate_bias_;
+
+ int projection_weights_;
+ int projection_bias_;
+
+ int output_state_;
+ int cell_state_;
+
+ int output_;
+
+ int n_batch_;
+ int n_input_;
+ int n_cell_;
+ int n_output_;
+};
+
+class HybridLayerNormLSTMOpModel : public LayerNormLSTMOpModel {
+ public:
+ HybridLayerNormLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
+ bool use_cifg, bool use_peephole,
+ bool use_projection_weights,
+ bool use_projection_bias, float cell_clip,
+ float proj_clip,
+ const std::vector<std::vector<int>>& input_shapes)
+ : LayerNormLSTMOpModel(n_batch, n_input, n_cell, n_output, use_cifg,
+ use_peephole, use_projection_weights,
+ use_projection_bias, cell_clip, proj_clip,
+ input_shapes, TensorType_UINT8) {}
+
+ void SetInputToInputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_input_weights_, f);
+ }
+
+ void SetInputToForgetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_forget_weights_, f);
+ }
+
+ void SetInputToCellWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_cell_weights_, f);
+ }
+
+ void SetInputToOutputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(input_to_output_weights_, f);
+ }
+
+ void SetRecurrentToInputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_input_weights_, f);
+ }
+
+ void SetRecurrentToForgetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_forget_weights_, f);
+ }
+
+ void SetRecurrentToCellWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_cell_weights_, f);
+ }
+
+ void SetRecurrentToOutputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_to_output_weights_, f);
+ }
+
+ void SetCellToInputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(cell_to_input_weights_, f);
+ }
+
+ void SetCellToForgetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(cell_to_forget_weights_, f);
+ }
+
+ void SetCellToOutputWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(cell_to_output_weights_, f);
+ }
+
+ void SetInputLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(input_layer_norm_weights_, f);
+ }
+
+ void SetForgetLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(forget_layer_norm_weights_, f);
+ }
+
+ void SetCellLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(cell_layer_norm_weights_, f);
+ }
+
+ void SetOutputLayerNormWeights(std::initializer_list<float> f) {
+ PopulateTensor(output_layer_norm_weights_, f);
+ }
+
+ void SetProjectionWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(projection_weights_, f);
+ }
+};
+
+class BaseLayerNormLstmTest : public ::testing::Test {
+ protected:
+ // Weights of the Layer Norm LSTM model. Some are optional.
+ std::initializer_list<float> input_to_input_weights_;
+ std::initializer_list<float> input_to_cell_weights_;
+ std::initializer_list<float> input_to_forget_weights_;
+ std::initializer_list<float> input_to_output_weights_;
+ std::initializer_list<float> input_gate_bias_;
+ std::initializer_list<float> cell_gate_bias_;
+ std::initializer_list<float> forget_gate_bias_;
+ std::initializer_list<float> output_gate_bias_;
+ std::initializer_list<float> recurrent_to_input_weights_;
+ std::initializer_list<float> recurrent_to_cell_weights_;
+ std::initializer_list<float> recurrent_to_forget_weights_;
+ std::initializer_list<float> recurrent_to_output_weights_;
+ std::initializer_list<float> cell_to_input_weights_;
+ std::initializer_list<float> cell_to_forget_weights_;
+ std::initializer_list<float> cell_to_output_weights_;
+ std::initializer_list<float> input_layer_norm_weights_;
+ std::initializer_list<float> forget_layer_norm_weights_;
+ std::initializer_list<float> cell_layer_norm_weights_;
+ std::initializer_list<float> output_layer_norm_weights_;
+ std::initializer_list<float> projection_weights_;
+
+ // Layer Norm LSTM input is stored as num_batch x num_inputs vector.
+ std::vector<std::vector<float>> layer_norm_lstm_input_;
+
+ // Compares output up to tolerance to the result of the layer_norm_lstm given
+ // the input.
+ void VerifyGoldens(const std::vector<std::vector<float>>& input,
+ const std::vector<std::vector<float>>& output,
+ LayerNormLSTMOpModel* layer_norm_lstm,
+ float tolerance = 1e-5) {
+ const int num_batches = input.size();
+ EXPECT_GT(num_batches, 0);
+ const int num_inputs = layer_norm_lstm->num_inputs();
+ EXPECT_GT(num_inputs, 0);
+ const int input_sequence_size = input[0].size() / num_inputs;
+ EXPECT_GT(input_sequence_size, 0);
+ for (int i = 0; i < input_sequence_size; ++i) {
+ for (int b = 0; b < num_batches; ++b) {
+ const float* batch_start = input[b].data() + i * num_inputs;
+ const float* batch_end = batch_start + num_inputs;
+
+ layer_norm_lstm->SetInput(b * layer_norm_lstm->num_inputs(),
+ batch_start, batch_end);
+ }
+
+ layer_norm_lstm->Invoke();
+
+ const int num_outputs = layer_norm_lstm->num_outputs();
+ std::vector<float> expected;
+ for (int b = 0; b < num_batches; ++b) {
+ const float* golden_start_batch = output[b].data() + i * num_outputs;
+ const float* golden_end_batch = golden_start_batch + num_outputs;
+ expected.insert(expected.end(), golden_start_batch, golden_end_batch);
+ }
+ EXPECT_THAT(layer_norm_lstm->GetOutput(),
+ ElementsAreArray(ArrayFloatNear(expected, tolerance)));
+ }
+ }
+};
+
+class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest
+ : public BaseLayerNormLstmTest {
+ void SetUp() override {
+ input_to_input_weights_ = {0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2,
+ 0.3, -0.4, 0.5, -0.8, 0.7, -0.6, 0.5,
+ -0.4, -0.5, -0.4, -0.3, -0.2, -0.1};
+
+ input_to_forget_weights_ = {-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2,
+ -0.4, 0.3, -0.8, -0.4, 0.3, -0.5, -0.4,
+ -0.6, 0.3, -0.4, -0.6, -0.5, -0.5};
+
+ input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2,
+ -0.3, -0.2, -0.6, 0.6, -0.1, -0.4, -0.3,
+ -0.7, 0.7, -0.9, -0.5, 0.8, 0.6};
+
+ input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3,
+ -0.3, -0.8, -0.2, 0.6, -0.2, 0.4, -0.7,
+ -0.3, -0.5, 0.1, 0.5, -0.6, -0.4};
+
+ input_gate_bias_ = {0.03, 0.15, 0.22, 0.38};
+
+ forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1};
+
+ cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08};
+
+ output_gate_bias_ = {0.05, -0.01, 0.2, 0.1};
+
+ recurrent_to_input_weights_ = {-0.2, -0.3, 0.4, 0.1, -0.5, 0.9,
+ -0.2, -0.3, -0.7, 0.05, -0.2, -0.6};
+
+ recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08,
+ -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
+
+ recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4,
+ 0.9, 0.3, -0.1, 0.2, 0.5, 0.2};
+
+ recurrent_to_output_weights_ = {0.3, -0.1, 0.1, -0.2, -0.5, -0.7,
+ -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
+
+ cell_to_input_weights_ = {0.05, 0.1, 0.25, 0.15};
+
+ cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03};
+
+ cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05};
+
+ input_layer_norm_weights_ = {0.1, 0.2, 0.3, 0.5};
+ forget_layer_norm_weights_ = {0.2, 0.2, 0.4, 0.3};
+ cell_layer_norm_weights_ = {0.7, 0.2, 0.3, 0.8};
+ output_layer_norm_weights_ = {0.6, 0.2, 0.2, 0.5};
+
+ projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5,
+ 0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
+
+ layer_norm_lstm_input_ = {
+ {// Batch0: 3 (input_sequence_size) * 5 (n_input)
+ 0.7, 0.8, 0.1, 0.2, 0.3, // seq 0
+ 0.8, 0.1, 0.2, 0.4, 0.5, // seq 1
+ 0.2, 0.7, 0.7, 0.1, 0.7}, // seq 2
+
+ {// Batch1: 3 (input_sequence_size) * 5 (n_input)
+ 0.3, 0.2, 0.9, 0.8, 0.1, // seq 0
+ 0.1, 0.5, 0.2, 0.4, 0.2, // seq 1
+ 0.6, 0.9, 0.2, 0.5, 0.7}, // seq 2
+ };
+ }
+};
+
+TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
+ LayerNormLstmBlackBoxTest) {
+ const int n_batch = 2;
+ const int n_input = 5;
+ const int n_cell = 4;
+ const int n_output = 3;
+ const float ceil_clip = 0.0;
+ const float proj_clip = 0.0;
+
+ LayerNormLSTMOpModel layer_norm_lstm(
+ n_batch, n_input, n_cell, n_output,
+ /*use_cifg=*/false, /*use_peephole=*/true,
+ /*use_projection_weights=*/true,
+ /*use_projection_bias=*/false, ceil_clip, proj_clip,
+ {
+ {n_batch, n_input}, // input tensor
+
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {n_cell}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_layer_norm_weight tensor
+ {n_cell}, // forget_layer_norm_weight tensor
+ {n_cell}, // cell_layer_norm_weight tensor
+ {n_cell}, // output_layer_norm_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {n_output, n_cell}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
+ layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
+ layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
+
+ layer_norm_lstm.SetInputGateBias(input_gate_bias_);
+ layer_norm_lstm.SetCellBias(cell_gate_bias_);
+ layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
+ layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
+
+ layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+ layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
+ layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
+
+ layer_norm_lstm.SetInputLayerNormWeights(input_layer_norm_weights_);
+ layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_);
+ layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_);
+ layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_);
+
+ layer_norm_lstm.SetProjectionWeights(projection_weights_);
+
+ // Verify the final output.
+ const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
+ {
+ // Batch0: 3 (input_sequence_size) * 3 (n_output)
+ 0.0244077, 0.128027, -0.00170918, // seq 0
+ 0.0137642, 0.140751, 0.0395835, // seq 1
+ -0.00459231, 0.155278, 0.0837377, // seq 2
+ },
+ {
+ // Batch1: 3 (input_sequence_size) * 3 (n_output)
+ -0.00692428, 0.0848741, 0.063445, // seq 0
+ -0.00403912, 0.139963, 0.072681, // seq 1
+ 0.00752706, 0.161903, 0.0561371, // seq 2
+ }};
+
+ VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
+ &layer_norm_lstm);
+}
+
+TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
+ HybridLayerNormLstmBlackBoxTest) {
+ const int n_batch = 2;
+ const int n_input = 5;
+ const int n_cell = 4;
+ const int n_output = 3;
+ const float ceil_clip = 0.0;
+ const float proj_clip = 0.0;
+
+ HybridLayerNormLSTMOpModel layer_norm_lstm(
+ n_batch, n_input, n_cell, n_output,
+ /*use_cifg=*/false, /*use_peephole=*/true,
+ /*use_projection_weights=*/true,
+ /*use_projection_bias=*/false, ceil_clip, proj_clip,
+ {
+ {n_batch, n_input}, // input tensor
+
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {n_cell}, // cell_to_input_weight tensor
+ {n_cell}, // cell_to_forget_weight tensor
+ {n_cell}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_layer_norm_weight tensor
+ {n_cell}, // forget_layer_norm_weight tensor
+ {n_cell}, // cell_layer_norm_weight tensor
+ {n_cell}, // output_layer_norm_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {n_output, n_cell}, // projection_weight tensor
+ {0}, // projection_bias tensor
+ });
+
+ layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
+ layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
+ layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
+ layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
+
+ layer_norm_lstm.SetInputGateBias(input_gate_bias_);
+ layer_norm_lstm.SetCellBias(cell_gate_bias_);
+ layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
+ layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
+
+ layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
+ layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
+ layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
+ layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
+
+ layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
+ layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
+ layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
+
+ layer_norm_lstm.SetInputLayerNormWeights(input_layer_norm_weights_);
+ layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_);
+ layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_);
+ layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_);
+
+ layer_norm_lstm.SetProjectionWeights(projection_weights_);
+
+ const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
+ {
+ // Batch0: 3 (input_sequence_size) * 3 (n_output)
+ 0.0244576, 0.127847, -0.00181765, // seq 0
+ 0.0137518, 0.140892, 0.0402234, // seq 1
+ -0.0048839, 0.155096, 0.0840309, // seq 2
+ },
+ {
+ // Batch1: 3 (input_sequence_size) * 3 (n_output)
+ -0.00728636, 0.0843957, 0.0634786, // seq 0
+ -0.00448382, 0.139278, 0.0737372, // seq 1
+ 0.00734616, 0.161793, 0.0560238, // seq 2
+ }};
+
+ VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
+ &layer_norm_lstm);
+}
+
+} // namespace
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/local_response_norm.cc b/tensorflow/contrib/lite/kernels/local_response_norm.cc
index 36dca299d0..334d2a2788 100644
--- a/tensorflow/contrib/lite/kernels/local_response_norm.cc
+++ b/tensorflow/contrib/lite/kernels/local_response_norm.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -64,11 +64,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (output->type == kTfLiteFloat32) {
-#define TF_LITE_LOCAL_RESPONSE_NORM(type) \
- type::LocalResponseNormalization( \
- GetTensorData<float>(input), GetTensorDims(input), params->radius, \
- params->bias, params->alpha, params->beta, GetTensorData<float>(output), \
- GetTensorDims(output))
+#define TF_LITE_LOCAL_RESPONSE_NORM(type) \
+ tflite::LocalResponseNormalizationParams op_params; \
+ op_params.range = params->radius; \
+ op_params.bias = params->bias; \
+ op_params.alpha = params->alpha; \
+ op_params.beta = params->beta; \
+ type::LocalResponseNormalization( \
+ op_params, GetTensorShape(input), GetTensorData<float>(input), \
+ GetTensorShape(output), GetTensorData<float>(output))
if (kernel_type == kReference) {
TF_LITE_LOCAL_RESPONSE_NORM(reference_ops);
}
diff --git a/tensorflow/contrib/lite/kernels/logical.cc b/tensorflow/contrib/lite/kernels/logical.cc
index 87c2fee667..f770cb35d1 100644
--- a/tensorflow/contrib/lite/kernels/logical.cc
+++ b/tensorflow/contrib/lite/kernels/logical.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -86,14 +86,14 @@ TfLiteStatus LogicalImpl(TfLiteContext* context, TfLiteNode* node,
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (data->requires_broadcast) {
- reference_ops::BroadcastLogical(
- GetTensorData<bool>(input1), GetTensorDims(input1),
- GetTensorData<bool>(input2), GetTensorDims(input2),
- GetTensorData<bool>(output), GetTensorDims(output), func);
+ reference_ops::BroadcastLogical4DSlow(
+ GetTensorShape(input1), GetTensorData<bool>(input1),
+ GetTensorShape(input2), GetTensorData<bool>(input2),
+ GetTensorShape(output), GetTensorData<bool>(output), func);
} else {
- reference_ops::Logical(GetTensorData<bool>(input1), GetTensorDims(input1),
- GetTensorData<bool>(input2), GetTensorDims(input2),
- GetTensorData<bool>(output), GetTensorDims(output),
+ reference_ops::Logical(GetTensorShape(input1), GetTensorData<bool>(input1),
+ GetTensorShape(input2), GetTensorData<bool>(input2),
+ GetTensorShape(output), GetTensorData<bool>(output),
func);
}
diff --git a/tensorflow/contrib/lite/kernels/lsh_projection.cc b/tensorflow/contrib/lite/kernels/lsh_projection.cc
index 69523b02cc..9fa1c5f100 100644
--- a/tensorflow/contrib/lite/kernels/lsh_projection.cc
+++ b/tensorflow/contrib/lite/kernels/lsh_projection.cc
@@ -59,8 +59,8 @@ limitations under the License.
#include <limits>
#include <memory>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
#include <farmhash.h>
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
index ba251c451e..aaa3ce966e 100644
--- a/tensorflow/contrib/lite/kernels/lstm.cc
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/gemm_support.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
@@ -37,7 +37,7 @@ namespace builtin {
namespace lstm {
struct OpData {
- // Which kernel type to use. Full kernel (18 or 20 inputs) or basic kernel
+ // Which kernel type to use. Full kernel (20 inputs) or basic kernel
// (5 inputs).
TfLiteLSTMKernelType kernel_type;
@@ -47,7 +47,7 @@ struct OpData {
int scratch_tensor_index;
};
-// For full inputs kernel (18 or 20 inputs).
+// For full inputs kernel (20-inputs).
namespace full {
// Input Tensors of size {n_batch, n_input}
@@ -81,19 +81,13 @@ constexpr int kProjectionWeightsTensor = 16; // Optional
// Projection bias tensor of size {n_output}
constexpr int kProjectionBiasTensor = 17; // Optional
-// If the node has 20 inputs, the following 2 tensors are used as state tensors.
-// These are defined as variable tensors, and will be modified by this op.
+// These state tensors are defined as variable tensors, and will be modified by
+// this op.
constexpr int kInputActivationStateTensor = 18;
constexpr int kInputCellStateTensor = 19;
// Output tensors.
-// * If the node has 18 inputs, these 2 tensors are used as state tensors.
-// * If the node has 20 inputs, these 2 tensors are ignored.
-// TODO(ycling): Make the 2 output state tensors optional, and propagate the
-// state to output tensors when the 2 tensors present.
-constexpr int kOutputStateTensor = 0;
-constexpr int kCellStateTensor = 1;
-constexpr int kOutputTensor = 2;
+constexpr int kOutputTensor = 0;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* op_data = new OpData();
@@ -258,30 +252,12 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 3);
-
- // True if the node is using input variable state tensors. It means:
- // * The state tensors are defined as inputs. In this case it would be the
- // 19th and 20th input tensors.
- // * Otherwise, the output tensors are used to store states.
- bool use_input_variable_states;
- if (node->inputs->size == 20) {
- use_input_variable_states = true;
- op_data->activation_state_tensor_index =
- node->inputs->data[kInputActivationStateTensor];
- op_data->cell_state_tensor_index =
- node->inputs->data[kInputCellStateTensor];
- } else if (node->inputs->size == 18) {
- use_input_variable_states = false;
- op_data->activation_state_tensor_index =
- node->outputs->data[kOutputStateTensor];
- op_data->cell_state_tensor_index = node->outputs->data[kCellStateTensor];
- } else {
- context->ReportError(
- context, "The LSTM Full kernel expects 18 or 20 inputs. Got %d inputs",
- node->inputs->size);
- return kTfLiteError;
- }
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 20);
+
+ op_data->activation_state_tensor_index =
+ node->inputs->data[kInputActivationStateTensor];
+ op_data->cell_state_tensor_index = node->inputs->data[kInputCellStateTensor];
// Inferring batch size, number of outputs and number of cells from the
// input tensors.
@@ -316,31 +292,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* cell_state =
&context->tensors[op_data->cell_state_tensor_index];
- if (use_input_variable_states) {
- // Check the shape of input state tensors.
- // These tensor may be 1D or 2D. It's fine as long as the total size is
- // correct.
- TF_LITE_ENSURE_EQ(context, NumElements(activation_state),
- n_batch * n_output);
- TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
- } else {
- // If the state tensors are outputs, this function takes the
- // responsibility to resize the state tensors.
- TfLiteIntArray* activation_state_size = TfLiteIntArrayCreate(2);
- activation_state_size->data[0] = n_batch;
- activation_state_size->data[1] = n_output;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_state,
- activation_state_size));
-
- TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2);
- cell_size->data[0] = n_batch;
- cell_size->data[1] = n_cell;
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, cell_state, cell_size));
- // Mark state tensors as persistent tensors.
- activation_state->allocation_type = kTfLiteArenaRwPersistent;
- cell_state->allocation_type = kTfLiteArenaRwPersistent;
- }
+ // Check the shape of input state tensors.
+ // These tensor may be 1D or 2D. It's fine as long as the total size is
+ // correct.
+ TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output);
+ TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
// Resize the output tensors.
TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc
index 0266f5fe57..e7ddfceb45 100644
--- a/tensorflow/contrib/lite/kernels/lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/lstm_test.cc
@@ -106,14 +106,13 @@ class LSTMOpModel : public SingleOpModel {
input_cell_state_ =
AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true);
- output_state_ = AddOutput(TensorType_FLOAT32);
- cell_state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
cell_clip, proj_clip)
.Union());
+
BuildInterpreter(input_shapes);
}
@@ -185,22 +184,6 @@ class LSTMOpModel : public SingleOpModel {
PopulateTensor(projection_bias_, f);
}
- void ResetOutputState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(output_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
- void ResetCellState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(cell_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
void SetInput(int offset, const float* begin, const float* end) {
PopulateTensor(input_, offset, const_cast<float*>(begin),
const_cast<float*>(end));
@@ -469,10 +452,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
@@ -529,10 +508,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
/*tolerance=*/0.0157651);
}
@@ -637,10 +612,6 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
lstm.SetCellToOutputWeights(cell_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
@@ -698,14 +669,10 @@ TEST_F(CifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
lstm.SetCellToOutputWeights(cell_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
}
-class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest {
+class NoCifgPeepholeProjectionNoClippingLstmTest : public BaseLstmTest {
void SetUp() override {
input_to_input_weights_ = {
0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463,
@@ -1304,7 +1271,7 @@ class NoCifgPeepholeProjectionClippingLstmTest : public BaseLstmTest {
}
};
-TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
+TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, LstmBlackBoxTest) {
const int n_batch = 2;
const int n_input = 5;
const int n_cell = 20;
@@ -1362,14 +1329,10 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
lstm.SetProjectionWeights(projection_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
-TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) {
+TEST_F(NoCifgPeepholeProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
const int n_batch = 2;
const int n_input = 5;
const int n_cell = 20;
@@ -1428,10 +1391,6 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) {
lstm.SetProjectionWeights(projection_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
}
diff --git a/tensorflow/contrib/lite/kernels/maximum_minimum.cc b/tensorflow/contrib/lite/kernels/maximum_minimum.cc
index 8d676218bd..7cb01465ee 100644
--- a/tensorflow/contrib/lite/kernels/maximum_minimum.cc
+++ b/tensorflow/contrib/lite/kernels/maximum_minimum.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -86,13 +86,14 @@ struct MinimumOp {
template <typename data_type, typename op_type>
void TFLiteOperation(TfLiteContext* context, TfLiteNode* node,
const OpContext& op_context) {
- reference_ops::TensorFlowMaximumMinimum<data_type>(
+ reference_ops::MaximumMinimumBroadcast4DSlow(
+ GetTensorShape(op_context.input1),
GetTensorData<data_type>(op_context.input1),
- GetTensorDims(op_context.input1),
+ GetTensorShape(op_context.input2),
GetTensorData<data_type>(op_context.input2),
- GetTensorDims(op_context.input2),
+ GetTensorShape(op_context.output),
GetTensorData<data_type>(op_context.output),
- GetTensorDims(op_context.output), op_type::template op<data_type>);
+ op_type::template op<data_type>);
}
template <KernelType kernel_type, typename OpType>
diff --git a/tensorflow/contrib/lite/kernels/mfcc.cc b/tensorflow/contrib/lite/kernels/mfcc.cc
index 3f5bc4d68a..66cf147d75 100644
--- a/tensorflow/contrib/lite/kernels/mfcc.cc
+++ b/tensorflow/contrib/lite/kernels/mfcc.cc
@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/kernels/internal/mfcc.h"
-#include "flatbuffers/flexbuffers.h"
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/mfcc_dct.h"
#include "tensorflow/contrib/lite/kernels/internal/mfcc_mel_filterbank.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/mfcc_test.cc b/tensorflow/contrib/lite/kernels/mfcc_test.cc
index 0291ca8c1c..c9124adcaf 100644
--- a/tensorflow/contrib/lite/kernels/mfcc_test.cc
+++ b/tensorflow/contrib/lite/kernels/mfcc_test.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <vector>
#include <gtest/gtest.h>
-#include "flatbuffers/flexbuffers.h"
+#include "flatbuffers/flexbuffers.h" // flatbuffers
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc
index 561e39cfc6..e0aac8a842 100644
--- a/tensorflow/contrib/lite/kernels/mul.cc
+++ b/tensorflow/contrib/lite/kernels/mul.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
@@ -102,24 +102,28 @@ template <KernelType kernel_type>
void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
const OpData* data, const TfLiteTensor* input1,
const TfLiteTensor* input2, TfLiteTensor* output) {
-#define TF_LITE_MUL(type, opname, data_type) \
- data_type output_activation_min, output_activation_max; \
- CalculateActivationRange(params->activation, &output_activation_min, \
- &output_activation_max); \
- type::opname(GetTensorData<data_type>(input1), GetTensorDims(input1), \
- GetTensorData<data_type>(input2), GetTensorDims(input2), \
- output_activation_min, output_activation_max, \
- GetTensorData<data_type>(output), GetTensorDims(output))
+#define TF_LITE_MUL(type, opname, data_type) \
+ data_type output_activation_min, output_activation_max; \
+ CalculateActivationRange(params->activation, &output_activation_min, \
+ &output_activation_max); \
+ tflite::ArithmeticParams op_params; \
+ SetActivationParams(output_activation_min, output_activation_max, \
+ &op_params); \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<data_type>(input1), GetTensorShape(input2), \
+ GetTensorData<data_type>(input2), GetTensorShape(output), \
+ GetTensorData<data_type>(output))
+
if (output->type == kTfLiteInt32) {
if (kernel_type == kReference) {
if (data->requires_broadcast) {
- TF_LITE_MUL(reference_ops, BroadcastMul, int32_t);
+ TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, int32_t);
} else {
TF_LITE_MUL(reference_ops, Mul, int32_t);
}
} else {
if (data->requires_broadcast) {
- TF_LITE_MUL(optimized_ops, BroadcastMul, int32_t);
+ TF_LITE_MUL(optimized_ops, BroadcastMul4DSlow, int32_t);
} else {
TF_LITE_MUL(optimized_ops, Mul, int32_t);
}
@@ -127,13 +131,13 @@ void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
} else if (output->type == kTfLiteFloat32) {
if (kernel_type == kReference) {
if (data->requires_broadcast) {
- TF_LITE_MUL(reference_ops, BroadcastMul, float);
+ TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, float);
} else {
TF_LITE_MUL(reference_ops, Mul, float);
}
} else {
if (data->requires_broadcast) {
- TF_LITE_MUL(optimized_ops, BroadcastMul, float);
+ TF_LITE_MUL(optimized_ops, BroadcastMul4DSlow, float);
} else {
TF_LITE_MUL(optimized_ops, Mul, float);
}
@@ -149,14 +153,20 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* input2, TfLiteTensor* output) {
if (input1->type == kTfLiteUInt8 && input2->type == kTfLiteUInt8 &&
output->type == kTfLiteUInt8) {
-#define TF_LITE_MUL(type, opname) \
- type::opname(GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
- -input1->params.zero_point, GetTensorData<uint8_t>(input2), \
- GetTensorDims(input2), -input2->params.zero_point, \
- output->params.zero_point, data->output_multiplier, \
- data->output_shift, data->output_activation_min, \
- data->output_activation_max, GetTensorData<uint8_t>(output), \
- GetTensorDims(output));
+#define TF_LITE_MUL(type, opname) \
+ tflite::ArithmeticParams op_params; \
+ SetActivationParams(data->output_activation_min, \
+ data->output_activation_max, &op_params); \
+ op_params.input1_offset = -input1->params.zero_point; \
+ op_params.input2_offset = -input2->params.zero_point; \
+ op_params.output_offset = output->params.zero_point; \
+ op_params.output_multiplier = data->output_multiplier; \
+ op_params.output_shift = data->output_shift; \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<uint8_t>(input1), GetTensorShape(input2), \
+ GetTensorData<uint8_t>(input2), GetTensorShape(output), \
+ GetTensorData<uint8_t>(output))
+
// The quantized version of Mul doesn't support activations, so we
// always use BroadcastMul.
if (kernel_type == kReference) {
@@ -167,10 +177,12 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
#undef TF_LITE_MUL
} else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 &&
output->type == kTfLiteInt16) {
-#define TF_LITE_MUL(type, opname) \
- type::opname(GetTensorData<int16_t>(input1), GetTensorDims(input1), \
- GetTensorData<int16_t>(input2), GetTensorDims(input2), \
- GetTensorData<int16_t>(output), GetTensorDims(output));
+#define TF_LITE_MUL(type, opname) \
+ tflite::ArithmeticParams op_params; \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<int16_t>(input1), GetTensorShape(input2), \
+ GetTensorData<int16_t>(input2), GetTensorShape(output), \
+ GetTensorData<int16_t>(output))
if (kernel_type == kReference) {
TF_LITE_MUL(reference_ops, Mul);
} else {
@@ -179,12 +191,15 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
#undef TF_LITE_MUL
} else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 &&
output->type == kTfLiteUInt8) {
-#define TF_LITE_MUL(type, opname) \
- type::opname(GetTensorData<int16_t>(input1), GetTensorDims(input1), \
- GetTensorData<int16_t>(input2), GetTensorDims(input2), \
- output->params.zero_point, data->output_activation_min, \
- data->output_activation_max, GetTensorData<uint8_t>(output), \
- GetTensorDims(output));
+#define TF_LITE_MUL(type, opname) \
+ tflite::ArithmeticParams op_params; \
+ SetActivationParams(data->output_activation_min, \
+ data->output_activation_max, &op_params); \
+ op_params.output_offset = output->params.zero_point; \
+ type::opname(op_params, GetTensorShape(input1), \
+ GetTensorData<int16_t>(input1), GetTensorShape(input2), \
+ GetTensorData<int16_t>(input2), GetTensorShape(output), \
+ GetTensorData<uint8_t>(output))
if (kernel_type == kReference) {
TF_LITE_MUL(reference_ops, Mul);
} else {
diff --git a/tensorflow/contrib/lite/kernels/neg.cc b/tensorflow/contrib/lite/kernels/neg.cc
index 4124c05388..0ddd0644f5 100644
--- a/tensorflow/contrib/lite/kernels/neg.cc
+++ b/tensorflow/contrib/lite/kernels/neg.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/one_hot.cc b/tensorflow/contrib/lite/kernels/one_hot.cc
index 9ff3dca932..910aed6f14 100644
--- a/tensorflow/contrib/lite/kernels/one_hot.cc
+++ b/tensorflow/contrib/lite/kernels/one_hot.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
index 1c728a4733..90a915bb02 100644
--- a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
+++ b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc
@@ -101,8 +101,6 @@ class LSTMOpModel : public SingleOpModel {
input_cell_state_ =
AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true);
- output_state_ = AddOutput(TensorType_FLOAT32);
- cell_state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_LSTM, BuiltinOptions_LSTMOptions,
@@ -180,22 +178,6 @@ class LSTMOpModel : public SingleOpModel {
PopulateTensor(projection_bias_, f);
}
- void ResetOutputState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(output_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
- void ResetCellState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(cell_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
void SetInput(int offset, float* begin, float* end) {
PopulateTensor(input_, offset, begin, end);
}
@@ -238,8 +220,6 @@ class LSTMOpModel : public SingleOpModel {
int input_cell_state_;
int output_;
- int output_state_;
- int cell_state_;
int n_batch_;
int n_input_;
@@ -324,10 +304,6 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
lstm.SetCellToOutputWeights(
{-0.17135078, 0.82760304, 0.85573703, -0.77109635});
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
// Verify the model by unpacking it.
lstm.Verify();
}
diff --git a/tensorflow/contrib/lite/kernels/pack.cc b/tensorflow/contrib/lite/kernels/pack.cc
index cc326a7d51..4cb98fdd19 100644
--- a/tensorflow/contrib/lite/kernels/pack.cc
+++ b/tensorflow/contrib/lite/kernels/pack.cc
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc
index 4be8c243c1..0d939405f6 100644
--- a/tensorflow/contrib/lite/kernels/pad.cc
+++ b/tensorflow/contrib/lite/kernels/pad.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -92,8 +92,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
op_context.constant_values->type);
}
- // TODO(nupurgarg): Our current implementations rely on the inputs being 4D.
- TF_LITE_ENSURE_EQ(context, op_context.dims, 4);
+ // TODO(nupurgarg): Current implementations rely on the inputs being <= 4D.
+ TF_LITE_ENSURE(context, op_context.dims <= 4);
// Exit early if paddings is a non-const tensor. Set output tensor to
// dynamic so output size can be determined in Eval.
@@ -134,12 +134,22 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
after_padding.push_back(paddings_data[idx * 2 + 1]);
}
-#define TF_LITE_PAD(type, scalar, pad_value) \
- type::PadV2(GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), before_padding, after_padding, \
- GetTensorData<scalar>(op_context.output), \
- GetTensorDims(op_context.output), pad_value)
-
+#define TF_LITE_PAD(type, scalar, pad_value) \
+ TF_LITE_ENSURE(context, before_padding.size() <= 4); \
+ TF_LITE_ENSURE(context, after_padding.size() <= 4); \
+ tflite::PadParams op_params; \
+ op_params.left_padding_count = before_padding.size(); \
+ op_params.right_padding_count = after_padding.size(); \
+ for (int i = 0; i < op_context.dims; ++i) { \
+ op_params.left_padding[i] = before_padding[op_context.dims - 1 - i]; \
+ op_params.right_padding[i] = after_padding[op_context.dims - 1 - i]; \
+ } \
+ const scalar pad_value_copy = pad_value; \
+ \
+ type::Pad(op_params, GetTensorShape(op_context.input), \
+ GetTensorData<scalar>(op_context.input), &pad_value_copy, \
+ GetTensorShape(op_context.output), \
+ GetTensorData<scalar>(op_context.output))
switch (op_context.input->type) {
case kTfLiteFloat32: {
float pad_value = op_context.constant_values == nullptr
diff --git a/tensorflow/contrib/lite/kernels/pad_test.cc b/tensorflow/contrib/lite/kernels/pad_test.cc
index f8b9064fbb..f663899713 100644
--- a/tensorflow/contrib/lite/kernels/pad_test.cc
+++ b/tensorflow/contrib/lite/kernels/pad_test.cc
@@ -193,7 +193,7 @@ TEST(PadOpTest, TooManyDimensions) {
PadOpConstModel({TensorType_FLOAT32, {1, 2, 3, 4, 5, 6, 7, 8, 9}}, {9, 2},
{1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9},
{TensorType_FLOAT32}),
- "dims != 4");
+ "dims <= 4");
}
TEST(PadOpTest, UnequalDimensions) {
@@ -221,6 +221,15 @@ TEST(PadOpTest, SimpleConstTest) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1}));
}
+TEST(PadOpTest, SimpleConst1DTest) {
+ PadOpConstModel m({TensorType_FLOAT32, {2}}, {1, 2}, {1, 2},
+ {TensorType_FLOAT32});
+ m.SetInput({2, 3});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 3, 0, 0}));
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({5}));
+}
+
TEST(PadOpTest, SimpleDynamicTest) {
PadOpDynamicModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, {4, 2},
{TensorType_FLOAT32});
@@ -334,7 +343,7 @@ TEST(PadV2OpTest, TooManyDimensions) {
{TensorType_FLOAT32, {1, 2, 3, 4, 5, 6, 7, 8, 9}}, {9, 2},
{1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9}, 0.0,
{TensorType_FLOAT32}),
- "dims != 4");
+ "dims <= 4");
}
TEST(PadV2OpTest, UnequalDimensions) {
diff --git a/tensorflow/contrib/lite/kernels/padding.h b/tensorflow/contrib/lite/kernels/padding.h
index 3cb55f19a9..42b6b45d3b 100644
--- a/tensorflow/contrib/lite/kernels/padding.h
+++ b/tensorflow/contrib/lite/kernels/padding.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_
#define TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/kernels/pooling.cc b/tensorflow/contrib/lite/kernels/pooling.cc
index 29a5be0683..6451142391 100644
--- a/tensorflow/contrib/lite/kernels/pooling.cc
+++ b/tensorflow/contrib/lite/kernels/pooling.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/pow.cc b/tensorflow/contrib/lite/kernels/pow.cc
index 4a539c47a8..1e96cc80b1 100644
--- a/tensorflow/contrib/lite/kernels/pow.cc
+++ b/tensorflow/contrib/lite/kernels/pow.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -80,14 +80,14 @@ template <typename T>
void PowImpl(const TfLiteTensor* input1, const TfLiteTensor* input2,
TfLiteTensor* output, bool requires_broadcast) {
if (requires_broadcast) {
- reference_ops::BroadcastPow(GetTensorData<T>(input1), GetTensorDims(input1),
- GetTensorData<T>(input2), GetTensorDims(input2),
- GetTensorData<T>(output),
- GetTensorDims(output));
+ reference_ops::BroadcastPow4DSlow(
+ GetTensorShape(input1), GetTensorData<T>(input1),
+ GetTensorShape(input2), GetTensorData<T>(input2),
+ GetTensorShape(output), GetTensorData<T>(output));
} else {
- reference_ops::Pow(GetTensorData<T>(input1), GetTensorDims(input1),
- GetTensorData<T>(input2), GetTensorDims(input2),
- GetTensorData<T>(output), GetTensorDims(output));
+ reference_ops::Pow(GetTensorShape(input1), GetTensorData<T>(input1),
+ GetTensorShape(input2), GetTensorData<T>(input2),
+ GetTensorShape(output), GetTensorData<T>(output));
}
}
diff --git a/tensorflow/contrib/lite/kernels/reduce.cc b/tensorflow/contrib/lite/kernels/reduce.cc
index 29374a0c27..d94d821e87 100644
--- a/tensorflow/contrib/lite/kernels/reduce.cc
+++ b/tensorflow/contrib/lite/kernels/reduce.cc
@@ -13,9 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string.h>
+#include <limits>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -177,6 +178,9 @@ TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
case kTfLiteUInt8:
temp_sum->type = kTfLiteInt32;
break;
+ case kTfLiteBool:
+ temp_sum->type = kTfLiteBool;
+ break;
default:
return kTfLiteError;
}
@@ -204,6 +208,13 @@ TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
+TfLiteStatus PrepareAny(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+ const TfLiteTensor* input = GetInput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteBool);
+ return PrepareSimple(context, node);
+}
+
TfLiteStatus PrepareMean(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, PrepareSimple(context, node));
@@ -256,11 +267,27 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, int64_t, int64_t));
break;
case kTfLiteUInt8:
- TF_LITE_ENSURE_EQ(context, op_context.input->params.scale,
- op_context.output->params.scale);
- TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point,
- op_context.output->params.zero_point);
- TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, uint8_t, int));
+ if (op_context.input->params.zero_point ==
+ op_context.output->params.zero_point &&
+ op_context.input->params.scale == op_context.output->params.scale) {
+ TF_LITE_ENSURE(context, TF_LITE_MEAN(reference_ops, uint8_t, int));
+ } else {
+ TF_LITE_ENSURE(
+ context,
+ reference_ops::Mean<>(
+ GetTensorData<uint8_t>(op_context.input),
+ op_context.input->params.zero_point,
+ op_context.input->params.scale, op_context.input->dims->data,
+ op_context.input->dims->size,
+ GetTensorData<uint8_t>(op_context.output),
+ op_context.output->params.zero_point,
+ op_context.output->params.scale,
+ op_context.output->dims->data, op_context.output->dims->size,
+ GetTensorData<int>(op_context.axis), num_axis,
+ op_context.params->keep_dims, GetTensorData<int>(temp_index),
+ GetTensorData<int>(resolved_axis),
+ GetTensorData<int>(temp_sum)));
+ }
break;
default:
return kTfLiteError;
@@ -270,196 +297,125 @@ TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
-template <KernelType kernel_type>
-TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
- OpContext op_context(context, node);
- int num_axis = static_cast<int>(NumElements(op_context.axis));
+// The underlying logic for Reduce Sum/Prod/Max/Min/Any
+template <typename T>
+TfLiteStatus EvalLogic(TfLiteContext* context, TfLiteNode* node,
+ OpContext* op_context, T init_value,
+ T reducer(const T current, const T in)) {
+ int64_t num_axis = NumElements(op_context->axis);
TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
// Resize the output tensor if the output tensor is dynamic.
- if (IsDynamicTensor(op_context.output)) {
+ if (IsDynamicTensor(op_context->output)) {
TF_LITE_ENSURE_OK(context,
- ResizeTempAxis(context, &op_context, resolved_axis));
- TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+ ResizeTempAxis(context, op_context, resolved_axis));
+ TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, op_context));
}
-
-#define TF_LITE_SUM(kernel_type, data_type) \
- kernel_type::Sum<>( \
- GetTensorData<data_type>(op_context.input), \
- op_context.input->dims->data, op_context.input->dims->size, \
- GetTensorData<data_type>(op_context.output), \
- op_context.output->dims->data, op_context.output->dims->size, \
- GetTensorData<int>(op_context.axis), num_axis, \
- op_context.params->keep_dims, GetTensorData<int>(temp_index), \
- GetTensorData<int>(resolved_axis))
-
- if (kernel_type == kReference) {
- switch (op_context.input->type) {
- case kTfLiteFloat32:
- TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, float));
- break;
- case kTfLiteInt32:
- TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, int));
- break;
- case kTfLiteInt64:
- TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, int64_t));
- break;
- case kTfLiteUInt8:
- TF_LITE_ENSURE_EQ(context, op_context.input->params.scale,
- op_context.output->params.scale);
- TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point,
- op_context.output->params.zero_point);
- TF_LITE_ENSURE(context, TF_LITE_SUM(reference_ops, uint8_t));
- break;
- default:
- return kTfLiteError;
- }
+ if (op_context->input->type == kTfLiteUInt8) {
+ TF_LITE_ENSURE_EQ(context, op_context->input->params.scale,
+ op_context->output->params.scale);
+ TF_LITE_ENSURE_EQ(context, op_context->input->params.zero_point,
+ op_context->output->params.zero_point);
}
-#undef TF_LITE_SUM
+ TF_LITE_ENSURE(
+ context,
+ reference_ops::ReduceGeneric<T>(
+ GetTensorData<T>(op_context->input), op_context->input->dims->data,
+ op_context->input->dims->size, GetTensorData<T>(op_context->output),
+ op_context->output->dims->data, op_context->output->dims->size,
+ GetTensorData<int>(op_context->axis), num_axis,
+ op_context->params->keep_dims, GetTensorData<int>(temp_index),
+ GetTensorData<int>(resolved_axis), init_value, reducer));
return kTfLiteOk;
}
-template <KernelType kernel_type>
-TfLiteStatus EvalProd(TfLiteContext* context, TfLiteNode* node) {
- OpContext op_context(context, node);
- int64_t num_axis = NumElements(op_context.axis);
- TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
- TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
- // Resize the output tensor if the output tensor is dynamic.
- if (IsDynamicTensor(op_context.output)) {
- TF_LITE_ENSURE_OK(context,
- ResizeTempAxis(context, &op_context, resolved_axis));
- TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
- }
-
-#define TF_LITE_PROD(kernel_type, data_type) \
- kernel_type::ReduceProd<>( \
- GetTensorData<data_type>(op_context.input), \
- op_context.input->dims->data, op_context.input->dims->size, \
- GetTensorData<data_type>(op_context.output), \
- op_context.output->dims->data, op_context.output->dims->size, \
- GetTensorData<int>(op_context.axis), num_axis, \
- op_context.params->keep_dims, GetTensorData<int>(temp_index), \
- GetTensorData<int>(resolved_axis))
+enum ReduceType {
+ kSum,
+ kProd,
+ kMax,
+ kMin,
+ kAny,
+};
- if (kernel_type == kReference) {
- switch (op_context.input->type) {
- case kTfLiteFloat32:
- TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, float));
- break;
- case kTfLiteInt32:
- TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, int));
- break;
- case kTfLiteInt64:
- TF_LITE_ENSURE(context, TF_LITE_PROD(reference_ops, int64_t));
- break;
- case kTfLiteUInt8:
- // TODO(wangtz): uint8 reduce_prod is not yet supported.
- default:
- return kTfLiteError;
- }
+// Eval for determined input type and reduce type.
+template <typename T>
+TfLiteStatus EvalType(TfLiteContext* context, TfLiteNode* node,
+ OpContext* op_context, ReduceType reduce_type) {
+ switch (reduce_type) {
+ case kSum:
+ return EvalLogic<T>(
+ context, node, op_context, static_cast<T>(0),
+ [](const T current, const T in) -> T { return in + current; });
+ break;
+ case kProd:
+ return EvalLogic<T>(
+ context, node, op_context, static_cast<T>(1),
+ [](const T current, const T in) -> T { return in * current; });
+ break;
+ case kMax:
+ return EvalLogic<T>(context, node, op_context,
+ std::numeric_limits<T>::lowest(),
+ [](const T current, const T in) -> T {
+ return (in > current) ? in : current;
+ });
+ break;
+ case kMin:
+ return EvalLogic<T>(context, node, op_context,
+ std::numeric_limits<T>::max(),
+ [](const T current, const T in) -> T {
+ return (in < current) ? in : current;
+ });
+ break;
+ default:
+ return kTfLiteError;
}
-#undef TF_LITE_PROD
- return kTfLiteOk;
}
-template <KernelType kernel_type>
-TfLiteStatus EvalMax(TfLiteContext* context, TfLiteNode* node) {
- OpContext op_context(context, node);
- int64_t num_axis = NumElements(op_context.axis);
- TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
- TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
- // Resize the output tensor if the output tensor is dynamic.
- if (IsDynamicTensor(op_context.output)) {
- TF_LITE_ENSURE_OK(context,
- ResizeTempAxis(context, &op_context, resolved_axis));
- TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
- }
-
-#define TF_LITE_MAX(kernel_type, data_type) \
- kernel_type::ReduceMax<>( \
- GetTensorData<data_type>(op_context.input), \
- op_context.input->dims->data, op_context.input->dims->size, \
- GetTensorData<data_type>(op_context.output), \
- op_context.output->dims->data, op_context.output->dims->size, \
- GetTensorData<int>(op_context.axis), num_axis, \
- op_context.params->keep_dims, GetTensorData<int>(temp_index), \
- GetTensorData<int>(resolved_axis))
-
- if (kernel_type == kReference) {
- switch (op_context.input->type) {
- case kTfLiteFloat32:
- TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, float));
- break;
- case kTfLiteInt32:
- TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, int));
- break;
- case kTfLiteInt64:
- TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, int64_t));
- break;
- case kTfLiteUInt8:
- TF_LITE_ENSURE_EQ(context, op_context.input->params.scale,
- op_context.output->params.scale);
- TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point,
- op_context.output->params.zero_point);
- TF_LITE_ENSURE(context, TF_LITE_MAX(reference_ops, uint8_t));
- break;
- default:
- return kTfLiteError;
- }
+// Template specialization for bool type
+template <>
+TfLiteStatus EvalType<bool>(TfLiteContext* context, TfLiteNode* node,
+ OpContext* op_context, ReduceType reduce_type) {
+ switch (reduce_type) {
+ case kAny:
+ return EvalLogic<bool>(context, node, op_context, false,
+ [](const bool current, const bool in) -> bool {
+ return in || current;
+ });
+ break;
+ default:
+ return kTfLiteError;
}
-#undef TF_LITE_MAX
- return kTfLiteOk;
}
-template <KernelType kernel_type>
-TfLiteStatus EvalMin(TfLiteContext* context, TfLiteNode* node) {
- OpContext op_context(context, node);
- int64_t num_axis = NumElements(op_context.axis);
- TfLiteTensor* temp_index = GetTemporary(context, node, /*index=*/0);
- TfLiteTensor* resolved_axis = GetTemporary(context, node, /*index=*/1);
- // Resize the output tensor if the output tensor is dynamic.
- if (IsDynamicTensor(op_context.output)) {
- TF_LITE_ENSURE_OK(context,
- ResizeTempAxis(context, &op_context, resolved_axis));
- TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
+// The entry point that handles input types and then calls template functions to
+// handle ReduceType.
+template <KernelType kernel_type, ReduceType reduce_type>
+TfLiteStatus EvalGeneric(TfLiteContext* context, TfLiteNode* node) {
+ if (kernel_type != kReference) {
+ return kTfLiteOk;
}
-
-#define TF_LITE_MIN(kernel_type, data_type) \
- kernel_type::ReduceMin<>( \
- GetTensorData<data_type>(op_context.input), \
- op_context.input->dims->data, op_context.input->dims->size, \
- GetTensorData<data_type>(op_context.output), \
- op_context.output->dims->data, op_context.output->dims->size, \
- GetTensorData<int>(op_context.axis), num_axis, \
- op_context.params->keep_dims, GetTensorData<int>(temp_index), \
- GetTensorData<int>(resolved_axis))
-
- if (kernel_type == kReference) {
- switch (op_context.input->type) {
- case kTfLiteFloat32:
- TF_LITE_ENSURE(context, TF_LITE_MIN(reference_ops, float));
- break;
- case kTfLiteInt32:
- TF_LITE_ENSURE(context, TF_LITE_MIN(reference_ops, int));
- break;
- case kTfLiteInt64:
- TF_LITE_ENSURE(context, TF_LITE_MIN(reference_ops, int64_t));
- break;
- case kTfLiteUInt8:
- TF_LITE_ENSURE_EQ(context, op_context.input->params.scale,
- op_context.output->params.scale);
- TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point,
- op_context.output->params.zero_point);
- TF_LITE_ENSURE(context, TF_LITE_MIN(reference_ops, uint8_t));
- break;
- default:
- return kTfLiteError;
- }
+ OpContext op_context(context, node);
+ switch (op_context.input->type) {
+ case kTfLiteFloat32:
+ return EvalType<float>(context, node, &op_context, reduce_type);
+ break;
+ case kTfLiteInt32:
+ return EvalType<int>(context, node, &op_context, reduce_type);
+ break;
+ case kTfLiteInt64:
+ return EvalType<int64_t>(context, node, &op_context, reduce_type);
+ break;
+ case kTfLiteUInt8:
+ return EvalType<uint8_t>(context, node, &op_context, reduce_type);
+ break;
+ case kTfLiteBool:
+ return EvalType<bool>(context, node, &op_context, reduce_type);
+ break;
+ default:
+ return kTfLiteError;
}
-#undef TF_LITE_MIN
- return kTfLiteOk;
}
+
} // namespace reduce
TfLiteRegistration* Register_MEAN_REF() {
@@ -470,30 +426,37 @@ TfLiteRegistration* Register_MEAN_REF() {
}
TfLiteRegistration* Register_SUM_REF() {
- static TfLiteRegistration r = {reduce::Init, reduce::Free,
- reduce::PrepareSimple,
- reduce::EvalSum<reduce::kReference>};
+ static TfLiteRegistration r = {
+ reduce::Init, reduce::Free, reduce::PrepareSimple,
+ reduce::EvalGeneric<reduce::kReference, reduce::kSum>};
return &r;
}
TfLiteRegistration* Register_REDUCE_PROD_REF() {
- static TfLiteRegistration r = {reduce::Init, reduce::Free,
- reduce::PrepareSimple,
- reduce::EvalProd<reduce::kReference>};
+ static TfLiteRegistration r = {
+ reduce::Init, reduce::Free, reduce::PrepareSimple,
+ reduce::EvalGeneric<reduce::kReference, reduce::kProd>};
return &r;
}
TfLiteRegistration* Register_REDUCE_MAX_REF() {
- static TfLiteRegistration r = {reduce::Init, reduce::Free,
- reduce::PrepareSimple,
- reduce::EvalMax<reduce::kReference>};
+ static TfLiteRegistration r = {
+ reduce::Init, reduce::Free, reduce::PrepareSimple,
+ reduce::EvalGeneric<reduce::kReference, reduce::kMax>};
return &r;
}
TfLiteRegistration* Register_REDUCE_MIN_REF() {
- static TfLiteRegistration r = {reduce::Init, reduce::Free,
- reduce::PrepareSimple,
- reduce::EvalMin<reduce::kReference>};
+ static TfLiteRegistration r = {
+ reduce::Init, reduce::Free, reduce::PrepareSimple,
+ reduce::EvalGeneric<reduce::kReference, reduce::kMin>};
+ return &r;
+}
+
+TfLiteRegistration* Register_REDUCE_ANY_REF() {
+ static TfLiteRegistration r = {
+ reduce::Init, reduce::Free, reduce::PrepareAny,
+ reduce::EvalGeneric<reduce::kReference, reduce::kAny>};
return &r;
}
@@ -505,6 +468,7 @@ TfLiteRegistration* Register_REDUCE_PROD() {
}
TfLiteRegistration* Register_REDUCE_MAX() { return Register_REDUCE_MAX_REF(); }
TfLiteRegistration* Register_REDUCE_MIN() { return Register_REDUCE_MIN_REF(); }
+TfLiteRegistration* Register_REDUCE_ANY() { return Register_REDUCE_ANY_REF(); }
} // namespace builtin
} // namespace ops
diff --git a/tensorflow/contrib/lite/kernels/reduce_test.cc b/tensorflow/contrib/lite/kernels/reduce_test.cc
index d9aca64356..6d289b14d8 100644
--- a/tensorflow/contrib/lite/kernels/reduce_test.cc
+++ b/tensorflow/contrib/lite/kernels/reduce_test.cc
@@ -198,6 +198,35 @@ class MinOpDynamicModel : public BaseOpModel {
}
};
+// Model for the tests case where axis is a const tensor.
+class AnyOpConstModel : public BaseOpModel {
+ public:
+ AnyOpConstModel(const TensorData& input, const TensorData& output,
+ std::initializer_list<int> axis_shape,
+ std::initializer_list<int> axis, bool keep_dims) {
+ input_ = AddInput(input);
+ axis_ = AddConstInput(TensorType_INT32, axis, axis_shape);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_REDUCE_ANY, BuiltinOptions_ReducerOptions,
+ CreateReducerOptions(builder_, keep_dims).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+};
+
+// Model for the tests case where axis is a dynamic tensor.
+class AnyOpDynamicModel : public BaseOpModel {
+ public:
+ AnyOpDynamicModel(const TensorData& input, const TensorData& output,
+ const TensorData& axis, bool keep_dims) {
+ input_ = AddInput(input);
+ axis_ = AddInput(axis);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_REDUCE_ANY, BuiltinOptions_ReducerOptions,
+ CreateReducerOptions(builder_, keep_dims).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+};
+
// for quantized Add, the error shouldn't exceed step
float GetTolerance(int min, int max) { return (max - min) / 255.0; }
@@ -338,6 +367,33 @@ TEST(DynamicUint8MeanOpTest, KeepDims) {
ElementsAreArray(ArrayFloatNear({9.2815, 0.3695}, kQuantizedTolerance)));
}
+TEST(DynamicUint8MeanOpTest, QuantizedScalar) {
+ float kQuantizedTolerance = GetTolerance(-10.0, 12.0);
+ std::vector<float> data = {0.643};
+ MeanOpDynamicModel m({TensorType_UINT8, {}, 0.0, 1.0},
+ {TensorType_UINT8, {}, -10.0, 12.0},
+ {TensorType_INT32, {1}}, true);
+ std::vector<int> axis = {0};
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), IsEmpty());
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({0.643}, kQuantizedTolerance)));
+}
+
+TEST(ConstUint8MeanOpTest, QuantizedKeepDims) {
+ float kQuantizedTolerance = GetTolerance(-5.0, 5.0);
+ std::vector<float> data = {0.4, 0.2, 0.3, 0.4, 0.5, 0.6};
+ MeanOpConstModel m({TensorType_UINT8, {3, 2}, 0.0, 1.0},
+ {TensorType_UINT8, {3}, -5.0, 5.0}, {1}, {1}, true);
+ m.QuantizeAndPopulate<uint8_t>(m.Input(), data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1}));
+ EXPECT_THAT(
+ m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({0.3, 0.35, 0.55}, kQuantizedTolerance)));
+}
+
// Tests for reduce_sum
TEST(ConstFloatSumOpTest, NotKeepDims) {
@@ -751,7 +807,7 @@ TEST(DynamicFloatMinOpTest, KeepDims) {
ElementsAreArray(ArrayFloatNear({1, 3, 5})));
}
-TEST(DynamicFloatMinOpTest, Scale) {
+TEST(DynamicFloatMinOpTest, Scalar) {
std::vector<float> data = {9.527};
MinOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}},
{TensorType_INT32, {1}}, true);
@@ -835,6 +891,68 @@ TEST(DynamicUint8MinOpTest, Scalar) {
ElementsAreArray(ArrayFloatNear({11.1294}, kQuantizedTolerance)));
}
+// Tests for reduce_any
+
+TEST(ConstAnyOpTest, NotKeepDims) {
+ std::vector<bool> data = {false, false, false, false, false, false,
+ false, true, false, false, false, true};
+ AnyOpConstModel m({TensorType_BOOL, {2, 3, 2}}, {TensorType_BOOL, {2}}, {4},
+ {1, 0, -3, -3}, false);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput<bool>(), ElementsAreArray({false, true}));
+}
+
+TEST(ConstAnyOpTest, KeepDims) {
+ std::vector<bool> data = {false, false, false, false, false, false,
+ false, true, false, false, false, true};
+ AnyOpConstModel m({TensorType_BOOL, {2, 3, 2}}, {TensorType_BOOL, {3}}, {2},
+ {0, 2}, true);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
+ EXPECT_THAT(m.GetOutput<bool>(), ElementsAreArray({true, false, true}));
+}
+
+TEST(DynamicAnyOpTest, NotKeepDims) {
+ std::vector<bool> data = {false, false, false, false, false, false,
+ false, true, false, false, false, true};
+ AnyOpDynamicModel m({TensorType_BOOL, {2, 3, 2}}, {TensorType_BOOL, {2}},
+ {TensorType_INT32, {4}}, false);
+ std::vector<int> axis = {1, 0, -3, -3};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2}));
+ EXPECT_THAT(m.GetOutput<bool>(), ElementsAreArray({false, true}));
+}
+
+TEST(DynamicAnyOpTest, KeepDims) {
+ std::vector<bool> data = {false, false, false, false, false, false,
+ false, true, false, false, false, true};
+ AnyOpDynamicModel m({TensorType_BOOL, {2, 3, 2}}, {TensorType_BOOL, {3}},
+ {TensorType_INT32, {2}}, true);
+ std::vector<int> axis = {0, 2};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1}));
+ EXPECT_THAT(m.GetOutput<bool>(), ElementsAreArray({true, false, true}));
+}
+
+TEST(DynamicAnyOpTest, Scalar) {
+ std::vector<bool> data = {false};
+ AnyOpDynamicModel m({TensorType_BOOL, {1}}, {TensorType_BOOL, {1}},
+ {TensorType_INT32, {1}}, true);
+ std::vector<int> axis = {0};
+ m.SetAxis(axis);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1}));
+ EXPECT_THAT(m.GetOutput<bool>(), ElementsAreArray({false}));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc
index 10d1fcc5bc..c66959fdf4 100644
--- a/tensorflow/contrib/lite/kernels/register.cc
+++ b/tensorflow/contrib/lite/kernels/register.cc
@@ -22,8 +22,10 @@ namespace ops {
namespace custom {
TfLiteRegistration* Register_AUDIO_SPECTROGRAM();
+TfLiteRegistration* Register_LAYER_NORM_LSTM();
TfLiteRegistration* Register_MFCC();
TfLiteRegistration* Register_DETECTION_POSTPROCESS();
+TfLiteRegistration* Register_RELU_1();
} // namespace custom
@@ -95,6 +97,7 @@ TfLiteRegistration* Register_SUM();
TfLiteRegistration* Register_REDUCE_PROD();
TfLiteRegistration* Register_REDUCE_MAX();
TfLiteRegistration* Register_REDUCE_MIN();
+TfLiteRegistration* Register_REDUCE_ANY();
TfLiteRegistration* Register_SELECT();
TfLiteRegistration* Register_SLICE();
TfLiteRegistration* Register_SIN();
@@ -113,6 +116,8 @@ TfLiteRegistration* Register_ONE_HOT();
TfLiteRegistration* Register_LOGICAL_OR();
TfLiteRegistration* Register_LOGICAL_AND();
TfLiteRegistration* Register_LOGICAL_NOT();
+TfLiteRegistration* Register_UNPACK();
+TfLiteRegistration* Register_FLOOR_DIV();
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
context->ReportError(
@@ -221,6 +226,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_REDUCE_PROD, Register_REDUCE_PROD());
AddBuiltin(BuiltinOperator_REDUCE_MAX, Register_REDUCE_MAX());
AddBuiltin(BuiltinOperator_REDUCE_MIN, Register_REDUCE_MIN());
+ AddBuiltin(BuiltinOperator_REDUCE_ANY, Register_REDUCE_ANY());
AddBuiltin(BuiltinOperator_EXPAND_DIMS, Register_EXPAND_DIMS());
AddBuiltin(BuiltinOperator_SPARSE_TO_DENSE, Register_SPARSE_TO_DENSE());
AddBuiltin(BuiltinOperator_EQUAL, Register_EQUAL());
@@ -235,12 +241,16 @@ BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_LOGICAL_OR, Register_LOGICAL_OR());
AddBuiltin(BuiltinOperator_LOGICAL_AND, Register_LOGICAL_AND());
AddBuiltin(BuiltinOperator_LOGICAL_NOT, Register_LOGICAL_NOT());
+ AddBuiltin(BuiltinOperator_UNPACK, Register_UNPACK());
+ AddBuiltin(BuiltinOperator_FLOOR_DIV, Register_FLOOR_DIV());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
AddCustom("Mfcc", tflite::ops::custom::Register_MFCC());
AddCustom("AudioSpectrogram",
tflite::ops::custom::Register_AUDIO_SPECTROGRAM());
+ AddCustom("LayerNormLstm", tflite::ops::custom::Register_LAYER_NORM_LSTM());
+ AddCustom("Relu1", tflite::ops::custom::Register_RELU_1());
AddCustom("TFLite_Detection_PostProcess",
tflite::ops::custom::Register_DETECTION_POSTPROCESS());
}
diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h
index 0296152d68..61856ab9de 100644
--- a/tensorflow/contrib/lite/kernels/register.h
+++ b/tensorflow/contrib/lite/kernels/register.h
@@ -16,8 +16,9 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_
#include <unordered_map>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/mutable_op_resolver.h"
namespace tflite {
namespace ops {
diff --git a/tensorflow/contrib/lite/kernels/relu1.cc b/tensorflow/contrib/lite/kernels/relu1.cc
new file mode 100644
index 0000000000..abafee2d57
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/relu1.cc
@@ -0,0 +1,59 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace relu1 {
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+ const TfLiteTensor* input = GetInput(context, node, 0);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ output->type = input->type;
+ return context->ResizeTensor(context, output,
+ TfLiteIntArrayCopy(input->dims));
+}
+
+// This is derived from lite/kernels/activations.cc.
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+ const int elements = NumElements(input);
+ const float* in = input->data.f;
+ const float* in_end = in + elements;
+ float* out = output->data.f;
+ for (; in < in_end; ++in, ++out) {
+ *out = std::min(std::max(0.f, *in), 1.f);
+ }
+ return kTfLiteOk;
+}
+
+} // namespace relu1
+
+TfLiteRegistration* Register_RELU_1() {
+ static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
+ relu1::Prepare, relu1::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/relu1_test.cc b/tensorflow/contrib/lite/kernels/relu1_test.cc
new file mode 100644
index 0000000000..c1e0149c20
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/relu1_test.cc
@@ -0,0 +1,79 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "flatbuffers/flexbuffers.h" // flatbuffers
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_RELU_1();
+
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BaseActivationsOpModel : public SingleOpModel {
+ public:
+ explicit BaseActivationsOpModel(const TensorData& input) {
+ input_ = AddInput(input);
+ output_ = AddOutput({input.type, {}});
+ flexbuffers::Builder fbb;
+ fbb.Map([&]() {});
+ fbb.Finish();
+ SetCustomOp("RELU_1", fbb.GetBuffer(), Register_RELU_1);
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ protected:
+ int input_;
+ int output_;
+};
+
+class FloatActivationsOpModel : public BaseActivationsOpModel {
+ public:
+ using BaseActivationsOpModel::BaseActivationsOpModel;
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor(input_, data);
+ }
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+TEST(FloatActivationsOpTest, Relu1) {
+ FloatActivationsOpModel m(/*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
+ m.SetInput({
+ 0.0, -0.6, 0.2, -0.4, //
+ 0.3, -2.0, 1.1, -0.1, //
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({
+ 0.0, 0.0, 0.2, 0.0, //
+ 0.3, 0.0, 1.0, 0.0, //
+ }));
+}
+
+} // namespace
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/reshape.cc b/tensorflow/contrib/lite/kernels/reshape.cc
index 49ba0571e2..f41147b2d6 100644
--- a/tensorflow/contrib/lite/kernels/reshape.cc
+++ b/tensorflow/contrib/lite/kernels/reshape.cc
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
index 86c4cd3ee8..fb045d15f3 100644
--- a/tensorflow/contrib/lite/kernels/resize_bilinear.cc
+++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -88,11 +88,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
if (output->type == kTfLiteFloat32) {
-#define TF_LITE_RESIZE_BILINEAR(type, datatype) \
- type::ResizeBilinear(GetTensorData<datatype>(input), GetTensorDims(input), \
- GetTensorData<int32>(size), GetTensorDims(size), \
- GetTensorData<datatype>(output), GetTensorDims(output), \
- params->align_corners)
+#define TF_LITE_RESIZE_BILINEAR(type, datatype) \
+ tflite::ResizeBilinearParams op_params; \
+ op_params.align_corners = params->align_corners; \
+ type::ResizeBilinear(op_params, GetTensorShape(input), \
+ GetTensorData<datatype>(input), GetTensorShape(size), \
+ GetTensorData<int32>(size), GetTensorShape(output), \
+ GetTensorData<datatype>(output))
if (kernel_type == kReference) {
TF_LITE_RESIZE_BILINEAR(reference_ops, float);
diff --git a/tensorflow/contrib/lite/kernels/select.cc b/tensorflow/contrib/lite/kernels/select.cc
index 3cdb5db209..3959502d91 100644
--- a/tensorflow/contrib/lite/kernels/select.cc
+++ b/tensorflow/contrib/lite/kernels/select.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/shape.cc b/tensorflow/contrib/lite/kernels/shape.cc
index dbcd2ef004..66d4c9e5c1 100644
--- a/tensorflow/contrib/lite/kernels/shape.cc
+++ b/tensorflow/contrib/lite/kernels/shape.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/skip_gram.cc b/tensorflow/contrib/lite/kernels/skip_gram.cc
index c90a15b3a2..de80a4016e 100644
--- a/tensorflow/contrib/lite/kernels/skip_gram.cc
+++ b/tensorflow/contrib/lite/kernels/skip_gram.cc
@@ -33,8 +33,8 @@ limitations under the License.
#include <string>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
#include "tensorflow/contrib/lite/string_util.h"
diff --git a/tensorflow/contrib/lite/kernels/slice.cc b/tensorflow/contrib/lite/kernels/slice.cc
index 6a20e802a9..ccfee41b9c 100644
--- a/tensorflow/contrib/lite/kernels/slice.cc
+++ b/tensorflow/contrib/lite/kernels/slice.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include <string.h>
#include <cmath>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -159,10 +159,28 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
sizes.push_back(1);
}
-#define TF_LITE_SLICE(data_type) \
- optimized_ops::Slice<data_type>( \
- GetTensorData<data_type>(input), GetTensorDims(input), begins, sizes, \
- GetTensorData<data_type>(output), GetTensorDims(output))
+ // The original Slice op implementation only accepted 4-D sizes. That
+ // constraint is, for the present, maintained here.
+ //
+ // The dimensions in the kernel used to be in reverse-order, and TFLite
+ // arranged the begins and sizes vectors accordingly. This macro incorporates
+ // the needed reversing.
+#define TF_LITE_SLICE(data_type) \
+ { \
+ TF_LITE_ENSURE_EQ(context, begins.size(), 4); \
+ TF_LITE_ENSURE_EQ(context, sizes.size(), 4); \
+ tflite::SliceParams op_params; \
+ op_params.begin_count = 4; \
+ op_params.size_count = 4; \
+ for (int i = 0; i < 4; ++i) { \
+ op_params.begin[i] = begins[3 - i]; \
+ op_params.size[i] = sizes[3 - i]; \
+ } \
+ \
+ optimized_ops::Slice<data_type>( \
+ op_params, GetTensorShape(input), GetTensorData<data_type>(input), \
+ GetTensorShape(output), GetTensorData<data_type>(output)); \
+ }
switch (input->type) {
case kTfLiteFloat32:
diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
index 03079f1c3b..3a10d2e60c 100644
--- a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
+++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -114,14 +114,16 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
}
#define TF_LITE_SPACE_TO_BATCH_ND(type, scalar, pad_value) \
- type::SpaceToBatchND(GetTensorData<scalar>(op_context.input), \
- GetTensorDims(op_context.input), \
+ tflite::SpaceToBatchParams op_params; \
+ op_params.output_offset = pad_value; \
+ type::SpaceToBatchND(op_params, GetTensorShape(op_context.input), \
+ GetTensorData<scalar>(op_context.input), \
+ GetTensorShape(op_context.block_shape), \
GetTensorData<int32_t>(op_context.block_shape), \
- GetTensorDims(op_context.block_shape), \
+ GetTensorShape(op_context.paddings), \
GetTensorData<int32_t>(op_context.paddings), \
- GetTensorDims(op_context.paddings), \
- GetTensorData<scalar>(op_context.output), \
- GetTensorDims(op_context.output), pad_value)
+ GetTensorShape(op_context.output), \
+ GetTensorData<scalar>(op_context.output))
switch (op_context.input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
if (kernel_type == kReference) {
diff --git a/tensorflow/contrib/lite/kernels/space_to_depth.cc b/tensorflow/contrib/lite/kernels/space_to_depth.cc
index 9dbe9b9eda..64c56c017b 100644
--- a/tensorflow/contrib/lite/kernels/space_to_depth.cc
+++ b/tensorflow/contrib/lite/kernels/space_to_depth.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
@@ -79,10 +79,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-#define TF_LITE_SPACE_TO_DEPTH(type, scalar) \
- type::SpaceToDepth<scalar>( \
- GetTensorData<scalar>(input), GetTensorDims(input), params->block_size, \
- GetTensorData<scalar>(output), GetTensorDims(output))
+#define TF_LITE_SPACE_TO_DEPTH(type, scalar) \
+ tflite::SpaceToDepthParams op_params; \
+ op_params.block_size = params->block_size; \
+ type::SpaceToDepth(op_params, GetTensorShape(input), \
+ GetTensorData<scalar>(input), GetTensorShape(output), \
+ GetTensorData<scalar>(output))
switch (input->type) { // Already know in/out types are same.
case kTfLiteFloat32:
if (kernel_type == kReference) {
diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
index fec2a6f0d9..178568e07c 100644
--- a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
+++ b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/split.cc b/tensorflow/contrib/lite/kernels/split.cc
index b144486041..719e2dc606 100644
--- a/tensorflow/contrib/lite/kernels/split.cc
+++ b/tensorflow/contrib/lite/kernels/split.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
diff --git a/tensorflow/contrib/lite/kernels/squeeze.cc b/tensorflow/contrib/lite/kernels/squeeze.cc
index 09a5662fd9..080c51cd18 100644
--- a/tensorflow/contrib/lite/kernels/squeeze.cc
+++ b/tensorflow/contrib/lite/kernels/squeeze.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc
index bed2117f9a..87ffcc4110 100644
--- a/tensorflow/contrib/lite/kernels/strided_slice.cc
+++ b/tensorflow/contrib/lite/kernels/strided_slice.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include <string.h>
#include <cmath>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc
index 77a1f59689..1be0c83f17 100644
--- a/tensorflow/contrib/lite/kernels/sub.cc
+++ b/tensorflow/contrib/lite/kernels/sub.cc
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc
index 9e8ed3cbf3..9903fd5c35 100644
--- a/tensorflow/contrib/lite/kernels/svdf.cc
+++ b/tensorflow/contrib/lite/kernels/svdf.cc
@@ -23,8 +23,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -105,16 +105,11 @@ constexpr int kInputTensor = 0;
constexpr int kWeightsFeatureTensor = 1;
constexpr int kWeightsTimeTensor = 2;
constexpr int kBiasTensor = 3;
-
-// * If the node has 5 inputs the following tensor is used as state tensor.
-// This is defined to be a variable tensor, and will be modified by this op.
+// This is a variable tensor, and will be modified by this op.
constexpr int kInputActivationStateTensor = 4;
-// Output tensors.
-// * If node has 4 inputs, kStateTensor will be used as state tensor.
-// * If node has 5 inputs, kStateTensor is ignored.
-constexpr int kStateTensor = 0;
-constexpr int kOutputTensor = 1;
+// Output tensor.
+constexpr int kOutputTensor = 0;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* op_data = new OpData();
@@ -134,21 +129,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
int scratch_tensor_index = op_data->scratch_tensor_index;
// Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
- bool use_input_variable_states;
- if (node->inputs->size == 5) {
- use_input_variable_states = true;
- op_data->activation_state_tensor_index =
- node->inputs->data[kInputActivationStateTensor];
- } else if (node->inputs->size == 4) {
- use_input_variable_states = false;
- op_data->activation_state_tensor_index = node->outputs->data[kStateTensor];
- } else {
- context->ReportError(context,
- "The SVDF kernel expects 4 or 5 inputs. Got %d inputs",
- node->inputs->size);
- return kTfLiteError;
- }
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
+ op_data->activation_state_tensor_index =
+ node->inputs->data[kInputActivationStateTensor];
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* weights_feature =
@@ -178,28 +162,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
&context->tensors[op_data->activation_state_tensor_index];
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- if (use_input_variable_states) {
- // Check the shape of input state tensors.
- TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2);
- TF_LITE_ENSURE_EQ(context, SizeOfDimension(activation_state, 0),
- batch_size);
- TF_LITE_ENSURE_EQ(context, SizeOfDimension(activation_state, 1),
- memory_size * num_filters);
- } else {
- // Resize activation_state.
- // For each batch, the state is a 2-D tensor: memory_size * num_filters
- // The left most column is used to save current cycle activation.
- // The right most column is used to save temporary output which will be
- // reduced to num_units outputs.
- TfLiteIntArray* state_size_array = TfLiteIntArrayCreate(2);
- state_size_array->data[0] = batch_size;
- state_size_array->data[1] = memory_size * num_filters;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_state,
- state_size_array));
-
- // Mark state as a persistent tensor.
- activation_state->allocation_type = kTfLiteArenaRwPersistent;
- }
+ // Check the shape of input state tensors.
+ TF_LITE_ENSURE_EQ(context, NumDimensions(activation_state), 2);
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(activation_state, 0), batch_size);
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(activation_state, 1),
+ memory_size * num_filters);
// Resize output.
TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
diff --git a/tensorflow/contrib/lite/kernels/svdf_test.cc b/tensorflow/contrib/lite/kernels/svdf_test.cc
index e485938343..6d60dc63f4 100644
--- a/tensorflow/contrib/lite/kernels/svdf_test.cc
+++ b/tensorflow/contrib/lite/kernels/svdf_test.cc
@@ -145,7 +145,6 @@ class BaseSVDFOpModel : public SingleOpModel {
activation_state_ = AddInput(
TensorData{TensorType_FLOAT32, {batches, memory_size * num_filters}},
/*is_variable=*/true);
- state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(
BuiltinOperator_SVDF, BuiltinOptions_SVDFOptions,
@@ -187,7 +186,6 @@ class BaseSVDFOpModel : public SingleOpModel {
int weights_time_;
int bias_;
int activation_state_;
- int state_;
int output_;
int batches_;
diff --git a/tensorflow/contrib/lite/kernels/tile.cc b/tensorflow/contrib/lite/kernels/tile.cc
index 5181a8f89a..49421eb870 100644
--- a/tensorflow/contrib/lite/kernels/tile.cc
+++ b/tensorflow/contrib/lite/kernels/tile.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/tile_test.cc b/tensorflow/contrib/lite/kernels/tile_test.cc
index 4f78c224e5..e73ca7b750 100644
--- a/tensorflow/contrib/lite/kernels/tile_test.cc
+++ b/tensorflow/contrib/lite/kernels/tile_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/topk_v2.cc b/tensorflow/contrib/lite/kernels/topk_v2.cc
index 2dd760bbfe..6c38b6739e 100644
--- a/tensorflow/contrib/lite/kernels/topk_v2.cc
+++ b/tensorflow/contrib/lite/kernels/topk_v2.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <algorithm>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
diff --git a/tensorflow/contrib/lite/kernels/topk_v2_test.cc b/tensorflow/contrib/lite/kernels/topk_v2_test.cc
index 2abb89b617..16106fdafe 100644
--- a/tensorflow/contrib/lite/kernels/topk_v2_test.cc
+++ b/tensorflow/contrib/lite/kernels/topk_v2_test.cc
@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
diff --git a/tensorflow/contrib/lite/kernels/transpose.cc b/tensorflow/contrib/lite/kernels/transpose.cc
index 800b0563d7..95359962e0 100644
--- a/tensorflow/contrib/lite/kernels/transpose.cc
+++ b/tensorflow/contrib/lite/kernels/transpose.cc
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/transpose_conv.cc b/tensorflow/contrib/lite/kernels/transpose_conv.cc
index a9baa5c698..6f2d98ede8 100644
--- a/tensorflow/contrib/lite/kernels/transpose_conv.cc
+++ b/tensorflow/contrib/lite/kernels/transpose_conv.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
index 0acd705950..63817bd886 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
@@ -64,10 +64,14 @@ constexpr int kProjectionWeightsTensor = 16; // Optional
// Projection bias tensor of size {n_output}
constexpr int kProjectionBiasTensor = 17; // Optional
+// Stateful input tensors that are variables and will be modified by the Op.
+// Activation state tensor of size {n_batch, n_output}
+constexpr int kInputActivationStateTensor = 18;
+// Cell state tensor of size {n_batch, n_cell}
+constexpr int kInputCellStateTensor = 19;
+
// Output tensors.
-constexpr int kOutputStateTensor = 0;
-constexpr int kCellStateTensor = 1;
-constexpr int kOutputTensor = 2;
+constexpr int kOutputTensor = 0;
// Temporary tensors
enum TemporaryTensor {
@@ -82,7 +86,7 @@ enum TemporaryTensor {
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
- auto* scratch_tensor_index = new int;
+ auto* scratch_tensor_index = new int();
context->AddTensors(context, kNumTemporaryTensors, scratch_tensor_index);
return scratch_tensor_index;
}
@@ -247,8 +251,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
// Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 18);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 3);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 20);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
@@ -276,12 +280,21 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input,
n_output, n_cell));
- // Get the pointer to output, output_state and cell_state buffer tensors.
+ // Get the pointer to output, activation_state and cell_state buffer tensors.
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
- TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
- // Resize the output, output_state and cell_state tensors.
+ TfLiteTensor* activation_state =
+ GetVariableInput(context, node, kInputActivationStateTensor);
+ TfLiteTensor* cell_state =
+ GetVariableInput(context, node, kInputCellStateTensor);
+
+ // Check the shape of input state tensors.
+ // These tensor may be 1D or 2D. It's fine as long as the total size is
+ // correct.
+ TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output);
+ TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
+
+ // Resize the output tensors.
TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
output_size->data[0] = max_time;
output_size->data[1] = n_batch;
@@ -289,22 +302,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, output, output_size));
- TfLiteIntArray* output_state_size = TfLiteIntArrayCreate(2);
- output_state_size->data[0] = n_batch;
- output_state_size->data[1] = n_output;
- TF_LITE_ENSURE_OK(
- context, context->ResizeTensor(context, output_state, output_state_size));
-
- TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2);
- cell_size->data[0] = n_batch;
- cell_size->data[1] = n_cell;
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, cell_state, cell_size));
-
- // Mark state tensors as persistent tensors.
- output_state->allocation_type = kTfLiteArenaRwPersistent;
- cell_state->allocation_type = kTfLiteArenaRwPersistent;
-
// The weights are of consistent type, so it suffices to check one.
// TODO(mirkov): create a utility/macro for this check, so all Ops can use it.
const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 &&
@@ -340,7 +337,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
if (is_hybrid_op) {
// Allocate temporary tensors to store quantized values of input,
- // output_state and cell_state tensors.
+ // activation_state and cell_state tensors.
node->temporaries->data[kInputQuantized] =
*scratch_tensor_index + kInputQuantized;
TfLiteTensor* input_quantized =
@@ -354,17 +351,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
node->temporaries->data[kOutputStateQuantized] =
*scratch_tensor_index + kOutputStateQuantized;
- TfLiteTensor* output_state_quantized =
+ TfLiteTensor* activation_state_quantized =
GetTemporary(context, node, kOutputStateQuantized);
- output_state_quantized->type = kTfLiteUInt8;
- output_state_quantized->allocation_type = kTfLiteArenaRw;
- if (!TfLiteIntArrayEqual(output_state_quantized->dims,
- output_state->dims)) {
- TfLiteIntArray* output_state_quantized_size =
- TfLiteIntArrayCopy(output_state->dims);
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, output_state_quantized,
- output_state_quantized_size));
+ activation_state_quantized->type = kTfLiteUInt8;
+ activation_state_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(activation_state_quantized->dims,
+ activation_state->dims)) {
+ TfLiteIntArray* activation_state_quantized_size =
+ TfLiteIntArrayCopy(activation_state->dims);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, activation_state_quantized,
+ activation_state_quantized_size));
}
node->temporaries->data[kCellStateQuantized] =
*scratch_tensor_index + kCellStateQuantized;
@@ -449,7 +446,7 @@ TfLiteStatus EvalFloat(
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
- TfLiteTensor* output_state, TfLiteTensor* cell_state,
+ TfLiteTensor* activation_state, TfLiteTensor* cell_state,
TfLiteTensor* output) {
const int max_time = input->dims->data[0];
const int n_batch = input->dims->data[1];
@@ -510,7 +507,7 @@ TfLiteStatus EvalFloat(
const float* cell_bias_ptr = cell_bias->data.f;
const float* output_gate_bias_ptr = output_gate_bias->data.f;
- float* output_state_ptr = output_state->data.f;
+ float* activation_state_ptr = activation_state->data.f;
float* cell_state_ptr = cell_state->data.f;
// Feed the sequence into the LSTM step-by-step.
@@ -527,7 +524,7 @@ TfLiteStatus EvalFloat(
cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
output_gate_bias_ptr, projection_weights_ptr, projection_bias_ptr,
- params, n_batch, n_cell, n_input, n_output, output_state_ptr,
+ params, n_batch, n_cell, n_input, n_output, activation_state_ptr,
cell_state_ptr, input_gate_scratch, forget_gate_scratch, cell_scratch,
output_gate_scratch, output_ptr_batch);
}
@@ -552,9 +549,9 @@ TfLiteStatus EvalHybrid(
const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
- TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
- TfLiteTensor* output_state, TfLiteTensor* cell_state,
- TfLiteTensor* output) {
+ TfLiteTensor* activation_state_quantized,
+ TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state,
+ TfLiteTensor* cell_state, TfLiteTensor* output) {
const int max_time = input->dims->data[0];
const int n_batch = input->dims->data[1];
const int n_input = input->dims->data[2];
@@ -655,14 +652,14 @@ TfLiteStatus EvalHybrid(
const float* cell_bias_ptr = cell_bias->data.f;
const float* output_gate_bias_ptr = output_gate_bias->data.f;
- float* output_state_ptr = output_state->data.f;
+ float* activation_state_ptr = activation_state->data.f;
float* cell_state_ptr = cell_state->data.f;
// Temporary storage for quantized values and scaling factors.
int8_t* quantized_input_ptr =
reinterpret_cast<int8_t*>(input_quantized->data.uint8);
- int8_t* quantized_output_state_ptr =
- reinterpret_cast<int8_t*>(output_state_quantized->data.uint8);
+ int8_t* quantized_activation_state_ptr =
+ reinterpret_cast<int8_t*>(activation_state_quantized->data.uint8);
int8_t* quantized_cell_state_ptr =
reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
float* scaling_factors_ptr = scaling_factors->data.f;
@@ -692,8 +689,8 @@ TfLiteStatus EvalHybrid(
n_input, n_output, input_gate_scratch, forget_gate_scratch,
cell_scratch, output_gate_scratch, scaling_factors_ptr,
prod_scaling_factors_ptr, recovered_cell_weights_ptr,
- quantized_input_ptr, quantized_output_state_ptr,
- quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
+ quantized_input_ptr, quantized_activation_state_ptr,
+ quantized_cell_state_ptr, activation_state_ptr, cell_state_ptr,
output_ptr_batch);
}
return kTfLiteOk;
@@ -744,8 +741,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// Index the scratch buffers pointers to the global scratch buffer.
TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
- TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor);
- TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor);
+ TfLiteTensor* activation_state =
+ GetVariableInput(context, node, kInputActivationStateTensor);
+ TfLiteTensor* cell_state =
+ GetVariableInput(context, node, kInputCellStateTensor);
+
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (input_to_output_weights->type) {
@@ -758,11 +758,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
cell_to_output_weights, input_gate_bias,
forget_gate_bias, cell_bias, output_gate_bias,
projection_weights, projection_bias, params,
- scratch_buffer, output_state, cell_state, output);
+ scratch_buffer, activation_state, cell_state, output);
}
case kTfLiteUInt8: {
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
- TfLiteTensor* output_state_quantized =
+ TfLiteTensor* activation_state_quantized =
GetTemporary(context, node, /*index=*/2);
TfLiteTensor* cell_state_quantized =
GetTemporary(context, node, /*index=*/3);
@@ -780,8 +780,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias,
projection_weights, projection_bias, params, scratch_buffer,
scaling_factors, prod_scaling_factors, recovered_cell_weights,
- input_quantized, output_state_quantized, cell_state_quantized,
- output_state, cell_state, output);
+ input_quantized, activation_state_quantized, cell_state_quantized,
+ activation_state, cell_state, output);
}
default:
context->ReportError(context, "Type %d is not currently supported.",
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
index de38bdef6f..cd3aac0532 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
@@ -100,8 +100,14 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
projection_bias_ = AddNullInput();
}
- output_state_ = AddOutput(TensorType_FLOAT32);
- cell_state_ = AddOutput(TensorType_FLOAT32);
+ // Adding the 2 input state tensors.
+ input_activation_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}},
+ /*is_variable=*/true);
+ input_cell_state_ =
+ AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}},
+ /*is_variable=*/true);
+
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
@@ -180,22 +186,6 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
PopulateTensor(projection_bias_, f);
}
- void ResetOutputState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(output_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
- void ResetCellState() {
- const int zero_buffer_size = n_cell_ * n_batch_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(cell_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
void SetInput(int offset, const float* begin, const float* end) {
PopulateTensor(input_, offset, const_cast<float*>(begin),
const_cast<float*>(end));
@@ -233,9 +223,10 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
int projection_weights_;
int projection_bias_;
+ int input_activation_state_;
+ int input_cell_state_;
+
int output_;
- int output_state_;
- int cell_state_;
int n_batch_;
int n_input_;
@@ -458,6 +449,9 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToInputWeights(input_to_input_weights_);
@@ -475,10 +469,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
@@ -519,6 +509,9 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToInputWeights(input_to_input_weights_);
@@ -536,10 +529,6 @@ TEST_F(NoCifgNoPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm,
/*tolerance=*/0.0157651);
}
@@ -629,6 +618,9 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToCellWeights(input_to_cell_weights_);
@@ -646,10 +638,6 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, LstmBlackBoxTest) {
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
lstm.SetCellToOutputWeights(cell_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
@@ -691,6 +679,9 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
{0, 0}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToCellWeights(input_to_cell_weights_);
@@ -708,10 +699,6 @@ TEST_F(CifgPeepholeNoProjectionNoClippingLstmTest, HybridLstmBlackBoxTest) {
lstm.SetCellToForgetWeights(cell_to_forget_weights_);
lstm.SetCellToOutputWeights(cell_to_output_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.03573);
}
@@ -1351,6 +1338,9 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToInputWeights(input_to_input_weights_);
@@ -1374,10 +1364,6 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, LstmBlackBoxTest) {
lstm.SetProjectionWeights(projection_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm);
}
@@ -1418,6 +1404,9 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) {
{n_output, n_cell}, // projection_weight tensor
{0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
});
lstm.SetInputToInputWeights(input_to_input_weights_);
@@ -1441,10 +1430,6 @@ TEST_F(NoCifgPeepholeProjectionClippingLstmTest, HybridLstmBlackBoxTest) {
lstm.SetProjectionWeights(projection_weights_);
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
-
VerifyGoldens(lstm_input_, lstm_golden_output_, &lstm, /*tolerance=*/0.00467);
}
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
index 0d6d29a171..744ee7c109 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include <iostream>
#include <limits>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/kernels/activation_functor.h"
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
@@ -31,12 +31,15 @@ namespace ops {
namespace builtin {
namespace unidirectional_sequence_rnn {
+// Input tensors.
constexpr int kInputTensor = 0;
constexpr int kWeightsTensor = 1;
constexpr int kRecurrentWeightsTensor = 2;
constexpr int kBiasTensor = 3;
-constexpr int kHiddenStateTensor = 0;
-constexpr int kOutputTensor = 1;
+constexpr int kHiddenStateTensor = 4;
+
+// Output tensor.
+constexpr int kOutputTensor = 0;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* scratch_tensor_index = new int;
@@ -50,14 +53,16 @@ void Free(TfLiteContext* context, void* buffer) {
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check we have all the inputs and outputs we need.
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input_weights = GetInput(context, node, kWeightsTensor);
const TfLiteTensor* recurrent_weights =
GetInput(context, node, kRecurrentWeightsTensor);
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
+ const TfLiteTensor* hidden_state =
+ GetInput(context, node, kHiddenStateTensor);
// Check all the parameters of tensor match within themselves and match the
// input configuration.
@@ -74,20 +79,12 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, input_weights->type, recurrent_weights->type);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(hidden_state), 2);
+ TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[0], batch_size);
+ TF_LITE_ENSURE_EQ(context, hidden_state->dims->data[1], num_units);
- TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
- // Resize state.
- TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2);
- hidden_state_size_array->data[0] = batch_size;
- hidden_state_size_array->data[1] = num_units;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, hidden_state,
- hidden_state_size_array));
-
- // Mark hidden state as a persistent tensor.
- hidden_state->allocation_type = kTfLiteArenaRwPersistent;
-
// Resize output.
TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(3);
output_size_array->data[0] = (time_major) ? max_time : batch_size;
@@ -276,7 +273,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* recurrent_weights =
GetInput(context, node, kRecurrentWeightsTensor);
const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
- TfLiteTensor* hidden_state = GetOutput(context, node, kHiddenStateTensor);
+ // The hidden_state is a variable input tensor that can be modified.
+ TfLiteTensor* hidden_state =
+ const_cast<TfLiteTensor*>(GetInput(context, node, kHiddenStateTensor));
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
switch (input_weights->type) {
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
index 0adab837b0..6b48e3fff7 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc
@@ -183,7 +183,7 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
weights_ = AddInput(weights);
recurrent_weights_ = AddInput(recurrent_weights);
bias_ = AddInput(TensorType_FLOAT32);
- hidden_state_ = AddOutput(TensorType_FLOAT32);
+ hidden_state_ = AddInput(TensorType_FLOAT32, true);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
BuiltinOptions_SequenceRNNOptions,
@@ -194,12 +194,14 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
BuildInterpreter({{sequence_len_, batches_, input_size_},
{units_, input_size_},
{units_, units_},
- {units_}});
+ {units_},
+ {batches_, units}});
} else {
BuildInterpreter({{batches_, sequence_len_, input_size_},
{units_, input_size_},
{units_, units_},
- {units_}});
+ {units_},
+ {batches_, units_}});
}
}
@@ -221,14 +223,6 @@ class UnidirectionalRNNOpModel : public SingleOpModel {
PopulateTensor(input_, offset, begin, end);
}
- void ResetHiddenState() {
- const int zero_buffer_size = units_ * batches_;
- std::unique_ptr<float[]> zero_buffer(new float[zero_buffer_size]);
- memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float));
- PopulateTensor(hidden_state_, 0, zero_buffer.get(),
- zero_buffer.get() + zero_buffer_size);
- }
-
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
int input_size() { return input_size_; }
@@ -273,7 +267,6 @@ TEST(UnidirectionalRNNOpTest, BlackBoxTest) {
rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
float* batch_start = rnn_input;
@@ -299,7 +292,6 @@ TEST(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTest) {
rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
float* batch_start = rnn_input;
@@ -326,7 +318,6 @@ TEST(UnidirectionalRNNOpTest, TimeMajorBlackBoxTest) {
rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
for (int i = 0; i < rnn.sequence_len(); i++) {
float* batch_start = rnn_input + i * rnn.input_size();
@@ -356,7 +347,6 @@ TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTest) {
rnn.SetWeights(rnn_weights);
rnn.SetBias(rnn_bias);
rnn.SetRecurrentWeights(rnn_recurrent_weights);
- rnn.ResetHiddenState();
for (int i = 0; i < rnn.sequence_len(); i++) {
float* batch_start = rnn_input + i * rnn.input_size();
diff --git a/tensorflow/contrib/lite/kernels/unpack.cc b/tensorflow/contrib/lite/kernels/unpack.cc
new file mode 100644
index 0000000000..9ff06f8331
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/unpack.cc
@@ -0,0 +1,130 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace unpack {
+namespace {
+
+constexpr int kInputTensor = 0;
+
+// Op data for unpack op.
+struct OpData {
+ int num;
+ int axis;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* data = new OpData;
+ data->axis = 0;
+ return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ const OpData* data = reinterpret_cast<OpData*>(node->builtin_data);
+
+ TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), data->num);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
+ TF_LITE_ENSURE(context, NumDimensions(input) > 1);
+ TF_LITE_ENSURE(context, NumDimensions(input) > data->axis);
+ // TODO(renjieliu): Support negative axis.
+ TF_LITE_ENSURE(context, data->axis >= 0);
+ if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32) {
+ context->ReportError(context,
+ "Currently pack only supports int32 and float32.");
+ return kTfLiteError;
+ }
+
+ const TfLiteIntArray* input_shape = input->dims;
+ // Num should be equal to the shape[axis].
+ // Resize outputs. rank will be R - 1.
+ TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(input) - 1);
+ int o = 0;
+ for (int index = 0; index < NumDimensions(input); ++index) {
+ if (index != data->axis) {
+ output_shape->data[o++] = input_shape->data[index];
+ }
+ }
+
+ TF_LITE_ENSURE_EQ(context, data->num, input_shape->data[data->axis]);
+ for (int i = 0; i < data->num; ++i) {
+ TfLiteIntArray* copied_output_shape = TfLiteIntArrayCopy(output_shape);
+ TfLiteTensor* output = GetOutput(context, node, i);
+ TF_LITE_ENSURE_EQ(context, output->type, input->type);
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, output, copied_output_shape));
+ }
+
+ TfLiteIntArrayFree(output_shape);
+ return kTfLiteOk;
+}
+
+template <typename T>
+void UnpackImpl(TfLiteContext* context, TfLiteNode* node,
+ const TfLiteTensor* input, int output_count, int axis) {
+ VectorOfTensors<T> all_outputs(*context, *node->outputs);
+ reference_ops::Unpack<T>(axis, GetTensorData<T>(input), GetTensorDims(input),
+ NumDimensions(input), output_count,
+ all_outputs.data(), **all_outputs.dims());
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const OpData* data = reinterpret_cast<OpData*>(node->builtin_data);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ switch (input->type) {
+ case kTfLiteFloat32: {
+ UnpackImpl<float>(context, node, input, data->num, data->axis);
+ break;
+ }
+ case kTfLiteInt32: {
+ UnpackImpl<int32_t>(context, node, input, data->num, data->axis);
+ break;
+ }
+ default: {
+ context->ReportError(context,
+ "Currently pack only supports int32 and float32.");
+ return kTfLiteError;
+ }
+ }
+
+ return kTfLiteOk;
+}
+} // namespace
+} // namespace unpack
+
+TfLiteRegistration* Register_UNPACK() {
+ static TfLiteRegistration r = {unpack::Init, unpack::Free, unpack::Prepare,
+ unpack::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/unpack_test.cc b/tensorflow/contrib/lite/kernels/unpack_test.cc
new file mode 100644
index 0000000000..4efc92a0fd
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/unpack_test.cc
@@ -0,0 +1,225 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <vector>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAre;
+
+template <typename T>
+class UnpackOpModel : public SingleOpModel {
+ public:
+ UnpackOpModel(const TensorData& input, int axis) {
+ CHECK_LE(axis, input.shape.size());
+ const int num_outputs = input.shape[axis];
+ input_ = AddInput(input);
+ for (int i = 0; i < num_outputs; ++i) {
+ outputs_.push_back(AddOutput(input.type));
+ }
+ SetBuiltinOp(BuiltinOperator_UNPACK, BuiltinOptions_UnpackOptions,
+ CreatePackOptions(builder_, num_outputs, axis).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ void SetInput(std::initializer_list<T> data) {
+ PopulateTensor<T>(input_, data);
+ }
+
+ std::vector<std::vector<T>> GetOutputDatas() {
+ std::vector<std::vector<T>> output_datas;
+ for (const int output : outputs_) {
+ std::cerr << "the output is " << output << std::endl;
+ output_datas.push_back(ExtractVector<T>(output));
+ }
+ return output_datas;
+ }
+
+ std::vector<std::vector<int>> GetOutputShapes() {
+ std::vector<std::vector<int>> output_shapes;
+ for (const int output : outputs_) {
+ output_shapes.push_back(GetTensorShape(output));
+ }
+ return output_shapes;
+ }
+
+ private:
+ int input_;
+ std::vector<int> outputs_;
+};
+
+// float32 tests.
+TEST(UnpackOpTest, FloatThreeOutputs) {
+ UnpackOpModel<float> model({TensorType_FLOAT32, {3, 2}}, 0);
+ model.SetInput({1, 2, 3, 4, 5, 6});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 3);
+ EXPECT_THAT(output_shapes[0], ElementsAre(2));
+ EXPECT_THAT(output_shapes[1], ElementsAre(2));
+ EXPECT_THAT(output_shapes[2], ElementsAre(2));
+
+ // Check outputs values.
+ const std::vector<std::vector<float>>& output_datas = model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 3);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 2));
+ EXPECT_THAT(output_datas[1], ElementsAre(3, 4));
+ EXPECT_THAT(output_datas[2], ElementsAre(5, 6));
+}
+
+TEST(UnpackOpTest, FloatThreeOutputsAxisOne) {
+ UnpackOpModel<float> model({TensorType_FLOAT32, {3, 2}}, 1);
+ model.SetInput({1, 2, 3, 4, 5, 6});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 2);
+ EXPECT_THAT(output_shapes[0], ElementsAre(3));
+ EXPECT_THAT(output_shapes[1], ElementsAre(3));
+
+ // Check outputs values.
+ const std::vector<std::vector<float>>& output_datas = model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 2);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 3, 5));
+ EXPECT_THAT(output_datas[1], ElementsAre(2, 4, 6));
+}
+
+TEST(UnpackOpTest, FloatOneOutput) {
+ UnpackOpModel<float> model({TensorType_FLOAT32, {1, 6}}, 0);
+ model.SetInput({1, 2, 3, 4, 5, 6});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 1);
+ EXPECT_THAT(output_shapes[0], ElementsAre(6));
+
+ // Check outputs values.
+ const std::vector<std::vector<float>>& output_datas = model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 1);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 2, 3, 4, 5, 6));
+}
+
+TEST(UnpackOpTest, FloatThreeDimensionsOutputs) {
+ UnpackOpModel<float> model({TensorType_FLOAT32, {2, 2, 2}}, 2);
+ model.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 2);
+ EXPECT_THAT(output_shapes[0], ElementsAre(2, 2));
+ EXPECT_THAT(output_shapes[1], ElementsAre(2, 2));
+
+ // Check outputs values.
+ const std::vector<std::vector<float>>& output_datas = model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 2);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 3, 5, 7));
+ EXPECT_THAT(output_datas[1], ElementsAre(2, 4, 6, 8));
+}
+
+// int32 tests.
+TEST(UnpackOpTest, IntThreeOutputs) {
+ UnpackOpModel<int32_t> model({TensorType_INT32, {3, 2}}, 0);
+ model.SetInput({1, 2, 3, 4, 5, 6});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 3);
+ EXPECT_THAT(output_shapes[0], ElementsAre(2));
+ EXPECT_THAT(output_shapes[1], ElementsAre(2));
+ EXPECT_THAT(output_shapes[2], ElementsAre(2));
+
+ // Check outputs values.
+ const std::vector<std::vector<int32_t>>& output_datas =
+ model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 3);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 2));
+ EXPECT_THAT(output_datas[1], ElementsAre(3, 4));
+ EXPECT_THAT(output_datas[2], ElementsAre(5, 6));
+}
+
+TEST(UnpackOpTest, IntThreeOutputsAxisOne) {
+ UnpackOpModel<int32_t> model({TensorType_INT32, {3, 2}}, 1);
+ model.SetInput({1, 2, 3, 4, 5, 6});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 2);
+ EXPECT_THAT(output_shapes[0], ElementsAre(3));
+ EXPECT_THAT(output_shapes[1], ElementsAre(3));
+
+ // Check outputs values.
+ const std::vector<std::vector<int32_t>>& output_datas =
+ model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 2);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 3, 5));
+ EXPECT_THAT(output_datas[1], ElementsAre(2, 4, 6));
+}
+
+TEST(UnpackOpTest, IntOneOutput) {
+ UnpackOpModel<int32_t> model({TensorType_INT32, {1, 6}}, 0);
+ model.SetInput({1, 2, 3, 4, 5, 6});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 1);
+ EXPECT_THAT(output_shapes[0], ElementsAre(6));
+
+ // Check outputs values.
+ const std::vector<std::vector<int32_t>>& output_datas =
+ model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 1);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 2, 3, 4, 5, 6));
+}
+
+TEST(UnpackOpTest, IntThreeDimensionsOutputs) {
+ UnpackOpModel<int32_t> model({TensorType_INT32, {2, 2, 2}}, 2);
+ model.SetInput({1, 2, 3, 4, 5, 6, 7, 8});
+ model.Invoke();
+
+ // Check outputs shapes.
+ const std::vector<std::vector<int>>& output_shapes = model.GetOutputShapes();
+ EXPECT_EQ(output_shapes.size(), 2);
+ EXPECT_THAT(output_shapes[0], ElementsAre(2, 2));
+ EXPECT_THAT(output_shapes[1], ElementsAre(2, 2));
+
+ // Check outputs values.
+ const std::vector<std::vector<int32_t>>& output_datas =
+ model.GetOutputDatas();
+ EXPECT_EQ(output_datas.size(), 2);
+ EXPECT_THAT(output_datas[0], ElementsAre(1, 3, 5, 7));
+ EXPECT_THAT(output_datas[1], ElementsAre(2, 4, 6, 8));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh b/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh
index b58ae26601..6195426d6d 100755
--- a/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh
+++ b/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh
@@ -14,6 +14,7 @@
# limitations under the License.
# ==============================================================================
+# TODO(ycling): Refactoring - Move this script into `tools/make`.
set -e
echo "Starting"
@@ -32,7 +33,7 @@ echo "Headers, populating: TensorFlow Lite"
cd $TFLITE_DIR/../../..
find tensorflow/contrib/lite -name '*.h' \
- -not -path 'tensorflow/contrib/lite/downloads/*' \
+ -not -path 'tensorflow/contrib/lite/tools/*' \
-not -path 'tensorflow/contrib/lite/examples/*' \
-not -path 'tensorflow/contrib/lite/gen/*' \
-not -path 'tensorflow/contrib/lite/toco/*' \
@@ -44,7 +45,7 @@ tar xf tmp.tar
rm -f tmp.tar
echo "Headers, populating: Flatbuffer"
-cd $TFLITE_DIR/downloads/flatbuffers/include/
+cd $TFLITE_DIR/tools/make/downloads/flatbuffers/include/
find . -name '*.h' | tar -cf $FW_DIR_TFLITE_HDRS/tmp.tar -T -
cd $FW_DIR_TFLITE_HDRS
tar xf tmp.tar
@@ -57,7 +58,7 @@ cp $TFLITE_DIR/../../../bazel-genfiles/tensorflow/tools/lib_package/include/tens
$FW_DIR_TFLITE
echo "Copying static libraries"
-cp $TFLITE_DIR/gen/lib/libtensorflow-lite.a \
+cp $TFLITE_DIR/tools/make/gen/lib/libtensorflow-lite.a \
$FW_DIR_TFLITE/tensorflow_lite
# This is required, otherwise they interfere with the documentation of the
diff --git a/tensorflow/contrib/lite/memory_planner.h b/tensorflow/contrib/lite/memory_planner.h
index 0294ec815c..2d4707f849 100644
--- a/tensorflow/contrib/lite/memory_planner.h
+++ b/tensorflow/contrib/lite/memory_planner.h
@@ -15,7 +15,7 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_
#define TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/mmap_allocation.cc b/tensorflow/contrib/lite/mmap_allocation.cc
index fa9a3cd1d8..92934d1fd1 100644
--- a/tensorflow/contrib/lite/mmap_allocation.cc
+++ b/tensorflow/contrib/lite/mmap_allocation.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include <unistd.h>
#include "tensorflow/contrib/lite/allocation.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 5f8d5c318a..241865b3d8 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -20,8 +20,9 @@ limitations under the License.
#include <sys/types.h>
#include "tensorflow/contrib/lite/allocation.h"
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h"
#include "tensorflow/contrib/lite/model.h"
#ifndef TFLITE_MCU
#include "tensorflow/contrib/lite/nnapi_delegate.h"
@@ -42,41 +43,6 @@ ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
const char* kEmptyTensorName = "";
-TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
- ErrorReporter* error_reporter) {
- switch (tensor_type) {
- case TensorType_FLOAT32:
- *type = kTfLiteFloat32;
- break;
- case TensorType_INT16:
- *type = kTfLiteInt16;
- break;
- case TensorType_INT32:
- *type = kTfLiteInt32;
- break;
- case TensorType_UINT8:
- *type = kTfLiteUInt8;
- break;
- case TensorType_INT64:
- *type = kTfLiteInt64;
- break;
- case TensorType_STRING:
- *type = kTfLiteString;
- break;
- case TensorType_BOOL:
- *type = kTfLiteBool;
- break;
- case TensorType_COMPLEX64:
- *type = kTfLiteComplex64;
- break;
- default:
- error_reporter->Report("Unimplemented data type %s (%d) in tensor\n",
- EnumNameTensorType(tensor_type), tensor_type);
- return kTfLiteError;
- }
- return kTfLiteOk;
-}
-
#ifndef TFLITE_MCU
// Loads a model from `filename`. If `mmap_file` is true then use mmap,
// otherwise make a copy of the model in a buffer.
@@ -198,39 +164,10 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
auto opcodes = model_->operator_codes();
for (const OperatorCode* opcode : *opcodes) {
const TfLiteRegistration* registration = nullptr;
- auto builtin_code = opcode->builtin_code();
- int version = opcode->version();
-
- if (builtin_code > BuiltinOperator_MAX ||
- builtin_code < BuiltinOperator_MIN) {
- error_reporter_->Report(
- "Op builtin_code out or range: %d. Are you using old TFLite binary "
- "with newer model?",
- builtin_code);
- status = kTfLiteError;
- } else if (builtin_code != BuiltinOperator_CUSTOM) {
- registration = op_resolver_.FindOp(builtin_code, version);
- if (registration == nullptr) {
- error_reporter_->Report(
- "Didn't find op for builtin opcode '%s' version '%d'\n",
- EnumNameBuiltinOperator(builtin_code), version);
- status = kTfLiteError;
- }
- } else if (!opcode->custom_code()) {
- error_reporter_->Report(
- "Operator with CUSTOM builtin_code has no custom_code.\n");
- status = kTfLiteError;
- } else {
- const char* name = opcode->custom_code()->c_str();
- registration = op_resolver_.FindOp(name, version);
- flatbuffer_op_index_to_registration_types_.push_back(
- BuiltinOperator_CUSTOM);
- if (registration == nullptr) {
- error_reporter_->Report(
- "Didn't find custom op for name '%s' with version %d\n", name,
- version);
- status = kTfLiteError;
- }
+ status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_,
+ &registration);
+ if (status != kTfLiteOk) {
+ return status;
}
flatbuffer_op_index_to_registration_.push_back(registration);
}
@@ -247,561 +184,16 @@ std::vector<int> FlatBufferIntArrayToVector(T* flat_array) {
return ret;
}
-// Copies the contents from the flatbuffer int vector `flatbuffer` into the
-// int array `buffer`. `flat_vector` and `buffer` represent the same
-// configuration operation for a given operation.
-void FlatBufferIntVectorToArray(int max_size_of_buffer,
- const flatbuffers::Vector<int32_t>* flat_vector,
- int* buffer, ErrorReporter* error_reporter) {
- if (!flat_vector) {
- error_reporter->Report("Input array not provided for operation.\n");
- } else {
- int num_dimensions = flat_vector->Length();
- if (num_dimensions > max_size_of_buffer / sizeof(int)) {
- error_reporter->Report(
- "Found too many dimensions in the operation's input array.\n");
- } else {
- for (int i = 0; i < num_dimensions; ++i) {
- buffer[i] = flat_vector->Get(i);
- }
- }
- }
-}
-
-// Allocate a structure using C malloc, but make sure the structure is a
-// POD structure that doesn't require constructors to run. The reason we do
-// this, is that Interpreter's C extension part will take ownership and wants
-// to use malloc() and free().
-template <class T>
-T* MallocPOD() {
- static_assert(std::is_pod<T>::value, "Builtin data structure must be POD.");
- return static_cast<T*>(malloc(sizeof(T)));
-}
-
-// Parse the appropriate data out of the op.
-//
-// This handles builtin data explicitly as there are flatbuffer schemas.
-// If it returns kTfLiteOk, it passes the data out with `builtin_data`, which
-// need to be released by calling `free`.`
-// If it returns kTfLiteError, `builtin_data` will be `nullptr`.
-TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
- ErrorReporter* error_reporter, void** builtin_data) {
- auto parse_padding = [](Padding padding) {
- switch (padding) {
- case Padding_SAME:
- return kTfLitePaddingSame;
- case Padding_VALID:
- return kTfLitePaddingValid;
- }
- return kTfLitePaddingUnknown;
- };
- auto parse_activation = [](ActivationFunctionType activation) {
- switch (activation) {
- case ActivationFunctionType_NONE:
- return kTfLiteActNone;
- case ActivationFunctionType_RELU:
- return kTfLiteActRelu;
- case ActivationFunctionType_RELU_N1_TO_1:
- return kTfLiteActRelu1;
- case ActivationFunctionType_RELU6:
- return kTfLiteActRelu6;
- case ActivationFunctionType_TANH:
- return kTfLiteActTanh;
- case ActivationFunctionType_SIGN_BIT:
- return kTfLiteActSignBit;
- }
- return kTfLiteActNone;
- };
- auto parseLSHProjectionType = [](LSHProjectionType type) {
- switch (type) {
- case LSHProjectionType_SPARSE:
- return kTfLiteLshProjectionSparse;
- case LSHProjectionType_DENSE:
- return kTfLiteLshProjectionDense;
- default:
- return kTfLiteLshProjectionUnknown;
- }
- };
- auto parseCombinerType = [](CombinerType type) {
- switch (type) {
- case CombinerType_MEAN:
- return kTfLiteCombinerTypeMean;
- case CombinerType_SQRTN:
- return kTfLiteCombinerTypeSqrtn;
- case CombinerType_SUM:
- default:
- return kTfLiteCombinerTypeSum;
- }
- };
-
- *builtin_data = nullptr;
- switch (op_type) {
- case BuiltinOperator_CONV_2D: {
- TfLiteConvParams* params = MallocPOD<TfLiteConvParams>();
- if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
- params->padding = parse_padding(conv_params->padding());
- params->stride_width = conv_params->stride_w();
- params->stride_height = conv_params->stride_h();
- params->activation =
- parse_activation(conv_params->fused_activation_function());
-
- params->dilation_width_factor = conv_params->dilation_w_factor();
- params->dilation_height_factor = conv_params->dilation_h_factor();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_CAST: {
- TfLiteCastParams* params = MallocPOD<TfLiteCastParams>();
- if (auto* schema_params = op->builtin_options_as_CastOptions()) {
- auto in_status =
- ConvertTensorType(schema_params->in_data_type(),
- &params->in_data_type, error_reporter);
- auto out_status =
- ConvertTensorType(schema_params->out_data_type(),
- &params->out_data_type, error_reporter);
- if (in_status != kTfLiteOk || out_status != kTfLiteOk) {
- free(params);
- return kTfLiteError;
- }
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_LSH_PROJECTION: {
- TfLiteLSHProjectionParams* params =
- MallocPOD<TfLiteLSHProjectionParams>();
- if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) {
- params->type = parseLSHProjectionType(lshParams->type());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_AVERAGE_POOL_2D:
- case BuiltinOperator_MAX_POOL_2D:
- case BuiltinOperator_L2_POOL_2D: {
- TfLitePoolParams* params = MallocPOD<TfLitePoolParams>();
- if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) {
- params->padding = parse_padding(pool_params->padding());
- params->stride_width = pool_params->stride_w();
- params->stride_height = pool_params->stride_h();
- params->filter_width = pool_params->filter_width();
- params->filter_height = pool_params->filter_height();
- params->activation =
- parse_activation(pool_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_DEPTHWISE_CONV_2D: {
- TfLiteDepthwiseConvParams* params =
- MallocPOD<TfLiteDepthwiseConvParams>();
- if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) {
- params->padding = parse_padding(conv_params->padding());
- params->stride_width = conv_params->stride_w();
- params->stride_height = conv_params->stride_h();
- params->depth_multiplier = conv_params->depth_multiplier();
- params->activation =
- parse_activation(conv_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SVDF: {
- TfLiteSVDFParams* params = MallocPOD<TfLiteSVDFParams>();
- if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) {
- params->rank = svdf_params->rank();
- params->activation =
- parse_activation(svdf_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
- case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
- TfLiteSequenceRNNParams* params = MallocPOD<TfLiteSequenceRNNParams>();
- if (auto* sequence_rnn_params =
- op->builtin_options_as_SequenceRNNOptions()) {
- params->activation =
- parse_activation(sequence_rnn_params->fused_activation_function());
- params->time_major = sequence_rnn_params->time_major();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_RNN: {
- TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>();
- if (auto* rnn_params = op->builtin_options_as_RNNOptions()) {
- params->activation =
- parse_activation(rnn_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: {
- TfLiteEmbeddingLookupSparseParams* params =
- MallocPOD<TfLiteEmbeddingLookupSparseParams>();
- if (auto* embedding_params =
- op->builtin_options_as_EmbeddingLookupSparseOptions()) {
- params->combiner = parseCombinerType(embedding_params->combiner());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_FULLY_CONNECTED: {
- TfLiteFullyConnectedParams* params =
- MallocPOD<TfLiteFullyConnectedParams>();
- if (auto* fully_connected_params =
- op->builtin_options_as_FullyConnectedOptions()) {
- params->activation = parse_activation(
- fully_connected_params->fused_activation_function());
- switch (fully_connected_params->weights_format()) {
- case FullyConnectedOptionsWeightsFormat_DEFAULT:
- params->weights_format = kTfLiteFullyConnectedWeightsFormatDefault;
- break;
- case FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8:
- params->weights_format =
- kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8;
- break;
- default:
- error_reporter->Report("Unhandled fully-connected weights format.");
- return kTfLiteError;
- }
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_HASHTABLE_LOOKUP:
- // no-op.
- break;
- case BuiltinOperator_SOFTMAX: {
- TfLiteSoftmaxParams* params = MallocPOD<TfLiteSoftmaxParams>();
- if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) {
- params->beta = softmax_params->beta();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_CONCATENATION: {
- TfLiteConcatenationParams* params =
- MallocPOD<TfLiteConcatenationParams>();
- if (auto* concatenation_params =
- op->builtin_options_as_ConcatenationOptions()) {
- params->activation =
- parse_activation(concatenation_params->fused_activation_function());
- params->axis = concatenation_params->axis();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_MUL: {
- auto* params = MallocPOD<TfLiteMulParams>();
- if (auto* schema_params = op->builtin_options_as_MulOptions()) {
- params->activation =
- parse_activation(schema_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_ADD: {
- auto* params = MallocPOD<TfLiteAddParams>();
- if (auto* schema_params = op->builtin_options_as_AddOptions()) {
- params->activation =
- parse_activation(schema_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_DIV: {
- auto* params = MallocPOD<TfLiteDivParams>();
- if (auto* schema_params = op->builtin_options_as_DivOptions()) {
- params->activation =
- parse_activation(schema_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SUB: {
- auto* params = MallocPOD<TfLiteSubParams>();
- if (auto* schema_params = op->builtin_options_as_SubOptions()) {
- params->activation =
- parse_activation(schema_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_L2_NORMALIZATION: {
- auto* params = MallocPOD<TfLiteL2NormParams>();
- if (auto* schema_params = op->builtin_options_as_L2NormOptions()) {
- params->activation =
- parse_activation(schema_params->fused_activation_function());
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: {
- auto* params = MallocPOD<TfLiteLocalResponseNormParams>();
- if (auto* schema_params =
- op->builtin_options_as_LocalResponseNormalizationOptions()) {
- params->radius = schema_params->radius();
- params->bias = schema_params->bias();
- params->alpha = schema_params->alpha();
- params->beta = schema_params->beta();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM:
- case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
- case BuiltinOperator_LSTM: {
- TfLiteLSTMParams* params = MallocPOD<TfLiteLSTMParams>();
- if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
- params->activation =
- parse_activation(lstm_params->fused_activation_function());
- params->cell_clip = lstm_params->cell_clip();
- params->proj_clip = lstm_params->proj_clip();
- switch (lstm_params->kernel_type()) {
- case LSTMKernelType_FULL:
- params->kernel_type = kTfLiteLSTMFullKernel;
- break;
- case LSTMKernelType_BASIC:
- params->kernel_type = kTfLiteLSTMBasicKernel;
- break;
- }
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_RESIZE_BILINEAR: {
- auto* params = MallocPOD<TfLiteResizeBilinearParams>();
- if (auto* schema_params =
- op->builtin_options_as_ResizeBilinearOptions()) {
- params->align_corners = schema_params->align_corners();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_RESHAPE: {
- auto* params = MallocPOD<TfLiteReshapeParams>();
- if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) {
- auto* new_shape = schema_params->new_shape();
- FlatBufferIntVectorToArray(sizeof(params->shape), new_shape,
- params->shape, error_reporter);
- params->num_dimensions = new_shape->Length();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SKIP_GRAM: {
- TfLiteSkipGramParams* params = MallocPOD<TfLiteSkipGramParams>();
- if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) {
- params->ngram_size = skip_gram_params->ngram_size();
- params->max_skip_size = skip_gram_params->max_skip_size();
- params->include_all_ngrams = skip_gram_params->include_all_ngrams();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SPACE_TO_DEPTH: {
- auto* params = MallocPOD<TfLiteSpaceToDepthParams>();
- if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) {
- params->block_size = schema_params->block_size();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_GATHER: {
- TfLiteGatherParams* params = MallocPOD<TfLiteGatherParams>();
- params->axis = 0;
- if (auto* gather_params = op->builtin_options_as_GatherOptions()) {
- params->axis = gather_params->axis();
- }
-
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_MEAN:
- case BuiltinOperator_REDUCE_MAX:
- case BuiltinOperator_REDUCE_MIN:
- case BuiltinOperator_REDUCE_PROD:
- case BuiltinOperator_SUM: {
- auto* params = MallocPOD<TfLiteReducerParams>();
- if (auto* schema_params = op->builtin_options_as_ReducerOptions()) {
- params->keep_dims = schema_params->keep_dims();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SPLIT: {
- auto* params = MallocPOD<TfLiteSplitParams>();
- if (auto* schema_params = op->builtin_options_as_SplitOptions()) {
- params->num_splits = schema_params->num_splits();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SQUEEZE: {
- auto* params = MallocPOD<TfLiteSqueezeParams>();
- if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) {
- const auto& squeeze_dims = schema_params->squeeze_dims();
- FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims,
- params->squeeze_dims, error_reporter);
- params->num_squeeze_dims = squeeze_dims->Length();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_STRIDED_SLICE: {
- auto* params = MallocPOD<TfLiteStridedSliceParams>();
- if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) {
- params->begin_mask = schema_params->begin_mask();
- params->end_mask = schema_params->end_mask();
- params->ellipsis_mask = schema_params->ellipsis_mask();
- params->new_axis_mask = schema_params->new_axis_mask();
- params->shrink_axis_mask = schema_params->shrink_axis_mask();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_ARG_MAX: {
- auto* params = MallocPOD<TfLiteArgMaxParams>();
- if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) {
- ConvertTensorType(schema_params->output_type(), &params->output_type,
- error_reporter);
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_ARG_MIN: {
- auto* params = MallocPOD<TfLiteArgMinParams>();
- if (const auto* schema_params = op->builtin_options_as_ArgMinOptions()) {
- ConvertTensorType(schema_params->output_type(), &params->output_type,
- error_reporter);
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_TRANSPOSE_CONV: {
- TfLiteTransposeConvParams* params =
- MallocPOD<TfLiteTransposeConvParams>();
- if (auto* transpose_conv_params =
- op->builtin_options_as_TransposeConvOptions()) {
- params->padding = parse_padding(transpose_conv_params->padding());
- params->stride_width = transpose_conv_params->stride_w();
- params->stride_height = transpose_conv_params->stride_h();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SPARSE_TO_DENSE: {
- TfLiteSparseToDenseParams* params =
- MallocPOD<TfLiteSparseToDenseParams>();
- if (auto* sparse_to_dense_params =
- op->builtin_options_as_SparseToDenseOptions()) {
- params->validate_indices = sparse_to_dense_params->validate_indices();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_SHAPE: {
- auto* params = MallocPOD<TfLiteShapeParams>();
- if (auto* schema_params = op->builtin_options_as_ShapeOptions()) {
- ConvertTensorType(schema_params->out_type(), &params->out_type,
- error_reporter);
- }
- *builtin_data = static_cast<void*>(params);
- break;
- }
- case BuiltinOperator_PACK: {
- TfLitePackParams* params = MallocPOD<TfLitePackParams>();
- if (auto* pack_params = op->builtin_options_as_PackOptions()) {
- params->values_count = pack_params->values_count();
- params->axis = pack_params->axis();
- }
- *builtin_data = reinterpret_cast<void*>(params);
- break;
- }
- case BuiltinOperator_DELEGATE: {
- // TODO(ycling): Revisit when supporting saving delegated models.
- error_reporter->Report("DELEGATE op shouldn't exist in model.");
- return kTfLiteError;
- }
- case BuiltinOperator_FAKE_QUANT: {
- auto* params = MallocPOD<TfLiteFakeQuantParams>();
- if (auto* schema_params = op->builtin_options_as_FakeQuantOptions()) {
- params->min = schema_params->min();
- params->max = schema_params->max();
- params->num_bits = schema_params->num_bits();
- params->narrow_range = schema_params->narrow_range();
- }
- *builtin_data = static_cast<void*>(params);
- break;
- }
- case BuiltinOperator_ONE_HOT: {
- auto* params = MallocPOD<TfLiteOneHotParams>();
- if (auto* schema_params = op->builtin_options_as_OneHotOptions()) {
- params->axis = schema_params->axis();
- }
- *builtin_data = static_cast<void*>(params);
- break;
- }
-
- // Below are the ops with no builtin_data strcture.
- case BuiltinOperator_BATCH_TO_SPACE_ND:
- // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
- // ok for now, since there is no call implementation either.
- case BuiltinOperator_CALL:
- case BuiltinOperator_CONCAT_EMBEDDINGS:
- case BuiltinOperator_CUSTOM:
- case BuiltinOperator_DEQUANTIZE:
- case BuiltinOperator_EMBEDDING_LOOKUP:
- case BuiltinOperator_EQUAL:
- case BuiltinOperator_EXP:
- case BuiltinOperator_EXPAND_DIMS:
- case BuiltinOperator_FLOOR:
- case BuiltinOperator_GREATER:
- case BuiltinOperator_GREATER_EQUAL:
- case BuiltinOperator_LESS:
- case BuiltinOperator_LESS_EQUAL:
- case BuiltinOperator_LOG:
- case BuiltinOperator_LOGISTIC:
- case BuiltinOperator_LOG_SOFTMAX:
- case BuiltinOperator_MAXIMUM:
- case BuiltinOperator_MINIMUM:
- case BuiltinOperator_NEG:
- case BuiltinOperator_NOT_EQUAL:
- case BuiltinOperator_PAD:
- case BuiltinOperator_PADV2:
- case BuiltinOperator_PRELU:
- case BuiltinOperator_RELU:
- case BuiltinOperator_RELU6:
- case BuiltinOperator_RELU_N1_TO_1:
- case BuiltinOperator_RSQRT:
- case BuiltinOperator_SELECT:
- case BuiltinOperator_SIN:
- case BuiltinOperator_SLICE:
- case BuiltinOperator_SPACE_TO_BATCH_ND:
- case BuiltinOperator_SQRT:
- case BuiltinOperator_TANH:
- case BuiltinOperator_TILE:
- case BuiltinOperator_TOPK_V2:
- case BuiltinOperator_TRANSPOSE:
- case BuiltinOperator_POW:
- case BuiltinOperator_LOGICAL_OR:
- case BuiltinOperator_LOGICAL_AND:
- case BuiltinOperator_LOGICAL_NOT:
- case BuiltinOperator_UNPACK:
- break;
- }
- return kTfLiteOk;
-}
-
} // namespace
TfLiteStatus InterpreterBuilder::ParseNodes(
const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators,
Interpreter* interpreter) {
TfLiteStatus status = kTfLiteOk;
+
+ // Reduce the number of redundant allocations
+ interpreter->ReserveNodes(operators->Length());
+
for (int i = 0; i < operators->Length(); ++i) {
const auto* op = operators->Get(i);
int index = op->opcode_index();
diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h
index 8bc9ecd7ce..6abdfcd079 100644
--- a/tensorflow/contrib/lite/model.h
+++ b/tensorflow/contrib/lite/model.h
@@ -35,9 +35,10 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_MODEL_H_
#include <memory>
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
#include "tensorflow/contrib/lite/interpreter.h"
-#include "tensorflow/contrib/lite/op_resolver.h"
+#include "tensorflow/contrib/lite/mutable_op_resolver.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc
index df4f60d4ad..ec7d46af7c 100644
--- a/tensorflow/contrib/lite/model_test.cc
+++ b/tensorflow/contrib/lite/model_test.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/model.h"
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/testing/util.h"
// Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object,
diff --git a/tensorflow/contrib/lite/models/speech_test.cc b/tensorflow/contrib/lite/models/speech_test.cc
index 206de1962d..8ecf0b6154 100644
--- a/tensorflow/contrib/lite/models/speech_test.cc
+++ b/tensorflow/contrib/lite/models/speech_test.cc
@@ -102,7 +102,7 @@ class SpeechTest : public ::testing::TestWithParam<int> {
int GetMaxInvocations() { return GetParam(); }
};
-TEST_P(SpeechTest, HotwordOkGoogleRank1Test) {
+TEST_P(SpeechTest, DISABLED_HotwordOkGoogleRank1Test) {
std::stringstream os;
ASSERT_TRUE(ConvertCsvData(
"speech_hotword_model_rank1.tflite", "speech_hotword_model_in.csv",
@@ -114,7 +114,7 @@ TEST_P(SpeechTest, HotwordOkGoogleRank1Test) {
<< test_driver.GetErrorMessage();
}
-TEST_P(SpeechTest, HotwordOkGoogleRank2Test) {
+TEST_P(SpeechTest, DISABLED_HotwordOkGoogleRank2Test) {
std::stringstream os;
ASSERT_TRUE(ConvertCsvData(
"speech_hotword_model_rank2.tflite", "speech_hotword_model_in.csv",
@@ -126,7 +126,7 @@ TEST_P(SpeechTest, HotwordOkGoogleRank2Test) {
<< test_driver.GetErrorMessage();
}
-TEST_P(SpeechTest, SpeakerIdOkGoogleTest) {
+TEST_P(SpeechTest, DISABLED_SpeakerIdOkGoogleTest) {
std::stringstream os;
ASSERT_TRUE(ConvertCsvData(
"speech_speakerid_model.tflite", "speech_speakerid_model_in.csv",
@@ -139,7 +139,7 @@ TEST_P(SpeechTest, SpeakerIdOkGoogleTest) {
<< test_driver.GetErrorMessage();
}
-TEST_P(SpeechTest, AsrAmTest) {
+TEST_P(SpeechTest, DISABLED_AsrAmTest) {
std::stringstream os;
ASSERT_TRUE(
ConvertCsvData("speech_asr_am_model.tflite", "speech_asr_am_model_in.csv",
@@ -156,7 +156,7 @@ TEST_P(SpeechTest, AsrAmTest) {
// through the interpreter and stored the sum of all the output, which was them
// compared for correctness. In this test we are comparing all the intermediate
// results.
-TEST_P(SpeechTest, AsrLmTest) {
+TEST_P(SpeechTest, DISABLED_AsrLmTest) {
std::ifstream in_file;
testing::TfLiteDriver test_driver(/*use_nnapi=*/false);
ASSERT_TRUE(Init("speech_asr_lm_model.test_spec", &test_driver, &in_file));
@@ -165,7 +165,7 @@ TEST_P(SpeechTest, AsrLmTest) {
<< test_driver.GetErrorMessage();
}
-TEST_P(SpeechTest, EndpointerTest) {
+TEST_P(SpeechTest, DISABLED_EndpointerTest) {
std::stringstream os;
ASSERT_TRUE(ConvertCsvData(
"speech_endpointer_model.tflite", "speech_endpointer_model_in.csv",
@@ -178,7 +178,7 @@ TEST_P(SpeechTest, EndpointerTest) {
<< test_driver.GetErrorMessage();
}
-TEST_P(SpeechTest, TtsTest) {
+TEST_P(SpeechTest, DISABLED_TtsTest) {
std::stringstream os;
ASSERT_TRUE(ConvertCsvData("speech_tts_model.tflite",
"speech_tts_model_in.csv",
diff --git a/tensorflow/contrib/lite/op_resolver.cc b/tensorflow/contrib/lite/mutable_op_resolver.cc
index f6e435e982..8ee63d2a02 100644
--- a/tensorflow/contrib/lite/op_resolver.cc
+++ b/tensorflow/contrib/lite/mutable_op_resolver.cc
@@ -13,8 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/op_resolver.h"
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/mutable_op_resolver.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/mutable_op_resolver.h b/tensorflow/contrib/lite/mutable_op_resolver.h
new file mode 100644
index 0000000000..c319041e9b
--- /dev/null
+++ b/tensorflow/contrib/lite/mutable_op_resolver.h
@@ -0,0 +1,79 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_MUTABLE_OP_RESOLVER_H_
+#define TENSORFLOW_CONTRIB_LITE_MUTABLE_OP_RESOLVER_H_
+
+#include <unordered_map>
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+#include "tensorflow/contrib/lite/util.h"
+
+namespace tflite {
+
+// Some versions of gcc doesn't support partial specialization in class scope,
+// so these are defined in a namescope.
+namespace op_resolver_hasher {
+template <typename V>
+struct ValueHasher {
+ size_t operator()(const V& v) const { return std::hash<V>()(v); }
+};
+
+template <>
+struct ValueHasher<tflite::BuiltinOperator> {
+ size_t operator()(const tflite::BuiltinOperator& v) const {
+ return std::hash<int>()(static_cast<int>(v));
+ }
+};
+
+template <typename T>
+struct OperatorKeyHasher {
+ size_t operator()(const T& x) const {
+ size_t a = ValueHasher<typename T::first_type>()(x.first);
+ size_t b = ValueHasher<typename T::second_type>()(x.second);
+ return CombineHashes({a, b});
+ }
+};
+} // namespace op_resolver_hasher
+
+// An OpResolver that is mutable, also used as the op in gen_op_registration.
+// A typical usage:
+// MutableOpResolver resolver;
+// resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD());
+// resolver.AddCustom("CustomOp", Register_CUSTOM_OP());
+// InterpreterBuilder(model, resolver)(&interpreter);
+class MutableOpResolver : public OpResolver {
+ public:
+ const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
+ int version) const override;
+ const TfLiteRegistration* FindOp(const char* op, int version) const override;
+ void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration,
+ int min_version = 1, int max_version = 1);
+ void AddCustom(const char* name, TfLiteRegistration* registration,
+ int min_version = 1, int max_version = 1);
+
+ private:
+ typedef std::pair<tflite::BuiltinOperator, int> BuiltinOperatorKey;
+ typedef std::pair<std::string, int> CustomOperatorKey;
+
+ std::unordered_map<BuiltinOperatorKey, TfLiteRegistration,
+ op_resolver_hasher::OperatorKeyHasher<BuiltinOperatorKey> >
+ builtins_;
+ std::unordered_map<CustomOperatorKey, TfLiteRegistration,
+ op_resolver_hasher::OperatorKeyHasher<CustomOperatorKey> >
+ custom_ops_;
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_MUTABLE_OP_RESOLVER_H_
diff --git a/tensorflow/contrib/lite/op_resolver_test.cc b/tensorflow/contrib/lite/mutable_op_resolver_test.cc
index 10b7e31972..db690eaab9 100644
--- a/tensorflow/contrib/lite/op_resolver_test.cc
+++ b/tensorflow/contrib/lite/mutable_op_resolver_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/op_resolver.h"
+#include "tensorflow/contrib/lite/mutable_op_resolver.h"
#include <gtest/gtest.h>
#include "tensorflow/contrib/lite/testing/util.h"
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index d287aa635c..817486e898 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -18,8 +18,8 @@ limitations under the License.
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
-#include "tensorflow/contrib/lite/builtin_op_data.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h"
@@ -64,6 +64,14 @@ void logError(const char* format, ...) {
__LINE__); \
}
+#define RETURN_ERROR_IF_TFLITE_FAILED(x) \
+ if (x != kTfLiteOk) { \
+ logError( \
+ "Returning error since TFLite returned failure nnapi_delegate.cc:%d.", \
+ __LINE__); \
+ return kTfLiteError; \
+ }
+
#define RETURN_ERROR_IF_NN_FAILED(x) \
if (x != ANEURALNETWORKS_NO_ERROR) { \
logError( \
@@ -98,7 +106,10 @@ int32_t GetAndroidSdkVersion() {
return 0;
}
-static const int32_t kAndroidSdkVersion = GetAndroidSdkVersion();
+int32_t GetAndroidSdkVersionCached() {
+ static int32_t androidSdkVersion = GetAndroidSdkVersion();
+ return androidSdkVersion;
+}
} // namespace
@@ -296,17 +307,21 @@ TfLiteStatus AddOpsAndParams(
};
auto check_and_add_activation = [&add_scalar_int32](int activation) {
if (activation > kTfLiteActRelu6) {
- FATAL("NNAPI only supports RELU, RELU1 and RELU6 activations");
+ logError("NNAPI only supports RELU, RELU1 and RELU6 activations");
+ return kTfLiteError;
}
add_scalar_int32(activation);
+ return kTfLiteOk;
};
auto add_add_params = [&add_scalar_int32](void* data) {
auto* builtin = reinterpret_cast<TfLiteAddParams*>(data);
if (builtin->activation > kTfLiteActRelu6) {
- FATAL("NNAPI only supports RELU, RELU1 and RELU6 activations");
+ logError("NNAPI only supports RELU, RELU1 and RELU6 activations");
+ return kTfLiteError;
}
add_scalar_int32(builtin->activation);
+ return kTfLiteOk;
};
auto add_pooling_params = [&add_scalar_int32,
@@ -317,7 +332,7 @@ TfLiteStatus AddOpsAndParams(
add_scalar_int32(builtin->stride_height);
add_scalar_int32(builtin->filter_width);
add_scalar_int32(builtin->filter_height);
- check_and_add_activation(builtin->activation);
+ return check_and_add_activation(builtin->activation);
};
auto add_convolution_params = [&add_scalar_int32,
@@ -326,7 +341,7 @@ TfLiteStatus AddOpsAndParams(
add_scalar_int32(builtin->padding);
add_scalar_int32(builtin->stride_width);
add_scalar_int32(builtin->stride_height);
- check_and_add_activation(builtin->activation);
+ return check_and_add_activation(builtin->activation);
};
auto add_depthwise_conv_params = [&add_scalar_int32,
@@ -336,20 +351,22 @@ TfLiteStatus AddOpsAndParams(
add_scalar_int32(builtin->stride_width);
add_scalar_int32(builtin->stride_height);
add_scalar_int32(builtin->depth_multiplier);
- check_and_add_activation(builtin->activation);
+ return check_and_add_activation(builtin->activation);
};
auto add_fully_connected_params = [&check_and_add_activation](void* data) {
auto builtin = reinterpret_cast<TfLiteFullyConnectedParams*>(data);
- check_and_add_activation(builtin->activation);
+ return check_and_add_activation(builtin->activation);
};
auto add_concatenation_params = [&add_scalar_int32](void* data) {
auto builtin = reinterpret_cast<TfLiteConcatenationParams*>(data);
add_scalar_int32(builtin->axis);
if (builtin->activation != kTfLiteActNone) {
- FATAL("Concatenation does not support fused activation in NNAPI");
+ logError("Concatenation does not support fused activation in NNAPI");
+ return kTfLiteError;
}
+ return kTfLiteOk;
};
auto add_softmax_params = [&add_scalar_float32](void* data) {
@@ -430,22 +447,22 @@ TfLiteStatus AddOpsAndParams(
switch (builtin) {
case tflite::BuiltinOperator_ADD:
nn_op_type = ANEURALNETWORKS_ADD;
- add_add_params(node.builtin_data);
+ RETURN_ERROR_IF_TFLITE_FAILED(add_add_params(node.builtin_data));
break;
case tflite::BuiltinOperator_MUL:
nn_op_type = ANEURALNETWORKS_MUL;
- add_add_params(node.builtin_data);
+ RETURN_ERROR_IF_TFLITE_FAILED(add_add_params(node.builtin_data));
break;
case tflite::BuiltinOperator_AVERAGE_POOL_2D:
- add_pooling_params(node.builtin_data);
+ RETURN_ERROR_IF_TFLITE_FAILED(add_pooling_params(node.builtin_data));
nn_op_type = ANEURALNETWORKS_AVERAGE_POOL_2D;
break;
case tflite::BuiltinOperator_MAX_POOL_2D:
- add_pooling_params(node.builtin_data);
+ RETURN_ERROR_IF_TFLITE_FAILED(add_pooling_params(node.builtin_data));
nn_op_type = ANEURALNETWORKS_MAX_POOL_2D;
break;
case tflite::BuiltinOperator_L2_POOL_2D:
- add_pooling_params(node.builtin_data);
+ RETURN_ERROR_IF_TFLITE_FAILED(add_pooling_params(node.builtin_data));
nn_op_type = ANEURALNETWORKS_L2_POOL_2D;
break;
case tflite::BuiltinOperator_CONV_2D: {
@@ -456,7 +473,8 @@ TfLiteStatus AddOpsAndParams(
return kTfLiteError;
}
}
- add_convolution_params(node.builtin_data);
+ RETURN_ERROR_IF_TFLITE_FAILED(
+ add_convolution_params(node.builtin_data));
nn_op_type = ANEURALNETWORKS_CONV_2D;
break;
case tflite::BuiltinOperator_RELU:
@@ -475,11 +493,13 @@ TfLiteStatus AddOpsAndParams(
nn_op_type = ANEURALNETWORKS_LOGISTIC;
break;
case tflite::BuiltinOperator_DEPTHWISE_CONV_2D:
- add_depthwise_conv_params(node.builtin_data);
+ RETURN_ERROR_IF_TFLITE_FAILED(
+ add_depthwise_conv_params(node.builtin_data));
nn_op_type = ANEURALNETWORKS_DEPTHWISE_CONV_2D;
break;
case tflite::BuiltinOperator_CONCATENATION:
- add_concatenation_params(node.builtin_data);
+ RETURN_ERROR_IF_TFLITE_FAILED(
+ add_concatenation_params(node.builtin_data));
nn_op_type = ANEURALNETWORKS_CONCATENATION;
break;
case tflite::BuiltinOperator_SOFTMAX:
@@ -487,7 +507,8 @@ TfLiteStatus AddOpsAndParams(
nn_op_type = ANEURALNETWORKS_SOFTMAX;
break;
case tflite::BuiltinOperator_FULLY_CONNECTED:
- add_fully_connected_params(node.builtin_data);
+ RETURN_ERROR_IF_TFLITE_FAILED(
+ add_fully_connected_params(node.builtin_data));
nn_op_type = ANEURALNETWORKS_FULLY_CONNECTED;
break;
case tflite::BuiltinOperator_RESHAPE:
@@ -541,14 +562,14 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_DIV:
nnapi_version = 11; // require NNAPI 1.1
nn_op_type = ANEURALNETWORKS_DIV;
- check_and_add_activation(
- reinterpret_cast<TfLiteDivParams*>(node.builtin_data)->activation);
+ RETURN_ERROR_IF_TFLITE_FAILED(check_and_add_activation(
+ reinterpret_cast<TfLiteDivParams*>(node.builtin_data)->activation));
break;
case tflite::BuiltinOperator_SUB:
nnapi_version = 11; // require NNAPI 1.1
nn_op_type = ANEURALNETWORKS_SUB;
- check_and_add_activation(
- reinterpret_cast<TfLiteSubParams*>(node.builtin_data)->activation);
+ RETURN_ERROR_IF_TFLITE_FAILED(check_and_add_activation(
+ reinterpret_cast<TfLiteSubParams*>(node.builtin_data)->activation));
break;
case tflite::BuiltinOperator_SQUEEZE:
nnapi_version = 11; // requires NNAPI 1.1
@@ -649,6 +670,8 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_LOGICAL_AND:
case tflite::BuiltinOperator_LOGICAL_NOT:
case tflite::BuiltinOperator_UNPACK:
+ case tflite::BuiltinOperator_FLOOR_DIV:
+ case tflite::BuiltinOperator_REDUCE_ANY:
logError("Op code %d is currently not delegated to NNAPI", builtin);
return kTfLiteError;
break;
@@ -658,8 +681,9 @@ TfLiteStatus AddOpsAndParams(
break;
}
- if (nnapi_version == 11 && kAndroidSdkVersion < 28) {
- FATAL("Op %d needs NNAPI1.1", builtin);
+ if (nnapi_version == 11 && GetAndroidSdkVersionCached() < 28) {
+ logError("Op %d needs NNAPI1.1", builtin);
+ return kTfLiteError;
}
// Add the operation.
@@ -707,9 +731,9 @@ TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) {
interpreter->outputs().size());
uint32_t next_id = 0;
- RETURN_ERROR_IF_NN_FAILED(addTensorOperands(
+ RETURN_ERROR_IF_TFLITE_FAILED(addTensorOperands(
interpreter, nn_model_, &next_id, &tensor_id_to_nnapi_id));
- RETURN_ERROR_IF_NN_FAILED(
+ RETURN_ERROR_IF_TFLITE_FAILED(
AddOpsAndParams(interpreter, nn_model_, next_id, &model_states_inputs_,
&model_states_outputs_, tensor_id_to_nnapi_id));
diff --git a/tensorflow/contrib/lite/nnapi_delegate.h b/tensorflow/contrib/lite/nnapi_delegate.h
index 2bdb2cc5c8..22359d557e 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.h
+++ b/tensorflow/contrib/lite/nnapi_delegate.h
@@ -16,8 +16,8 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_NNAPI_DELEGATE_H_
#include "tensorflow/contrib/lite/allocation.h"
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/interpreter.h"
class ANeuralNetworksModel;
diff --git a/tensorflow/contrib/lite/nnapi_delegate_disabled.cc b/tensorflow/contrib/lite/nnapi_delegate_disabled.cc
index efde72b1a7..e3536d3db6 100644
--- a/tensorflow/contrib/lite/nnapi_delegate_disabled.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate_disabled.cc
@@ -27,7 +27,13 @@ NNAPIAllocation::NNAPIAllocation(const char* filename,
NNAPIAllocation::~NNAPIAllocation() {}
-NNAPIDelegate::~NNAPIDelegate() {}
+NNAPIDelegate::~NNAPIDelegate() {
+#define UNUSED_MEMBER(x) (void)(x)
+ UNUSED_MEMBER(nn_model_);
+ UNUSED_MEMBER(nn_compiled_model_);
+ UNUSED_MEMBER(model_status_);
+#undef UNUSED_MEMBER
+}
TfLiteStatus NNAPIDelegate::BuildGraph(Interpreter* interpreter) {
return kTfLiteError;
diff --git a/tensorflow/contrib/lite/op_resolver.h b/tensorflow/contrib/lite/op_resolver.h
index 9d7e3f2085..e93134cbde 100644
--- a/tensorflow/contrib/lite/op_resolver.h
+++ b/tensorflow/contrib/lite/op_resolver.h
@@ -12,83 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// Compatibility shim for moved header location.
#ifndef TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_
#define TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_
-#include <unordered_map>
-#include "tensorflow/contrib/lite/context.h"
-#include "tensorflow/contrib/lite/schema/schema_generated.h"
-#include "tensorflow/contrib/lite/util.h"
-
-namespace tflite {
-
-// Abstract interface that returns TfLiteRegistrations given op codes or custom
-// op names. This is the mechanism that ops being referenced in the flatbuffer
-// model are mapped to executable function pointers (TfLiteRegistrations).
-class OpResolver {
- public:
- // Finds the op registration for a builtin operator by enum code.
- virtual const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
- int version) const = 0;
- // Finds the op registration of a custom operator by op name.
- virtual const TfLiteRegistration* FindOp(const char* op,
- int version) const = 0;
- virtual ~OpResolver() {}
-};
-
-// Some versions of gcc doesn't support partial specialization in class scope,
-// so these are defined in a namescope.
-namespace op_resolver_hasher {
-template <typename V>
-struct ValueHasher {
- size_t operator()(const V& v) const { return std::hash<V>()(v); }
-};
-
-template <>
-struct ValueHasher<tflite::BuiltinOperator> {
- size_t operator()(const tflite::BuiltinOperator& v) const {
- return std::hash<int>()(static_cast<int>(v));
- }
-};
-
-template <typename T>
-struct OperatorKeyHasher {
- size_t operator()(const T& x) const {
- size_t a = ValueHasher<typename T::first_type>()(x.first);
- size_t b = ValueHasher<typename T::second_type>()(x.second);
- return CombineHashes({a, b});
- }
-};
-} // namespace op_resolver_hasher
-
-// An OpResolver that is mutable, also used as the op in gen_op_registration.
-// A typical usage:
-// MutableOpResolver resolver;
-// resolver.AddBuiltin(BuiltinOperator_ADD, Register_ADD());
-// resolver.AddCustom("CustomOp", Register_CUSTOM_OP());
-// InterpreterBuilder(model, resolver)(&interpreter);
-class MutableOpResolver : public OpResolver {
- public:
- const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
- int version) const override;
- const TfLiteRegistration* FindOp(const char* op, int version) const override;
- void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration,
- int min_version = 1, int max_version = 1);
- void AddCustom(const char* name, TfLiteRegistration* registration,
- int min_version = 1, int max_version = 1);
-
- private:
- typedef std::pair<tflite::BuiltinOperator, int> BuiltinOperatorKey;
- typedef std::pair<std::string, int> CustomOperatorKey;
-
- std::unordered_map<BuiltinOperatorKey, TfLiteRegistration,
- op_resolver_hasher::OperatorKeyHasher<BuiltinOperatorKey> >
- builtins_;
- std::unordered_map<CustomOperatorKey, TfLiteRegistration,
- op_resolver_hasher::OperatorKeyHasher<CustomOperatorKey> >
- custom_ops_;
-};
-
-} // namespace tflite
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+#include "tensorflow/contrib/lite/mutable_op_resolver.h"
#endif // TENSORFLOW_CONTRIB_LITE_OP_RESOLVER_H_
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD
index 47f0c8e9a2..57e1290e07 100644
--- a/tensorflow/contrib/lite/python/BUILD
+++ b/tensorflow/contrib/lite/python/BUILD
@@ -70,7 +70,7 @@ py_library(
py_test(
name = "lite_test",
srcs = ["lite_test.py"],
- data = [":interpreter_test_data"],
+ data = ["@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pb"],
srcs_version = "PY2AND3",
tags = [
"no_oss",
@@ -130,6 +130,7 @@ py_test(
],
deps = [
":convert",
+ ":interpreter",
":op_hint",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py
index 12cc66dc55..1c5516ae7c 100644
--- a/tensorflow/contrib/lite/python/convert.py
+++ b/tensorflow/contrib/lite/python/convert.py
@@ -126,7 +126,7 @@ def build_toco_convert_protos(input_tensors,
reorder_across_fake_quant=False,
allow_custom_ops=False,
change_concat_input_ranges=False,
- quantize_weights=False,
+ post_training_quantize=False,
dump_graphviz_dir=None,
dump_graphviz_video=False):
"""Builds protocol buffers describing a conversion of a model using TOCO.
@@ -149,9 +149,11 @@ def build_toco_convert_protos(input_tensors,
as `input_tensors`, or None. (default None)
output_format: Output file format. Currently must be `{TFLITE,
GRAPHVIZ_DOT}`. (default TFLITE)
- quantized_input_stats: List of tuples of integers representing the mean and
+ quantized_input_stats: List of tuples of floats representing the mean and
standard deviation. Each tuple maps to the corresponding input tensor.
- Only need if `inference_type` is `QUANTIZED_UINT8`. (default None)
+ Only need if `inference_input_type` is `QUANTIZED_UINT8`.
+ real_input_value = (quantized_input_value - mean_value) / std_dev_value.
+ (default None)
default_ranges_stats: Tuple of integers representing (min, max) range values
for all arrays without a specified range. Intended for experimenting with
quantization via "dummy quantization". (default None)
@@ -171,9 +173,9 @@ def build_toco_convert_protos(input_tensors,
change_concat_input_ranges: Boolean to change behavior of min/max ranges for
inputs and outputs of the concat operator for quantized models. Changes
the ranges of concat operator overlap when true. (default False)
- quantize_weights: Boolean indicating whether to store weights as quantized
- weights followed by dequantize operations. Computation is still done in
- float, but reduces model size (at the cost of accuracy and latency).
+ post_training_quantize: Boolean indicating whether to quantize the weights
+ of the converted float model. Model size will be reduced and there will be
+ latency improvements (at the cost of accuracy).
(default False)
dump_graphviz_dir: Full filepath of folder to dump the graphs at various
stages of processing GraphViz .dot files. Preferred over
@@ -197,10 +199,12 @@ def build_toco_convert_protos(input_tensors,
toco.inference_type = inference_type
if inference_input_type:
toco.inference_input_type = inference_input_type
+ else:
+ toco.inference_input_type = toco.inference_type
toco.drop_control_dependency = drop_control_dependency
toco.reorder_across_fake_quant = reorder_across_fake_quant
toco.allow_custom_ops = allow_custom_ops
- toco.quantize_weights = quantize_weights
+ toco.post_training_quantize = post_training_quantize
if default_ranges_stats:
toco.default_ranges_min = default_ranges_stats[0]
toco.default_ranges_max = default_ranges_stats[1]
@@ -212,7 +216,7 @@ def build_toco_convert_protos(input_tensors,
model.change_concat_input_ranges = change_concat_input_ranges
for idx, input_tensor in enumerate(input_tensors):
input_array = model.input_arrays.add()
- if inference_type == lite_constants.QUANTIZED_UINT8:
+ if toco.inference_input_type == lite_constants.QUANTIZED_UINT8:
input_array.mean_value, input_array.std_value = quantized_input_stats[idx]
input_array.name = tensor_name(input_tensor)
if input_shapes is None:
@@ -226,6 +230,54 @@ def build_toco_convert_protos(input_tensors,
return model, toco
+def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
+ *args, **kwargs):
+ """"Convert a model using TOCO.
+
+ This function is used to convert GraphDefs that cannot be loaded into
+ TensorFlow to TFLite. Conversion can be customized by providing arguments
+ that are forwarded to `build_toco_convert_protos` (see documentation for
+ details).
+
+ Args:
+ input_data: Input data (i.e. often `sess.graph_def`),
+ input_arrays_with_shape: Tuple of strings representing input tensor names
+ and list of integers representing input shapes
+ (e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
+ into TensorFlow and when `input_tensors` is None. (default None)
+ output_arrays: List of output tensors to freeze graph with. Use only when
+ graph cannot be loaded into TensorFlow and when `output_tensors` is None.
+ (default None)
+ *args: See `build_toco_convert_protos`,
+ **kwargs: See `build_toco_convert_protos`.
+
+ Returns:
+ The converted data. For example if TFLite was the destination, then
+ this will be a tflite flatbuffer in a bytes array.
+
+ Raises:
+ Defined in `build_toco_convert_protos`.
+ """
+ model_flags, toco_flags = build_toco_convert_protos(
+ input_tensors=[], output_tensors=[], *args, **kwargs)
+
+ for idx, (name, shape) in enumerate(input_arrays_with_shape):
+ input_array = model_flags.input_arrays.add()
+ if kwargs["inference_type"] == lite_constants.QUANTIZED_UINT8:
+ input_array.mean_value, input_array.std_value = kwargs[
+ "quantized_input_stats"][idx]
+ input_array.name = name
+ input_array.shape.dims.extend(map(int, shape))
+
+ for name in output_arrays:
+ model_flags.output_arrays.append(name)
+
+ data = toco_convert_protos(model_flags.SerializeToString(),
+ toco_flags.SerializeToString(),
+ input_data.SerializeToString())
+ return data
+
+
def toco_convert_impl(input_data, input_tensors, output_tensors, *args,
**kwargs):
""""Convert a model using TOCO.
diff --git a/tensorflow/contrib/lite/python/convert_test.py b/tensorflow/contrib/lite/python/convert_test.py
index bc05514cec..40a8b5fafb 100644
--- a/tensorflow/contrib/lite/python/convert_test.py
+++ b/tensorflow/contrib/lite/python/convert_test.py
@@ -17,9 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.contrib.lite.python import convert
from tensorflow.contrib.lite.python import lite_constants
from tensorflow.contrib.lite.python import op_hint
+from tensorflow.contrib.lite.python.interpreter import Interpreter
from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
@@ -37,9 +40,12 @@ class ConvertTest(test_util.TensorFlowTestCase):
dtype=dtypes.float32)
out_tensor = in_tensor + in_tensor
sess = session.Session()
+
# Try running on valid graph
- result = convert.toco_convert(sess.graph_def, [in_tensor], [out_tensor])
- self.assertTrue(result)
+ tflite_model = convert.toco_convert(sess.graph_def, [in_tensor],
+ [out_tensor])
+ self.assertTrue(tflite_model)
+
# TODO(aselle): remove tests that fail (we must get TOCO to not fatal
# all the time).
# Try running on identity graph (known fail)
@@ -52,11 +58,85 @@ class ConvertTest(test_util.TensorFlowTestCase):
out_tensor = array_ops.fake_quant_with_min_max_args(in_tensor + in_tensor,
min=0., max=1.)
sess = session.Session()
- result = convert.toco_convert(
+
+ tflite_model = convert.toco_convert(
sess.graph_def, [in_tensor], [out_tensor],
inference_type=lite_constants.QUANTIZED_UINT8,
quantized_input_stats=[(0., 1.)])
- self.assertTrue(result)
+ self.assertTrue(tflite_model)
+
+ def testGraphDefBasic(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name="input")
+ _ = in_tensor + in_tensor
+ sess = session.Session()
+
+ tflite_model = convert.toco_convert_graph_def(
+ sess.graph_def, [("input", [1, 16, 16, 3])], ["add"],
+ inference_type=lite_constants.FLOAT)
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual("input", input_details[0]["name"])
+ self.assertEqual(np.float32, input_details[0]["dtype"])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]["shape"]).all())
+ self.assertEqual((0., 0.), input_details[0]["quantization"])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual("add", output_details[0]["name"])
+ self.assertEqual(np.float32, output_details[0]["dtype"])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]["shape"]).all())
+ self.assertEqual((0., 0.), output_details[0]["quantization"])
+
+ def testGraphDefQuantization(self):
+ in_tensor_1 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputA")
+ in_tensor_2 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputB")
+ _ = array_ops.fake_quant_with_min_max_args(
+ in_tensor_1 + in_tensor_2, min=0., max=1., name="output")
+ sess = session.Session()
+
+ input_arrays_map = [("inputA", [1, 16, 16, 3]), ("inputB", [1, 16, 16, 3])]
+ output_arrays = ["output"]
+ tflite_model = convert.toco_convert_graph_def(
+ sess.graph_def,
+ input_arrays_map,
+ output_arrays,
+ inference_type=lite_constants.QUANTIZED_UINT8,
+ quantized_input_stats=[(0., 1.), (0., 1.)])
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(2, len(input_details))
+ self.assertEqual("inputA", input_details[0]["name"])
+ self.assertEqual(np.uint8, input_details[0]["dtype"])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]["shape"]).all())
+ self.assertEqual((1., 0.),
+ input_details[0]["quantization"]) # scale, zero_point
+
+ self.assertEqual("inputB", input_details[1]["name"])
+ self.assertEqual(np.uint8, input_details[1]["dtype"])
+ self.assertTrue(([1, 16, 16, 3] == input_details[1]["shape"]).all())
+ self.assertEqual((1., 0.),
+ input_details[1]["quantization"]) # scale, zero_point
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual("output", output_details[0]["name"])
+ self.assertEqual(np.uint8, output_details[0]["dtype"])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]["shape"]).all())
+ self.assertTrue(output_details[0]["quantization"][0] > 0) # scale
class ConvertTestOpHint(test_util.TensorFlowTestCase):
@@ -108,7 +188,7 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase):
return output
output = array_ops.identity(_swish(image, swish_scale), name="ModelOutput")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# check if identities have been put into the graph (2 input, 1 output,
# and 1 final output).
self.assertEqual(self._countIdentities(sess.graph_def.node), 4)
@@ -135,7 +215,7 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase):
output = array_ops.identity(_scaled_and_bias_and_identity(a, x, b),
name="ModelOutput")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# make sure one identity for each input (3) and output (2) => 3 + 2 = 5
# +1 for the final output
self.assertEqual(self._countIdentities(sess.graph_def.node), 6)
@@ -162,7 +242,7 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase):
output = array_ops.identity(
math_ops.add(_double_values(a), _double_values(b)), name="ModelOutput")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# make sure one identity for each input (2) and output (2) => 2 + 2
# +1 for the final output
self.assertEqual(self._countIdentities(sess.graph_def.node), 5)
@@ -199,7 +279,7 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase):
aggregate=op_hint.OpHint.AGGREGATE_STACK)
res = math_ops.add(math_ops.mul(a, b), math_ops.mul(c, b))
custom.add_outputs([res])
- with self.test_session():
+ with self.cached_session():
self.assertEqual(self._get_input_index(a), 0)
self.assertEqual(self._get_sort_index(a), 0)
self.assertEqual(self._get_input_index(b), 1)
@@ -214,7 +294,7 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase):
b = custom.add_input(b) # should auto assign 0
a = custom.add_input(a, index_override=1)
c = custom.add_input(c) # should auto assign 2
- with self.test_session():
+ with self.cached_session():
self.assertEqual(self._get_input_index(a), 1)
self.assertEqual(self._get_input_index(b), 0)
self.assertEqual(self._get_input_index(c), 2)
@@ -240,10 +320,9 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase):
curr = array_ops.stack([c0, c1])
output = array_ops.identity(curr, name="FINAL_OUTPUT")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
graph_def=sess.graph_def)
- print(stubbed_graphdef)
self.assertCountEqual(
self._getGraphOpTypes(
stubbed_graphdef,
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index 2313bfa3b6..44dfb97b84 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -42,6 +42,7 @@ from tensorflow.contrib.lite.python import lite_constants as constants
from tensorflow.contrib.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import
from tensorflow.contrib.lite.python.convert import tensor_name as _tensor_name
from tensorflow.contrib.lite.python.convert import toco_convert # pylint: disable=unused-import
+from tensorflow.contrib.lite.python.convert import toco_convert_graph_def as _toco_convert_graph_def
from tensorflow.contrib.lite.python.convert import toco_convert_impl as _toco_convert_impl
from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import
from tensorflow.contrib.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model
@@ -55,7 +56,9 @@ from tensorflow.python import keras as _keras
from tensorflow.python.client import session as _session
from tensorflow.python.framework import graph_util as _tf_graph_util
from tensorflow.python.framework import ops as _ops
+from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError
from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
+from tensorflow.python.lib.io import file_io as _file_io
from tensorflow.python.saved_model import signature_constants as _signature_constants
from tensorflow.python.saved_model import tag_constants as _tag_constants
@@ -76,9 +79,11 @@ class TocoConverter(object):
output_format: Output file format. Currently must be `{TFLITE,
GRAPHVIZ_DOT}`. (default TFLITE)
quantized_input_stats: Dict of strings representing input tensor names
- mapped to tuple of integers representing the mean and standard deviation
+ mapped to tuple of floats representing the mean and standard deviation
of the training data (e.g., {"foo" : (0., 1.)}). Only need if
- `inference_type` is `QUANTIZED_UINT8`. (default {})
+ `inference_input_type` is `QUANTIZED_UINT8`.
+ real_input_value = (quantized_input_value - mean_value) / std_dev_value.
+ (default {})
default_ranges_stats: Tuple of integers representing (min, max) range values
for all arrays without a specified range. Intended for experimenting with
quantization via "dummy quantization". (default None)
@@ -98,9 +103,9 @@ class TocoConverter(object):
created for any op that is unknown. The developer will need to provide
these to the TensorFlow Lite runtime with a custom resolver.
(default False)
- quantize_weights: Boolean indicating whether to store weights as quantized
- weights followed by dequantize operations. Computation is still done in
- float, but reduces model size (at the cost of accuracy and latency).
+ post_training_quantize: Boolean indicating whether to quantize the weights
+ of the converted float model. Model size will be reduced and there will be
+ latency improvements (at the cost of accuracy).
(default False)
dump_graphviz_dir: Full filepath of folder to dump the graphs at various
stages of processing GraphViz .dot files. Preferred over
@@ -133,7 +138,12 @@ class TocoConverter(object):
```
"""
- def __init__(self, graph_def, input_tensors, output_tensors):
+ def __init__(self,
+ graph_def,
+ input_tensors,
+ output_tensors,
+ input_arrays_with_shape=None,
+ output_arrays=None):
"""Constructor for TocoConverter.
Args:
@@ -142,6 +152,17 @@ class TocoConverter(object):
input_tensors: List of input tensors. Type and shape are computed using
`foo.get_shape()` and `foo.dtype`.
output_tensors: List of output tensors (only .name is used from this).
+ input_arrays_with_shape: Tuple of strings representing input tensor names
+ and list of integers representing input shapes
+ (e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
+ into TensorFlow and when `input_tensors` and `output_tensors` are None.
+ (default None)
+ output_arrays: List of output tensors to freeze graph with. Use only when
+ graph cannot be loaded into TensorFlow and when `input_tensors` and
+ `output_tensors` are None. (default None)
+
+ Raises:
+ ValueError: Invalid arguments.
"""
self._graph_def = graph_def
self._input_tensors = input_tensors
@@ -155,10 +176,19 @@ class TocoConverter(object):
self.reorder_across_fake_quant = False
self.change_concat_input_ranges = False
self.allow_custom_ops = False
- self.quantize_weights = False
+ self.post_training_quantize = False
self.dump_graphviz_dir = None
self.dump_graphviz_video = False
+ # Attributes are used by models that cannot be loaded into TensorFlow.
+ if not self._has_valid_tensors():
+ if not input_arrays_with_shape or not output_arrays:
+ raise ValueError(
+ "If input_tensors and output_tensors are None, both "
+ "input_arrays_with_shape and output_arrays must be defined.")
+ self._input_arrays_with_shape = input_arrays_with_shape
+ self._output_arrays = output_arrays
+
@classmethod
def from_session(cls, sess, input_tensors, output_tensors):
"""Creates a TocoConverter class from a TensorFlow Session.
@@ -196,18 +226,24 @@ class TocoConverter(object):
TocoConverter class.
Raises:
- ValueError:
+ IOError:
+ File not found.
Unable to parse input file.
+ ValueError:
The graph is not frozen.
input_arrays or output_arrays contains an invalid tensor name.
+ input_shapes is not correctly defined when required
"""
with _ops.Graph().as_default():
with _session.Session() as sess:
# Read GraphDef from file.
- graph_def = _graph_pb2.GraphDef()
- with open(graph_def_file, "rb") as f:
+ if not _file_io.file_exists(graph_def_file):
+ raise IOError("File '{0}' does not exist.".format(graph_def_file))
+ with _file_io.FileIO(graph_def_file, "rb") as f:
file_content = f.read()
+
try:
+ graph_def = _graph_pb2.GraphDef()
graph_def.ParseFromString(file_content)
except (_text_format.ParseError, DecodeError):
try:
@@ -218,24 +254,49 @@ class TocoConverter(object):
file_content = file_content.decode("utf-8")
else:
file_content = file_content.encode("utf-8")
+ graph_def = _graph_pb2.GraphDef()
_text_format.Merge(file_content, graph_def)
except (_text_format.ParseError, DecodeError):
- raise ValueError(
+ raise IOError(
"Unable to parse input file '{}'.".format(graph_def_file))
- _import_graph_def(graph_def, name="")
- # Get input and output tensors.
- input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays)
- output_tensors = _get_tensors_from_tensor_names(sess.graph,
- output_arrays)
- _set_tensor_shapes(input_tensors, input_shapes)
-
- # Check if graph is frozen.
- if not _is_frozen_graph(sess):
- raise ValueError("Please freeze the graph using freeze_graph.py.")
-
- # Create TocoConverter class.
- return cls(sess.graph_def, input_tensors, output_tensors)
+ # Handles models with custom TFLite ops that cannot be resolved in
+ # TensorFlow.
+ load_model_in_session = True
+ try:
+ _import_graph_def(graph_def, name="")
+ except _NotFoundError:
+ load_model_in_session = False
+
+ if load_model_in_session:
+ # Check if graph is frozen.
+ if not _is_frozen_graph(sess):
+ raise ValueError("Please freeze the graph using freeze_graph.py.")
+
+ # Get input and output tensors.
+ input_tensors = _get_tensors_from_tensor_names(
+ sess.graph, input_arrays)
+ output_tensors = _get_tensors_from_tensor_names(
+ sess.graph, output_arrays)
+ _set_tensor_shapes(input_tensors, input_shapes)
+
+ return cls(sess.graph_def, input_tensors, output_tensors)
+ else:
+ if not input_shapes:
+ raise ValueError("input_shapes must be defined for this model.")
+ if set(input_arrays) != set(input_shapes.keys()):
+ raise ValueError("input_shapes must contain a value for each item "
+ "in input_array.")
+
+ input_arrays_with_shape = [
+ (name, input_shapes[name]) for name in input_arrays
+ ]
+ return cls(
+ graph_def,
+ input_tensors=None,
+ output_tensors=None,
+ input_arrays_with_shape=input_arrays_with_shape,
+ output_arrays=output_arrays)
@classmethod
def from_saved_model(cls,
@@ -330,25 +391,25 @@ class TocoConverter(object):
None value for dimension in input_tensor.
"""
# Checks dimensions in input tensor.
- for tensor in self._input_tensors:
- if not tensor.get_shape():
- raise ValueError("Provide an input shape for input array '{0}'.".format(
- _tensor_name(tensor)))
- shape = tensor.get_shape().as_list()
- if None in shape[1:]:
- raise ValueError(
- "None is only supported in the 1st dimension. Tensor '{0}' has "
- "invalid shape '{1}'.".format(_tensor_name(tensor), shape))
- elif shape[0] is None:
- self._set_batch_size(batch_size=1)
+ if self._has_valid_tensors():
+ for tensor in self._input_tensors:
+ if not tensor.get_shape():
+ raise ValueError("Provide an input shape for input array "
+ "'{0}'.".format(_tensor_name(tensor)))
+ shape = tensor.get_shape().as_list()
+ if None in shape[1:]:
+ raise ValueError(
+ "None is only supported in the 1st dimension. Tensor '{0}' has "
+ "invalid shape '{1}'.".format(_tensor_name(tensor), shape))
+ elif shape[0] is None:
+ self._set_batch_size(batch_size=1)
# Get quantization stats. Ensures there is one stat per name if the stats
# are specified.
if self.quantized_input_stats:
quantized_stats = []
invalid_stats = []
- for tensor in self._input_tensors:
- name = _tensor_name(tensor)
+ for name in self.get_input_arrays():
if name in self.quantized_input_stats:
quantized_stats.append(self.quantized_input_stats[name])
else:
@@ -360,24 +421,35 @@ class TocoConverter(object):
else:
quantized_stats = None
+ converter_kwargs = {
+ "inference_type": self.inference_type,
+ "inference_input_type": self.inference_input_type,
+ "input_format": constants.TENSORFLOW_GRAPHDEF,
+ "output_format": self.output_format,
+ "quantized_input_stats": quantized_stats,
+ "default_ranges_stats": self.default_ranges_stats,
+ "drop_control_dependency": self.drop_control_dependency,
+ "reorder_across_fake_quant": self.reorder_across_fake_quant,
+ "change_concat_input_ranges": self.change_concat_input_ranges,
+ "allow_custom_ops": self.allow_custom_ops,
+ "post_training_quantize": self.post_training_quantize,
+ "dump_graphviz_dir": self.dump_graphviz_dir,
+ "dump_graphviz_video": self.dump_graphviz_video
+ }
+
# Converts model.
- result = _toco_convert_impl(
- input_data=self._graph_def,
- input_tensors=self._input_tensors,
- output_tensors=self._output_tensors,
- inference_type=self.inference_type,
- inference_input_type=self.inference_input_type,
- input_format=constants.TENSORFLOW_GRAPHDEF,
- output_format=self.output_format,
- quantized_input_stats=quantized_stats,
- default_ranges_stats=self.default_ranges_stats,
- drop_control_dependency=self.drop_control_dependency,
- reorder_across_fake_quant=self.reorder_across_fake_quant,
- change_concat_input_ranges=self.change_concat_input_ranges,
- allow_custom_ops=self.allow_custom_ops,
- quantize_weights=self.quantize_weights,
- dump_graphviz_dir=self.dump_graphviz_dir,
- dump_graphviz_video=self.dump_graphviz_video)
+ if self._has_valid_tensors():
+ result = _toco_convert_impl(
+ input_data=self._graph_def,
+ input_tensors=self._input_tensors,
+ output_tensors=self._output_tensors,
+ **converter_kwargs)
+ else:
+ result = _toco_convert_graph_def(
+ input_data=self._graph_def,
+ input_arrays_with_shape=self._input_arrays_with_shape,
+ output_arrays=self._output_arrays,
+ **converter_kwargs)
return result
def get_input_arrays(self):
@@ -386,7 +458,18 @@ class TocoConverter(object):
Returns:
List of strings.
"""
- return [_tensor_name(tensor) for tensor in self._input_tensors]
+ if self._has_valid_tensors():
+ return [_tensor_name(tensor) for tensor in self._input_tensors]
+ else:
+ return [name for name, _ in self._input_arrays_with_shape]
+
+ def _has_valid_tensors(self):
+ """Checks if the input and output tensors have been initialized.
+
+ Returns:
+ Bool.
+ """
+ return self._input_tensors and self._output_tensors
def _set_batch_size(self, batch_size):
"""Sets the first dimension of the input tensor to `batch_size`.
@@ -394,7 +477,14 @@ class TocoConverter(object):
Args:
batch_size: Batch size for the model. Replaces the first dimension of an
input size array if undefined. (default 1)
+
+ Raises:
+ ValueError: input_tensor is not defined.
"""
+ if not self._has_valid_tensors():
+ raise ValueError("The batch size cannot be set for this model. Please "
+ "use input_shapes parameter.")
+
for tensor in self._input_tensors:
shape = tensor.get_shape().as_list()
shape[0] = batch_size
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index 2f13684228..3f8ea433ff 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -35,11 +35,51 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer
from tensorflow.python.platform import gfile
+from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
from tensorflow.python.saved_model import saved_model
from tensorflow.python.training.training_util import write_graph
+class FromConstructor(test_util.TensorFlowTestCase):
+
+ # Tests invalid constructors using a dummy value for the GraphDef.
+ def testInvalidConstructor(self):
+ message = ('If input_tensors and output_tensors are None, both '
+ 'input_arrays_with_shape and output_arrays must be defined.')
+
+ # `output_arrays` is not defined.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter(
+ None, None, [], input_arrays_with_shape=[('input', [3, 9])])
+ self.assertEqual(message, str(error.exception))
+
+ # `input_arrays_with_shape` is not defined.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter(None, [], None, output_arrays=['output'])
+ self.assertEqual(message, str(error.exception))
+
+ # Tests valid constructors using a dummy value for the GraphDef.
+ def testValidConstructor(self):
+ converter = lite.TocoConverter(
+ None,
+ None,
+ None,
+ input_arrays_with_shape=[('input', [3, 9])],
+ output_arrays=['output'])
+ self.assertFalse(converter._has_valid_tensors())
+ self.assertEqual(converter.get_input_arrays(), ['input'])
+
+ with self.assertRaises(ValueError) as error:
+ converter._set_batch_size(1)
+ self.assertEqual(
+ 'The batch size cannot be set for this model. Please use '
+ 'input_shapes parameter.', str(error.exception))
+
+ converter = lite.TocoConverter(None, ['input_tensor'], ['output_tensor'])
+ self.assertTrue(converter._has_valid_tensors())
+
+
class FromSessionTest(test_util.TensorFlowTestCase):
def testFloat(self):
@@ -279,6 +319,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
# Convert model and ensure model is not None.
converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
converter.inference_input_type = lite_constants.QUANTIZED_UINT8
+ converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -331,7 +372,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
self.assertTrue(output_details[0]['quantization'][0] > 0) # scale
- def testQuantizeWeights(self):
+ def testPostTrainingQuantize(self):
np.random.seed(0)
# We need the tensor to have more than 1024 elements for quantize_weights
# to kick in. Thus, the [33, 33] shape.
@@ -352,14 +393,14 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertTrue(float_tflite)
# Convert quantized weights model.
- quantized_weights_converter = lite.TocoConverter.from_session(
+ quantized_converter = lite.TocoConverter.from_session(
sess, [in_tensor_1], [out_tensor])
- quantized_weights_converter.quantize_weights = True
- quantized_weights_tflite = quantized_weights_converter.convert()
- self.assertTrue(quantized_weights_tflite)
+ quantized_converter.post_training_quantize = True
+ quantized_tflite = quantized_converter.convert()
+ self.assertTrue(quantized_tflite)
# Ensure that the quantized weights tflite model is smaller.
- self.assertTrue(len(quantized_weights_tflite) < len(float_tflite))
+ self.assertTrue(len(quantized_tflite) < len(float_tflite))
class FromFrozenGraphFile(test_util.TensorFlowTestCase):
@@ -373,6 +414,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
# Write graph to file.
graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
write_graph(sess.graph_def, '', graph_def_file, False)
+ sess.close()
# Convert model and ensure model is not None.
converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
@@ -407,6 +449,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
# Write graph to file.
graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
write_graph(sess.graph_def, '', graph_def_file, False)
+ sess.close()
# Convert model and ensure model is not None.
converter = lite.TocoConverter.from_frozen_graph(
@@ -434,6 +477,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
# Write graph to file.
graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
write_graph(sess.graph_def, '', graph_def_file, False)
+ sess.close()
# Ensure the graph with variables cannot be converted.
with self.assertRaises(ValueError) as error:
@@ -451,6 +495,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
# Write graph to file.
graph_def_file = os.path.join(self.get_temp_dir(), 'model.pbtxt')
write_graph(sess.graph_def, '', graph_def_file, True)
+ sess.close()
# Convert model and ensure model is not None.
converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
@@ -476,20 +521,104 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
self.assertEqual((0., 0.), output_details[0]['quantization'])
- def testInvalidFile(self):
+ def testInvalidFileNotFound(self):
+ with self.assertRaises(IOError) as error:
+ lite.TocoConverter.from_frozen_graph('invalid_file', ['Placeholder'],
+ ['add'])
+ self.assertEqual('File \'invalid_file\' does not exist.',
+ str(error.exception))
+
+ def testInvalidFileBadData(self):
graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file')
with gfile.Open(graph_def_file, 'wb') as temp_file:
temp_file.write('bad data')
temp_file.flush()
# Attempts to convert the invalid model.
- with self.assertRaises(ValueError) as error:
+ with self.assertRaises(IOError) as error:
lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
['add'])
self.assertEqual(
'Unable to parse input file \'{}\'.'.format(graph_def_file),
str(error.exception))
+ # TODO(nupurgarg): Test model loading in open source.
+ def _initObjectDetectionArgs(self):
+ # Initializes the arguments required for the object detection model.
+ self._graph_def_file = resource_loader.get_path_to_datafile(
+ 'testdata/tflite_graph.pb')
+ self._input_arrays = ['normalized_input_image_tensor']
+ self._output_arrays = [
+ 'TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1',
+ 'TFLite_Detection_PostProcess:2', 'TFLite_Detection_PostProcess:3'
+ ]
+ self._input_shapes = {'normalized_input_image_tensor': [1, 300, 300, 3]}
+
+ def testTFLiteGraphDef(self):
+ # Tests the object detection model that cannot be loaded in TensorFlow.
+ self._initObjectDetectionArgs()
+
+ converter = lite.TocoConverter.from_frozen_graph(
+ self._graph_def_file, self._input_arrays, self._output_arrays,
+ self._input_shapes)
+ converter.allow_custom_ops = True
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('normalized_input_image_tensor', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 300, 300, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(4, len(output_details))
+ self.assertEqual('TFLite_Detection_PostProcess', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 10, 4] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ self.assertEqual('TFLite_Detection_PostProcess:1',
+ output_details[1]['name'])
+ self.assertTrue(([1, 10] == output_details[1]['shape']).all())
+ self.assertEqual('TFLite_Detection_PostProcess:2',
+ output_details[2]['name'])
+ self.assertTrue(([1, 10] == output_details[2]['shape']).all())
+ self.assertEqual('TFLite_Detection_PostProcess:3',
+ output_details[3]['name'])
+ self.assertTrue(([1] == output_details[3]['shape']).all())
+
+ def testTFLiteGraphDefMissingShape(self):
+ # Tests invalid cases for the model that cannot be loaded in TensorFlow.
+ self._initObjectDetectionArgs()
+
+ # Missing `input_shapes`.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter.from_frozen_graph(
+ self._graph_def_file, self._input_arrays, self._output_arrays)
+ self.assertEqual('input_shapes must be defined for this model.',
+ str(error.exception))
+
+ def testTFLiteGraphDefInvalidShape(self):
+ # Tests invalid cases for the model that cannot be loaded in TensorFlow.
+ self._initObjectDetectionArgs()
+
+ # `input_shapes` does not contain the names in `input_arrays`.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter.from_frozen_graph(
+ self._graph_def_file,
+ self._input_arrays,
+ self._output_arrays,
+ input_shapes={'invalid-value': [1, 19]})
+ self.assertEqual(
+ 'input_shapes must contain a value for each item in input_array.',
+ str(error.exception))
+
class FromSavedModelTest(test_util.TensorFlowTestCase):
@@ -628,26 +757,27 @@ class FromKerasFile(test_util.TensorFlowTestCase):
keras.backend.clear_session()
def _getSequentialModel(self):
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(2, input_shape=(3,)))
- model.add(keras.layers.RepeatVector(3))
- model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
- model.compile(
- loss=keras.losses.MSE,
- optimizer=keras.optimizers.RMSprop(),
- metrics=[keras.metrics.categorical_accuracy],
- sample_weight_mode='temporal')
- x = np.random.random((1, 3))
- y = np.random.random((1, 3, 3))
- model.train_on_batch(x, y)
- model.predict(x)
-
- try:
- fd, keras_file = tempfile.mkstemp('.h5')
- keras.models.save_model(model, keras_file)
- finally:
- os.close(fd)
- return keras_file
+ with session.Session().as_default():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.RepeatVector(3))
+ model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.categorical_accuracy],
+ sample_weight_mode='temporal')
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3, 3))
+ model.train_on_batch(x, y)
+ model.predict(x)
+
+ try:
+ fd, keras_file = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, keras_file)
+ finally:
+ os.close(fd)
+ return keras_file
def testSequentialModel(self):
"""Test a Sequential tf.keras model with default inputs."""
@@ -752,25 +882,26 @@ class FromKerasFile(test_util.TensorFlowTestCase):
def testFunctionalModel(self):
"""Test a Functional tf.keras model with default inputs."""
- inputs = keras.layers.Input(shape=(3,), name='input')
- x = keras.layers.Dense(2)(inputs)
- output = keras.layers.Dense(3)(x)
-
- model = keras.models.Model(inputs, output)
- model.compile(
- loss=keras.losses.MSE,
- optimizer=keras.optimizers.RMSprop(),
- metrics=[keras.metrics.categorical_accuracy])
- x = np.random.random((1, 3))
- y = np.random.random((1, 3))
- model.train_on_batch(x, y)
-
- model.predict(x)
- fd, keras_file = tempfile.mkstemp('.h5')
- try:
- keras.models.save_model(model, keras_file)
- finally:
- os.close(fd)
+ with session.Session().as_default():
+ inputs = keras.layers.Input(shape=(3,), name='input')
+ x = keras.layers.Dense(2)(inputs)
+ output = keras.layers.Dense(3)(x)
+
+ model = keras.models.Model(inputs, output)
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.categorical_accuracy])
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3))
+ model.train_on_batch(x, y)
+
+ model.predict(x)
+ fd, keras_file = tempfile.mkstemp('.h5')
+ try:
+ keras.models.save_model(model, keras_file)
+ finally:
+ os.close(fd)
# Convert to TFLite model.
converter = lite.TocoConverter.from_keras_model_file(keras_file)
@@ -809,36 +940,39 @@ class FromKerasFile(test_util.TensorFlowTestCase):
def testFunctionalModelMultipleInputs(self):
"""Test a Functional tf.keras model with multiple inputs and outputs."""
- a = keras.layers.Input(shape=(3,), name='input_a')
- b = keras.layers.Input(shape=(3,), name='input_b')
- dense = keras.layers.Dense(4, name='dense')
- c = dense(a)
- d = dense(b)
- e = keras.layers.Dropout(0.5, name='dropout')(c)
-
- model = keras.models.Model([a, b], [d, e])
- model.compile(
- loss=keras.losses.MSE,
- optimizer=keras.optimizers.RMSprop(),
- metrics=[keras.metrics.mae],
- loss_weights=[1., 0.5])
-
- input_a_np = np.random.random((10, 3))
- input_b_np = np.random.random((10, 3))
- output_d_np = np.random.random((10, 4))
- output_e_np = np.random.random((10, 4))
- model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np])
-
- model.predict([input_a_np, input_b_np], batch_size=5)
- fd, keras_file = tempfile.mkstemp('.h5')
- keras.models.save_model(model, keras_file)
+ with session.Session().as_default():
+ a = keras.layers.Input(shape=(3,), name='input_a')
+ b = keras.layers.Input(shape=(3,), name='input_b')
+ dense = keras.layers.Dense(4, name='dense')
+ c = dense(a)
+ d = dense(b)
+ e = keras.layers.Dropout(0.5, name='dropout')(c)
+
+ model = keras.models.Model([a, b], [d, e])
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.mae],
+ loss_weights=[1., 0.5])
+
+ input_a_np = np.random.random((10, 3))
+ input_b_np = np.random.random((10, 3))
+ output_d_np = np.random.random((10, 4))
+ output_e_np = np.random.random((10, 4))
+ model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np])
+
+ model.predict([input_a_np, input_b_np], batch_size=5)
+ fd, keras_file = tempfile.mkstemp('.h5')
+ try:
+ keras.models.save_model(model, keras_file)
+ finally:
+ os.close(fd)
# Convert to TFLite model.
converter = lite.TocoConverter.from_keras_model_file(keras_file)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
- os.close(fd)
os.remove(keras_file)
# Check values from converted model.
@@ -871,28 +1005,29 @@ class FromKerasFile(test_util.TensorFlowTestCase):
def testFunctionalSequentialModel(self):
"""Test a Functional tf.keras model containing a Sequential model."""
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(2, input_shape=(3,)))
- model.add(keras.layers.RepeatVector(3))
- model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
- model = keras.models.Model(model.input, model.output)
-
- model.compile(
- loss=keras.losses.MSE,
- optimizer=keras.optimizers.RMSprop(),
- metrics=[keras.metrics.categorical_accuracy],
- sample_weight_mode='temporal')
- x = np.random.random((1, 3))
- y = np.random.random((1, 3, 3))
- model.train_on_batch(x, y)
- model.predict(x)
-
- model.predict(x)
- fd, keras_file = tempfile.mkstemp('.h5')
- try:
- keras.models.save_model(model, keras_file)
- finally:
- os.close(fd)
+ with session.Session().as_default():
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.RepeatVector(3))
+ model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
+ model = keras.models.Model(model.input, model.output)
+
+ model.compile(
+ loss=keras.losses.MSE,
+ optimizer=keras.optimizers.RMSprop(),
+ metrics=[keras.metrics.categorical_accuracy],
+ sample_weight_mode='temporal')
+ x = np.random.random((1, 3))
+ y = np.random.random((1, 3, 3))
+ model.train_on_batch(x, y)
+ model.predict(x)
+
+ model.predict(x)
+ fd, keras_file = tempfile.mkstemp('.h5')
+ try:
+ keras.models.save_model(model, keras_file)
+ finally:
+ os.close(fd)
# Convert to TFLite model.
converter = lite.TocoConverter.from_keras_model_file(keras_file)
diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py
index 7d7a4ba94a..cc08ed3fe9 100644
--- a/tensorflow/contrib/lite/python/tflite_convert.py
+++ b/tensorflow/contrib/lite/python/tflite_convert.py
@@ -109,8 +109,14 @@ def _convert_model(flags):
if flags.mean_values and flags.std_dev_values:
input_arrays = converter.get_input_arrays()
- std_dev_values = _parse_array(flags.std_dev_values, type_fn=int)
- mean_values = _parse_array(flags.mean_values, type_fn=int)
+ std_dev_values = _parse_array(flags.std_dev_values, type_fn=float)
+
+ # In quantized inference, mean_value has to be integer so that the real
+ # value 0.0 is exactly representable.
+ if flags.inference_type == lite_constants.QUANTIZED_UINT8:
+ mean_values = _parse_array(flags.mean_values, type_fn=int)
+ else:
+ mean_values = _parse_array(flags.mean_values, type_fn=float)
quant_stats = list(zip(mean_values, std_dev_values))
if ((not flags.input_arrays and len(input_arrays) > 1) or
(len(input_arrays) != len(quant_stats))):
@@ -132,14 +138,18 @@ def _convert_model(flags):
if flags.reorder_across_fake_quant:
converter.reorder_across_fake_quant = flags.reorder_across_fake_quant
if flags.change_concat_input_ranges:
- converter.change_concat_input_ranges = flags.change_concat_input_ranges
+ converter.change_concat_input_ranges = (
+ flags.change_concat_input_ranges == "TRUE")
if flags.allow_custom_ops:
converter.allow_custom_ops = flags.allow_custom_ops
- if flags.quantize_weights:
+
+ if flags.post_training_quantize:
+ converter.post_training_quantize = flags.post_training_quantize
if flags.inference_type == lite_constants.QUANTIZED_UINT8:
- raise ValueError("--quantized_weights is not supported with "
- "--inference_type=QUANTIZED_UINT8")
- converter.quantize_weights = flags.quantize_weights
+ print("--post_training_quantize quantizes a graph of inference_type "
+ "FLOAT. Overriding inference type QUANTIZED_UINT8 to FLOAT.")
+ converter.inference_type = lite_constants.FLOAT
+
if flags.dump_graphviz_dir:
converter.dump_graphviz_dir = flags.dump_graphviz_dir
if flags.dump_graphviz_video:
@@ -292,12 +302,13 @@ def run_main(_):
"--std_dev_values",
type=str,
help=("Standard deviation of training data for each input tensor, "
- "comma-separated integers. Used for quantization. (default None)"))
+ "comma-separated floats. Used for quantized input tensors. "
+ "(default None)"))
parser.add_argument(
"--mean_values",
type=str,
help=("Mean of training data for each input tensor, comma-separated "
- "integers. Used for quantization. (default None)"))
+ "floats. Used for quantized input tensors. (default None)"))
parser.add_argument(
"--default_ranges_min",
type=int,
@@ -310,12 +321,20 @@ def run_main(_):
help=("Default value for max bound of min/max range values used for all "
"arrays without a specified range, Intended for experimenting with "
"quantization via \"dummy quantization\". (default None)"))
+ # quantize_weights is DEPRECATED.
parser.add_argument(
"--quantize_weights",
- type=bool,
- help=("Store float weights as quantized weights followed by dequantize "
- "operations. Inference is still done in FLOAT, but reduces model "
- "size (at the cost of accuracy and latency)."))
+ dest="post_training_quantize",
+ action="store_true",
+ help=argparse.SUPPRESS)
+ parser.add_argument(
+ "--post_training_quantize",
+ dest="post_training_quantize",
+ action="store_true",
+ help=(
+ "Boolean indicating whether to quantize the weights of the "
+ "converted float model. Model size will be reduced and there will "
+ "be latency improvements (at the cost of accuracy). (default False)"))
# Graph manipulation flags.
parser.add_argument(
@@ -333,9 +352,14 @@ def run_main(_):
"the graph. Results in a graph that differs from the quantized "
"training graph, potentially causing differing arithmetic "
"behavior. (default False)"))
+ # Usage for this flag is --change_concat_input_ranges=true or
+ # --change_concat_input_ranges=false in order to make it clear what the flag
+ # is set to. This keeps the usage consistent with other usages of the flag
+ # where the default is different. The default value here is False.
parser.add_argument(
"--change_concat_input_ranges",
- action="store_true",
+ type=str.upper,
+ choices=["TRUE", "FALSE"],
help=("Boolean to change behavior of min/max ranges for inputs and "
"outputs of the concat operator for quantized models. Changes the "
"ranges of concat operator overlap when true. (default False)"))
diff --git a/tensorflow/contrib/lite/schema/BUILD b/tensorflow/contrib/lite/schema/BUILD
index b616e449e6..55bf2c48b9 100644
--- a/tensorflow/contrib/lite/schema/BUILD
+++ b/tensorflow/contrib/lite/schema/BUILD
@@ -48,7 +48,7 @@ exports_files([
"schema_v3.fbs",
])
-load("//third_party/flatbuffers:build_defs.bzl", "flatbuffer_cc_library")
+load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library")
# Generic schema for inference on device.
flatbuffer_cc_library(
@@ -56,6 +56,20 @@ flatbuffer_cc_library(
srcs = ["schema.fbs"],
)
+# Generic schema for inference on device (but with reflections makes bigger).
+flatbuffer_cc_library(
+ name = "schema_fbs_with_reflection",
+ srcs = ["schema.fbs"],
+ flatc_args = [
+ "--reflect-types",
+ "--reflect-names",
+ "--no-union-value-namespacing",
+ "--gen-object-api",
+ ],
+ gen_reflections = True,
+ out_prefix = "reflection/",
+)
+
# Schema test to make sure we don't introduce backward incompatible changes
# to schemas.
cc_test(
diff --git a/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc b/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc
index cd46a06f7d..11057203a8 100644
--- a/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc
+++ b/tensorflow/contrib/lite/schema/flatbuffer_compatibility_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include <fstream>
#include <gtest/gtest.h>
-#include "flatbuffers/flatc.h"
+#include "flatbuffers/flatc.h" // flatbuffers
#include "tensorflow/core/platform/platform.h"
#ifdef PLATFORM_GOOGLE
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 58a94ff4a5..cf66403ec9 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -171,6 +171,8 @@ enum BuiltinOperator : byte {
LOGICAL_NOT = 87,
UNPACK = 88,
REDUCE_MIN = 89,
+ FLOOR_DIV = 90,
+ REDUCE_ANY = 91,
}
// Options for the builtin operators.
@@ -239,6 +241,7 @@ union BuiltinOptions {
LogicalAndOptions,
LogicalNotOptions,
UnpackOptions,
+ FloorDivOptions,
}
enum Padding : byte { SAME, VALID }
@@ -573,6 +576,9 @@ table UnpackOptions {
axis:int;
}
+table FloorDivOptions {
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
@@ -639,9 +645,9 @@ table SubGraph {
}
// Table of raw data buffers (used for constant tensors). Referenced by tensors
-// by index.
+// by index. The generous alignment accommodates mmap-friendly data structures.
table Buffer {
- data:[ubyte];
+ data:[ubyte] (force_align: 16);
}
table Model {
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index a2ea43f370..6d9630d75e 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -223,6 +223,9 @@ struct LogicalNotOptionsT;
struct UnpackOptions;
struct UnpackOptionsT;
+struct FloorDivOptions;
+struct FloorDivOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
@@ -378,11 +381,13 @@ enum BuiltinOperator {
BuiltinOperator_LOGICAL_NOT = 87,
BuiltinOperator_UNPACK = 88,
BuiltinOperator_REDUCE_MIN = 89,
+ BuiltinOperator_FLOOR_DIV = 90,
+ BuiltinOperator_REDUCE_ANY = 91,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_REDUCE_MIN
+ BuiltinOperator_MAX = BuiltinOperator_REDUCE_ANY
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[89] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[91] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -472,7 +477,9 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[89] {
BuiltinOperator_LOGICAL_AND,
BuiltinOperator_LOGICAL_NOT,
BuiltinOperator_UNPACK,
- BuiltinOperator_REDUCE_MIN
+ BuiltinOperator_REDUCE_MIN,
+ BuiltinOperator_FLOOR_DIV,
+ BuiltinOperator_REDUCE_ANY
};
return values;
}
@@ -569,6 +576,8 @@ inline const char **EnumNamesBuiltinOperator() {
"LOGICAL_NOT",
"UNPACK",
"REDUCE_MIN",
+ "FLOOR_DIV",
+ "REDUCE_ANY",
nullptr
};
return names;
@@ -645,11 +654,12 @@ enum BuiltinOptions {
BuiltinOptions_LogicalAndOptions = 62,
BuiltinOptions_LogicalNotOptions = 63,
BuiltinOptions_UnpackOptions = 64,
+ BuiltinOptions_FloorDivOptions = 65,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_UnpackOptions
+ BuiltinOptions_MAX = BuiltinOptions_FloorDivOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[65] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[66] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -715,7 +725,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[65] {
BuiltinOptions_OneHotOptions,
BuiltinOptions_LogicalAndOptions,
BuiltinOptions_LogicalNotOptions,
- BuiltinOptions_UnpackOptions
+ BuiltinOptions_UnpackOptions,
+ BuiltinOptions_FloorDivOptions
};
return values;
}
@@ -787,6 +798,7 @@ inline const char **EnumNamesBuiltinOptions() {
"LogicalAndOptions",
"LogicalNotOptions",
"UnpackOptions",
+ "FloorDivOptions",
nullptr
};
return names;
@@ -1057,6 +1069,10 @@ template<> struct BuiltinOptionsTraits<UnpackOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_UnpackOptions;
};
+template<> struct BuiltinOptionsTraits<FloorDivOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_FloorDivOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -1600,6 +1616,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_UnpackOptions ?
reinterpret_cast<const UnpackOptionsT *>(value) : nullptr;
}
+ FloorDivOptionsT *AsFloorDivOptions() {
+ return type == BuiltinOptions_FloorDivOptions ?
+ reinterpret_cast<FloorDivOptionsT *>(value) : nullptr;
+ }
+ const FloorDivOptionsT *AsFloorDivOptions() const {
+ return type == BuiltinOptions_FloorDivOptions ?
+ reinterpret_cast<const FloorDivOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -5739,6 +5763,46 @@ inline flatbuffers::Offset<UnpackOptions> CreateUnpackOptions(
flatbuffers::Offset<UnpackOptions> CreateUnpackOptions(flatbuffers::FlatBufferBuilder &_fbb, const UnpackOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct FloorDivOptionsT : public flatbuffers::NativeTable {
+ typedef FloorDivOptions TableType;
+ FloorDivOptionsT() {
+ }
+};
+
+struct FloorDivOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef FloorDivOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ FloorDivOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(FloorDivOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<FloorDivOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const FloorDivOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct FloorDivOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit FloorDivOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ FloorDivOptionsBuilder &operator=(const FloorDivOptionsBuilder &);
+ flatbuffers::Offset<FloorDivOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<FloorDivOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<FloorDivOptions> CreateFloorDivOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ FloorDivOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<FloorDivOptions> CreateFloorDivOptions(flatbuffers::FlatBufferBuilder &_fbb, const FloorDivOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@@ -6064,6 +6128,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const UnpackOptions *builtin_options_as_UnpackOptions() const {
return builtin_options_type() == BuiltinOptions_UnpackOptions ? static_cast<const UnpackOptions *>(builtin_options()) : nullptr;
}
+ const FloorDivOptions *builtin_options_as_FloorDivOptions() const {
+ return builtin_options_type() == BuiltinOptions_FloorDivOptions ? static_cast<const FloorDivOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -6351,6 +6418,10 @@ template<> inline const UnpackOptions *Operator::builtin_options_as<UnpackOption
return builtin_options_as_UnpackOptions();
}
+template<> inline const FloorDivOptions *Operator::builtin_options_as<FloorDivOptions>() const {
+ return builtin_options_as_FloorDivOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -8567,6 +8638,29 @@ inline flatbuffers::Offset<UnpackOptions> CreateUnpackOptions(flatbuffers::FlatB
_axis);
}
+inline FloorDivOptionsT *FloorDivOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new FloorDivOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void FloorDivOptions::UnPackTo(FloorDivOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<FloorDivOptions> FloorDivOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FloorDivOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateFloorDivOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<FloorDivOptions> CreateFloorDivOptions(flatbuffers::FlatBufferBuilder &_fbb, const FloorDivOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const FloorDivOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateFloorDivOptions(
+ _fbb);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@@ -9012,6 +9106,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const UnpackOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_FloorDivOptions: {
+ auto ptr = reinterpret_cast<const FloorDivOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -9286,6 +9384,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const UnpackOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_FloorDivOptions: {
+ auto ptr = reinterpret_cast<const FloorDivOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -9548,6 +9650,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const UnpackOptionsT *>(value);
return CreateUnpackOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_FloorDivOptions: {
+ auto ptr = reinterpret_cast<const FloorDivOptionsT *>(value);
+ return CreateFloorDivOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -9810,6 +9916,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new UnpackOptionsT(*reinterpret_cast<UnpackOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_FloorDivOptions: {
+ value = new FloorDivOptionsT(*reinterpret_cast<FloorDivOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -10137,6 +10247,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_FloorDivOptions: {
+ auto ptr = reinterpret_cast<FloorDivOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
diff --git a/tensorflow/contrib/lite/simple_memory_arena.h b/tensorflow/contrib/lite/simple_memory_arena.h
index f738315cf2..45d0d8735e 100644
--- a/tensorflow/contrib/lite/simple_memory_arena.h
+++ b/tensorflow/contrib/lite/simple_memory_arena.h
@@ -17,7 +17,7 @@ limitations under the License.
#include <list>
#include <memory>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/error_reporter.cc b/tensorflow/contrib/lite/stderr_reporter.cc
index 646913c026..e29a6345fd 100644
--- a/tensorflow/contrib/lite/error_reporter.cc
+++ b/tensorflow/contrib/lite/stderr_reporter.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/stderr_reporter.h"
#include <cstdarg>
#include <cstdio>
@@ -22,26 +22,6 @@ limitations under the License.
namespace tflite {
-ErrorReporter::~ErrorReporter() {}
-
-int ErrorReporter::Report(const char* format, ...) {
- va_list args;
- va_start(args, format);
- int code = Report(format, args);
- va_end(args);
- return code;
-}
-
-// TODO(aselle): Make the name of ReportError on context the same, so
-// we can use the ensure functions w/o a context and w/ a reporter.
-int ErrorReporter::ReportError(void*, const char* format, ...) {
- va_list args;
- va_start(args, format);
- int code = Report(format, args);
- va_end(args);
- return code;
-}
-
int StderrReporter::Report(const char* format, va_list args) {
#ifdef __ANDROID__
// On Android stderr is not captured for applications, only for code run from
diff --git a/tensorflow/contrib/lite/stderr_reporter.h b/tensorflow/contrib/lite/stderr_reporter.h
new file mode 100644
index 0000000000..c6f4ffbdff
--- /dev/null
+++ b/tensorflow/contrib/lite/stderr_reporter.h
@@ -0,0 +1,34 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_STDERR_REPORTER_H_
+#define TENSORFLOW_CONTRIB_LITE_STDERR_REPORTER_H_
+
+#include <cstdarg>
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+
+namespace tflite {
+
+// An error reporter that simplify writes the message to stderr.
+struct StderrReporter : public ErrorReporter {
+ int Report(const char* format, va_list args) override;
+};
+
+// Return the default error reporter (output to stderr).
+ErrorReporter* DefaultErrorReporter();
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_STDERR_REPORTER_H_
diff --git a/tensorflow/contrib/lite/string_util.cc b/tensorflow/contrib/lite/string_util.cc
index a316a40b62..b991e999b6 100644
--- a/tensorflow/contrib/lite/string_util.cc
+++ b/tensorflow/contrib/lite/string_util.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <string.h>
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/interpreter.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/string_util.h b/tensorflow/contrib/lite/string_util.h
index 57f129bf5e..d24627b509 100644
--- a/tensorflow/contrib/lite/string_util.h
+++ b/tensorflow/contrib/lite/string_util.h
@@ -42,7 +42,7 @@ limitations under the License.
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/string.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/string_util_test.cc b/tensorflow/contrib/lite/string_util_test.cc
index d53fec7512..a583a9184b 100644
--- a/tensorflow/contrib/lite/string_util_test.cc
+++ b/tensorflow/contrib/lite/string_util_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/string_util.h"
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/testing/util.h"
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index 89912fd116..3a6c16cafc 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -36,7 +36,7 @@ load(
tags = [
"gen_zip_test",
"no_oss",
- "tflite_not_portable",
+ "tflite_not_portable_intentional",
],
test_name = test_name,
deps = [
@@ -214,6 +214,7 @@ cc_library(
deps = [
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string",
+ "//tensorflow/contrib/lite/core/api",
],
)
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 599c82940e..32f02a4f6c 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -780,10 +780,15 @@ def make_binary_op_tests(zip_path, binary_operator):
"input_shape_2": [[5]],
"activation": [False, True]
}, {
- "dtype": [tf.float32],
+ "dtype": [tf.float32, tf.int32],
"input_shape_1": [[1, 3, 4, 3]],
"input_shape_2": [[3]],
- "activation": [True]
+ "activation": [True, False]
+ }, {
+ "dtype": [tf.float32, tf.int32],
+ "input_shape_1": [[3]],
+ "input_shape_2": [[1, 3, 4, 3]],
+ "activation": [True, False]
}, {
"dtype": [tf.float32],
"input_shape_1": [[]],
@@ -821,13 +826,17 @@ def make_binary_op_tests(zip_path, binary_operator):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
-def make_reduce_tests(reduce_op, min_value=-10, max_value=10):
+def make_reduce_tests(reduce_op,
+ min_value=-10,
+ max_value=10,
+ boolean_tensor_only=False):
"""Make a set of tests to do reduce operation.
Args:
reduce_op: TensorFlow reduce operation to test, i.e. `tf.reduce_mean`.
min_value: min value for created tensor data.
max_value: max value for created tensor data.
+ boolean_tensor_only: If true, will only generate tensor with boolean value.
Returns:
a function representing the true generator with `reduce_op_in` curried.
@@ -867,10 +876,11 @@ def make_reduce_tests(reduce_op, min_value=-10, max_value=10):
def build_graph(parameters):
"""Build the mean op testing graph."""
+ dtype = parameters["input_dtype"]
+ if boolean_tensor_only:
+ dtype = tf.bool
input_tensor = tf.placeholder(
- dtype=parameters["input_dtype"],
- name="input",
- shape=parameters["input_shape"])
+ dtype=dtype, name="input", shape=parameters["input_shape"])
# Get axis as either a placeholder or constants.
if parameters["const_axis"]:
@@ -889,9 +899,12 @@ def make_reduce_tests(reduce_op, min_value=-10, max_value=10):
return input_tensors, [out]
def build_inputs(parameters, sess, inputs, outputs):
+ dtype = parameters["input_dtype"]
+ if boolean_tensor_only:
+ dtype = tf.bool
values = [
create_tensor_data(
- parameters["input_dtype"],
+ dtype,
parameters["input_shape"],
min_value=min_value,
max_value=max_value)
@@ -931,6 +944,11 @@ def make_reduce_min_tests(zip_path):
return make_reduce_tests(tf.reduce_min)(zip_path)
+def make_reduce_any_tests(zip_path):
+ """Make a set of tests to do any."""
+ return make_reduce_tests(tf.reduce_any, boolean_tensor_only=True)(zip_path)
+
+
def make_exp_tests(zip_path):
"""Make a set of tests to do exp."""
@@ -1085,6 +1103,10 @@ def make_pow_tests(zip_path):
make_binary_op_tests(zip_path, tf.pow)
+def make_floor_div_tests(zip_path):
+ make_binary_op_tests(zip_path, tf.floor_div)
+
+
def make_gather_tests(zip_path):
"""Make a set of tests to do gather."""
@@ -1657,6 +1679,7 @@ def make_pad_tests(zip_path):
# TODO(nupurgarg): Add test for tf.uint8.
test_parameters = [
+ # 4D:
{
"dtype": [tf.int32, tf.int64, tf.float32],
"input_shape": [[1, 1, 2, 1], [2, 1, 1, 1]],
@@ -1664,13 +1687,20 @@ def make_pad_tests(zip_path):
[0, 0], [2, 3]]],
"constant_paddings": [True, False],
},
- # Non-4D use case.
+ # 2D:
{
"dtype": [tf.int32, tf.int64, tf.float32],
- "input_shape": [[1, 2], [0, 1, 2]],
+ "input_shape": [[1, 2]],
"paddings": [[[0, 1], [2, 3]]],
"constant_paddings": [True, False],
},
+ # 1D:
+ {
+ "dtype": [tf.int32],
+ "input_shape": [[1]],
+ "paddings": [[[1, 2]]],
+ "constant_paddings": [False],
+ },
]
def build_graph(parameters):
@@ -1708,6 +1738,7 @@ def make_padv2_tests(zip_path):
# TODO(nupurgarg): Add test for tf.uint8.
test_parameters = [
+ # 4D:
{
"dtype": [tf.int32, tf.int64, tf.float32],
"input_shape": [[1, 1, 2, 1], [2, 1, 1, 1]],
@@ -1716,14 +1747,22 @@ def make_padv2_tests(zip_path):
"constant_paddings": [True, False],
"constant_values": [0, 2],
},
- # Non-4D use case.
+ # 2D:
{
"dtype": [tf.int32, tf.int64, tf.float32],
- "input_shape": [[1, 2], [0, 1, 2]],
+ "input_shape": [[1, 2]],
"paddings": [[[0, 1], [2, 3]]],
"constant_paddings": [True, False],
"constant_values": [0, 2],
},
+ # 1D:
+ {
+ "dtype": [tf.int32],
+ "input_shape": [[1]],
+ "paddings": [[[0, 1]]],
+ "constant_paddings": [False],
+ "constant_values": [0, 2],
+ },
]
def build_graph(parameters):
@@ -2378,7 +2417,7 @@ def make_lstm_tests(zip_path):
"time_step_size": [1],
"input_vec_size": [3],
"num_cells": [4],
- "split_tflite_lstm_inputs": [True, False],
+ "split_tflite_lstm_inputs": [False],
},
]
@@ -3149,6 +3188,36 @@ def make_pack_tests(zip_path):
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_unpack_tests(zip_path):
+ """Make a set of tests to do unstack."""
+
+ test_parameters = [{
+ "base_shape": [[3, 4, 3], [3, 4], [5, 6, 7, 8]],
+ "axis": [0, 1, 2, 3],
+ }]
+
+ def get_valid_axis(parameters):
+ """Return a tweaked version of 'axis'."""
+ axis = parameters["axis"]
+ shape = parameters["base_shape"][:]
+ while axis > len(shape) - 1:
+ axis -= 1
+ return axis
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=tf.float32, name=("input"), shape=parameters["base_shape"])
+ outs = tf.unstack(input_tensor, axis=get_valid_axis(parameters))
+ return [input_tensor], outs
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_value = create_tensor_data(np.float32, shape=parameters["base_shape"])
+ return [input_value], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_value])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def _make_logical_tests(op):
"""Make a set of tests to do logical operations."""
diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
index e67fee2a1c..349aa5a3b4 100644
--- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
+++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
@@ -58,12 +58,6 @@ tensorflow::Env* env = tensorflow::Env::Default();
// Key is a substring of the test name and value is a bug number.
// TODO(ahentz): make sure we clean this list up frequently.
std::map<string, string> kBrokenTests = {
- // Pad and PadV2 only supports 4D tensors.
- {R"(^\/pad.*,input_shape=\[.,.\],paddings=\[\[.,.\],\[.,.\]\])",
- "70527055"},
- {R"(^\/padv2.*,input_shape=\[.,.\],paddings=\[\[.,.\],\[.,.\]\])",
- "70527055"},
-
// L2Norm only supports tensors with 4D or fewer.
{R"(^\/l2norm_dim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"},
@@ -101,6 +95,15 @@ std::map<string, string> kBrokenTests = {
"77546240"},
{R"(^\/arg_min_max.*axis_is_last_dim=False.*input_shape=\[.,.\])",
"77546240"},
+
+ // No Support for float.
+ {R"(^\/floor_div.*dtype=tf\.float32)", "112859002"},
+
+ // Relu does not support int32.
+ // These test cases appends a Relu after the tested ops when
+ // activation=True. The tests are failing since Relu doesn't support int32.
+ {R"(^\/div.*activation=True.*dtype=tf\.int32)", "112968789"},
+ {R"(^\/floor_div.*activation=True.*dtype=tf\.int32)", "112968789"},
};
// Allows test data to be unarchived into a temporary directory and makes
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc
index 4dacf9c84b..1836eb53b9 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.cc
+++ b/tensorflow/contrib/lite/testing/tflite_driver.cc
@@ -302,28 +302,6 @@ bool TfLiteDriver::CheckResults() {
void TfLiteDriver::ResetLSTMStateTensors() {
interpreter_->ResetVariableTensorsToZero();
-
- // Below is a workaround for initializing state tensors for LSTM.
- // TODO(ycling): Remove the code below after nobody is using the 18-inputs
- // definition.
- for (auto node_index : interpreter_->execution_plan()) {
- const auto& node_and_reg = interpreter_->node_and_registration(node_index);
- const auto& node = node_and_reg->first;
- const auto& registration = node_and_reg->second;
-
- if (registration.builtin_code == tflite::BuiltinOperator_LSTM) {
- const auto* params =
- reinterpret_cast<const TfLiteLSTMParams*>(node.builtin_data);
- if (params->kernel_type == kTfLiteLSTMFullKernel &&
- node.inputs->size == 18 && node.outputs->size >= 2) {
- // The first 2 outputs of LSTM are state tensors.
- for (int i = 0; i < 2; ++i) {
- int node_index = node.outputs->data[i];
- ResetTensor(node_index);
- }
- }
- }
- }
}
} // namespace testing
diff --git a/tensorflow/contrib/lite/testing/util.h b/tensorflow/contrib/lite/testing/util.h
index 8aa639157b..925791d390 100644
--- a/tensorflow/contrib/lite/testing/util.h
+++ b/tensorflow/contrib/lite/testing/util.h
@@ -17,7 +17,7 @@ limitations under the License.
#include <cstdio>
-#include "tensorflow/contrib/lite/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
#include "tensorflow/contrib/lite/string.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index 02d0890a7a..bea90f1ce8 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -213,7 +213,6 @@ cc_library(
"graph_transformations/quantization_util.cc",
"graph_transformations/quantization_util.h",
"graph_transformations/quantize.cc",
- "graph_transformations/quantize_weights.cc",
"graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc",
"graph_transformations/remove_final_dequantize_op.cc",
"graph_transformations/remove_tensorflow_assert.cc",
@@ -373,6 +372,7 @@ cc_library(
":toco_graphviz_dump_options",
":toco_port",
":types_proto_cc",
+ "//tensorflow/contrib/lite/kernels/internal:types",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
"@com_googlesource_code_re2//:re2",
diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h
index aef35ad490..f14dbc258b 100644
--- a/tensorflow/contrib/lite/toco/args.h
+++ b/tensorflow/contrib/lite/toco/args.h
@@ -236,8 +236,9 @@ struct ParsedTocoFlags {
Arg<bool> drop_fake_quant = Arg<bool>(false);
Arg<bool> reorder_across_fake_quant = Arg<bool>(false);
Arg<bool> allow_custom_ops = Arg<bool>(false);
- Arg<bool> quantize_weights = Arg<bool>(false);
+ Arg<bool> post_training_quantize = Arg<bool>(false);
// Deprecated flags
+ Arg<bool> quantize_weights = Arg<bool>(false);
Arg<string> input_type;
Arg<string> input_types;
Arg<bool> debug_disable_recurrent_cell_fusion = Arg<bool>(false);
@@ -246,6 +247,10 @@ struct ParsedTocoFlags {
Arg<bool> allow_nudging_weights_to_use_fast_gemm_kernel = Arg<bool>(false);
Arg<int64> dedupe_array_min_size_bytes = Arg<int64>(64);
Arg<bool> split_tflite_lstm_inputs = Arg<bool>(true);
+ // WARNING: Experimental interface, subject to change
+ Arg<bool> allow_eager_ops = Arg<bool>(false);
+ // WARNING: Experimental interface, subject to change
+ Arg<bool> force_eager_ops = Arg<bool>(false);
};
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc
index f489c5ac65..b52a79282c 100644
--- a/tensorflow/contrib/lite/toco/export_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc
@@ -1701,9 +1701,11 @@ void ConvertReduceOperator(const Model& model, const T& src_op,
*new_op->add_input() = src_op.inputs[0];
*new_op->add_input() = src_op.inputs[1];
- const tensorflow::DataType params_type =
- GetTensorFlowDataType(model, src_op.inputs[0]);
- (*new_op->mutable_attr())["T"].set_type(params_type);
+ if (src_op.type != OperatorType::kAny) {
+ const tensorflow::DataType params_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*new_op->mutable_attr())["T"].set_type(params_type);
+ }
const tensorflow::DataType indices_type =
GetTensorFlowDataType(model, src_op.inputs[1]);
(*new_op->mutable_attr())["Tidx"].set_type(indices_type);
@@ -1900,21 +1902,6 @@ void ConvertPowOperator(const Model& model, const PowOperator& src_op,
(*pow_op->mutable_attr())["T"].set_type(data_type);
}
-void ConvertAnyOperator(const Model& model, const AnyOperator& src_op,
- GraphDef* tensorflow_graph) {
- tensorflow::NodeDef* any_op = tensorflow_graph->add_node();
- any_op->set_op("Any");
- any_op->set_name(src_op.outputs[0]);
- CHECK_EQ(src_op.inputs.size(), 2);
- for (int i = 0; i < 2; ++i) {
- *any_op->add_input() = src_op.inputs[i];
- }
- const tensorflow::DataType data_type =
- GetTensorFlowDataType(model, src_op.inputs[1]);
- (*any_op->mutable_attr())["Tidx"].set_type(data_type);
- (*any_op->mutable_attr())["keep_dims"].set_b(src_op.keep_dims);
-}
-
void ConvertLogicalAndOperator(const Model& model,
const LogicalAndOperator& src_op,
GraphDef* tensorflow_graph) {
@@ -1967,6 +1954,20 @@ void ConvertCTCBeamSearchDecoderOperator(
(*op->mutable_attr())["merge_repeated"].set_b(src_op.merge_repeated);
}
+void ConvertUnpackOperator(const Model& model, const UnpackOperator& src_op,
+ const char* op_name, GraphDef* tensorflow_graph) {
+ tensorflow::NodeDef* unpack_op = tensorflow_graph->add_node();
+ unpack_op->set_op(op_name);
+ unpack_op->set_name(src_op.outputs[0]);
+ CHECK_EQ(src_op.inputs.size(), 2);
+ *unpack_op->add_input() = src_op.inputs[0];
+ const tensorflow::DataType data_type =
+ GetTensorFlowDataType(model, src_op.inputs[0]);
+ (*unpack_op->mutable_attr())["T"].set_type(data_type);
+ (*unpack_op->mutable_attr())["num"].set_i(src_op.num);
+ (*unpack_op->mutable_attr())["axis"].set_i(src_op.axis);
+}
+
void ConvertOperator(const Model& model, const Operator& src_op,
GraphDef* tensorflow_graph) {
if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) {
@@ -2207,8 +2208,9 @@ void ConvertOperator(const Model& model, const Operator& src_op,
ConvertPowOperator(model, static_cast<const PowOperator&>(src_op), "Pow",
tensorflow_graph);
} else if (src_op.type == OperatorType::kAny) {
- ConvertAnyOperator(model, static_cast<const AnyOperator&>(src_op),
- tensorflow_graph);
+ ConvertReduceOperator(model,
+ static_cast<const TensorFlowAnyOperator&>(src_op),
+ tensorflow_graph, "Any");
} else if (src_op.type == OperatorType::kLogicalAnd) {
ConvertLogicalAndOperator(model,
static_cast<const LogicalAndOperator&>(src_op),
@@ -2228,6 +2230,9 @@ void ConvertOperator(const Model& model, const Operator& src_op,
ConvertCTCBeamSearchDecoderOperator(
model, static_cast<const CTCBeamSearchDecoderOperator&>(src_op),
"CTCBeamSearchDecoder", tensorflow_graph);
+ } else if (src_op.type == OperatorType::kUnpack) {
+ ConvertUnpackOperator(model, static_cast<const UnpackOperator&>(src_op),
+ "Unpack", tensorflow_graph);
} else {
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(src_op.type);
}
diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
index 4bf47aa3c4..84680b968e 100644
--- a/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
+++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md
@@ -24,8 +24,8 @@ Table of contents:
* [Multiple output arrays](#multiple-output-arrays)
* [Specifying subgraphs](#specifying-subgraphs)
* [Graph visualizations](#graph-visualizations)
- * [Using --output_format=GRAPHVIZ_DOT](#using-output-formatgraphviz-dot)
- * [Using --dump_graphviz](#using-dump-graphviz)
+ * [Using --output_format=GRAPHVIZ_DOT](#using-output-format-graphviz-dot)
+ * [Using --dump_graphviz_dir](#using-dump-graphviz-dir)
* [Graph "video" logging](#graph-video-logging)
* [Legend for the graph visualizations](#graphviz-legend)
@@ -247,17 +247,17 @@ function tends to get fused).
## Graph visualizations
-TOCO can export a graph to the GraphViz Dot format for easy visualization via
+TOCO can export a graph to the Graphviz Dot format for easy visualization via
either the `--output_format` flag or the `--dump_graphviz_dir` flag. The
subsections below outline the use cases for each.
-### Using `--output_format=GRAPHVIZ_DOT`
+### Using `--output_format=GRAPHVIZ_DOT` <a name="using-output-format-graphviz-dot"></a>
-The first way to get a graphviz rendering is to pass `GRAPHVIZ_DOT` into
+The first way to get a Graphviz rendering is to pass `GRAPHVIZ_DOT` into
`--output_format`. This results in a plausible visualization of the graph. This
-reduces the requirements that exist during conversion between other input and
-output formats. This may be useful if conversion from TENSORFLOW_GRAPHDEF to
-TFLITE is failing.
+reduces the requirements that exist during conversion from a TensorFlow GraphDef
+to a TensorFlow Lite FlatBuffer. This may be useful if the conversion to TFLite
+is failing.
```
curl https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.50_128_frozen.tgz \
@@ -287,10 +287,10 @@ google-chrome /tmp/foo.dot.pdf
Example PDF files are viewable online in the next section.
-### Using `--dump_graphviz`
+### Using `--dump_graphviz_dir`
-The second way to get a graphviz rendering is to pass the `--dump_graphviz_dir`
-flag, specifying a destination directory to dump GraphViz rendering to. Unlike
+The second way to get a Graphviz rendering is to pass the `--dump_graphviz_dir`
+flag, specifying a destination directory to dump Graphviz rendering to. Unlike
the previous approach, this one retains the original output format. This
provides a visualization of the actual graph resulting from a specific
conversion process.
diff --git a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md
index decc8a45a4..00bc8d4ccb 100644
--- a/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md
+++ b/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md
@@ -38,7 +38,7 @@ files. The flag `--output_file` is always required. Additionally, either
of TFLite specific transformations. Therefore, the resulting
visualization may not reflect the final set of graph
transformations. To get a final visualization with all graph
- transformations use `--dump_graphviz` instead.
+ transformations use `--dump_graphviz_dir` instead.
The following flags specify optional parameters when using SavedModels.
@@ -67,21 +67,22 @@ based on index.
* `--input_shapes`. Type: colon-separated list of comma-separated lists of
integers. Each comma-separated list of integers gives the shape of one of
- the input arrays specified in [TensorFlow
- convention](https://www.tensorflow.org/versions/r1.2/programmers_guide/dims_types#shape).
+ the input arrays specified in
+ [TensorFlow convention](https://www.tensorflow.org/versions/r1.2/programmers_guide/dims_types#shape).
* Example: `--input_shapes=1,60,80,3` for a typical vision model means a
batch size of 1, an input image height of 60, an input image width of
80, and an input image depth of 3 (representing RGB channels).
* Example: `--input_arrays=foo,bar --input_shapes=2,3:4,5,6` means "foo"
has a shape of [2, 3] and "bar" has a shape of [4, 5, 6].
-* `--std_dev_values`, `--mean_values`. Type: comma-separated list of integers.
+* `--std_dev_values`, `--mean_values`. Type: comma-separated list of floats.
These specify the (de-)quantization parameters of the input array, when it
- is quantized.
+ is quantized. This is only needed if `inference_input_type` is
+ `QUANTIZED_UINT8`.
* The meaning of `mean_values` and `std_dev_values` is as follows: each
quantized value in the quantized input array will be interpreted as a
mathematical real number (i.e. as an input activation value) according
to the following formula:
- * `real_value = (quantized_input_value - mean_value) / std_value`.
+ * `real_value = (quantized_input_value - mean_value) / std_dev_value`.
* When performing float inference (`--inference_type=FLOAT`) on a
quantized input, the quantized input would be immediately dequantized by
the inference code according to the above formula, before proceeding
@@ -91,7 +92,8 @@ based on index.
the inference code. However, the quantization parameters of all arrays,
including those of the input arrays as specified by `mean_value` and
`std_dev_value`, determine the fixed-point multipliers used in the
- quantized inference code.
+ quantized inference code. `mean_value` must be an integer when
+ performing quantized inference.
## Transformation flags
@@ -147,10 +149,10 @@ have.
true, custom ops are created for any op that is unknown. The developer will
need to provide these to the TensorFlow Lite runtime with a custom resolver.
-* `--quantize_weights`. Type: boolean. Default: False. Indicates whether to
- store weights as quantized weights followed by dequantize operations.
- Computation is still done in float, but reduces model size (at the cost of
- accuracy and latency).
+* `--post_training_quantize`. Type: boolean. Default: False. Boolean
+ indicating whether to quantize the weights of the converted float model.
+ Model size will be reduced and there will be latency improvements (at the
+ cost of accuracy).
## Logging flags
diff --git a/tensorflow/contrib/lite/toco/g3doc/python_api.md b/tensorflow/contrib/lite/toco/g3doc/python_api.md
index 3799eac0a1..51f808d4f0 100644
--- a/tensorflow/contrib/lite/toco/g3doc/python_api.md
+++ b/tensorflow/contrib/lite/toco/g3doc/python_api.md
@@ -70,6 +70,7 @@ val = img + var
out = tf.identity(val, name="out")
with tf.Session() as sess:
+ sess.run(tf.global_variables_initializer())
converter = tf.contrib.lite.TocoConverter.from_session(sess, [img], [out])
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
diff --git a/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg b/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg
index 262e13a591..335debde57 100644
--- a/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg
+++ b/tensorflow/contrib/lite/toco/g3doc/toco_landscape.svg
@@ -1 +1 @@
-<svg version="1.1" viewBox="0.0 0.0 720.0 540.0" fill="none" stroke="none" stroke-linecap="square" stroke-miterlimit="10" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg"><clipPath id="p.0"><path d="m0 0l720.0 0l0 540.0l-720.0 0l0 -540.0z" clip-rule="nonzero"/></clipPath><g clip-path="url(#p.0)"><path fill="#000000" fill-opacity="0.0" d="m0 0l720.0 0l0 540.0l-720.0 0z" fill-rule="evenodd"/><path fill="#f3f3f3" d="m19.375328 28.750656l361.6378 0l0 358.01575l-361.6378 0z" fill-rule="evenodd"/><path stroke="#cccccc" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m19.375328 28.750656l361.6378 0l0 358.01575l-361.6378 0z" fill-rule="evenodd"/><path fill="#434343" d="m338.49512 374.66016q-0.609375 0 -1.171875 -0.140625q-0.546875 -0.15625 -0.96875 -0.421875q-0.25 -0.15625 -0.359375 -0.296875q-0.09375 -0.140625 -0.09375 -0.34375q0 -0.171875 0.09375 -0.28125q0.109375 -0.109375 0.265625 -0.109375q0.171875 0 0.46875 0.1875q0.40625 0.25 0.796875 0.390625q0.390625 0.140625 0.984375 0.140625q0.71875 0 1.109375 -0.25q0.40625 -0.265625 0.40625 -0.734375q0 -0.296875 -0.15625 -0.46875q-0.140625 -0.1875 -0.5 -0.328125q-0.359375 -0.140625 -1.046875 -0.296875q-1.171875 -0.25 -1.6875 -0.671875q-0.5 -0.421875 -0.5 -1.15625q0 -0.578125 0.3125 -1.015625q0.328125 -0.4375 0.890625 -0.6875q0.5625 -0.265625 1.28125 -0.265625q0.53125 0 1.015625 0.140625q0.484375 0.140625 0.859375 0.390625q0.453125 0.328125 0.453125 0.671875q0 0.171875 -0.109375 0.296875q-0.109375 0.125 -0.25 0.125q-0.15625 0 -0.484375 -0.234375q-0.375 -0.234375 -0.703125 -0.359375q-0.328125 -0.140625 -0.828125 -0.140625q-0.625 0 -1.015625 0.28125q-0.375 0.265625 -0.375 0.734375q0 0.296875 0.140625 0.484375q0.140625 0.171875 0.46875 0.3125q0.328125 0.140625 0.9375 0.28125q0.90625 0.1875 1.40625 0.4375q0.5 0.234375 0.703125 0.578125q0.21875 0.34375 0.21875 0.890625q0 0.828125 -0.703125 1.34375q-0.703125 0.515625 -1.859375 0.515625zm9.241241 -1.59375q0.140625 0 0.25 0.125q0.109375 0.109375 0.109375 0.296875q0 0.328125 -0.46875 0.609375q-0.484375 0.28125 -1.015625 0.421875q-0.53125 0.140625 -1.046875 0.140625q-1.5 0 -2.375 -0.890625q-0.875 -0.890625 -0.875 -2.46875q0 -1.0 0.390625 -1.765625q0.390625 -0.765625 1.078125 -1.1875q0.703125 -0.4375 1.59375 -0.4375q1.265625 0 2.015625 0.828125q0.75 0.828125 0.75 2.25q0 0.265625 -0.109375 0.390625q-0.109375 0.109375 -0.34375 0.109375l-4.296875 0q0.125 2.296875 2.171875 2.296875q0.53125 0 0.890625 -0.140625q0.375 -0.140625 0.8125 -0.390625q0.34375 -0.1875 0.46875 -0.1875zm-2.34375 -4.3125q-0.84375 0 -1.359375 0.53125q-0.515625 0.53125 -0.609375 1.515625l3.765625 0q-0.015625 -1.0 -0.484375 -1.515625q-0.46875 -0.53125 -1.3125 -0.53125zm7.5551147 -0.8125q0.546875 -0.03125 0.546875 0.453125q0 0.21875 -0.125 0.34375q-0.109375 0.125 -0.40625 0.15625l-0.390625 0.03125q-0.890625 0.078125 -1.328125 0.640625q-0.4375 0.546875 -0.4375 1.296875l0 3.234375q0 0.265625 -0.15625 0.40625q-0.140625 0.125 -0.375 0.125q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -5.625q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.21875 0 0.359375 0.140625q0.140625 0.140625 0.140625 0.375l0 0.75q0.28125 -0.578125 0.796875 -0.890625q0.515625 -0.3125 1.1875 -0.359375l0.1875 -0.015625zm6.157959 0.328125q0.15625 -0.3125 0.46875 -0.3125q0.203125 0 0.359375 0.140625q0.15625 0.125 0.15625 0.328125q0 0.109375 -0.046875 0.203125l-2.59375 5.609375q-0.078125 0.171875 -0.25 0.28125q-0.15625 0.09375 -0.34375 0.09375q-0.171875 0 -0.328125 -0.09375q-0.15625 -0.109375 -0.25 -0.28125l-2.59375 -5.609375q-0.046875 -0.09375 -0.046875 -0.1875q0 -0.203125 0.171875 -0.34375q0.1875 -0.15625 0.390625 -0.15625q0.140625 0 0.265625 0.078125q0.125 0.078125 0.1875 0.234375l2.234375 5.0l2.21875 -4.984375zm7.2099915 4.796875q0.140625 0 0.25 0.125q0.109375 0.109375 0.109375 0.296875q0 0.328125 -0.46875 0.609375q-0.484375 0.28125 -1.015625 0.421875q-0.53125 0.140625 -1.046875 0.140625q-1.5 0 -2.375 -0.890625q-0.875 -0.890625 -0.875 -2.46875q0 -1.0 0.390625 -1.765625q0.390625 -0.765625 1.078125 -1.1875q0.703125 -0.4375 1.59375 -0.4375q1.265625 0 2.015625 0.828125q0.75 0.828125 0.75 2.25q0 0.265625 -0.109375 0.390625q-0.109375 0.109375 -0.34375 0.109375l-4.296875 0q0.125 2.296875 2.171875 2.296875q0.53125 0 0.890625 -0.140625q0.375 -0.140625 0.8125 -0.390625q0.34375 -0.1875 0.46875 -0.1875zm-2.34375 -4.3125q-0.84375 0 -1.359375 0.53125q-0.515625 0.53125 -0.609375 1.515625l3.765625 0q-0.015625 -1.0 -0.484375 -1.515625q-0.46875 -0.53125 -1.3125 -0.53125zm7.5551453 -0.8125q0.546875 -0.03125 0.546875 0.453125q0 0.21875 -0.125 0.34375q-0.109375 0.125 -0.40625 0.15625l-0.390625 0.03125q-0.890625 0.078125 -1.328125 0.640625q-0.4375 0.546875 -0.4375 1.296875l0 3.234375q0 0.265625 -0.15625 0.40625q-0.140625 0.125 -0.375 0.125q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -5.625q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.21875 0 0.359375 0.140625q0.140625 0.140625 0.140625 0.375l0 0.75q0.28125 -0.578125 0.796875 -0.890625q0.515625 -0.3125 1.1875 -0.359375l0.1875 -0.015625z" fill-rule="nonzero"/><path fill="#d9d9d9" d="m25.624672 36.249344l301.88977 0l0 69.98425l-301.88977 0z" fill-rule="evenodd"/><path stroke="#cccccc" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" stroke-dasharray="4.0,3.0" d="m25.624672 36.249344l301.88977 0l0 69.98425l-301.88977 0z" fill-rule="evenodd"/><path fill="#434343" d="m134.36497 56.831844q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm9.004181 -1.421875q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.839676 -0.75q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm5.84729 6.0625q-0.56248474 0 -1.0624847 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.87498474 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0624847 -0.234375 -1.5156097 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.1562347 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.56248474 0 -0.90623474 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84373474 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125zm6.2131653 0q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm7.1288147 -5.25q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm1.970398 6.03125q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.5434265 0q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125zm4.721527 0.015625q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm12.222534 -4.9375q0.125 -0.28125 0.390625 -0.28125q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.078125 -0.03125 0.171875l-1.984375 5.046875q-0.078125 0.15625 -0.21875 0.25q-0.140625 0.078125 -0.296875 0.078125q-0.15625 0 -0.296875 -0.078125q-0.140625 -0.09375 -0.21875 -0.25l-1.65625 -4.21875l-1.640625 4.21875q-0.0625 0.15625 -0.203125 0.25q-0.140625 0.078125 -0.3125 0.078125q-0.15625 0 -0.296875 -0.078125q-0.140625 -0.09375 -0.21875 -0.25l-1.984375 -5.03125q-0.046875 -0.09375 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.171875 -0.140625 0.359375 -0.140625q0.296875 0 0.40625 0.296875l1.65625 4.421875l1.6875 -4.390625q0.078125 -0.15625 0.203125 -0.234375q0.125 -0.09375 0.265625 -0.09375q0.15625 0 0.28125 0.09375q0.125 0.078125 0.1875 0.234375l1.6875 4.375l1.65625 -4.40625zm12.637604 5.09375q0.046875 0.09375 0.046875 0.203125q0 0.171875 -0.140625 0.296875q-0.140625 0.125 -0.328125 0.125q-0.296875 0 -0.421875 -0.296875l-0.84375 -1.9375l-4.53125 0l-0.859375 1.9375q-0.125 0.296875 -0.421875 0.296875q-0.1875 0 -0.34375 -0.125q-0.140625 -0.125 -0.140625 -0.3125q0 -0.09375 0.046875 -0.1875l3.4375 -7.640625q0.078125 -0.15625 0.21875 -0.234375q0.140625 -0.09375 0.3125 -0.09375q0.171875 0 0.3125 0.09375q0.15625 0.078125 0.21875 0.234375l3.4375 7.640625zm-5.859375 -2.421875l3.8125 0l-1.90625 -4.3125l-1.90625 4.3125zm7.78656 3.046875q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm4.9744263 4.34375q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm4.4157715 0.015625q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125z" fill-rule="nonzero"/><path fill="#f3f3f3" d="m396.75067 183.75066l249.00787 0l0 203.02364l-249.00787 0z" fill-rule="evenodd"/><path stroke="#cccccc" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m396.75067 183.75066l249.00787 0l0 203.02364l-249.00787 0z" fill-rule="evenodd"/><path fill="#434343" d="m409.42255 374.66803q-0.90625 0 -1.609375 -0.40625q-0.6875 -0.421875 -1.078125 -1.171875q-0.375 -0.765625 -0.375 -1.765625q0 -1.0 0.390625 -1.765625q0.40625 -0.78125 1.109375 -1.203125q0.703125 -0.4375 1.625 -0.4375q0.5 0 1.0 0.140625q0.5 0.140625 0.875 0.40625q0.234375 0.171875 0.328125 0.328125q0.109375 0.140625 0.109375 0.328125q0 0.1875 -0.109375 0.3125q-0.09375 0.109375 -0.25 0.109375q-0.09375 0 -0.203125 -0.046875q-0.09375 -0.046875 -0.171875 -0.09375q-0.078125 -0.0625 -0.09375 -0.078125q-0.359375 -0.234375 -0.671875 -0.359375q-0.3125 -0.140625 -0.765625 -0.140625q-0.96875 0 -1.515625 0.671875q-0.53125 0.65625 -0.53125 1.828125q0 1.171875 0.53125 1.8125q0.546875 0.640625 1.515625 0.640625q0.453125 0 0.78125 -0.125q0.328125 -0.140625 0.65625 -0.375q0.15625 -0.09375 0.28125 -0.15625q0.140625 -0.0625 0.234375 -0.0625q0.140625 0 0.234375 0.125q0.109375 0.109375 0.109375 0.296875q0 0.171875 -0.09375 0.3125q-0.09375 0.140625 -0.34375 0.3125q-0.375 0.25 -0.90625 0.40625q-0.515625 0.15625 -1.0625 0.15625zm4.2591553 -0.03125q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -8.46875q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.21875 0 0.375 0.140625q0.15625 0.140625 0.15625 0.390625l0 8.46875q0 0.25 -0.15625 0.390625q-0.15625 0.140625 -0.375 0.140625zm3.092102 0q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -5.625q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.234375 0 0.375 0.140625q0.15625 0.140625 0.15625 0.390625l0 5.625q0 0.265625 -0.15625 0.40625q-0.140625 0.125 -0.375 0.125zm0 -8.09375q-0.3125 0 -0.515625 -0.171875q-0.203125 -0.1875 -0.203125 -0.5q0 -0.296875 0.203125 -0.484375q0.203125 -0.1875 0.515625 -0.1875q0.328125 0 0.515625 0.1875q0.203125 0.1875 0.203125 0.484375q0 0.3125 -0.203125 0.5q-0.1875 0.171875 -0.515625 0.171875zm7.5765076 6.53125q0.140625 0 0.25 0.125q0.109375 0.109375 0.109375 0.296875q0 0.328125 -0.46875 0.609375q-0.484375 0.28125 -1.015625 0.421875q-0.53125 0.140625 -1.046875 0.140625q-1.5 0 -2.375 -0.890625q-0.875 -0.890625 -0.875 -2.46875q0 -1.0 0.390625 -1.765625q0.390625 -0.765625 1.078125 -1.1875q0.703125 -0.4375 1.59375 -0.4375q1.265625 0 2.015625 0.828125q0.75 0.828125 0.75 2.25q0 0.265625 -0.109375 0.390625q-0.109375 0.109375 -0.34375 0.109375l-4.296875 0q0.125 2.296875 2.171875 2.296875q0.53125 0 0.890625 -0.140625q0.375 -0.140625 0.8125 -0.390625q0.34375 -0.1875 0.46875 -0.1875zm-2.34375 -4.3125q-0.84375 0 -1.359375 0.53125q-0.515625 0.53125 -0.609375 1.515625l3.765625 0q-0.015625 -1.0 -0.484375 -1.515625q-0.46875 -0.53125 -1.3125 -0.53125zm7.6020203 -0.84375q2.328125 0 2.328125 2.578125l0 3.609375q0 0.25 -0.140625 0.390625q-0.140625 0.140625 -0.390625 0.140625q-0.25 0 -0.40625 -0.140625q-0.140625 -0.140625 -0.140625 -0.390625l0 -3.546875q0 -0.90625 -0.359375 -1.3125q-0.34375 -0.421875 -1.125 -0.421875q-0.890625 0 -1.421875 0.546875q-0.53125 0.546875 -0.53125 1.484375l0 3.25q0 0.25 -0.140625 0.390625q-0.140625 0.140625 -0.390625 0.140625q-0.25 0 -0.40625 -0.140625q-0.140625 -0.140625 -0.140625 -0.390625l0 -5.625q0 -0.234375 0.140625 -0.375q0.15625 -0.15625 0.40625 -0.15625q0.234375 0 0.375 0.15625q0.140625 0.140625 0.140625 0.359375l0 0.6875q0.328125 -0.609375 0.890625 -0.921875q0.578125 -0.3125 1.3125 -0.3125zm7.304718 5.875q0.46875 0.03125 0.46875 0.421875q0 0.21875 -0.171875 0.34375q-0.171875 0.109375 -0.5 0.078125l-0.359375 -0.015625q-1.0625 -0.09375 -1.578125 -0.640625q-0.5 -0.5625 -0.5 -1.703125l0 -3.34375l-0.890625 0q-0.234375 0 -0.359375 -0.109375q-0.125 -0.109375 -0.125 -0.296875q0 -0.203125 0.125 -0.3125q0.125 -0.125 0.359375 -0.125l0.890625 0l0 -1.515625q0 -0.25 0.140625 -0.390625q0.15625 -0.140625 0.40625 -0.140625q0.234375 0 0.375 0.140625q0.15625 0.140625 0.15625 0.390625l0 1.515625l1.484375 0q0.203125 0 0.328125 0.125q0.140625 0.109375 0.140625 0.3125q0 0.1875 -0.140625 0.296875q-0.125 0.109375 -0.328125 0.109375l-1.484375 0l0 3.40625q0 0.734375 0.296875 1.0625q0.296875 0.3125 0.90625 0.359375l0.359375 0.03125z" fill-rule="nonzero"/><path fill="#f4cccc" d="m206.61942 201.17455l140.47244 0l0 30.992126l-140.47244 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m206.61942 201.17455l140.47244 0l0 30.992126l-140.47244 0z" fill-rule="evenodd"/><path fill="#000000" d="m237.0857 213.5031q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm4.248535 1.71875q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.861023 4.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm8.417801 3.875q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm8.199051 4.46875q0.203125 0 0.296875 0.109375q0.109375 0.09375 0.109375 0.265625q0 0.1875 -0.109375 0.296875q-0.09375 0.09375 -0.296875 0.09375l-4.203125 0q-0.203125 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.3125q0 -0.1875 0.140625 -0.359375l3.546875 -4.28125l-3.28125 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l4.0625 0q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.3125q0 0.1875 -0.140625 0.359375l-3.5625 4.28125l3.421875 0zm6.2547913 -0.59375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm3.3865662 5.875q-0.171875 0 -0.28125 -0.09375q-0.109375 -0.09375 -0.109375 -0.21875q0 -0.140625 0.109375 -0.234375q0.109375 -0.09375 0.28125 -0.09375l5.21875 0q0.171875 0 0.28125 0.09375q0.109375 0.09375 0.109375 0.234375q0 0.125 -0.109375 0.21875q-0.109375 0.09375 -0.28125 0.09375l-5.21875 0zm11.2500305 -6.609375q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 5.09375q0 1.296875 -0.671875 1.96875q-0.671875 0.671875 -1.984375 0.671875q-1.28125 0 -2.140625 -0.515625q-0.421875 -0.234375 -0.421875 -0.546875q0 -0.171875 0.078125 -0.28125q0.09375 -0.109375 0.234375 -0.109375q0.125 0 0.4375 0.171875q0.421875 0.21875 0.828125 0.34375q0.40625 0.140625 0.96875 0.140625q0.859375 0 1.28125 -0.453125q0.4375 -0.453125 0.4375 -1.3125l0 -1.03125q-0.25 0.5625 -0.78125 0.859375q-0.515625 0.296875 -1.21875 0.296875q-0.765625 0 -1.359375 -0.359375q-0.59375 -0.359375 -0.9375 -1.015625q-0.328125 -0.65625 -0.328125 -1.515625q0 -0.875 0.328125 -1.53125q0.34375 -0.65625 0.9375 -1.015625q0.59375 -0.359375 1.359375 -0.359375q0.6875 0 1.203125 0.296875q0.515625 0.296875 0.78125 0.84375l0 -0.640625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625zm-2.28125 4.984375q0.84375 0 1.3125 -0.546875q0.484375 -0.5625 0.484375 -1.546875q0 -0.984375 -0.46875 -1.53125q-0.46875 -0.5625 -1.328125 -0.5625q-0.84375 0 -1.34375 0.5625q-0.484375 0.546875 -0.484375 1.53125q0 0.984375 0.484375 1.546875q0.5 0.546875 1.34375 0.546875zm7.4695435 -4.984375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.3131714 -5.296875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.20282 -5.265625q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm4.331665 6.046875q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm5.2167664 -6.046875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm8.45282 -4.9375q0.140625 -0.296875 0.421875 -0.296875q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.109375 -0.046875 0.1875l-3.375 7.28125q-0.0625 0.125 -0.171875 0.1875q-0.109375 0.078125 -0.234375 0.078125q-0.1875 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.296875q0 -0.09375 0.046875 -0.1875l0.84375 -1.8125l-2.375 -5.140625q-0.046875 -0.078125 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.15625 -0.140625 0.359375 -0.140625q0.109375 0 0.21875 0.078125q0.125 0.078125 0.1875 0.203125l2.0 4.5l2.0 -4.46875z" fill-rule="nonzero"/><path fill="#f4cccc" d="m132.49081 319.42978l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m132.49081 319.42978l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m163.01448 339.50836q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm8.160431 0.03125q-1.171875 0 -2.046875 -0.515625q-0.859375 -0.53125 -1.328125 -1.5q-0.46875 -0.984375 -0.46875 -2.296875q0 -1.34375 0.453125 -2.3125q0.46875 -0.984375 1.328125 -1.5q0.875 -0.53125 2.0625 -0.53125q1.1875 0 2.0625 0.53125q0.875 0.515625 1.328125 1.5q0.46875 0.96875 0.46875 2.296875q0 1.3125 -0.46875 2.296875q-0.46875 0.984375 -1.34375 1.515625q-0.859375 0.515625 -2.046875 0.515625zm0 -0.84375q1.34375 0 2.09375 -0.90625q0.75 -0.90625 0.75 -2.578125q0 -1.6875 -0.75 -2.578125q-0.734375 -0.90625 -2.09375 -0.90625q-1.34375 0 -2.09375 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.09375 0.90625zm9.214935 0.84375q-1.1875 0 -2.0625 -0.515625q-0.875 -0.53125 -1.359375 -1.5q-0.46875 -0.984375 -0.46875 -2.3125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.359375 -1.5q0.875 -0.53125 2.0625 -0.53125q0.8125 0 1.515625 0.265625q0.71875 0.25 1.25 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.21875 0.125q-0.15625 0 -0.359375 -0.140625q-0.609375 -0.46875 -1.109375 -0.65625q-0.5 -0.203125 -1.140625 -0.203125q-1.390625 0 -2.140625 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.140625 0.90625q0.640625 0 1.140625 -0.1875q0.5 -0.1875 1.109375 -0.671875q0.203125 -0.125 0.359375 -0.125q0.125 0 0.21875 0.125q0.09375 0.109375 0.09375 0.296875q0 0.234375 -0.1875 0.40625q-0.53125 0.484375 -1.25 0.75q-0.703125 0.25 -1.515625 0.25zm8.077179 0q-1.171875 0 -2.046875 -0.515625q-0.859375 -0.53125 -1.328125 -1.5q-0.46875 -0.984375 -0.46875 -2.296875q0 -1.34375 0.453125 -2.3125q0.46875 -0.984375 1.328125 -1.5q0.875 -0.53125 2.0625 -0.53125q1.1875 0 2.0625 0.53125q0.875 0.515625 1.328125 1.5q0.46875 0.96875 0.46875 2.296875q0 1.3125 -0.46875 2.296875q-0.46875 0.984375 -1.34375 1.515625q-0.859375 0.515625 -2.046875 0.515625zm0 -0.84375q1.34375 0 2.09375 -0.90625q0.75 -0.90625 0.75 -2.578125q0 -1.6875 -0.75 -2.578125q-0.734375 -0.90625 -2.09375 -0.90625q-1.34375 0 -2.09375 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.09375 0.90625z" fill-rule="nonzero"/><path fill="#d9ead3" d="m284.12296 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m284.12296 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m314.7006 332.47687q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm5.113556 0q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.6840515 -0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -7.5625q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.171875l3.875 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-4.375 0zm6.3394165 0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm4.987152 6.515625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.9081726 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#000000" d="m303.37402 346.47687q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.5434265 0q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125zm4.674652 -6.046875q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.0631714 -0.015625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm4.3300476 -5.28125q0.765625 0 1.34375 0.375q0.59375 0.359375 0.921875 1.046875q0.328125 0.6875 0.328125 1.59375q0 0.90625 -0.328125 1.59375q-0.328125 0.6875 -0.921875 1.078125q-0.578125 0.375 -1.34375 0.375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 0.640625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.203125q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.59375q0.46875 -0.59375 0.46875 -1.65625q0 -1.046875 -0.46875 -1.625q-0.46875 -0.578125 -1.328125 -0.578125q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm8.687164 -5.25q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 5.078125q0 0.203125 -0.125 0.34375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -0.609375q-0.28125 0.53125 -0.78125 0.8125q-0.5 0.265625 -1.125 0.265625q-1.03125 0 -1.5625 -0.578125q-0.53125 -0.578125 -0.53125 -1.71875l0 -3.265625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 3.234375q0 0.78125 0.3125 1.15625q0.3125 0.359375 0.984375 0.359375q0.765625 0 1.234375 -0.5q0.46875 -0.5 0.46875 -1.3125l0 -2.9375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625zm4.8726807 -1.71875q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm3.9360352 0q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm5.873535 6.328125q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.7927856 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625z" fill-rule="nonzero"/><path fill="#f4cccc" d="m413.02625 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m413.02625 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m443.6039 332.47687q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm5.113556 0q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.6840515 -0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -7.5625q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.171875l3.875 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-4.375 0zm6.3394165 0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm4.987152 6.515625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.908142 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#000000" d="m429.9527 346.47687q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm5.237152 1.234375q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm6.56604 5.28125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.9081726 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.7927856 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm4.282898 -0.015625q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.14032 -5.25q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.861023 4.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.5896606 4.53125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.9081726 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.7927856 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m371.61902 334.89435l41.417297 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m371.61902 334.89435l37.990234 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m409.60925 334.89435l-1.1245728 1.1246033l3.0897522 -1.1246033l-3.0897522 -1.1245728z" fill-rule="evenodd"/><path fill="#c9daf8" d="m548.5407 277.52954l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 277.52954l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m587.0588 293.13934q0.1875 0 0.296875 0.109375q0.109375 0.109375 0.109375 0.296875l0 2.984375q0 0.296875 -0.09375 0.4375q-0.078125 0.140625 -0.328125 0.234375q-0.46875 0.203125 -1.15625 0.328125q-0.6875 0.109375 -1.375 0.109375q-1.25 0 -2.171875 -0.515625q-0.90625 -0.515625 -1.390625 -1.484375q-0.484375 -0.96875 -0.484375 -2.328125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.375 -1.5q0.90625 -0.53125 2.125 -0.53125q0.84375 0 1.5625 0.265625q0.71875 0.25 1.203125 0.734375q0.21875 0.203125 0.21875 0.421875q0 0.171875 -0.109375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.140625 0 -0.328125 -0.140625q-0.625 -0.484375 -1.140625 -0.671875q-0.5 -0.1875 -1.15625 -0.1875q-1.4375 0 -2.203125 0.90625q-0.75 0.890625 -0.75 2.578125q0 1.71875 0.765625 2.609375q0.78125 0.890625 2.28125 0.890625q1.109375 0 2.03125 -0.328125l0 -2.578125l-1.75 0q-0.203125 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.265625q0 -0.1875 0.125 -0.28125q0.125 -0.109375 0.328125 -0.109375l2.234375 0zm2.8911743 4.46875q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm7.7869263 4.375q-1.65625 0 -2.515625 -0.859375q-0.84375 -0.859375 -0.84375 -2.546875l0 -4.703125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.78125q0 1.25 0.609375 1.875q0.609375 0.609375 1.78125 0.609375q1.171875 0 1.765625 -0.609375q0.609375 -0.625 0.609375 -1.875l0 -4.78125q0 -0.234375 0.140625 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.703125q0 1.671875 -0.859375 2.546875q-0.859375 0.859375 -2.5 0.859375z" fill-rule="nonzero"/><path fill="#c9daf8" d="m548.5407 319.3983l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 319.3983l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m584.63763 339.50812q-1.1875 0 -2.0625 -0.515625q-0.875 -0.53125 -1.359375 -1.5q-0.46875 -0.984375 -0.46875 -2.3125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.359375 -1.5q0.875 -0.53125 2.0625 -0.53125q0.8125 0 1.515625 0.265625q0.71875 0.25 1.25 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.21875 0.125q-0.15625 0 -0.359375 -0.140625q-0.609375 -0.46875 -1.109375 -0.65625q-0.5 -0.203125 -1.140625 -0.203125q-1.390625 0 -2.140625 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.140625 0.90625q0.640625 0 1.140625 -0.1875q0.5 -0.1875 1.109375 -0.671875q0.203125 -0.125 0.359375 -0.125q0.125 0 0.21875 0.125q0.09375 0.109375 0.09375 0.296875q0 0.234375 -0.1875 0.40625q-0.53125 0.484375 -1.25 0.75q-0.703125 0.25 -1.515625 0.25zm5.0302734 -0.03125q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm7.7869263 4.375q-1.65625 0 -2.515625 -0.859375q-0.84375 -0.859375 -0.84375 -2.546875l0 -4.703125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.78125q0 1.25 0.609375 1.875q0.609375 0.609375 1.78125 0.609375q1.171875 0 1.765625 -0.609375q0.609375 -0.625 0.609375 -1.875l0 -4.78125q0 -0.234375 0.140625 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.703125q0 1.671875 -0.859375 2.546875q-0.859375 0.859375 -2.5 0.859375z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m219.98688 334.92584l64.12598 -0.03149414" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m219.98688 334.92584l60.698914 -0.029815674" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m280.68576 334.89603l-1.1240234 1.1251526l3.0892334 -1.1260986l-3.090332 -1.1230774z" fill-rule="evenodd"/><path fill="#d9ead3" d="m413.02625 141.28871l20.53543 0l0 20.53543l-20.53543 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m413.02625 141.28871l20.53543 0l0 20.53543l-20.53543 0z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m437.52493 135.68242l73.763794 0l0 31.748032l-73.763794 0z" fill-rule="evenodd"/><path fill="#000000" d="m448.0718 156.20241q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm8.3211975 -5.140625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.0631714 -0.015625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.767517 -5.28125q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm10.15921 0.75q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm8.691681 -5.71875q0.140625 -0.296875 0.421875 -0.296875q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.109375 -0.046875 0.1875l-3.375 7.28125q-0.0625 0.125 -0.171875 0.1875q-0.109375 0.078125 -0.234375 0.078125q-0.1875 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.296875q0 -0.09375 0.046875 -0.1875l0.84375 -1.8125l-2.375 -5.140625q-0.046875 -0.078125 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.15625 -0.140625 0.359375 -0.140625q0.109375 0 0.21875 0.078125q0.125 0.078125 0.1875 0.203125l2.0 4.5l2.0 -4.46875zm4.902405 -0.328125q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm8.76532 -0.640625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#f4cccc" d="m519.9029 141.28871l20.5354 0l0 20.53543l-20.5354 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m519.9029 141.28871l20.5354 0l0 20.53543l-20.5354 0z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m544.40155 135.68242l100.0 0l0 31.748032l-100.0 0z" fill-rule="evenodd"/><path fill="#000000" d="m554.9328 156.26491q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm5.3845215 -6.046875q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm6.456726 -1.703125q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm4.248535 1.71875q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm6.3444214 0.765625q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125zm6.47876 -0.78125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm4.283142 -5.265625q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.782898 0q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 5.078125q0 0.203125 -0.125 0.34375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -0.609375q-0.28125 0.53125 -0.78125 0.8125q-0.5 0.265625 -1.125 0.265625q-1.03125 0 -1.5625 -0.578125q-0.53125 -0.578125 -0.53125 -1.71875l0 -3.265625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 3.234375q0 0.78125 0.3125 1.15625q0.3125 0.359375 0.984375 0.359375q0.765625 0 1.234375 -0.5q0.46875 -0.5 0.46875 -1.3125l0 -2.9375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625zm4.7008057 6.046875q-0.8125 0 -1.453125 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.6875 -0.34375 -1.578125q0 -0.90625 0.359375 -1.59375q0.359375 -0.703125 0.984375 -1.078125q0.640625 -0.390625 1.46875 -0.390625q0.453125 0 0.90625 0.125q0.453125 0.125 0.78125 0.359375q0.21875 0.140625 0.3125 0.28125q0.09375 0.140625 0.09375 0.3125q0 0.171875 -0.09375 0.28125q-0.09375 0.09375 -0.234375 0.09375q-0.078125 0 -0.1875 -0.046875q-0.09375 -0.046875 -0.15625 -0.09375q-0.0625 -0.046875 -0.09375 -0.0625q-0.3125 -0.203125 -0.59375 -0.3125q-0.28125 -0.125 -0.6875 -0.125q-0.875 0 -1.359375 0.59375q-0.484375 0.59375 -0.484375 1.65625q0 1.046875 0.484375 1.625q0.484375 0.578125 1.359375 0.578125q0.40625 0 0.703125 -0.109375q0.296875 -0.125 0.59375 -0.328125q0.140625 -0.09375 0.25 -0.15625q0.125 -0.0625 0.203125 -0.0625q0.140625 0 0.21875 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.15625 -0.09375 0.28125q-0.078125 0.125 -0.296875 0.28125q-0.34375 0.234375 -0.8125 0.375q-0.46875 0.125 -0.953125 0.125zm6.029297 -0.78125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.830017 -5.265625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 5.078125q0 0.203125 -0.125 0.34375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -0.609375q-0.28125 0.53125 -0.78125 0.8125q-0.5 0.265625 -1.125 0.265625q-1.03125 0 -1.5625 -0.578125q-0.53125 -0.578125 -0.53125 -1.71875l0 -3.265625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 3.234375q0 0.78125 0.3125 1.15625q0.3125 0.359375 0.984375 0.359375q0.765625 0 1.234375 -0.5q0.46875 -0.5 0.46875 -1.3125l0 -2.9375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625zm5.1851807 0q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.861023 4.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#d9ead3" d="m31.874912 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m31.874912 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m67.27695 264.03653q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.359375l0 7.578125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.234375 0 -0.375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -3.4375l-5.062496 0l0 3.4375q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.234375 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 3.296875l5.062496 0l0 -3.296875q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.375 -0.140625zm3.0648193 8.515625q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm6.5711823 0.90625q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm9.0746765 -5.359375q0.8125 0 1.40625 0.34375q0.609375 0.328125 0.9375 0.9375q0.328125 0.59375 0.328125 1.390625q0 0.78125 -0.359375 1.40625q-0.359375 0.625 -1.0 0.96875q-0.640625 0.328125 -1.484375 0.328125q-0.734375 0 -1.453125 -0.25q-0.703125 -0.265625 -1.1875 -0.734375q-0.203125 -0.171875 -0.203125 -0.40625q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.234375 -0.125q0.171875 0 0.34375 0.140625q0.515625 0.4375 1.046875 0.640625q0.53125 0.203125 1.109375 0.203125q0.890625 0 1.390625 -0.5q0.5 -0.5 0.5 -1.359375q0 -0.84375 -0.5 -1.359375q-0.5 -0.515625 -1.359375 -0.515625q-1.09375 0 -1.78125 0.84375q-0.15625 0.171875 -0.40625 0.171875q-0.15625 0 -0.28125 -0.09375q-0.109375 -0.109375 -0.109375 -0.296875l0 -4.125q0 -0.21875 0.125 -0.34375q0.125 -0.125 0.359375 -0.125l4.21875 0q0.21875 0 0.34375 0.109375q0.125 0.09375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.125 0.109375 -0.34375 0.109375l-3.734375 0l0 3.015625q0.34375 -0.328125 0.78125 -0.5q0.453125 -0.171875 0.984375 -0.171875z" fill-rule="nonzero"/><path fill="#d9ead3" d="m190.14 134.76706l87.49608 0l0 30.992126l-87.49608 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m190.14 134.76706l87.49608 0l0 30.992126l-87.49608 0z" fill-rule="evenodd"/><path fill="#000000" d="m215.10997 150.37688q0.1875 0 0.296875 0.109375q0.109375 0.109375 0.109375 0.296875l0 2.984375q0 0.296875 -0.09375 0.4375q-0.078125 0.140625 -0.328125 0.234375q-0.46875 0.203125 -1.15625 0.328125q-0.6875 0.109375 -1.375 0.109375q-1.25 0 -2.171875 -0.515625q-0.90625 -0.515625 -1.390625 -1.484375q-0.484375 -0.96875 -0.484375 -2.328125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.375 -1.5q0.90625 -0.53125 2.125 -0.53125q0.84375 0 1.5625 0.265625q0.71875 0.25 1.203125 0.734375q0.21875 0.203125 0.21875 0.421875q0 0.171875 -0.109375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.140625 0 -0.328125 -0.140625q-0.625 -0.484375 -1.140625 -0.671875q-0.5 -0.1875 -1.15625 -0.1875q-1.4375 0 -2.203125 0.90625q-0.75 0.890625 -0.75 2.578125q0 1.71875 0.765625 2.609375q0.78125 0.890625 2.28125 0.890625q1.109375 0 2.03125 -0.328125l0 -2.578125l-1.75 0q-0.203125 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.265625q0 -0.1875 0.125 -0.28125q0.125 -0.109375 0.328125 -0.109375l2.234375 0zm5.1568146 -1.5625q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.3131714 -5.296875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.2028046 -5.265625q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm4.5035553 5.984375q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm10.461807 -0.515625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.480301 -2.453125q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125z" fill-rule="nonzero"/><path fill="#d9ead3" d="m233.1085 252.53609l87.49608 0l0 30.992142l-87.49608 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m233.1085 252.53609l87.49608 0l0 30.992142l-87.49608 0z" fill-rule="evenodd"/><path fill="#000000" d="m260.00964 265.61465q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm8.9496765 -6.03125q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.767273 6.046875q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm8.535065 -0.046875q0.203125 0 0.296875 0.109375q0.109375 0.09375 0.109375 0.265625q0 0.1875 -0.109375 0.296875q-0.09375 0.09375 -0.296875 0.09375l-4.203125 0q-0.203125 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.3125q0 -0.1875 0.140625 -0.359375l3.546875 -4.28125l-3.28125 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l4.0625 0q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.3125q0 0.1875 -0.140625 0.359375l-3.5625 4.28125l3.421875 0zm6.2547913 -0.59375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.8396606 -0.75q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125z" fill-rule="nonzero"/><path fill="#000000" d="m258.07846 275.1459q0.1875 0 0.296875 0.109375q0.109375 0.109375 0.109375 0.296875l0 2.984375q0 0.296875 -0.09375 0.4375q-0.078125 0.140625 -0.328125 0.234375q-0.46875 0.203125 -1.15625 0.328125q-0.6875 0.109375 -1.3749847 0.109375q-1.25 0 -2.171875 -0.515625q-0.90625 -0.515625 -1.390625 -1.484375q-0.484375 -0.96875 -0.484375 -2.328125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.375 -1.5q0.90625 -0.53125 2.125 -0.53125q0.84373474 0 1.5624847 0.265625q0.71875 0.25 1.203125 0.734375q0.21875 0.203125 0.21875 0.421875q0 0.171875 -0.109375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.140625 0 -0.328125 -0.140625q-0.625 -0.484375 -1.140625 -0.671875q-0.5 -0.1875 -1.1562347 -0.1875q-1.4375 0 -2.203125 0.90625q-0.75 0.890625 -0.75 2.578125q0 1.71875 0.765625 2.609375q0.78125 0.890625 2.28125 0.890625q1.1093597 0 2.0312347 -0.328125l0 -2.578125l-1.7499847 0q-0.203125 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.265625q0 -0.1875 0.125 -0.28125q0.125 -0.109375 0.328125 -0.109375l2.2343597 0zm5.15683 -1.5625q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.3131714 -5.296875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.2027893 -5.265625q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm4.5035706 5.984375q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm10.461792 -0.515625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.480316 -2.453125q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m276.85565 232.16667l0 20.377945" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85565 232.16667l0 16.950867" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m276.85565 249.11754l-1.1246033 -1.124588l1.1246033 3.0897675l1.1245728 -3.0897675z" fill-rule="evenodd"/><path fill="#f4cccc" d="m31.874016 68.3563l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m31.874016 68.3563l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m58.725647 87.669235q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.9706573 -6.984375q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm1.8266602 7.75q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm8.498016 -0.8125q0.171875 0.15625 0.171875 0.359375q0 0.15625 -0.140625 0.296875q-0.140625 0.140625 -0.3125 0.140625q-0.15625 0 -0.328125 -0.140625l-4.484375 -3.921875l0 3.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 3.4375l4.28125 -3.796875q0.125 -0.140625 0.3125 -0.140625q0.171875 0 0.296875 0.140625q0.140625 0.140625 0.140625 0.3125q0 0.171875 -0.15625 0.328125l-3.875 3.421875l4.09375 3.5625zm5.8329315 -0.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.792801 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm6.3444214 0.765625q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125z" fill-rule="nonzero"/><path fill="#f4cccc" d="m132.49081 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m132.49081 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m152.20152 88.37367q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.484375 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-4.015625 0l0 2.9375l3.78125 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.78125 0l0 3.078125l4.015625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-4.484375 0zm8.31218 0.078125q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125zm6.4787903 -0.78125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm1.8769073 0.765625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm8.799652 1.234375q1.9375 0 1.9375 2.3125l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.328125 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.890625 -0.359375q-0.734375 0 -1.15625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.90625 -0.359375q-0.71875 0 -1.140625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.359375 -0.140625q0.203125 0 0.328125 0.125q0.140625 0.125 0.140625 0.34375l0 0.578125q0.265625 -0.515625 0.734375 -0.78125q0.46875 -0.28125 1.078125 -0.28125q1.375 0 1.78125 1.140625q0.265625 -0.515625 0.78125 -0.828125q0.515625 -0.3125 1.171875 -0.3125zm6.0990753 0q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.0631714 -0.015625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.8144073 0.78125q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm7.1287994 -5.25q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625z" fill-rule="nonzero"/><path fill="#f4cccc" d="m233.1076 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m233.1076 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m269.00754 88.46742q-0.90625 0 -1.734375 -0.265625q-0.8125 -0.265625 -1.3125 -0.734375q-0.171875 -0.15625 -0.171875 -0.40625q0 -0.171875 0.09375 -0.296875q0.09375 -0.125 0.234375 -0.125q0.15625 0 0.328125 0.125q1.109375 0.859375 2.546875 0.859375q1.03125 0 1.578125 -0.390625q0.5625 -0.390625 0.5625 -1.125q0 -0.421875 -0.265625 -0.671875q-0.265625 -0.265625 -0.703125 -0.421875q-0.4375 -0.15625 -1.15625 -0.328125q-0.984375 -0.21875 -1.625 -0.46875q-0.625 -0.265625 -1.015625 -0.734375q-0.390625 -0.46875 -0.390625 -1.21875q0 -0.71875 0.390625 -1.265625q0.390625 -0.5625 1.09375 -0.875q0.703125 -0.3125 1.59375 -0.3125q0.84375 0 1.5625 0.265625q0.734375 0.25 1.234375 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.125 0 -0.34375 -0.140625q-0.59375 -0.46875 -1.09375 -0.65625q-0.5 -0.203125 -1.21875 -0.203125q-0.984375 0 -1.546875 0.421875q-0.546875 0.40625 -0.546875 1.15625q0 0.625 0.484375 0.953125q0.484375 0.3125 1.5 0.5625q1.09375 0.25 1.71875 0.484375q0.625 0.21875 1.03125 0.671875q0.421875 0.4375 0.421875 1.171875q0 0.71875 -0.390625 1.265625q-0.390625 0.53125 -1.109375 0.828125q-0.703125 0.296875 -1.609375 0.296875zm5.0446777 -0.03125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125zm2.784027 0q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm8.799652 1.234375q1.9375 0 1.9375 2.3125l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.328125 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.890625 -0.359375q-0.734375 0 -1.15625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.90625 -0.359375q-0.71875 0 -1.140625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.359375 -0.140625q0.203125 0 0.328125 0.125q0.140625 0.125 0.140625 0.34375l0 0.578125q0.265625 -0.515625 0.734375 -0.78125q0.46875 -0.28125 1.078125 -0.28125q1.375 0 1.78125 1.140625q0.265625 -0.515625 0.78125 -0.828125q0.515625 -0.3125 1.171875 -0.3125z" fill-rule="nonzero"/><path fill="#d9ead3" d="m282.5035 134.76706l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m282.5035 134.76706l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m297.8283 154.87688q-1.1875 0 -2.0625 -0.515625q-0.875 -0.53125 -1.359375 -1.5q-0.46875 -0.984375 -0.46875 -2.3125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.359375 -1.5q0.875 -0.53125 2.0625 -0.53125q0.8125 0 1.515625 0.265625q0.71875 0.25 1.25 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.21875 0.125q-0.15625 0 -0.359375 -0.140625q-0.609375 -0.46875 -1.109375 -0.65625q-0.5 -0.203125 -1.140625 -0.203125q-1.390625 0 -2.140625 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.140625 0.90625q0.640625 0 1.140625 -0.1875q0.5 -0.1875 1.109375 -0.671875q0.203125 -0.125 0.359375 -0.125q0.125 0 0.21875 0.125q0.09375 0.109375 0.09375 0.296875q0 0.234375 -0.1875 0.40625q-0.53125 0.484375 -1.25 0.75q-0.703125 0.25 -1.515625 0.25zm7.358429 -6.078125q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm8.37854 4.625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.308441 5.3125q-0.8125 0 -1.453125 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.6875 -0.34375 -1.578125q0 -0.90625 0.359375 -1.59375q0.359375 -0.703125 0.984375 -1.078125q0.640625 -0.390625 1.46875 -0.390625q0.453125 0 0.90625 0.125q0.453125 0.125 0.78125 0.359375q0.21875 0.140625 0.3125 0.28125q0.09375 0.140625 0.09375 0.3125q0 0.171875 -0.09375 0.28125q-0.09375 0.09375 -0.234375 0.09375q-0.078125 0 -0.1875 -0.046875q-0.09375 -0.046875 -0.15625 -0.09375q-0.0625 -0.046875 -0.09375 -0.0625q-0.3125 -0.203125 -0.59375 -0.3125q-0.28125 -0.125 -0.6875 -0.125q-0.875 0 -1.359375 0.59375q-0.484375 0.59375 -0.484375 1.65625q0 1.046875 0.484375 1.625q0.484375 0.578125 1.359375 0.578125q0.40625 0 0.703125 -0.109375q0.296875 -0.125 0.59375 -0.328125q0.140625 -0.09375 0.25 -0.15625q0.125 -0.0625 0.203125 -0.0625q0.140625 0 0.21875 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.15625 -0.09375 0.28125q-0.078125 0.125 -0.296875 0.28125q-0.34375 0.234375 -0.8125 0.375q-0.46875 0.125 -0.953125 0.125zm7.998047 -0.84375q0.203125 0.171875 0.203125 0.375q0 0.1875 -0.125 0.328125q-0.125 0.125 -0.3125 0.125q-0.15625 0 -0.328125 -0.140625l-3.125 -2.703125l0 2.359375q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 4.875l2.859375 -2.625q0.15625 -0.140625 0.328125 -0.140625q0.1875 0 0.3125 0.140625q0.140625 0.125 0.140625 0.296875q0 0.203125 -0.171875 0.359375l-2.375 2.109375l2.59375 2.265625zm4.2812805 -5.21875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm6.67157 0.796875q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm4.722534 0.78125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm5.237152 1.234375q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm6.5660706 5.28125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.361267 0.78125q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m276.85565 99.34974l0 17.70874l-42.960632 0l0 17.724327" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85565 99.34974l0 17.70874l-42.960632 0l0 14.297249" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m233.89502 131.35573l-1.124588 -1.124588l1.124588 3.0897675l1.1245728 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m276.85565 99.34974l0 17.70874l49.385803 0l0 17.724327" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85565 99.34974l0 17.70874l49.385803 0l0 14.297249" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m326.24146 131.35573l-1.1245728 -1.124588l1.1245728 3.0897675l1.1246033 -3.0897675z" fill-rule="evenodd"/><path fill="#c9daf8" d="m548.5407 235.66077l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 235.66077l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m579.47955 247.1612q0.203125 0 0.328125 0.140625q0.125 0.125 0.125 0.359375l0 7.578125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625q-0.234375 0 -0.390625 -0.203125l-4.984375 -6.65625l0 6.359375q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.21875 0 -0.34375 -0.140625q-0.109375 -0.140625 -0.109375 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.40625 0.203125l4.96875 6.65625l0 -6.359375q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.34375 -0.140625zm8.868103 0q0.203125 0 0.328125 0.140625q0.125 0.125 0.125 0.359375l0 7.578125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625q-0.234375 0 -0.390625 -0.203125l-4.984375 -6.65625l0 6.359375q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.21875 0 -0.34375 -0.140625q-0.109375 -0.140625 -0.109375 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.40625 0.203125l4.96875 6.65625l0 -6.359375q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.34375 -0.140625zm12.917175 7.953125q0.046875 0.09375 0.046875 0.203125q0 0.171875 -0.140625 0.296875q-0.140625 0.125 -0.328125 0.125q-0.296875 0 -0.421875 -0.296875l-0.84375 -1.9375l-4.53125 0l-0.859375 1.9375q-0.125 0.296875 -0.421875 0.296875q-0.1875 0 -0.34375 -0.125q-0.140625 -0.125 -0.140625 -0.3125q0 -0.09375 0.046875 -0.1875l3.4375 -7.640625q0.078125 -0.15625 0.21875 -0.234375q0.140625 -0.09375 0.3125 -0.09375q0.171875 0 0.3125 0.09375q0.15625 0.078125 0.21875 0.234375l3.4375 7.640625zm-5.859375 -2.421875l3.8125 0l-1.90625 -4.3125l-1.90625 4.3125zm7.78656 3.046875q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm4.9744263 4.34375q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625z" fill-rule="nonzero"/><path fill="#c9daf8" d="m548.5407 193.79199l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 193.79199l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m589.5417 213.87056q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm2.7480469 0q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm2.7479858 0q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m75.62294 283.52823l0 17.950958l100.62993 0l0 17.954529" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m75.62295 283.52823l0 17.950928l100.62992 0l0 14.527496" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.25287 316.00665l-1.124588 -1.1246033l1.124588 3.0897827l1.124588 -3.0897827z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m276.85654 283.52823l0 17.950958l-100.62991 0l0 17.954529" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85654 283.52823l0 17.950928l-100.62991 0l0 14.527496" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.22662 316.00665l-1.124588 -1.1246033l1.124588 3.0897827l1.124588 -3.0897827z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 0.06298828l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 0.06298828l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 334.95734l-1.1245728 1.1246033l3.0897827 -1.1246033l-3.0897827 -1.1245728z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 -41.858246l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 -41.858246l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 293.0361l-1.1245728 1.1245728l3.0897827 -1.1245728l-3.0897827 -1.1246033z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 -83.74802l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 -83.74802l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 251.14633l-1.1245728 1.1245728l3.0897827 -1.1245728l-3.0897827 -1.124588z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 -125.60629l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 -125.60629l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 209.28806l-1.1245728 1.124588l3.0897827 -1.124588l-3.0897827 -1.124588z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m233.88803 165.75919l0 17.70752l42.960632 0l0 17.694061" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m233.88805 165.75919l0 17.70752l42.960617 0l0 14.266968" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m276.84866 197.73367l-1.1245728 -1.124588l1.1245728 3.0897675l1.1246033 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m326.25156 165.75919l0 17.70752l-49.385834 0l0 17.694061" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m326.25156 165.75919l0 17.70752l-49.385834 0l0 14.266968" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m276.86572 197.73367l-1.1245728 -1.124588l1.1245728 3.0897675l1.1246033 -3.0897675z" fill-rule="evenodd"/><path fill="#d9ead3" d="m132.49171 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m132.49171 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m146.9475 272.6459q-0.90625 0 -1.734375 -0.265625q-0.8125 -0.265625 -1.3125 -0.734375q-0.171875 -0.15625 -0.171875 -0.40625q0 -0.171875 0.09375 -0.296875q0.09375 -0.125 0.234375 -0.125q0.15625 0 0.328125 0.125q1.109375 0.859375 2.546875 0.859375q1.03125 0 1.578125 -0.390625q0.5625 -0.390625 0.5625 -1.125q0 -0.421875 -0.265625 -0.671875q-0.265625 -0.265625 -0.703125 -0.421875q-0.4375 -0.15625 -1.15625 -0.328125q-0.984375 -0.21875 -1.625 -0.46875q-0.625 -0.265625 -1.015625 -0.734375q-0.390625 -0.46875 -0.390625 -1.21875q0 -0.71875 0.390625 -1.265625q0.390625 -0.5625 1.09375 -0.875q0.703125 -0.3125 1.59375 -0.3125q0.84375 0 1.5625 0.265625q0.734375 0.25 1.234375 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.125 0 -0.34375 -0.140625q-0.59375 -0.46875 -1.09375 -0.65625q-0.5 -0.203125 -1.21875 -0.203125q-0.984375 0 -1.546875 0.421875q-0.546875 0.40625 -0.546875 1.15625q0 0.625 0.484375 0.953125q0.484375 0.3125 1.5 0.5625q1.09375 0.25 1.71875 0.484375q0.625 0.21875 1.03125 0.671875q0.421875 0.4375 0.421875 1.171875q0 0.71875 -0.390625 1.265625q-0.390625 0.53125 -1.109375 0.828125q-0.703125 0.296875 -1.609375 0.296875zm6.9353027 -6.078125q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm8.578796 -4.96875q0.140625 -0.296875 0.421875 -0.296875q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.109375 -0.046875 0.1875l-2.34375 5.046875q-0.0625 0.15625 -0.21875 0.25q-0.140625 0.078125 -0.3125 0.078125q-0.15625 0 -0.296875 -0.078125q-0.140625 -0.09375 -0.21875 -0.25l-2.328125 -5.046875q-0.046875 -0.078125 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.15625 -0.140625 0.359375 -0.140625q0.109375 0 0.21875 0.078125q0.125 0.078125 0.1875 0.203125l2.0 4.5l2.0 -4.46875zm6.480545 4.296875q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm8.589676 -3.28125q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.328125l0 7.625q0 0.21875 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.640625q-0.265625 0.546875 -0.78125 0.84375q-0.5 0.296875 -1.1875 0.296875q-0.765625 0 -1.359375 -0.375q-0.578125 -0.390625 -0.90625 -1.078125q-0.328125 -0.6875 -0.328125 -1.59375q0 -0.90625 0.328125 -1.59375q0.328125 -0.6875 0.90625 -1.046875q0.59375 -0.375 1.359375 -0.375q0.6875 0 1.1875 0.296875q0.515625 0.296875 0.78125 0.84375l0 -3.203125q0 -0.21875 0.125 -0.34375q0.125 -0.125 0.359375 -0.125zm-2.25 7.796875q0.84375 0 1.296875 -0.578125q0.46875 -0.59375 0.46875 -1.65625q0 -1.0625 -0.46875 -1.640625q-0.453125 -0.578125 -1.296875 -0.578125q-0.859375 0 -1.34375 0.578125q-0.46875 0.578125 -0.46875 1.625q0 1.0625 0.46875 1.65625q0.484375 0.59375 1.34375 0.59375zm12.202805 -7.796875q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.359375l0 7.59375q0 0.21875 -0.125 0.359375q-0.109375 0.125 -0.328125 0.125q-0.21875 0 -0.328125 -0.125q-0.109375 -0.140625 -0.109375 -0.359375l0 -6.125l-2.59375 4.984375q-0.171875 0.34375 -0.5 0.34375q-0.3125 0 -0.484375 -0.34375l-2.625 -4.921875l0 6.0625q0 0.21875 -0.109375 0.359375q-0.109375 0.125 -0.328125 0.125q-0.21875 0 -0.34375 -0.125q-0.109375 -0.140625 -0.109375 -0.359375l0 -7.59375q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.3125 0 0.484375 0.34375l3.046875 5.84375l3.015625 -5.84375q0.09375 -0.1875 0.203125 -0.265625q0.125 -0.078125 0.28125 -0.078125zm4.8576965 8.59375q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm8.925674 -7.796875q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.328125l0 7.625q0 0.21875 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.640625q-0.265625 0.546875 -0.78125 0.84375q-0.5 0.296875 -1.1875 0.296875q-0.765625 0 -1.359375 -0.375q-0.578125 -0.390625 -0.90625 -1.078125q-0.328125 -0.6875 -0.328125 -1.59375q0 -0.90625 0.328125 -1.59375q0.328125 -0.6875 0.90625 -1.046875q0.59375 -0.375 1.359375 -0.375q0.6875 0 1.1875 0.296875q0.515625 0.296875 0.78125 0.84375l0 -3.203125q0 -0.21875 0.125 -0.34375q0.125 -0.125 0.359375 -0.125zm-2.25 7.796875q0.84375 0 1.296875 -0.578125q0.46875 -0.59375 0.46875 -1.65625q0 -1.0625 -0.46875 -1.640625q-0.453125 -0.578125 -1.296875 -0.578125q-0.859375 0 -1.34375 0.578125q-0.46875 0.578125 -0.46875 1.625q0 1.0625 0.46875 1.65625q0.484375 0.59375 1.34375 0.59375zm9.06218 -0.640625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm4.386551 5.296875q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m176.23885 99.34974l0 153.19684" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m176.23885 99.34974l0 149.76978" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.23885 249.1195l-1.124588 -1.124588l1.124588 3.0897675l1.124588 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m176.23975 283.52823l0 17.950958l0.06298828 0l0 17.954529" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m176.23975 283.52823l0 17.950928l0.06298828 0l0 14.527496" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.30273 316.00665l-1.1245728 -1.1246033l1.1245728 3.0897827l1.124588 -3.0897827z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m75.62205 99.34843l0 153.19684" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m75.62205 99.34843l0 149.76978" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m75.62205 249.1182l-1.1245804 -1.124588l1.1245804 3.0897675l1.1245804 -3.0897675z" fill-rule="evenodd"/></g></svg> \ No newline at end of file
+<svg version="1.1" viewBox="0.0 0.0 720.0 540.0" fill="none" stroke="none" stroke-linecap="square" stroke-miterlimit="10" xmlns:xlink="http://www.w3.org/1999/xlink" xmlns="http://www.w3.org/2000/svg"><clipPath id="p.0"><path d="m0 0l720.0 0l0 540.0l-720.0 0l0 -540.0z" clip-rule="nonzero"/></clipPath><g clip-path="url(#p.0)"><path fill="#000000" fill-opacity="0.0" d="m0 0l720.0 0l0 540.0l-720.0 0z" fill-rule="evenodd"/><path fill="#f3f3f3" d="m19.375328 28.750656l361.6378 0l0 358.01575l-361.6378 0z" fill-rule="evenodd"/><path stroke="#cccccc" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m19.375328 28.750656l361.6378 0l0 358.01575l-361.6378 0z" fill-rule="evenodd"/><path fill="#434343" d="m338.49512 374.66016q-0.609375 0 -1.171875 -0.140625q-0.546875 -0.15625 -0.96875 -0.421875q-0.25 -0.15625 -0.359375 -0.296875q-0.09375 -0.140625 -0.09375 -0.34375q0 -0.171875 0.09375 -0.28125q0.109375 -0.109375 0.265625 -0.109375q0.171875 0 0.46875 0.1875q0.40625 0.25 0.796875 0.390625q0.390625 0.140625 0.984375 0.140625q0.71875 0 1.109375 -0.25q0.40625 -0.265625 0.40625 -0.734375q0 -0.296875 -0.15625 -0.46875q-0.140625 -0.1875 -0.5 -0.328125q-0.359375 -0.140625 -1.046875 -0.296875q-1.171875 -0.25 -1.6875 -0.671875q-0.5 -0.421875 -0.5 -1.15625q0 -0.578125 0.3125 -1.015625q0.328125 -0.4375 0.890625 -0.6875q0.5625 -0.265625 1.28125 -0.265625q0.53125 0 1.015625 0.140625q0.484375 0.140625 0.859375 0.390625q0.453125 0.328125 0.453125 0.671875q0 0.171875 -0.109375 0.296875q-0.109375 0.125 -0.25 0.125q-0.15625 0 -0.484375 -0.234375q-0.375 -0.234375 -0.703125 -0.359375q-0.328125 -0.140625 -0.828125 -0.140625q-0.625 0 -1.015625 0.28125q-0.375 0.265625 -0.375 0.734375q0 0.296875 0.140625 0.484375q0.140625 0.171875 0.46875 0.3125q0.328125 0.140625 0.9375 0.28125q0.90625 0.1875 1.40625 0.4375q0.5 0.234375 0.703125 0.578125q0.21875 0.34375 0.21875 0.890625q0 0.828125 -0.703125 1.34375q-0.703125 0.515625 -1.859375 0.515625zm9.241241 -1.59375q0.140625 0 0.25 0.125q0.109375 0.109375 0.109375 0.296875q0 0.328125 -0.46875 0.609375q-0.484375 0.28125 -1.015625 0.421875q-0.53125 0.140625 -1.046875 0.140625q-1.5 0 -2.375 -0.890625q-0.875 -0.890625 -0.875 -2.46875q0 -1.0 0.390625 -1.765625q0.390625 -0.765625 1.078125 -1.1875q0.703125 -0.4375 1.59375 -0.4375q1.265625 0 2.015625 0.828125q0.75 0.828125 0.75 2.25q0 0.265625 -0.109375 0.390625q-0.109375 0.109375 -0.34375 0.109375l-4.296875 0q0.125 2.296875 2.171875 2.296875q0.53125 0 0.890625 -0.140625q0.375 -0.140625 0.8125 -0.390625q0.34375 -0.1875 0.46875 -0.1875zm-2.34375 -4.3125q-0.84375 0 -1.359375 0.53125q-0.515625 0.53125 -0.609375 1.515625l3.765625 0q-0.015625 -1.0 -0.484375 -1.515625q-0.46875 -0.53125 -1.3125 -0.53125zm7.5551147 -0.8125q0.546875 -0.03125 0.546875 0.453125q0 0.21875 -0.125 0.34375q-0.109375 0.125 -0.40625 0.15625l-0.390625 0.03125q-0.890625 0.078125 -1.328125 0.640625q-0.4375 0.546875 -0.4375 1.296875l0 3.234375q0 0.265625 -0.15625 0.40625q-0.140625 0.125 -0.375 0.125q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -5.625q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.21875 0 0.359375 0.140625q0.140625 0.140625 0.140625 0.375l0 0.75q0.28125 -0.578125 0.796875 -0.890625q0.515625 -0.3125 1.1875 -0.359375l0.1875 -0.015625zm6.157959 0.328125q0.15625 -0.3125 0.46875 -0.3125q0.203125 0 0.359375 0.140625q0.15625 0.125 0.15625 0.328125q0 0.109375 -0.046875 0.203125l-2.59375 5.609375q-0.078125 0.171875 -0.25 0.28125q-0.15625 0.09375 -0.34375 0.09375q-0.171875 0 -0.328125 -0.09375q-0.15625 -0.109375 -0.25 -0.28125l-2.59375 -5.609375q-0.046875 -0.09375 -0.046875 -0.1875q0 -0.203125 0.171875 -0.34375q0.1875 -0.15625 0.390625 -0.15625q0.140625 0 0.265625 0.078125q0.125 0.078125 0.1875 0.234375l2.234375 5.0l2.21875 -4.984375zm7.2099915 4.796875q0.140625 0 0.25 0.125q0.109375 0.109375 0.109375 0.296875q0 0.328125 -0.46875 0.609375q-0.484375 0.28125 -1.015625 0.421875q-0.53125 0.140625 -1.046875 0.140625q-1.5 0 -2.375 -0.890625q-0.875 -0.890625 -0.875 -2.46875q0 -1.0 0.390625 -1.765625q0.390625 -0.765625 1.078125 -1.1875q0.703125 -0.4375 1.59375 -0.4375q1.265625 0 2.015625 0.828125q0.75 0.828125 0.75 2.25q0 0.265625 -0.109375 0.390625q-0.109375 0.109375 -0.34375 0.109375l-4.296875 0q0.125 2.296875 2.171875 2.296875q0.53125 0 0.890625 -0.140625q0.375 -0.140625 0.8125 -0.390625q0.34375 -0.1875 0.46875 -0.1875zm-2.34375 -4.3125q-0.84375 0 -1.359375 0.53125q-0.515625 0.53125 -0.609375 1.515625l3.765625 0q-0.015625 -1.0 -0.484375 -1.515625q-0.46875 -0.53125 -1.3125 -0.53125zm7.5551453 -0.8125q0.546875 -0.03125 0.546875 0.453125q0 0.21875 -0.125 0.34375q-0.109375 0.125 -0.40625 0.15625l-0.390625 0.03125q-0.890625 0.078125 -1.328125 0.640625q-0.4375 0.546875 -0.4375 1.296875l0 3.234375q0 0.265625 -0.15625 0.40625q-0.140625 0.125 -0.375 0.125q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -5.625q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.21875 0 0.359375 0.140625q0.140625 0.140625 0.140625 0.375l0 0.75q0.28125 -0.578125 0.796875 -0.890625q0.515625 -0.3125 1.1875 -0.359375l0.1875 -0.015625z" fill-rule="nonzero"/><path fill="#d9d9d9" d="m25.624672 36.249344l301.88977 0l0 69.98425l-301.88977 0z" fill-rule="evenodd"/><path stroke="#cccccc" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" stroke-dasharray="4.0,3.0" d="m25.624672 36.249344l301.88977 0l0 69.98425l-301.88977 0z" fill-rule="evenodd"/><path fill="#434343" d="m134.36497 56.831844q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm9.004181 -1.421875q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.839676 -0.75q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm5.84729 6.0625q-0.56248474 0 -1.0624847 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.87498474 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0624847 -0.234375 -1.5156097 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.1562347 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.56248474 0 -0.90623474 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84373474 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125zm6.2131653 0q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm7.1288147 -5.25q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm1.970398 6.03125q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.5434265 0q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125zm4.721527 0.015625q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm12.222534 -4.9375q0.125 -0.28125 0.390625 -0.28125q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.078125 -0.03125 0.171875l-1.984375 5.046875q-0.078125 0.15625 -0.21875 0.25q-0.140625 0.078125 -0.296875 0.078125q-0.15625 0 -0.296875 -0.078125q-0.140625 -0.09375 -0.21875 -0.25l-1.65625 -4.21875l-1.640625 4.21875q-0.0625 0.15625 -0.203125 0.25q-0.140625 0.078125 -0.3125 0.078125q-0.15625 0 -0.296875 -0.078125q-0.140625 -0.09375 -0.21875 -0.25l-1.984375 -5.03125q-0.046875 -0.09375 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.171875 -0.140625 0.359375 -0.140625q0.296875 0 0.40625 0.296875l1.65625 4.421875l1.6875 -4.390625q0.078125 -0.15625 0.203125 -0.234375q0.125 -0.09375 0.265625 -0.09375q0.15625 0 0.28125 0.09375q0.125 0.078125 0.1875 0.234375l1.6875 4.375l1.65625 -4.40625zm12.637604 5.09375q0.046875 0.09375 0.046875 0.203125q0 0.171875 -0.140625 0.296875q-0.140625 0.125 -0.328125 0.125q-0.296875 0 -0.421875 -0.296875l-0.84375 -1.9375l-4.53125 0l-0.859375 1.9375q-0.125 0.296875 -0.421875 0.296875q-0.1875 0 -0.34375 -0.125q-0.140625 -0.125 -0.140625 -0.3125q0 -0.09375 0.046875 -0.1875l3.4375 -7.640625q0.078125 -0.15625 0.21875 -0.234375q0.140625 -0.09375 0.3125 -0.09375q0.171875 0 0.3125 0.09375q0.15625 0.078125 0.21875 0.234375l3.4375 7.640625zm-5.859375 -2.421875l3.8125 0l-1.90625 -4.3125l-1.90625 4.3125zm7.78656 3.046875q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm4.9744263 4.34375q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm4.4157715 0.015625q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125z" fill-rule="nonzero"/><path fill="#f3f3f3" d="m396.75067 183.75066l249.00787 0l0 203.02364l-249.00787 0z" fill-rule="evenodd"/><path stroke="#cccccc" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m396.75067 183.75066l249.00787 0l0 203.02364l-249.00787 0z" fill-rule="evenodd"/><path fill="#434343" d="m409.42255 374.66803q-0.90625 0 -1.609375 -0.40625q-0.6875 -0.421875 -1.078125 -1.171875q-0.375 -0.765625 -0.375 -1.765625q0 -1.0 0.390625 -1.765625q0.40625 -0.78125 1.109375 -1.203125q0.703125 -0.4375 1.625 -0.4375q0.5 0 1.0 0.140625q0.5 0.140625 0.875 0.40625q0.234375 0.171875 0.328125 0.328125q0.109375 0.140625 0.109375 0.328125q0 0.1875 -0.109375 0.3125q-0.09375 0.109375 -0.25 0.109375q-0.09375 0 -0.203125 -0.046875q-0.09375 -0.046875 -0.171875 -0.09375q-0.078125 -0.0625 -0.09375 -0.078125q-0.359375 -0.234375 -0.671875 -0.359375q-0.3125 -0.140625 -0.765625 -0.140625q-0.96875 0 -1.515625 0.671875q-0.53125 0.65625 -0.53125 1.828125q0 1.171875 0.53125 1.8125q0.546875 0.640625 1.515625 0.640625q0.453125 0 0.78125 -0.125q0.328125 -0.140625 0.65625 -0.375q0.15625 -0.09375 0.28125 -0.15625q0.140625 -0.0625 0.234375 -0.0625q0.140625 0 0.234375 0.125q0.109375 0.109375 0.109375 0.296875q0 0.171875 -0.09375 0.3125q-0.09375 0.140625 -0.34375 0.3125q-0.375 0.25 -0.90625 0.40625q-0.515625 0.15625 -1.0625 0.15625zm4.2591553 -0.03125q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -8.46875q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.21875 0 0.375 0.140625q0.15625 0.140625 0.15625 0.390625l0 8.46875q0 0.25 -0.15625 0.390625q-0.15625 0.140625 -0.375 0.140625zm3.092102 0q-0.234375 0 -0.390625 -0.140625q-0.15625 -0.140625 -0.15625 -0.390625l0 -5.625q0 -0.25 0.15625 -0.390625q0.15625 -0.140625 0.390625 -0.140625q0.234375 0 0.375 0.140625q0.15625 0.140625 0.15625 0.390625l0 5.625q0 0.265625 -0.15625 0.40625q-0.140625 0.125 -0.375 0.125zm0 -8.09375q-0.3125 0 -0.515625 -0.171875q-0.203125 -0.1875 -0.203125 -0.5q0 -0.296875 0.203125 -0.484375q0.203125 -0.1875 0.515625 -0.1875q0.328125 0 0.515625 0.1875q0.203125 0.1875 0.203125 0.484375q0 0.3125 -0.203125 0.5q-0.1875 0.171875 -0.515625 0.171875zm7.5765076 6.53125q0.140625 0 0.25 0.125q0.109375 0.109375 0.109375 0.296875q0 0.328125 -0.46875 0.609375q-0.484375 0.28125 -1.015625 0.421875q-0.53125 0.140625 -1.046875 0.140625q-1.5 0 -2.375 -0.890625q-0.875 -0.890625 -0.875 -2.46875q0 -1.0 0.390625 -1.765625q0.390625 -0.765625 1.078125 -1.1875q0.703125 -0.4375 1.59375 -0.4375q1.265625 0 2.015625 0.828125q0.75 0.828125 0.75 2.25q0 0.265625 -0.109375 0.390625q-0.109375 0.109375 -0.34375 0.109375l-4.296875 0q0.125 2.296875 2.171875 2.296875q0.53125 0 0.890625 -0.140625q0.375 -0.140625 0.8125 -0.390625q0.34375 -0.1875 0.46875 -0.1875zm-2.34375 -4.3125q-0.84375 0 -1.359375 0.53125q-0.515625 0.53125 -0.609375 1.515625l3.765625 0q-0.015625 -1.0 -0.484375 -1.515625q-0.46875 -0.53125 -1.3125 -0.53125zm7.6020203 -0.84375q2.328125 0 2.328125 2.578125l0 3.609375q0 0.25 -0.140625 0.390625q-0.140625 0.140625 -0.390625 0.140625q-0.25 0 -0.40625 -0.140625q-0.140625 -0.140625 -0.140625 -0.390625l0 -3.546875q0 -0.90625 -0.359375 -1.3125q-0.34375 -0.421875 -1.125 -0.421875q-0.890625 0 -1.421875 0.546875q-0.53125 0.546875 -0.53125 1.484375l0 3.25q0 0.25 -0.140625 0.390625q-0.140625 0.140625 -0.390625 0.140625q-0.25 0 -0.40625 -0.140625q-0.140625 -0.140625 -0.140625 -0.390625l0 -5.625q0 -0.234375 0.140625 -0.375q0.15625 -0.15625 0.40625 -0.15625q0.234375 0 0.375 0.15625q0.140625 0.140625 0.140625 0.359375l0 0.6875q0.328125 -0.609375 0.890625 -0.921875q0.578125 -0.3125 1.3125 -0.3125zm7.304718 5.875q0.46875 0.03125 0.46875 0.421875q0 0.21875 -0.171875 0.34375q-0.171875 0.109375 -0.5 0.078125l-0.359375 -0.015625q-1.0625 -0.09375 -1.578125 -0.640625q-0.5 -0.5625 -0.5 -1.703125l0 -3.34375l-0.890625 0q-0.234375 0 -0.359375 -0.109375q-0.125 -0.109375 -0.125 -0.296875q0 -0.203125 0.125 -0.3125q0.125 -0.125 0.359375 -0.125l0.890625 0l0 -1.515625q0 -0.25 0.140625 -0.390625q0.15625 -0.140625 0.40625 -0.140625q0.234375 0 0.375 0.140625q0.15625 0.140625 0.15625 0.390625l0 1.515625l1.484375 0q0.203125 0 0.328125 0.125q0.140625 0.109375 0.140625 0.3125q0 0.1875 -0.140625 0.296875q-0.125 0.109375 -0.328125 0.109375l-1.484375 0l0 3.40625q0 0.734375 0.296875 1.0625q0.296875 0.3125 0.90625 0.359375l0.359375 0.03125z" fill-rule="nonzero"/><path fill="#f4cccc" d="m206.61942 201.17455l140.47244 0l0 30.992126l-140.47244 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m206.61942 201.17455l140.47244 0l0 30.992126l-140.47244 0z" fill-rule="evenodd"/><path fill="#000000" d="m237.0857 213.5031q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm4.248535 1.71875q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.861023 4.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm8.417801 3.875q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm8.199051 4.46875q0.203125 0 0.296875 0.109375q0.109375 0.09375 0.109375 0.265625q0 0.1875 -0.109375 0.296875q-0.09375 0.09375 -0.296875 0.09375l-4.203125 0q-0.203125 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.3125q0 -0.1875 0.140625 -0.359375l3.546875 -4.28125l-3.28125 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l4.0625 0q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.3125q0 0.1875 -0.140625 0.359375l-3.5625 4.28125l3.421875 0zm6.2547913 -0.59375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm3.3865662 5.875q-0.171875 0 -0.28125 -0.09375q-0.109375 -0.09375 -0.109375 -0.21875q0 -0.140625 0.109375 -0.234375q0.109375 -0.09375 0.28125 -0.09375l5.21875 0q0.171875 0 0.28125 0.09375q0.109375 0.09375 0.109375 0.234375q0 0.125 -0.109375 0.21875q-0.109375 0.09375 -0.28125 0.09375l-5.21875 0zm11.2500305 -6.609375q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 5.09375q0 1.296875 -0.671875 1.96875q-0.671875 0.671875 -1.984375 0.671875q-1.28125 0 -2.140625 -0.515625q-0.421875 -0.234375 -0.421875 -0.546875q0 -0.171875 0.078125 -0.28125q0.09375 -0.109375 0.234375 -0.109375q0.125 0 0.4375 0.171875q0.421875 0.21875 0.828125 0.34375q0.40625 0.140625 0.96875 0.140625q0.859375 0 1.28125 -0.453125q0.4375 -0.453125 0.4375 -1.3125l0 -1.03125q-0.25 0.5625 -0.78125 0.859375q-0.515625 0.296875 -1.21875 0.296875q-0.765625 0 -1.359375 -0.359375q-0.59375 -0.359375 -0.9375 -1.015625q-0.328125 -0.65625 -0.328125 -1.515625q0 -0.875 0.328125 -1.53125q0.34375 -0.65625 0.9375 -1.015625q0.59375 -0.359375 1.359375 -0.359375q0.6875 0 1.203125 0.296875q0.515625 0.296875 0.78125 0.84375l0 -0.640625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625zm-2.28125 4.984375q0.84375 0 1.3125 -0.546875q0.484375 -0.5625 0.484375 -1.546875q0 -0.984375 -0.46875 -1.53125q-0.46875 -0.5625 -1.328125 -0.5625q-0.84375 0 -1.34375 0.5625q-0.484375 0.546875 -0.484375 1.53125q0 0.984375 0.484375 1.546875q0.5 0.546875 1.34375 0.546875zm7.4695435 -4.984375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.3131714 -5.296875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.20282 -5.265625q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm4.331665 6.046875q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm5.2167664 -6.046875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm8.45282 -4.9375q0.140625 -0.296875 0.421875 -0.296875q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.109375 -0.046875 0.1875l-3.375 7.28125q-0.0625 0.125 -0.171875 0.1875q-0.109375 0.078125 -0.234375 0.078125q-0.1875 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.296875q0 -0.09375 0.046875 -0.1875l0.84375 -1.8125l-2.375 -5.140625q-0.046875 -0.078125 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.15625 -0.140625 0.359375 -0.140625q0.109375 0 0.21875 0.078125q0.125 0.078125 0.1875 0.203125l2.0 4.5l2.0 -4.46875z" fill-rule="nonzero"/><path fill="#f4cccc" d="m132.49081 319.42978l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m132.49081 319.42978l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m163.01448 339.50836q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm8.160431 0.03125q-1.171875 0 -2.046875 -0.515625q-0.859375 -0.53125 -1.328125 -1.5q-0.46875 -0.984375 -0.46875 -2.296875q0 -1.34375 0.453125 -2.3125q0.46875 -0.984375 1.328125 -1.5q0.875 -0.53125 2.0625 -0.53125q1.1875 0 2.0625 0.53125q0.875 0.515625 1.328125 1.5q0.46875 0.96875 0.46875 2.296875q0 1.3125 -0.46875 2.296875q-0.46875 0.984375 -1.34375 1.515625q-0.859375 0.515625 -2.046875 0.515625zm0 -0.84375q1.34375 0 2.09375 -0.90625q0.75 -0.90625 0.75 -2.578125q0 -1.6875 -0.75 -2.578125q-0.734375 -0.90625 -2.09375 -0.90625q-1.34375 0 -2.09375 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.09375 0.90625zm9.214935 0.84375q-1.1875 0 -2.0625 -0.515625q-0.875 -0.53125 -1.359375 -1.5q-0.46875 -0.984375 -0.46875 -2.3125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.359375 -1.5q0.875 -0.53125 2.0625 -0.53125q0.8125 0 1.515625 0.265625q0.71875 0.25 1.25 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.21875 0.125q-0.15625 0 -0.359375 -0.140625q-0.609375 -0.46875 -1.109375 -0.65625q-0.5 -0.203125 -1.140625 -0.203125q-1.390625 0 -2.140625 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.140625 0.90625q0.640625 0 1.140625 -0.1875q0.5 -0.1875 1.109375 -0.671875q0.203125 -0.125 0.359375 -0.125q0.125 0 0.21875 0.125q0.09375 0.109375 0.09375 0.296875q0 0.234375 -0.1875 0.40625q-0.53125 0.484375 -1.25 0.75q-0.703125 0.25 -1.515625 0.25zm8.077179 0q-1.171875 0 -2.046875 -0.515625q-0.859375 -0.53125 -1.328125 -1.5q-0.46875 -0.984375 -0.46875 -2.296875q0 -1.34375 0.453125 -2.3125q0.46875 -0.984375 1.328125 -1.5q0.875 -0.53125 2.0625 -0.53125q1.1875 0 2.0625 0.53125q0.875 0.515625 1.328125 1.5q0.46875 0.96875 0.46875 2.296875q0 1.3125 -0.46875 2.296875q-0.46875 0.984375 -1.34375 1.515625q-0.859375 0.515625 -2.046875 0.515625zm0 -0.84375q1.34375 0 2.09375 -0.90625q0.75 -0.90625 0.75 -2.578125q0 -1.6875 -0.75 -2.578125q-0.734375 -0.90625 -2.09375 -0.90625q-1.34375 0 -2.09375 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.09375 0.90625z" fill-rule="nonzero"/><path fill="#d9ead3" d="m284.12296 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m284.12296 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m314.7006 332.47687q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm5.113556 0q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.6840515 -0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -7.5625q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.171875l3.875 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-4.375 0zm6.3394165 0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm4.987152 6.515625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.9081726 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#000000" d="m303.37402 346.47687q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.5434265 0q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125zm4.674652 -6.046875q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.0631714 -0.015625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm4.3300476 -5.28125q0.765625 0 1.34375 0.375q0.59375 0.359375 0.921875 1.046875q0.328125 0.6875 0.328125 1.59375q0 0.90625 -0.328125 1.59375q-0.328125 0.6875 -0.921875 1.078125q-0.578125 0.375 -1.34375 0.375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 0.640625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.203125q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.59375q0.46875 -0.59375 0.46875 -1.65625q0 -1.046875 -0.46875 -1.625q-0.46875 -0.578125 -1.328125 -0.578125q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm8.687164 -5.25q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 5.078125q0 0.203125 -0.125 0.34375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -0.609375q-0.28125 0.53125 -0.78125 0.8125q-0.5 0.265625 -1.125 0.265625q-1.03125 0 -1.5625 -0.578125q-0.53125 -0.578125 -0.53125 -1.71875l0 -3.265625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 3.234375q0 0.78125 0.3125 1.15625q0.3125 0.359375 0.984375 0.359375q0.765625 0 1.234375 -0.5q0.46875 -0.5 0.46875 -1.3125l0 -2.9375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625zm4.8726807 -1.71875q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm3.9360352 0q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm5.873535 6.328125q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.7927856 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625z" fill-rule="nonzero"/><path fill="#f4cccc" d="m413.02625 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m413.02625 319.3983l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m443.6039 332.47687q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm5.113556 0q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm6.6840515 -0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -7.5625q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.171875l3.875 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-4.375 0zm6.3394165 0.0625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm4.987152 6.515625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.908142 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#000000" d="m429.9527 346.47687q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm5.237152 1.234375q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm6.56604 5.28125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.9081726 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.7927856 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm4.282898 -0.015625q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.14032 -5.25q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.861023 4.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.5896606 4.53125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.9081726 -0.65625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.7927856 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m371.61902 334.89435l41.417297 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m371.61902 334.89435l37.990234 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m409.60925 334.89435l-1.1245728 1.1246033l3.0897522 -1.1246033l-3.0897522 -1.1245728z" fill-rule="evenodd"/><path fill="#c9daf8" d="m548.5407 277.52954l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 277.52954l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m587.0588 293.13934q0.1875 0 0.296875 0.109375q0.109375 0.109375 0.109375 0.296875l0 2.984375q0 0.296875 -0.09375 0.4375q-0.078125 0.140625 -0.328125 0.234375q-0.46875 0.203125 -1.15625 0.328125q-0.6875 0.109375 -1.375 0.109375q-1.25 0 -2.171875 -0.515625q-0.90625 -0.515625 -1.390625 -1.484375q-0.484375 -0.96875 -0.484375 -2.328125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.375 -1.5q0.90625 -0.53125 2.125 -0.53125q0.84375 0 1.5625 0.265625q0.71875 0.25 1.203125 0.734375q0.21875 0.203125 0.21875 0.421875q0 0.171875 -0.109375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.140625 0 -0.328125 -0.140625q-0.625 -0.484375 -1.140625 -0.671875q-0.5 -0.1875 -1.15625 -0.1875q-1.4375 0 -2.203125 0.90625q-0.75 0.890625 -0.75 2.578125q0 1.71875 0.765625 2.609375q0.78125 0.890625 2.28125 0.890625q1.109375 0 2.03125 -0.328125l0 -2.578125l-1.75 0q-0.203125 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.265625q0 -0.1875 0.125 -0.28125q0.125 -0.109375 0.328125 -0.109375l2.234375 0zm2.8911743 4.46875q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm7.7869263 4.375q-1.65625 0 -2.515625 -0.859375q-0.84375 -0.859375 -0.84375 -2.546875l0 -4.703125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.78125q0 1.25 0.609375 1.875q0.609375 0.609375 1.78125 0.609375q1.171875 0 1.765625 -0.609375q0.609375 -0.625 0.609375 -1.875l0 -4.78125q0 -0.234375 0.140625 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.703125q0 1.671875 -0.859375 2.546875q-0.859375 0.859375 -2.5 0.859375z" fill-rule="nonzero"/><path fill="#c9daf8" d="m548.5407 319.3983l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 319.3983l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m584.63763 339.50812q-1.1875 0 -2.0625 -0.515625q-0.875 -0.53125 -1.359375 -1.5q-0.46875 -0.984375 -0.46875 -2.3125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.359375 -1.5q0.875 -0.53125 2.0625 -0.53125q0.8125 0 1.515625 0.265625q0.71875 0.25 1.25 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.21875 0.125q-0.15625 0 -0.359375 -0.140625q-0.609375 -0.46875 -1.109375 -0.65625q-0.5 -0.203125 -1.140625 -0.203125q-1.390625 0 -2.140625 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.140625 0.90625q0.640625 0 1.140625 -0.1875q0.5 -0.1875 1.109375 -0.671875q0.203125 -0.125 0.359375 -0.125q0.125 0 0.21875 0.125q0.09375 0.109375 0.09375 0.296875q0 0.234375 -0.1875 0.40625q-0.53125 0.484375 -1.25 0.75q-0.703125 0.25 -1.515625 0.25zm5.0302734 -0.03125q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm7.7869263 4.375q-1.65625 0 -2.515625 -0.859375q-0.84375 -0.859375 -0.84375 -2.546875l0 -4.703125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.78125q0 1.25 0.609375 1.875q0.609375 0.609375 1.78125 0.609375q1.171875 0 1.765625 -0.609375q0.609375 -0.625 0.609375 -1.875l0 -4.78125q0 -0.234375 0.140625 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 4.703125q0 1.671875 -0.859375 2.546875q-0.859375 0.859375 -2.5 0.859375z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m219.98688 334.92584l64.12598 -0.03149414" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m219.98688 334.92584l60.698914 -0.029815674" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m280.68576 334.89603l-1.1240234 1.1251526l3.0892334 -1.1260986l-3.090332 -1.1230774z" fill-rule="evenodd"/><path fill="#d9ead3" d="m413.02625 141.28871l20.53543 0l0 20.53543l-20.53543 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m413.02625 141.28871l20.53543 0l0 20.53543l-20.53543 0z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m437.52493 135.68242l73.763794 0l0 31.748032l-73.763794 0z" fill-rule="evenodd"/><path fill="#000000" d="m448.0718 156.20241q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm8.3211975 -5.140625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.0631714 -0.015625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.767517 -5.28125q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm10.15921 0.75q-0.234375 0 -0.375 -0.140625q-0.140625 -0.140625 -0.140625 -0.359375l0 -7.1875l-2.578125 0q-0.21875 0 -0.34375 -0.109375q-0.109375 -0.109375 -0.109375 -0.3125q0 -0.203125 0.109375 -0.296875q0.125 -0.109375 0.34375 -0.109375l6.15625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.109375 0.109375 -0.328125 0.109375l-2.578125 0l0 7.1875q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625zm8.691681 -5.71875q0.140625 -0.296875 0.421875 -0.296875q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.109375 -0.046875 0.1875l-3.375 7.28125q-0.0625 0.125 -0.171875 0.1875q-0.109375 0.078125 -0.234375 0.078125q-0.1875 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.296875q0 -0.09375 0.046875 -0.1875l0.84375 -1.8125l-2.375 -5.140625q-0.046875 -0.078125 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.15625 -0.140625 0.359375 -0.140625q0.109375 0 0.21875 0.078125q0.125 0.078125 0.1875 0.203125l2.0 4.5l2.0 -4.46875zm4.902405 -0.328125q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm8.76532 -0.640625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#f4cccc" d="m519.9029 141.28871l20.5354 0l0 20.53543l-20.5354 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m519.9029 141.28871l20.5354 0l0 20.53543l-20.5354 0z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m544.40155 135.68242l100.0 0l0 31.748032l-100.0 0z" fill-rule="evenodd"/><path fill="#000000" d="m554.9328 156.26491q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm5.3845215 -6.046875q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm6.456726 -1.703125q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm4.248535 1.71875q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm6.3444214 0.765625q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125zm6.47876 -0.78125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm4.283142 -5.265625q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.782898 0q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 5.078125q0 0.203125 -0.125 0.34375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -0.609375q-0.28125 0.53125 -0.78125 0.8125q-0.5 0.265625 -1.125 0.265625q-1.03125 0 -1.5625 -0.578125q-0.53125 -0.578125 -0.53125 -1.71875l0 -3.265625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 3.234375q0 0.78125 0.3125 1.15625q0.3125 0.359375 0.984375 0.359375q0.765625 0 1.234375 -0.5q0.46875 -0.5 0.46875 -1.3125l0 -2.9375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625zm4.7008057 6.046875q-0.8125 0 -1.453125 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.6875 -0.34375 -1.578125q0 -0.90625 0.359375 -1.59375q0.359375 -0.703125 0.984375 -1.078125q0.640625 -0.390625 1.46875 -0.390625q0.453125 0 0.90625 0.125q0.453125 0.125 0.78125 0.359375q0.21875 0.140625 0.3125 0.28125q0.09375 0.140625 0.09375 0.3125q0 0.171875 -0.09375 0.28125q-0.09375 0.09375 -0.234375 0.09375q-0.078125 0 -0.1875 -0.046875q-0.09375 -0.046875 -0.15625 -0.09375q-0.0625 -0.046875 -0.09375 -0.0625q-0.3125 -0.203125 -0.59375 -0.3125q-0.28125 -0.125 -0.6875 -0.125q-0.875 0 -1.359375 0.59375q-0.484375 0.59375 -0.484375 1.65625q0 1.046875 0.484375 1.625q0.484375 0.578125 1.359375 0.578125q0.40625 0 0.703125 -0.109375q0.296875 -0.125 0.59375 -0.328125q0.140625 -0.09375 0.25 -0.15625q0.125 -0.0625 0.203125 -0.0625q0.140625 0 0.21875 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.15625 -0.09375 0.28125q-0.078125 0.125 -0.296875 0.28125q-0.34375 0.234375 -0.8125 0.375q-0.46875 0.125 -0.953125 0.125zm6.029297 -0.78125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm5.830017 -5.265625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 5.078125q0 0.203125 -0.125 0.34375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.328125l0 -0.609375q-0.28125 0.53125 -0.78125 0.8125q-0.5 0.265625 -1.125 0.265625q-1.03125 0 -1.5625 -0.578125q-0.53125 -0.578125 -0.53125 -1.71875l0 -3.265625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.34375l0 3.234375q0 0.78125 0.3125 1.15625q0.3125 0.359375 0.984375 0.359375q0.765625 0 1.234375 -0.5q0.46875 -0.5 0.46875 -1.3125l0 -2.9375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625zm5.1851807 0q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm5.861023 4.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375z" fill-rule="nonzero"/><path fill="#d9ead3" d="m31.874912 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m31.874912 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m67.27695 264.03653q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.359375l0 7.578125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.234375 0 -0.375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -3.4375l-5.062496 0l0 3.4375q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.234375 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.359375l0 3.296875l5.062496 0l0 -3.296875q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.375 -0.140625zm3.0648193 8.515625q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm6.5711823 0.90625q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm9.0746765 -5.359375q0.8125 0 1.40625 0.34375q0.609375 0.328125 0.9375 0.9375q0.328125 0.59375 0.328125 1.390625q0 0.78125 -0.359375 1.40625q-0.359375 0.625 -1.0 0.96875q-0.640625 0.328125 -1.484375 0.328125q-0.734375 0 -1.453125 -0.25q-0.703125 -0.265625 -1.1875 -0.734375q-0.203125 -0.171875 -0.203125 -0.40625q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.234375 -0.125q0.171875 0 0.34375 0.140625q0.515625 0.4375 1.046875 0.640625q0.53125 0.203125 1.109375 0.203125q0.890625 0 1.390625 -0.5q0.5 -0.5 0.5 -1.359375q0 -0.84375 -0.5 -1.359375q-0.5 -0.515625 -1.359375 -0.515625q-1.09375 0 -1.78125 0.84375q-0.15625 0.171875 -0.40625 0.171875q-0.15625 0 -0.28125 -0.09375q-0.109375 -0.109375 -0.109375 -0.296875l0 -4.125q0 -0.21875 0.125 -0.34375q0.125 -0.125 0.359375 -0.125l4.21875 0q0.21875 0 0.34375 0.109375q0.125 0.09375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.125 0.109375 -0.34375 0.109375l-3.734375 0l0 3.015625q0.34375 -0.328125 0.78125 -0.5q0.453125 -0.171875 0.984375 -0.171875z" fill-rule="nonzero"/><path fill="#d9ead3" d="m190.14 134.76706l87.49608 0l0 30.992126l-87.49608 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m190.14 134.76706l87.49608 0l0 30.992126l-87.49608 0z" fill-rule="evenodd"/><path fill="#000000" d="m215.10997 150.37688q0.1875 0 0.296875 0.109375q0.109375 0.109375 0.109375 0.296875l0 2.984375q0 0.296875 -0.09375 0.4375q-0.078125 0.140625 -0.328125 0.234375q-0.46875 0.203125 -1.15625 0.328125q-0.6875 0.109375 -1.375 0.109375q-1.25 0 -2.171875 -0.515625q-0.90625 -0.515625 -1.390625 -1.484375q-0.484375 -0.96875 -0.484375 -2.328125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.375 -1.5q0.90625 -0.53125 2.125 -0.53125q0.84375 0 1.5625 0.265625q0.71875 0.25 1.203125 0.734375q0.21875 0.203125 0.21875 0.421875q0 0.171875 -0.109375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.140625 0 -0.328125 -0.140625q-0.625 -0.484375 -1.140625 -0.671875q-0.5 -0.1875 -1.15625 -0.1875q-1.4375 0 -2.203125 0.90625q-0.75 0.890625 -0.75 2.578125q0 1.71875 0.765625 2.609375q0.78125 0.890625 2.28125 0.890625q1.109375 0 2.03125 -0.328125l0 -2.578125l-1.75 0q-0.203125 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.265625q0 -0.1875 0.125 -0.28125q0.125 -0.109375 0.328125 -0.109375l2.234375 0zm5.1568146 -1.5625q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.3131714 -5.296875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.2028046 -5.265625q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm4.5035553 5.984375q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm10.461807 -0.515625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.480301 -2.453125q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125z" fill-rule="nonzero"/><path fill="#d9ead3" d="m233.1085 252.53609l87.49608 0l0 30.992142l-87.49608 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m233.1085 252.53609l87.49608 0l0 30.992142l-87.49608 0z" fill-rule="evenodd"/><path fill="#000000" d="m260.00964 265.61465q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.375 0q0.203125 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.203125 -0.125 0.3125q-0.125 0.109375 -0.328125 0.109375l-3.90625 0l0 2.90625l3.65625 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.3125q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.65625 0l0 3.453125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625zm8.9496765 -6.03125q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.767273 6.046875q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm8.535065 -0.046875q0.203125 0 0.296875 0.109375q0.109375 0.09375 0.109375 0.265625q0 0.1875 -0.109375 0.296875q-0.09375 0.09375 -0.296875 0.09375l-4.203125 0q-0.203125 0 -0.34375 -0.125q-0.125 -0.125 -0.125 -0.3125q0 -0.1875 0.140625 -0.359375l3.546875 -4.28125l-3.28125 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l4.0625 0q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.3125q0 0.1875 -0.140625 0.359375l-3.5625 4.28125l3.421875 0zm6.2547913 -0.59375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.8396606 -0.75q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125z" fill-rule="nonzero"/><path fill="#000000" d="m258.07846 275.1459q0.1875 0 0.296875 0.109375q0.109375 0.109375 0.109375 0.296875l0 2.984375q0 0.296875 -0.09375 0.4375q-0.078125 0.140625 -0.328125 0.234375q-0.46875 0.203125 -1.15625 0.328125q-0.6875 0.109375 -1.3749847 0.109375q-1.25 0 -2.171875 -0.515625q-0.90625 -0.515625 -1.390625 -1.484375q-0.484375 -0.96875 -0.484375 -2.328125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.375 -1.5q0.90625 -0.53125 2.125 -0.53125q0.84373474 0 1.5624847 0.265625q0.71875 0.25 1.203125 0.734375q0.21875 0.203125 0.21875 0.421875q0 0.171875 -0.109375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.140625 0 -0.328125 -0.140625q-0.625 -0.484375 -1.140625 -0.671875q-0.5 -0.1875 -1.1562347 -0.1875q-1.4375 0 -2.203125 0.90625q-0.75 0.890625 -0.75 2.578125q0 1.71875 0.765625 2.609375q0.78125 0.890625 2.28125 0.890625q1.1093597 0 2.0312347 -0.328125l0 -2.578125l-1.7499847 0q-0.203125 0 -0.328125 -0.109375q-0.125 -0.109375 -0.125 -0.265625q0 -0.1875 0.125 -0.28125q0.125 -0.109375 0.328125 -0.109375l2.2343597 0zm5.15683 -1.5625q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.3131714 -5.296875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm7.2027893 -5.265625q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm4.5035706 5.984375q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.34375 0q2.03125 0 3.140625 1.09375q1.109375 1.09375 1.109375 3.125q0 2.03125 -1.125 3.140625q-1.109375 1.09375 -3.125 1.09375l-2.34375 0zm2.28125 -0.84375q3.28125 0 3.28125 -3.390625q0 -3.390625 -3.28125 -3.390625l-1.796875 0l0 6.78125l1.796875 0zm10.461792 -0.515625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.480316 -2.453125q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m276.85565 232.16667l0 20.377945" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85565 232.16667l0 16.950867" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m276.85565 249.11754l-1.1246033 -1.124588l1.1246033 3.0897675l1.1245728 -3.0897675z" fill-rule="evenodd"/><path fill="#f4cccc" d="m31.874016 68.3563l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m31.874016 68.3563l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m58.725647 87.669235q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.9706573 -6.984375q-0.640625 0.046875 -0.96875 0.40625q-0.3125 0.34375 -0.3125 1.046875l0 0.390625l1.328125 0q0.203125 0 0.3125 0.109375q0.109375 0.109375 0.109375 0.28125q0 0.1875 -0.109375 0.28125q-0.109375 0.09375 -0.3125 0.09375l-1.328125 0l0 4.65625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -4.65625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -0.21875q0 -1.078125 0.53125 -1.6875q0.546875 -0.625 1.5625 -0.703125l0.3125 -0.015625q0.3125 -0.03125 0.453125 0.0625q0.140625 0.078125 0.140625 0.296875q0 0.34375 -0.421875 0.390625l-0.3125 0.03125zm1.8266602 7.75q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm8.498016 -0.8125q0.171875 0.15625 0.171875 0.359375q0 0.15625 -0.140625 0.296875q-0.140625 0.140625 -0.3125 0.140625q-0.15625 0 -0.328125 -0.140625l-4.484375 -3.921875l0 3.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 3.4375l4.28125 -3.796875q0.125 -0.140625 0.3125 -0.140625q0.171875 0 0.296875 0.140625q0.140625 0.140625 0.140625 0.3125q0 0.171875 -0.15625 0.328125l-3.875 3.421875l4.09375 3.5625zm5.8329315 -0.609375q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.792801 -0.734375q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625zm3.720398 -0.015625q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm6.3444214 0.765625q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125z" fill-rule="nonzero"/><path fill="#f4cccc" d="m132.49081 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m132.49081 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m152.20152 88.37367q-0.234375 0 -0.375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -7.5q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l4.484375 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-4.015625 0l0 2.9375l3.78125 0q0.21875 0 0.328125 0.109375q0.125 0.109375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-3.78125 0l0 3.078125l4.015625 0q0.21875 0 0.328125 0.109375q0.125 0.09375 0.125 0.296875q0 0.1875 -0.125 0.296875q-0.109375 0.109375 -0.328125 0.109375l-4.484375 0zm8.31218 0.078125q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125zm6.4787903 -0.78125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm1.8769073 0.765625q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm8.799652 1.234375q1.9375 0 1.9375 2.3125l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.328125 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.890625 -0.359375q-0.734375 0 -1.15625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.90625 -0.359375q-0.71875 0 -1.140625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.359375 -0.140625q0.203125 0 0.328125 0.125q0.140625 0.125 0.140625 0.34375l0 0.578125q0.265625 -0.515625 0.734375 -0.78125q0.46875 -0.28125 1.078125 -0.28125q1.375 0 1.78125 1.140625q0.265625 -0.515625 0.78125 -0.828125q0.515625 -0.3125 1.171875 -0.3125zm6.0990753 0q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm7.0631714 -0.015625q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.8144073 0.78125q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm7.1287994 -5.25q0.5 -0.03125 0.5 0.40625q0 0.203125 -0.109375 0.3125q-0.109375 0.109375 -0.375 0.140625l-0.359375 0.03125q-0.796875 0.078125 -1.1875 0.578125q-0.390625 0.484375 -0.390625 1.15625l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.140625 -0.359375q0.140625 -0.125 0.34375 -0.125q0.1875 0 0.3125 0.125q0.140625 0.125 0.140625 0.34375l0 0.671875q0.25 -0.53125 0.71875 -0.796875q0.46875 -0.28125 1.0625 -0.328125l0.171875 -0.015625z" fill-rule="nonzero"/><path fill="#f4cccc" d="m233.1076 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m233.1076 68.35761l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m269.00754 88.46742q-0.90625 0 -1.734375 -0.265625q-0.8125 -0.265625 -1.3125 -0.734375q-0.171875 -0.15625 -0.171875 -0.40625q0 -0.171875 0.09375 -0.296875q0.09375 -0.125 0.234375 -0.125q0.15625 0 0.328125 0.125q1.109375 0.859375 2.546875 0.859375q1.03125 0 1.578125 -0.390625q0.5625 -0.390625 0.5625 -1.125q0 -0.421875 -0.265625 -0.671875q-0.265625 -0.265625 -0.703125 -0.421875q-0.4375 -0.15625 -1.15625 -0.328125q-0.984375 -0.21875 -1.625 -0.46875q-0.625 -0.265625 -1.015625 -0.734375q-0.390625 -0.46875 -0.390625 -1.21875q0 -0.71875 0.390625 -1.265625q0.390625 -0.5625 1.09375 -0.875q0.703125 -0.3125 1.59375 -0.3125q0.84375 0 1.5625 0.265625q0.734375 0.25 1.234375 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.125 0 -0.34375 -0.140625q-0.59375 -0.46875 -1.09375 -0.65625q-0.5 -0.203125 -1.21875 -0.203125q-0.984375 0 -1.546875 0.421875q-0.546875 0.40625 -0.546875 1.15625q0 0.625 0.484375 0.953125q0.484375 0.3125 1.5 0.5625q1.09375 0.25 1.71875 0.484375q0.625 0.21875 1.03125 0.671875q0.421875 0.4375 0.421875 1.171875q0 0.71875 -0.390625 1.265625q-0.390625 0.53125 -1.109375 0.828125q-0.703125 0.296875 -1.609375 0.296875zm5.0446777 -0.03125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125zm2.784027 0q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm8.799652 1.234375q1.9375 0 1.9375 2.3125l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.328125 0.125q-0.21875 0 -0.359375 -0.125q-0.140625 -0.125 -0.140625 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.890625 -0.359375q-0.734375 0 -1.15625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.8125 -0.296875 -1.171875q-0.28125 -0.359375 -0.90625 -0.359375q-0.71875 0 -1.140625 0.5q-0.421875 0.484375 -0.421875 1.328125l0 2.921875q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.359375 -0.140625q0.203125 0 0.328125 0.125q0.140625 0.125 0.140625 0.34375l0 0.578125q0.265625 -0.515625 0.734375 -0.78125q0.46875 -0.28125 1.078125 -0.28125q1.375 0 1.78125 1.140625q0.265625 -0.515625 0.78125 -0.828125q0.515625 -0.3125 1.171875 -0.3125z" fill-rule="nonzero"/><path fill="#d9ead3" d="m282.5035 134.76706l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m282.5035 134.76706l87.49606 0l0 30.992126l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m297.8283 154.87688q-1.1875 0 -2.0625 -0.515625q-0.875 -0.53125 -1.359375 -1.5q-0.46875 -0.984375 -0.46875 -2.3125q0 -1.328125 0.46875 -2.296875q0.484375 -0.984375 1.359375 -1.5q0.875 -0.53125 2.0625 -0.53125q0.8125 0 1.515625 0.265625q0.71875 0.25 1.25 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.21875 0.125q-0.15625 0 -0.359375 -0.140625q-0.609375 -0.46875 -1.109375 -0.65625q-0.5 -0.203125 -1.140625 -0.203125q-1.390625 0 -2.140625 0.90625q-0.75 0.90625 -0.75 2.578125q0 1.671875 0.75 2.578125q0.75 0.90625 2.140625 0.90625q0.640625 0 1.140625 -0.1875q0.5 -0.1875 1.109375 -0.671875q0.203125 -0.125 0.359375 -0.125q0.125 0 0.21875 0.125q0.09375 0.109375 0.09375 0.296875q0 0.234375 -0.1875 0.40625q-0.53125 0.484375 -1.25 0.75q-0.703125 0.25 -1.515625 0.25zm7.358429 -6.078125q1.03125 0 1.546875 0.578125q0.53125 0.578125 0.53125 1.734375l0 3.25q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.21875q0 -0.78125 -0.328125 -1.15625q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.203125 0.125 -0.328125q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.125q0.125 0.125 0.125 0.34375l0 3.140625q0.28125 -0.53125 0.796875 -0.796875q0.515625 -0.28125 1.1875 -0.28125zm8.37854 4.625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm6.308441 5.3125q-0.8125 0 -1.453125 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.6875 -0.34375 -1.578125q0 -0.90625 0.359375 -1.59375q0.359375 -0.703125 0.984375 -1.078125q0.640625 -0.390625 1.46875 -0.390625q0.453125 0 0.90625 0.125q0.453125 0.125 0.78125 0.359375q0.21875 0.140625 0.3125 0.28125q0.09375 0.140625 0.09375 0.3125q0 0.171875 -0.09375 0.28125q-0.09375 0.09375 -0.234375 0.09375q-0.078125 0 -0.1875 -0.046875q-0.09375 -0.046875 -0.15625 -0.09375q-0.0625 -0.046875 -0.09375 -0.0625q-0.3125 -0.203125 -0.59375 -0.3125q-0.28125 -0.125 -0.6875 -0.125q-0.875 0 -1.359375 0.59375q-0.484375 0.59375 -0.484375 1.65625q0 1.046875 0.484375 1.625q0.484375 0.578125 1.359375 0.578125q0.40625 0 0.703125 -0.109375q0.296875 -0.125 0.59375 -0.328125q0.140625 -0.09375 0.25 -0.15625q0.125 -0.0625 0.203125 -0.0625q0.140625 0 0.21875 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.15625 -0.09375 0.28125q-0.078125 0.125 -0.296875 0.28125q-0.34375 0.234375 -0.8125 0.375q-0.46875 0.125 -0.953125 0.125zm7.998047 -0.84375q0.203125 0.171875 0.203125 0.375q0 0.1875 -0.125 0.328125q-0.125 0.125 -0.3125 0.125q-0.15625 0 -0.328125 -0.140625l-3.125 -2.703125l0 2.359375q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 4.875l2.859375 -2.625q0.15625 -0.140625 0.328125 -0.140625q0.1875 0 0.3125 0.140625q0.140625 0.125 0.140625 0.296875q0 0.203125 -0.171875 0.359375l-2.375 2.109375l2.59375 2.265625zm4.2812805 -5.21875q0.765625 0 1.34375 0.390625q0.59375 0.375 0.921875 1.0625q0.328125 0.6875 0.328125 1.609375q0 0.90625 -0.328125 1.59375q-0.328125 0.671875 -0.90625 1.046875q-0.578125 0.359375 -1.359375 0.359375q-0.6875 0 -1.203125 -0.296875q-0.5 -0.296875 -0.765625 -0.84375l0 2.8125q0 0.21875 -0.125 0.34375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.140625q-0.125 -0.125 -0.125 -0.328125l0 -7.234375q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.125 0.125 0.125 0.34375l0 0.640625q0.265625 -0.546875 0.765625 -0.84375q0.515625 -0.296875 1.203125 -0.296875zm-0.203125 5.265625q0.859375 0 1.328125 -0.578125q0.46875 -0.578125 0.46875 -1.625q0 -1.0625 -0.46875 -1.65625q-0.46875 -0.59375 -1.328125 -0.59375q-0.84375 0 -1.3125 0.578125q-0.453125 0.578125 -0.453125 1.640625q0 1.0625 0.453125 1.65625q0.46875 0.578125 1.3125 0.578125zm6.67157 0.796875q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm4.722534 0.78125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.234375 0.125 -0.359375q0.140625 -0.125 0.359375 -0.125q0.21875 0 0.34375 0.125q0.140625 0.125 0.140625 0.359375l0 5.0625q0 0.234375 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125zm0 -7.28125q-0.296875 0 -0.484375 -0.171875q-0.171875 -0.171875 -0.171875 -0.453125q0 -0.25 0.171875 -0.421875q0.1875 -0.171875 0.484375 -0.171875q0.28125 0 0.453125 0.171875q0.1875 0.171875 0.1875 0.421875q0 0.28125 -0.1875 0.453125q-0.171875 0.171875 -0.453125 0.171875zm5.237152 1.234375q2.09375 0 2.09375 2.3125l0 3.25q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -3.1875q0 -0.8125 -0.328125 -1.1875q-0.3125 -0.375 -1.0 -0.375q-0.8125 0 -1.296875 0.5q-0.46875 0.484375 -0.46875 1.328125l0 2.921875q0 0.234375 -0.125 0.359375q-0.125 0.125 -0.359375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -5.0625q0 -0.21875 0.125 -0.34375q0.125 -0.140625 0.359375 -0.140625q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.328125l0 0.609375q0.28125 -0.53125 0.796875 -0.8125q0.53125 -0.28125 1.1875 -0.28125zm6.5660706 5.28125q0.421875 0.03125 0.421875 0.375q0 0.203125 -0.15625 0.3125q-0.140625 0.09375 -0.4375 0.078125l-0.328125 -0.03125q-0.953125 -0.0625 -1.421875 -0.5625q-0.453125 -0.515625 -0.453125 -1.53125l0 -3.015625l-0.796875 0q-0.203125 0 -0.328125 -0.09375q-0.109375 -0.109375 -0.109375 -0.28125q0 -0.171875 0.109375 -0.28125q0.125 -0.109375 0.328125 -0.109375l0.796875 0l0 -1.359375q0 -0.21875 0.125 -0.34375q0.140625 -0.140625 0.375 -0.140625q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.34375l0 1.359375l1.328125 0q0.1875 0 0.296875 0.109375q0.125 0.109375 0.125 0.28125q0 0.171875 -0.125 0.28125q-0.109375 0.09375 -0.296875 0.09375l-1.328125 0l0 3.0625q0 0.65625 0.265625 0.953125q0.265625 0.296875 0.8125 0.328125l0.3125 0.03125zm3.361267 0.78125q-0.5625 0 -1.0625 -0.125q-0.5 -0.140625 -0.875 -0.375q-0.21875 -0.140625 -0.3125 -0.265625q-0.078125 -0.125 -0.078125 -0.3125q0 -0.15625 0.078125 -0.25q0.09375 -0.109375 0.234375 -0.109375q0.15625 0 0.421875 0.1875q0.359375 0.21875 0.71875 0.34375q0.359375 0.125 0.875 0.125q0.65625 0 1.015625 -0.21875q0.359375 -0.234375 0.359375 -0.671875q0 -0.265625 -0.140625 -0.421875q-0.125 -0.171875 -0.453125 -0.296875q-0.3125 -0.125 -0.9375 -0.25q-1.0625 -0.234375 -1.515625 -0.609375q-0.453125 -0.390625 -0.453125 -1.046875q0 -0.515625 0.28125 -0.90625q0.28125 -0.40625 0.796875 -0.625q0.515625 -0.234375 1.15625 -0.234375q0.46875 0 0.90625 0.125q0.4375 0.125 0.78125 0.34375q0.40625 0.296875 0.40625 0.609375q0 0.15625 -0.09375 0.265625q-0.09375 0.109375 -0.234375 0.109375q-0.140625 0 -0.4375 -0.203125q-0.328125 -0.21875 -0.625 -0.34375q-0.296875 -0.125 -0.75 -0.125q-0.5625 0 -0.90625 0.265625q-0.34375 0.25 -0.34375 0.671875q0 0.25 0.125 0.421875q0.125 0.15625 0.421875 0.28125q0.296875 0.125 0.84375 0.25q0.828125 0.1875 1.265625 0.40625q0.453125 0.203125 0.640625 0.515625q0.203125 0.3125 0.203125 0.796875q0 0.75 -0.640625 1.21875q-0.640625 0.453125 -1.671875 0.453125z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m276.85565 99.34974l0 17.70874l-42.960632 0l0 17.724327" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85565 99.34974l0 17.70874l-42.960632 0l0 14.297249" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m233.89502 131.35573l-1.124588 -1.124588l1.124588 3.0897675l1.1245728 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m276.85565 99.34974l0 17.70874l49.385803 0l0 17.724327" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85565 99.34974l0 17.70874l49.385803 0l0 14.297249" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m326.24146 131.35573l-1.1245728 -1.124588l1.1245728 3.0897675l1.1246033 -3.0897675z" fill-rule="evenodd"/><path fill="#c9daf8" d="m548.5407 235.66077l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 235.66077l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m579.47955 247.1612q0.203125 0 0.328125 0.140625q0.125 0.125 0.125 0.359375l0 7.578125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625q-0.234375 0 -0.390625 -0.203125l-4.984375 -6.65625l0 6.359375q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.21875 0 -0.34375 -0.140625q-0.109375 -0.140625 -0.109375 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.40625 0.203125l4.96875 6.65625l0 -6.359375q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.34375 -0.140625zm8.868103 0q0.203125 0 0.328125 0.140625q0.125 0.125 0.125 0.359375l0 7.578125q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.359375 0.140625q-0.234375 0 -0.390625 -0.203125l-4.984375 -6.65625l0 6.359375q0 0.21875 -0.125 0.359375q-0.125 0.140625 -0.34375 0.140625q-0.21875 0 -0.34375 -0.140625q-0.109375 -0.140625 -0.109375 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.359375 -0.140625q0.234375 0 0.40625 0.203125l4.96875 6.65625l0 -6.359375q0 -0.234375 0.125 -0.359375q0.125 -0.140625 0.34375 -0.140625zm12.917175 7.953125q0.046875 0.09375 0.046875 0.203125q0 0.171875 -0.140625 0.296875q-0.140625 0.125 -0.328125 0.125q-0.296875 0 -0.421875 -0.296875l-0.84375 -1.9375l-4.53125 0l-0.859375 1.9375q-0.125 0.296875 -0.421875 0.296875q-0.1875 0 -0.34375 -0.125q-0.140625 -0.125 -0.140625 -0.3125q0 -0.09375 0.046875 -0.1875l3.4375 -7.640625q0.078125 -0.15625 0.21875 -0.234375q0.140625 -0.09375 0.3125 -0.09375q0.171875 0 0.3125 0.09375q0.15625 0.078125 0.21875 0.234375l3.4375 7.640625zm-5.859375 -2.421875l3.8125 0l-1.90625 -4.3125l-1.90625 4.3125zm7.78656 3.046875q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.546875q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.375 -0.125l2.84375 0q1.328125 0 2.0625 0.65625q0.75 0.640625 0.75 1.828125q0 1.1875 -0.75 1.84375q-0.734375 0.65625 -2.0625 0.65625l-2.359375 0l0 3.03125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625zm2.765625 -4.34375q1.9375 0 1.9375 -1.6875q0 -1.671875 -1.9375 -1.671875l-2.265625 0l0 3.359375l2.265625 0zm4.9744263 4.34375q-0.21875 0 -0.359375 -0.140625q-0.125 -0.140625 -0.125 -0.359375l0 -7.578125q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.234375 0 0.359375 0.140625q0.140625 0.125 0.140625 0.359375l0 7.578125q0 0.21875 -0.140625 0.359375q-0.125 0.140625 -0.359375 0.140625z" fill-rule="nonzero"/><path fill="#c9daf8" d="m548.5407 193.79199l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m548.5407 193.79199l87.49603 0l0 30.992126l-87.49603 0z" fill-rule="evenodd"/><path fill="#000000" d="m589.5417 213.87056q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm2.7480469 0q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875zm2.7479858 0q-0.28125 0 -0.484375 -0.1875q-0.1875 -0.1875 -0.1875 -0.484375q0 -0.296875 0.1875 -0.484375q0.203125 -0.203125 0.484375 -0.203125q0.28125 0 0.46875 0.203125q0.1875 0.1875 0.1875 0.484375q0 0.296875 -0.1875 0.484375q-0.1875 0.1875 -0.46875 0.1875z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m75.62294 283.52823l0 17.950958l100.62993 0l0 17.954529" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m75.62295 283.52823l0 17.950928l100.62992 0l0 14.527496" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.25287 316.00665l-1.124588 -1.1246033l1.124588 3.0897827l1.124588 -3.0897827z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m276.85654 283.52823l0 17.950958l-100.62991 0l0 17.954529" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m276.85654 283.52823l0 17.950928l-100.62991 0l0 14.527496" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.22662 316.00665l-1.124588 -1.1246033l1.124588 3.0897827l1.124588 -3.0897827z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 0.06298828l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 0.06298828l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 334.95734l-1.1245728 1.1246033l3.0897827 -1.1246033l-3.0897827 -1.1245728z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 -41.858246l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 -41.858246l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 293.0361l-1.1245728 1.1245728l3.0897827 -1.1245728l-3.0897827 -1.1246033z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 -83.74802l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 -83.74802l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 251.14633l-1.1245728 1.1245728l3.0897827 -1.1245728l-3.0897827 -1.124588z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m500.5223 334.89435l24.009003 0l0 -125.60629l24.022522 0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m500.5223 334.89435l24.009003 0l0 -125.60629l20.595398 0" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m545.1267 209.28806l-1.1245728 1.124588l3.0897827 -1.124588l-3.0897827 -1.124588z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m233.88803 165.75919l0 17.70752l42.960632 0l0 17.694061" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m233.88805 165.75919l0 17.70752l42.960617 0l0 14.266968" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m276.84866 197.73367l-1.1245728 -1.124588l1.1245728 3.0897675l1.1246033 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m326.25156 165.75919l0 17.70752l-49.385834 0l0 17.694061" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m326.25156 165.75919l0 17.70752l-49.385834 0l0 14.266968" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m276.86572 197.73367l-1.1245728 -1.124588l1.1245728 3.0897675l1.1246033 -3.0897675z" fill-rule="evenodd"/><path fill="#d9ead3" d="m132.49171 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m132.49171 252.53609l87.49606 0l0 30.992142l-87.49606 0z" fill-rule="evenodd"/><path fill="#000000" d="m146.9475 272.6459q-0.90625 0 -1.734375 -0.265625q-0.8125 -0.265625 -1.3125 -0.734375q-0.171875 -0.15625 -0.171875 -0.40625q0 -0.171875 0.09375 -0.296875q0.09375 -0.125 0.234375 -0.125q0.15625 0 0.328125 0.125q1.109375 0.859375 2.546875 0.859375q1.03125 0 1.578125 -0.390625q0.5625 -0.390625 0.5625 -1.125q0 -0.421875 -0.265625 -0.671875q-0.265625 -0.265625 -0.703125 -0.421875q-0.4375 -0.15625 -1.15625 -0.328125q-0.984375 -0.21875 -1.625 -0.46875q-0.625 -0.265625 -1.015625 -0.734375q-0.390625 -0.46875 -0.390625 -1.21875q0 -0.71875 0.390625 -1.265625q0.390625 -0.5625 1.09375 -0.875q0.703125 -0.3125 1.59375 -0.3125q0.84375 0 1.5625 0.265625q0.734375 0.25 1.234375 0.734375q0.1875 0.1875 0.1875 0.421875q0 0.171875 -0.09375 0.296875q-0.09375 0.125 -0.234375 0.125q-0.125 0 -0.34375 -0.140625q-0.59375 -0.46875 -1.09375 -0.65625q-0.5 -0.203125 -1.21875 -0.203125q-0.984375 0 -1.546875 0.421875q-0.546875 0.40625 -0.546875 1.15625q0 0.625 0.484375 0.953125q0.484375 0.3125 1.5 0.5625q1.09375 0.25 1.71875 0.484375q0.625 0.21875 1.03125 0.671875q0.421875 0.4375 0.421875 1.171875q0 0.71875 -0.390625 1.265625q-0.390625 0.53125 -1.109375 0.828125q-0.703125 0.296875 -1.609375 0.296875zm6.9353027 -6.078125q2.203125 0 2.203125 2.296875l0 3.265625q0 0.21875 -0.125 0.359375q-0.125 0.125 -0.34375 0.125q-0.21875 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.578125q-0.21875 0.515625 -0.6875 0.796875q-0.46875 0.28125 -1.078125 0.28125q-0.5625 0 -1.046875 -0.21875q-0.46875 -0.234375 -0.75 -0.640625q-0.265625 -0.40625 -0.265625 -0.90625q0 -0.65625 0.328125 -1.015625q0.34375 -0.375 1.109375 -0.53125q0.765625 -0.15625 2.125 -0.15625l0.265625 0l0 -0.40625q0 -0.71875 -0.296875 -1.046875q-0.28125 -0.34375 -0.953125 -0.34375q-0.8125 0 -1.65625 0.453125q-0.3125 0.203125 -0.453125 0.203125q-0.140625 0 -0.234375 -0.109375q-0.09375 -0.109375 -0.09375 -0.28125q0 -0.171875 0.09375 -0.296875q0.109375 -0.125 0.328125 -0.25q0.421875 -0.25 0.953125 -0.375q0.546875 -0.140625 1.0625 -0.140625zm-0.390625 5.296875q0.71875 0 1.171875 -0.484375q0.46875 -0.484375 0.46875 -1.25l0 -0.34375l-0.21875 0q-1.046875 0 -1.609375 0.09375q-0.546875 0.078125 -0.78125 0.296875q-0.234375 0.203125 -0.234375 0.609375q0 0.46875 0.34375 0.78125q0.34375 0.296875 0.859375 0.296875zm8.578796 -4.96875q0.140625 -0.296875 0.421875 -0.296875q0.1875 0 0.328125 0.125q0.140625 0.109375 0.140625 0.296875q0 0.109375 -0.046875 0.1875l-2.34375 5.046875q-0.0625 0.15625 -0.21875 0.25q-0.140625 0.078125 -0.3125 0.078125q-0.15625 0 -0.296875 -0.078125q-0.140625 -0.09375 -0.21875 -0.25l-2.328125 -5.046875q-0.046875 -0.078125 -0.046875 -0.171875q0 -0.1875 0.15625 -0.3125q0.15625 -0.140625 0.359375 -0.140625q0.109375 0 0.21875 0.078125q0.125 0.078125 0.1875 0.203125l2.0 4.5l2.0 -4.46875zm6.480545 4.296875q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm8.589676 -3.28125q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.328125l0 7.625q0 0.21875 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.640625q-0.265625 0.546875 -0.78125 0.84375q-0.5 0.296875 -1.1875 0.296875q-0.765625 0 -1.359375 -0.375q-0.578125 -0.390625 -0.90625 -1.078125q-0.328125 -0.6875 -0.328125 -1.59375q0 -0.90625 0.328125 -1.59375q0.328125 -0.6875 0.90625 -1.046875q0.59375 -0.375 1.359375 -0.375q0.6875 0 1.1875 0.296875q0.515625 0.296875 0.78125 0.84375l0 -3.203125q0 -0.21875 0.125 -0.34375q0.125 -0.125 0.359375 -0.125zm-2.25 7.796875q0.84375 0 1.296875 -0.578125q0.46875 -0.59375 0.46875 -1.65625q0 -1.0625 -0.46875 -1.640625q-0.453125 -0.578125 -1.296875 -0.578125q-0.859375 0 -1.34375 0.578125q-0.46875 0.578125 -0.46875 1.625q0 1.0625 0.46875 1.65625q0.484375 0.59375 1.34375 0.59375zm12.202805 -7.796875q0.21875 0 0.34375 0.140625q0.125 0.125 0.125 0.359375l0 7.59375q0 0.21875 -0.125 0.359375q-0.109375 0.125 -0.328125 0.125q-0.21875 0 -0.328125 -0.125q-0.109375 -0.140625 -0.109375 -0.359375l0 -6.125l-2.59375 4.984375q-0.171875 0.34375 -0.5 0.34375q-0.3125 0 -0.484375 -0.34375l-2.625 -4.921875l0 6.0625q0 0.21875 -0.109375 0.359375q-0.109375 0.125 -0.328125 0.125q-0.21875 0 -0.34375 -0.125q-0.109375 -0.140625 -0.109375 -0.359375l0 -7.59375q0 -0.234375 0.125 -0.359375q0.140625 -0.140625 0.359375 -0.140625q0.3125 0 0.484375 0.34375l3.046875 5.84375l3.015625 -5.84375q0.09375 -0.1875 0.203125 -0.265625q0.125 -0.078125 0.28125 -0.078125zm4.8576965 8.59375q-0.828125 0 -1.46875 -0.359375q-0.625 -0.375 -0.96875 -1.0625q-0.34375 -0.703125 -0.34375 -1.609375q0 -0.90625 0.34375 -1.59375q0.34375 -0.703125 0.96875 -1.0625q0.640625 -0.375 1.46875 -0.375q0.828125 0 1.453125 0.375q0.640625 0.359375 0.984375 1.0625q0.34375 0.6875 0.34375 1.59375q0 0.90625 -0.34375 1.609375q-0.34375 0.6875 -0.984375 1.0625q-0.625 0.359375 -1.453125 0.359375zm0 -0.796875q0.859375 0 1.3125 -0.5625q0.46875 -0.578125 0.46875 -1.671875q0 -1.0625 -0.46875 -1.640625q-0.46875 -0.59375 -1.3125 -0.59375q-0.859375 0 -1.328125 0.59375q-0.46875 0.578125 -0.46875 1.640625q0 1.078125 0.453125 1.65625q0.46875 0.578125 1.34375 0.578125zm8.925674 -7.796875q0.21875 0 0.34375 0.140625q0.140625 0.125 0.140625 0.328125l0 7.625q0 0.21875 -0.140625 0.359375q-0.125 0.125 -0.34375 0.125q-0.234375 0 -0.359375 -0.125q-0.125 -0.140625 -0.125 -0.359375l0 -0.640625q-0.265625 0.546875 -0.78125 0.84375q-0.5 0.296875 -1.1875 0.296875q-0.765625 0 -1.359375 -0.375q-0.578125 -0.390625 -0.90625 -1.078125q-0.328125 -0.6875 -0.328125 -1.59375q0 -0.90625 0.328125 -1.59375q0.328125 -0.6875 0.90625 -1.046875q0.59375 -0.375 1.359375 -0.375q0.6875 0 1.1875 0.296875q0.515625 0.296875 0.78125 0.84375l0 -3.203125q0 -0.21875 0.125 -0.34375q0.125 -0.125 0.359375 -0.125zm-2.25 7.796875q0.84375 0 1.296875 -0.578125q0.46875 -0.59375 0.46875 -1.65625q0 -1.0625 -0.46875 -1.640625q-0.453125 -0.578125 -1.296875 -0.578125q-0.859375 0 -1.34375 0.578125q-0.46875 0.578125 -0.46875 1.625q0 1.0625 0.46875 1.65625q0.484375 0.59375 1.34375 0.59375zm9.06218 -0.640625q0.140625 0 0.234375 0.109375q0.09375 0.109375 0.09375 0.28125q0 0.296875 -0.421875 0.546875q-0.4375 0.25 -0.921875 0.375q-0.46875 0.125 -0.921875 0.125q-1.359375 0 -2.15625 -0.796875q-0.78125 -0.8125 -0.78125 -2.21875q0 -0.90625 0.34375 -1.59375q0.359375 -0.6875 0.984375 -1.0625q0.640625 -0.390625 1.4375 -0.390625q1.140625 0 1.8125 0.75q0.671875 0.734375 0.671875 2.0q0 0.25 -0.09375 0.359375q-0.09375 0.109375 -0.3125 0.109375l-3.859375 0q0.09375 2.0625 1.953125 2.0625q0.46875 0 0.796875 -0.125q0.34375 -0.125 0.71875 -0.34375q0.3125 -0.1875 0.421875 -0.1875zm-2.09375 -3.875q-0.765625 0 -1.234375 0.484375q-0.46875 0.484375 -0.546875 1.359375l3.390625 0q-0.015625 -0.890625 -0.4375 -1.359375q-0.421875 -0.484375 -1.171875 -0.484375zm4.386551 5.296875q-0.21875 0 -0.359375 -0.125q-0.125 -0.125 -0.125 -0.359375l0 -7.625q0 -0.21875 0.125 -0.34375q0.140625 -0.125 0.359375 -0.125q0.203125 0 0.34375 0.125q0.140625 0.125 0.140625 0.34375l0 7.625q0 0.234375 -0.140625 0.359375q-0.140625 0.125 -0.34375 0.125z" fill-rule="nonzero"/><path fill="#000000" fill-opacity="0.0" d="m176.23885 99.34974l0 153.19684" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m176.23885 99.34974l0 149.76978" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.23885 249.1195l-1.124588 -1.124588l1.124588 3.0897675l1.124588 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m176.23975 283.52823l0 17.950958l0.06298828 0l0 17.954529" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m176.23975 283.52823l0 17.950928l0.06298828 0l0 14.527496" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m176.30273 316.00665l-1.1245728 -1.1246033l1.1245728 3.0897827l1.124588 -3.0897827z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m75.62205 99.34843l0 153.19684" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m75.62205 99.34843l0 149.76978" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m75.62205 249.1182l-1.1245804 -1.124588l1.1245804 3.0897675l1.1245804 -3.0897675z" fill-rule="evenodd"/><path fill="#000000" fill-opacity="0.0" d="m99.50131 100.0l0 76.0l54.992126 0l0 76.0" fill-rule="evenodd"/><path stroke="#000000" stroke-width="1.0" stroke-linejoin="round" stroke-linecap="butt" d="m99.50131 100.0l0 76.0l54.992126 0l0 72.57292" fill-rule="evenodd"/><path fill="#000000" stroke="#000000" stroke-width="1.0" stroke-linecap="butt" d="m154.49344 248.5729l-1.124588 -1.1245728l1.124588 3.0897675l1.124588 -3.0897675z" fill-rule="evenodd"/></g></svg> \ No newline at end of file
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index 99f4a7d8f6..fdd0632451 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -142,7 +142,6 @@ DECLARE_GRAPH_TRANSFORMATION(PropagateFakeQuantNumBits);
DECLARE_GRAPH_TRANSFORMATION(PropagateFixedSizes)
DECLARE_GRAPH_TRANSFORMATION(HardcodeMinMax)
DECLARE_GRAPH_TRANSFORMATION(Quantize)
-DECLARE_GRAPH_TRANSFORMATION(QuantizeWeights)
DECLARE_GRAPH_TRANSFORMATION(RemoveFinalDequantizeOp)
DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowAssert)
DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowIdentity)
@@ -178,9 +177,10 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveSpaceToBatchNDAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveBatchToSpaceNDAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolvePadAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolvePadV2Attributes)
-DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes)
-DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveReduceAttributes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveReshapeAttributes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveTransposeAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantPack)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRandomUniform)
@@ -217,12 +217,6 @@ class PropagateDefaultMinMax : public GraphTransformation {
std::vector<std::pair<ArrayDataType, MinMax>> type_ranges_;
};
-class ResolveReshapeAttributes : public GraphTransformation {
- public:
- bool Run(Model* model, std::size_t op_index) override;
- const char* Name() const override { return "ResolveReshapeAttributes"; }
-};
-
class RemoveTrivialReshape : public GraphTransformation {
public:
bool Run(Model* model, std::size_t op_index) override;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
index 502de88f7c..3114fa93e8 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
@@ -63,6 +63,25 @@ bool HardcodeMinMaxForL2Normalization(Model* model, Operator* op) {
return true;
}
+bool HardcodeInputMinMaxFromOutput(Model* model, Operator* op) {
+ auto& input = model->GetArray(op->inputs[0]);
+ if (input.minmax) {
+ const auto* minmax = input.minmax.get();
+ if (minmax) {
+ return false;
+ }
+ }
+ auto& output = model->GetArray(op->outputs[0]);
+ if (output.minmax) {
+ const auto* minmax = model->GetArray(op->outputs[0]).minmax.get();
+ if (minmax) {
+ input.GetOrCreateMinMax() = *minmax;
+ return true;
+ }
+ }
+ return false;
+}
+
bool HardcodeMinMaxForConcatenation(Model* model, Operator* op) {
// Do not early return if the output already has min/max:
// we may still need to adjust the inputs min/max.
@@ -366,6 +385,16 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
changed = HardcodeMinMaxForL2Normalization(model, op);
break;
+ case OperatorType::kRelu:
+ // For any normalization other than batch norm, the quantizations ranges
+ // before and after relu are expected to be known. Having a quantization
+ // op before relu would reduce the number of bits of precision for the
+ // activation in half. So we deduce the range before relu from that after
+ // the relu. This would eliminate the need for two fake quantization nodes
+ // and would not reduce the bits of precision available for activation.
+ changed = HardcodeInputMinMaxFromOutput(model, op);
+ break;
+
case OperatorType::kConcatenation:
changed = HardcodeMinMaxForConcatenation(model, op);
break;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
index c8310161cb..323eefcd3a 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -227,6 +227,15 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
ArrayDataType::kFloat;
break;
}
+ case OperatorType::kUnpack: {
+ CHECK_EQ(op->inputs.size(), 1);
+ const int output_size = op->outputs.size();
+ for (int i = 0; i < output_size; ++i) {
+ model->GetArray(op->outputs[i]).data_type =
+ model->GetArray(op->inputs[0]).data_type;
+ }
+ break;
+ }
default: {
// These operators produce outputs with the same type as their 1st input
CHECK_GT(op->inputs.size(), 0);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index 91e290439a..f103bb94ae 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -539,6 +539,8 @@ bool KeepDims(const Operator& op) {
return static_cast<const TensorFlowProdOperator&>(op).keep_dims;
case OperatorType::kMean:
return static_cast<const MeanOperator&>(op).keep_dims;
+ case OperatorType::kAny:
+ return static_cast<const TensorFlowAnyOperator&>(op).keep_dims;
default:
LOG(FATAL) << "Not a reduction operator!";
return false;
@@ -559,26 +561,38 @@ void ProcessTensorFlowReductionOperator(Model* model, Operator* op) {
const bool keep_dims = KeepDims(*op);
if (op->inputs.size() == 2) {
// There is a reduction_indices input.
- const auto& reduction_array = model->GetArray(op->inputs[1]);
- if (!reduction_array.buffer) {
+ const auto& reduction_indices_array = model->GetArray(op->inputs[1]);
+ if (!reduction_indices_array.buffer) {
return;
}
- CHECK(reduction_array.buffer->type == ArrayDataType::kInt32);
- const auto& reduction_array_vals =
- reduction_array.GetBuffer<ArrayDataType::kInt32>().data;
- auto& output_dims = *output_array.mutable_shape()->mutable_dims();
- output_dims.clear();
- for (int i = 0; i < input_shape.dimensions_count(); i++) {
- bool is_reduction_dim = false;
- for (int r : reduction_array_vals) {
- if (i == r) {
- is_reduction_dim = true;
- }
+ CHECK(reduction_indices_array.buffer->type == ArrayDataType::kInt32);
+
+ int input_rank = input_shape.dimensions_count();
+ std::set<int32> true_indices;
+ const auto& reduction_indices =
+ reduction_indices_array.GetBuffer<ArrayDataType::kInt32>().data;
+ for (int i = 0; i < reduction_indices.size(); ++i) {
+ const int32 reduction_index = reduction_indices[i];
+ if (reduction_index < -input_rank || reduction_index >= input_rank) {
+ CHECK(false) << "Invalid reduction dimension " << reduction_index
+ << " for input with " << input_rank << " dimensions";
}
- if (!is_reduction_dim) {
- output_dims.push_back(input_shape.dims(i));
- } else if (keep_dims) {
- output_dims.push_back(1);
+ int32 wrapped_index = reduction_index;
+ if (wrapped_index < 0) {
+ wrapped_index += input_rank;
+ }
+ true_indices.insert(wrapped_index);
+ }
+
+ auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
+ mutable_dims->clear();
+ for (int i = 0; i < input_rank; ++i) {
+ if (true_indices.count(i) > 0) {
+ if (keep_dims) {
+ mutable_dims->emplace_back(1);
+ }
+ } else {
+ mutable_dims->emplace_back(input_shape.dims(i));
}
}
} else {
@@ -1300,12 +1314,16 @@ void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
// Compute output shape
for (int axis = 0; axis < num_input_axes; ++axis) {
+ const auto strided_slice_params =
+ tflite::strided_slice::BuildStridedSliceParams(
+ op->begin_mask, op->end_mask, op->shrink_axis_mask,
+ op->start_indices, op->stop_indices, op->strides);
int start_index = tflite::strided_slice::StartForAxis(
- op->begin_mask, op->start_indices, op->strides,
- input_array.shape().dims().data(), axis);
+ strided_slice_params, ToRuntimeShape(input_array.shape()), axis);
int stop_index = tflite::strided_slice::StopForAxis(
- op->end_mask, op->shrink_axis_mask, op->stop_indices, op->strides,
- input_array.shape().dims().data(), axis, start_index);
+ strided_slice_params, ToRuntimeShape(input_array.shape()), axis,
+ start_index);
+
int dim_size =
ceil(static_cast<float>(stop_index - start_index) / op->strides[axis]);
@@ -1515,65 +1533,6 @@ void ProcessTileOperator(Model* model, TensorFlowTileOperator* op) {
}
}
-void ProcessAnyOperator(Model* model, AnyOperator* op) {
- CHECK_EQ(op->inputs.size(), 2);
- CHECK_EQ(op->outputs.size(), 1);
-
- auto& output_array = model->GetArray(op->outputs[0]);
- if (output_array.has_shape()) {
- // We have already run.
- return;
- }
-
- const auto& input_array = model->GetArray(op->inputs[0]);
- if (!input_array.has_shape()) {
- // Yield until input dims have been resolved.
- return;
- }
- const auto& input_shape = input_array.shape();
-
- auto& reduction_indices_array = model->GetArray(op->inputs[1]);
- if (!reduction_indices_array.has_shape()) {
- // Yield until reduction indices shape been resolved.
- return;
- }
- if (!reduction_indices_array.buffer) {
- // Yield until the reduction indices are constant.
- return;
- }
- CHECK(reduction_indices_array.data_type == ArrayDataType::kInt32)
- << "Any reduction input must be int32";
-
- int input_rank = input_shape.dimensions_count();
- std::set<int32> true_indices;
- const auto& reduction_indices =
- reduction_indices_array.GetBuffer<ArrayDataType::kInt32>().data;
- for (int i = 0; i < reduction_indices.size(); ++i) {
- const int32 reduction_index = reduction_indices[i];
- if (reduction_index < -input_rank || reduction_index >= input_rank) {
- CHECK(false) << "Invalid reduction dimension " << reduction_index
- << " for input with " << input_rank << " dimensions";
- }
- int32 wrapped_index = reduction_index;
- if (wrapped_index < 0) {
- wrapped_index += input_rank;
- }
- true_indices.insert(wrapped_index);
- }
-
- auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
- mutable_dims->clear();
- for (int i = 0; i < input_rank; ++i) {
- if (true_indices.count(i) > 0) {
- if (op->keep_dims) {
- mutable_dims->emplace_back(1);
- }
- } else {
- mutable_dims->emplace_back(input_shape.dims(i));
- }
- }
-}
-
void ProcessOneHotOperator(Model* model, OneHotOperator* op) {
CHECK_EQ(op->inputs.size(), 4);
CHECK_EQ(op->outputs.size(), 1);
@@ -1629,6 +1588,32 @@ void ProcessOneHotOperator(Model* model, OneHotOperator* op) {
}
}
+void ProcessUnpackOperator(Model* model, UnpackOperator* op) {
+ CHECK_EQ(op->inputs.size(), 1);
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+
+ const std::vector<int>& input_dims = input_array.shape().dims();
+ std::vector<int> output_dims;
+
+ output_dims.reserve(input_dims.size() - 1);
+ for (int i = 0; i < input_dims.size(); ++i) {
+ if (i != op->axis) {
+ output_dims.push_back(input_dims[i]);
+ }
+ }
+ for (const string& output_name : op->outputs) {
+ auto& output_array = model->GetArray(output_name);
+ if (output_array.has_shape()) {
+ return;
+ }
+ *output_array.mutable_shape()->mutable_dims() = output_dims;
+ }
+}
+
} // namespace
bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
@@ -1743,6 +1728,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kSum:
case OperatorType::kReduceProd:
case OperatorType::kMean:
+ case OperatorType::kAny:
ProcessTensorFlowReductionOperator(model, op);
break;
case OperatorType::kSelect:
@@ -1874,12 +1860,13 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
case OperatorType::kTile:
ProcessTileOperator(model, static_cast<TensorFlowTileOperator*>(op));
break;
- case OperatorType::kAny:
- ProcessAnyOperator(model, static_cast<AnyOperator*>(op));
break;
case OperatorType::kOneHot:
ProcessOneHotOperator(model, static_cast<OneHotOperator*>(op));
break;
+ case OperatorType::kUnpack:
+ ProcessUnpackOperator(model, static_cast<UnpackOperator*>(op));
+ break;
default:
// Unimplemented, another graph transformation should drop it.
LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
index 8d22ae2eb1..1bc366f555 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
@@ -62,7 +62,8 @@ bool SupportsQuantization(const Operator& op) {
type == OperatorType::kLessEqual || type == OperatorType::kSelect ||
type == OperatorType::kArgMax || type == OperatorType::kRelu ||
type == OperatorType::kRelu1 || type == OperatorType::kRelu6 ||
- type == OperatorType::kShape || type == OperatorType::kExpandDims;
+ type == OperatorType::kShape || type == OperatorType::kExpandDims ||
+ type == OperatorType::kPack || type == OperatorType::kTopK_V2;
}
// The quantized op allows output arrays of type float using
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc
deleted file mode 100644
index 7a8515f6d1..0000000000
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize_weights.cc
+++ /dev/null
@@ -1,106 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <iterator>
-#include <string>
-#include <vector>
-
-#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
-#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h"
-#include "tensorflow/contrib/lite/toco/model.h"
-#include "tensorflow/contrib/lite/toco/tooling_util.h"
-
-namespace toco {
-
-namespace {
-
-// The minimum number of elements a weights array must have to be quantized
-// by this transformation.
-// TODO(suharshs): Make this minimum size configurable.
-const int kWeightsMinSize = 1024;
-
-// Gets the quantization params from the float array.
-void GetQuantizationParamsFromArray(const Array& array,
- QuantizationParams* params) {
- const std::vector<float>& float_vals =
- array.GetBuffer<ArrayDataType::kFloat>().data;
- auto minmax = std::minmax_element(float_vals.begin(), float_vals.end());
- *params = tflite::ChooseQuantizationParams<uint8>(
- *minmax.first, *minmax.second, array.narrow_range);
-}
-
-} // namespace
-
-bool QuantizeWeights::Run(Model* model, std::size_t op_index) {
- const auto op_it = model->operators.begin() + op_index;
- Operator* op = op_it->get();
-
- // Get the weights tensor, if the current operator has one.
- int weights_index;
- if (op->type == OperatorType::kConv ||
- op->type == OperatorType::kDepthwiseConv ||
- op->type == OperatorType::kFullyConnected) {
- weights_index = 1;
- } else if (op->type == OperatorType::kLstmCell) {
- weights_index = LstmCellOperator::WEIGHTS_INPUT;
- } else {
- return false;
- }
-
- // Return early if the array isn't a constant param, this can happen in early
- // transformation passes until transpose operations following the weight array
- // are resolved.
- const string weights = op->inputs[weights_index];
- if (!IsConstantParameterArray(*model, weights)) {
- return false;
- }
-
- // Return early if the weight tensor is not type float.
- Array& weights_array = model->GetArray(weights);
- if (weights_array.data_type != ArrayDataType::kFloat) {
- return false;
- }
-
- // Return early if the tensor is too small. Small tensors don't take up too
- // much space and can result in bad quantization results.
- if (weights_array.GetBuffer<ArrayDataType::kFloat>().data.size() <
- kWeightsMinSize) {
- return false;
- }
-
- // Quantize the weight tensor to type kUint8.
- QuantizationParams params;
- GetQuantizationParamsFromArray(weights_array, &params);
- QuantizeArray(this, model, weights, ArrayDataType::kUint8, params);
-
- // Insert a Dequantize operation after the quantized weights tensor.
- auto* dequantize_op = new DequantizeOperator;
- model->operators.emplace(op_it, dequantize_op);
-
- // Create a new intermediate tensor to connect the Dequantize op to the
- // original op.
- const string dequantized_output =
- AvailableArrayName(*model, weights + "_dequantized");
- Array& dequantized_output_array = model->GetOrCreateArray(dequantized_output);
- dequantized_output_array.data_type = ArrayDataType::kFloat;
-
- // Connect up the new Dequantize op with the weights and original op.
- op->inputs[weights_index] = dequantized_output;
- dequantize_op->inputs = {weights};
- dequantize_op->outputs = {dequantized_output};
-
- return true;
-}
-
-} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
index 9d8bd4fc39..8853ed87e6 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
@@ -52,14 +52,18 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
Buffer<Type> const& input_buffer = input_array.GetBuffer<Type>();
std::vector<int> src_coord(num_input_axes);
std::vector<int> stop_for_axis(num_input_axes);
+ const auto strided_slice_params =
+ tflite::strided_slice::BuildStridedSliceParams(
+ op.begin_mask, op.end_mask, op.shrink_axis_mask, op.start_indices,
+ op.stop_indices, op.strides);
+
for (int axis = 0; axis < num_input_axes; axis++) {
- int start = tflite::strided_slice::StartForAxis(
- op.begin_mask, op.start_indices, op.strides, input_shape.dims().data(),
- axis);
- src_coord[axis] = start;
+ int start_index = tflite::strided_slice::StartForAxis(
+ strided_slice_params, ToRuntimeShape(input_array.shape()), axis);
+ src_coord[axis] = start_index;
stop_for_axis[axis] = tflite::strided_slice::StopForAxis(
- op.end_mask, op.shrink_axis_mask, op.stop_indices, op.strides,
- input_shape.dims().data(), axis, start);
+ strided_slice_params, ToRuntimeShape(input_array.shape()), axis,
+ start_index);
}
// In order to handle any number (N) of dimensions, we copy elements one by
@@ -86,8 +90,7 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
if (tflite::strided_slice::LoopCondition(src_coord[axis], stop, stride)) {
// Reset axis and set carry
src_coord[axis] = tflite::strided_slice::StartForAxis(
- op.begin_mask, op.start_indices, op.strides,
- input_shape.dims().data(), axis);
+ strided_slice_params, ToRuntimeShape(input_shape), axis);
carry = true;
} else {
carry = false;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
index 475415e481..c698a9567a 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
@@ -51,6 +51,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
// Test for unary ops of types that we know how to resolve.
switch (unary_op->type) {
case OperatorType::kCast:
+ case OperatorType::kExp:
case OperatorType::kLog:
case OperatorType::kNeg:
case OperatorType::kRsqrt:
@@ -218,7 +219,8 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
max = std::max(max, (*input_float_data)[i]);
}
output_float_data[0] = max;
- } else if (unary_op->type == OperatorType::kNeg ||
+ } else if (unary_op->type == OperatorType::kExp ||
+ unary_op->type == OperatorType::kNeg ||
unary_op->type == OperatorType::kLog ||
unary_op->type == OperatorType::kRsqrt ||
unary_op->type == OperatorType::kSqrt ||
@@ -231,7 +233,9 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
for (int i = 0; i < output_buffer_size; i++) {
const float val = (*input_float_data)[i];
float outval = 0.f;
- if (unary_op->type == OperatorType::kNeg) {
+ if (unary_op->type == OperatorType::kExp) {
+ outval = std::exp(val);
+ } else if (unary_op->type == OperatorType::kNeg) {
outval = -val;
} else if (unary_op->type == OperatorType::kLog) {
outval = std::log(val);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc
index 7d456af2fb..73198ac7c0 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reduce_attributes.cc
@@ -52,6 +52,8 @@ bool ResolveReduceAttributes::Run(Model* model, std::size_t op_index) {
return ResolveAttributes(model, static_cast<TensorFlowMinOperator*>(op));
case OperatorType::kReduceMax:
return ResolveAttributes(model, static_cast<TensorFlowMaxOperator*>(op));
+ case OperatorType::kAny:
+ return ResolveAttributes(model, static_cast<TensorFlowMaxOperator*>(op));
default:
return false;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
index e163fc9ae1..acf1e3ede5 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
@@ -20,19 +20,6 @@ tf_cc_test(
)
tf_cc_test(
- name = "quantize_weights_test",
- srcs = ["quantize_weights_test.cc"],
- tags = ["no_oss"],
- deps = [
- "//tensorflow/contrib/lite/toco:graph_transformations",
- "//tensorflow/contrib/lite/toco:model",
- "//tensorflow/contrib/lite/toco:tooling_util",
- "@com_google_absl//absl/memory",
- "@com_google_googletest//:gtest_main",
- ],
-)
-
-tf_cc_test(
name = "resolve_constant_concatenation_test",
srcs = ["resolve_constant_concatenation_test.cc"],
tags = ["no_oss"],
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc
deleted file mode 100644
index c05eb0929f..0000000000
--- a/tensorflow/contrib/lite/toco/graph_transformations/tests/quantize_weights_test.cc
+++ /dev/null
@@ -1,167 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <math.h>
-#include <string>
-#include <vector>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "absl/memory/memory.h"
-#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
-#include "tensorflow/contrib/lite/toco/model.h"
-#include "tensorflow/contrib/lite/toco/tooling_util.h"
-
-namespace toco {
-
-class QuantizeWeightsTest : public ::testing::Test {
- protected:
- QuantizeWeightsTest() {}
-
- // The name of the weights input array.
- const string kWeightsName = "weights";
- // The zero_point of the values in the input array.
- const int kZeroPoint = 128;
-
- // Prepare a hypothetical TOCO model of a quantizable fully connected float
- // layer.
- void PrepareModel(Model* model, int elements_per_dim) {
- std::vector<string> fc_input_names = {"inputs", kWeightsName};
-
- const int kDim = 4;
- const int buf_size = std::pow(elements_per_dim, static_cast<double>(kDim));
- auto in_buf = absl::make_unique<float[]>(buf_size);
- // Initialize the array with values from -128.0 to 127.0, since these values
- // should be exactly representable by quantization.
- for (int i = 0; i < buf_size; i++) {
- in_buf[i] = static_cast<float>(i % 256 - kZeroPoint);
- }
-
- for (const string& fc_input_name : fc_input_names) {
- Array& in_array = model->GetOrCreateArray(fc_input_name);
- in_array.data_type = ArrayDataType::kFloat;
-
- // Initialize shape for the input array.
- Shape* in_array_shape = in_array.mutable_shape();
- std::vector<int>* in_array_shape_dim = in_array_shape->mutable_dims();
- in_array_shape_dim->resize(kDim, elements_per_dim);
- auto& in_array_buffer =
- in_array.GetMutableBuffer<ArrayDataType::kFloat>();
- in_array_buffer.data.resize(buf_size);
- float* buf_ptr =
- in_array.GetMutableBuffer<ArrayDataType::kFloat>().data.data();
- std::copy(in_buf.get(), in_buf.get() + buf_size, buf_ptr);
- }
-
- auto* fc_op = new FullyConnectedOperator;
- fc_op->inputs = fc_input_names;
- fc_op->outputs = {"fc_op_outputs"};
- Array& out_array = model->GetOrCreateArray(fc_op->outputs[0]);
- out_array.data_type = ArrayDataType::kFloat;
- Shape* out_array_shape = out_array.mutable_shape();
- std::vector<int>* out_array_shape_dim = out_array_shape->mutable_dims();
- out_array_shape_dim->resize(kDim, elements_per_dim);
- model->operators.push_back(std::unique_ptr<Operator>(fc_op));
- }
-};
-
-TEST_F(QuantizeWeightsTest, QuantizedFullyConnected) {
- // Test that weight arrays that are large enough are quantized.
- Model model;
- // 6 elements per dim gives us 1296 elements, which is sufficient to be
- // quantized.
- PrepareModel(&model, 6);
-
- // Check the state of the graph before the transformation.
- const auto& float_array_map = model.GetArrayMap();
- EXPECT_EQ(float_array_map.size(), 3);
- // Before the transformation, all arrays should be type float.
- for (const auto& element : float_array_map) {
- EXPECT_EQ(element.second->data_type, ArrayDataType::kFloat);
- }
- const std::vector<float> float_weight_vals =
- model.GetArray(kWeightsName).GetBuffer<ArrayDataType::kFloat>().data;
-
- // Invoke the transformation.
- GraphTransformationsSet graph_transformation_set;
- graph_transformation_set.Add(new toco::QuantizeWeights);
- (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
-
- // Check the state of the graph after the transformation.
- const auto& quantized_array_map = model.GetArrayMap();
- EXPECT_EQ(quantized_array_map.size(), 4);
- // After the transformation, three arrays should be type float and one array
- // should be uint8.
- int num_float = 0;
- int num_uint8 = 0;
- for (const auto& element : quantized_array_map) {
- if (element.second->data_type == ArrayDataType::kFloat) {
- num_float++;
- } else if (element.second->data_type == ArrayDataType::kUint8) {
- num_uint8++;
- } else {
- FAIL() << "Unexpected array type.";
- }
- }
- EXPECT_EQ(num_float, 3);
- EXPECT_EQ(num_uint8, 1);
- // Ensure that the values were quantized correctly.
- const std::vector<uint8>& quantized_weight_vals =
- model.GetArray(kWeightsName).GetBuffer<ArrayDataType::kUint8>().data;
- for (int i = 0; i < quantized_weight_vals.size(); i++) {
- EXPECT_EQ(quantized_weight_vals[i], float_weight_vals[i] + kZeroPoint);
- }
-
- // Ensure that a Dequantize operator has been inserted before the
- // FullyConnectedLayer.
- EXPECT_EQ(model.operators[0]->type, OperatorType::kDequantize);
-}
-
-TEST_F(QuantizeWeightsTest, NotQuantizedFullyConnected) {
- // Test that weight arrays that are too small are left untouched.
- Model model;
- // 5 elements per dim gives us 625 elements, which is NOT sufficient to be
- // quantized.
- PrepareModel(&model, 5);
-
- // Check the state of the graph before the transformation.
- const auto& float_array_map = model.GetArrayMap();
- EXPECT_EQ(float_array_map.size(), 3);
- // Before the transformation, all arrays should be type float.
- for (auto it = float_array_map.begin(); it != float_array_map.end(); it++) {
- EXPECT_EQ(it->second->data_type, ArrayDataType::kFloat);
- }
- std::vector<float> float_weight_vals =
- model.GetArray(kWeightsName).GetBuffer<ArrayDataType::kFloat>().data;
-
- // Invoke the transformation.
- GraphTransformationsSet graph_transformation_set;
- graph_transformation_set.Add(new toco::QuantizeWeights);
- (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
-
- // Check the state of the graph after the transformation.
- const auto& post_array_map = model.GetArrayMap();
- EXPECT_EQ(post_array_map.size(), 3);
- for (auto it = post_array_map.begin(); it != post_array_map.end(); it++) {
- EXPECT_EQ(it->second->data_type, ArrayDataType::kFloat);
- }
- // Ensure that the values remain unchanged.
- std::vector<float> const& quantized_weight_vals =
- model.GetArray(kWeightsName).GetBuffer<ArrayDataType::kFloat>().data;
- for (int i = 0; i < quantized_weight_vals.size(); i++) {
- EXPECT_EQ(quantized_weight_vals[i], float_weight_vals[i]);
- }
-}
-
-} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index b7fffbce22..9bc23c4b3c 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1576,6 +1576,26 @@ tensorflow::Status ConvertPackOperator(
return tensorflow::Status::OK();
}
+tensorflow::Status ConvertUnpackOperator(
+ const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
+ Model* model) {
+ CHECK_EQ(node.op(), "Unpack");
+ auto op = absl::make_unique<UnpackOperator>();
+ const int num_inputs = GetInputsCount(node, tf_import_flags);
+ QCHECK_EQ(num_inputs, 1);
+ op->inputs.push_back(node.input(0));
+ op->num = GetIntAttr(node, "num");
+ op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0;
+ op->dtype = ConvertDataType(toco::GetDataTypeAttr(node, "T"));
+
+ op->outputs.push_back(node.name()); // Implicit :0.
+ for (int i = 1; i < op->num; ++i) {
+ op->outputs.push_back(node.name() + ":" + std::to_string(i));
+ }
+ model->operators.emplace_back(std::move(op));
+ return tensorflow::Status::OK();
+}
+
// Some TensorFlow ops only occur in graph cycles, representing
// control flow. We do not currently support control flow, so we wouldn't
// be able to fully support such graphs, including performing inference,
@@ -1618,24 +1638,6 @@ tensorflow::Status ConvertShapeOperator(
return tensorflow::Status::OK();
}
-tensorflow::Status ConvertAnyOperator(
- const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
- Model* model) {
- CHECK_EQ(node.op(), "Any");
- TF_QCHECK_OK(CheckInputsCount(node, tf_import_flags, 2));
- const auto idx_type =
- HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32;
- CHECK(idx_type == DT_INT32);
- auto op = absl::make_unique<AnyOperator>();
- op->inputs.push_back(node.input(0));
- op->inputs.push_back(node.input(1));
- op->outputs.push_back(node.name());
- op->keep_dims =
- HasAttr(node, "keep_dims") ? GetBoolAttr(node, "keep_dims") : false;
- model->operators.push_back(std::move(op));
- return tensorflow::Status::OK();
-}
-
void StripCaretFromArrayNames(Model* model) {
for (auto& op : model->operators) {
for (auto& input : op->inputs) {
@@ -1917,7 +1919,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"Add", ConvertSimpleOperator<AddOperator, 2>},
{"AddN", ConvertSimpleOperator<AddNOperator>},
{"All", ConvertSimpleOperator<TensorFlowAllOperator>},
- {"Any", ConvertAnyOperator},
+ {"Any", ConvertReduceOperator<TensorFlowAnyOperator>},
{"ArgMax", ConvertArgMaxOperator},
{"ArgMin", ConvertArgMinOperator},
{"Assert", ConvertSimpleOperator<TensorFlowAssertOperator>},
@@ -2020,6 +2022,7 @@ ConverterMapType GetTensorFlowNodeConverterMap() {
{"TopK", ConvertTopKV2Operator},
{"TopKV2", ConvertTopKV2Operator},
{"Transpose", ConvertSimpleOperator<TransposeOperator, 2>},
+ {"Unpack", ConvertUnpackOperator},
});
}
@@ -2058,8 +2061,14 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
}
Model* model = new Model;
- const internal::ConverterMapType& converter_map =
- internal::GetTensorFlowNodeConverterMap();
+ internal::ConverterMapType converter_map;
+
+ // This is used for the TFLite "Full Eager Mode" conversion. All the ops are
+ // imported as `TensorFlowUnsupportedOperator`, and later all these ops are
+ // converted to TFLite Eager ops.
+ if (!tf_import_flags.import_all_ops_as_unsupported) {
+ converter_map = internal::GetTensorFlowNodeConverterMap();
+ }
for (auto node : inlined_graph.node()) {
StripZeroOutputIndexFromInputs(&node);
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.h b/tensorflow/contrib/lite/toco/import_tensorflow.h
index 2177872334..7db23f2d44 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.h
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.h
@@ -27,6 +27,11 @@ struct TensorFlowImportFlags {
// If true, control dependencies will be dropped immediately
// during the import of the TensorFlow GraphDef.
bool drop_control_dependency = false;
+
+ // Do not recognize any op and import all ops as
+ // `TensorFlowUnsupportedOperator`. This is used to populated with the
+ // `force_eager_ops` flag.
+ bool import_all_ops_as_unsupported = false;
};
std::unique_ptr<Model> ImportTensorFlowGraphDef(
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 412e14c4ad..2e100e37f6 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -149,6 +149,7 @@ enum class OperatorType : uint8 {
kLogicalNot,
kLogicalOr,
kCTCBeamSearchDecoder,
+ kUnpack,
};
// Helper to deal with TensorFlow arrays using a different ordering of
@@ -1770,8 +1771,9 @@ struct PowOperator : Operator {
// Inputs[1]: required: reduction_indices.
//
// TensorFlow equivalent: tf.reduce_any.
-struct AnyOperator : Operator {
- AnyOperator() : Operator(OperatorType::kAny) {}
+struct TensorFlowAnyOperator : Operator {
+ TensorFlowAnyOperator() : Operator(OperatorType::kAny) {}
+ std::vector<int> axis;
bool keep_dims = false;
};
@@ -1828,6 +1830,20 @@ struct LogicalOrOperator : Operator {
LogicalOrOperator() : Operator(OperatorType::kLogicalOr) {}
};
+// Unpack operator:
+//
+// Inputs:
+// Inputs[0]: required: A boolean input tensor.
+// Inputs[1]: required: reduction_indices.
+//
+// TensorFlow equivalent: tf.unstack.
+struct UnpackOperator : Operator {
+ UnpackOperator() : Operator(OperatorType::kUnpack) {}
+ int num;
+ int axis;
+ ArrayDataType dtype = ArrayDataType::kNone;
+};
+
// Alloc's are used for transient arrays only. An Alloc specifies which interval
// of the "transient_data" workspace buffer passed to inference functions, is to
// be used for the transient array at hand. The 'start' and 'end' values are
diff --git a/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py b/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py
index 3761e0095e..75c1c8970c 100644
--- a/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py
+++ b/tensorflow/contrib/lite/toco/python/toco_from_protos_test.py
@@ -50,7 +50,7 @@ class TocoFromProtosTest(googletest.TestCase):
toco_flags.output_format = toco_flags_pb2.TFLITE
toco_flags.inference_input_type = types_pb2.FLOAT
toco_flags.inference_type = types_pb2.FLOAT
- toco_flags.allow_custom_ops = True;
+ toco_flags.allow_custom_ops = True
model_flags = model_flags_pb2.ModelFlags()
input_array = model_flags.input_arrays.add()
input_array.name = TensorName(in_tensor)
diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/contrib/lite/toco/tflite/BUILD
index 709c53606b..71cdb7703e 100644
--- a/tensorflow/contrib/lite/toco/tflite/BUILD
+++ b/tensorflow/contrib/lite/toco/tflite/BUILD
@@ -91,6 +91,7 @@ cc_library(
"//tensorflow/contrib/lite/schema:schema_fbs",
"//tensorflow/contrib/lite/toco:model",
"//tensorflow/contrib/lite/toco:tooling_util",
+ "//tensorflow/contrib/lite/tools/optimize:quantize_weights",
"@com_google_absl//absl/strings",
"@flatbuffers",
],
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc
index 5ad307af14..fee10b1dff 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export.cc
@@ -16,10 +16,12 @@ limitations under the License.
#include "flatbuffers/flexbuffers.h"
#include "absl/strings/str_join.h"
+#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
#include "tensorflow/contrib/lite/toco/tflite/operator.h"
#include "tensorflow/contrib/lite/toco/tflite/types.h"
#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/contrib/lite/tools/optimize/quantize_weights.h"
#include "tensorflow/contrib/lite/version.h"
namespace toco {
@@ -47,12 +49,21 @@ namespace {
details::OperatorKey GetOperatorKey(
const ::toco::Operator& op,
- const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
+ bool allow_eager_ops) {
string custom_code;
if (op.type == OperatorType::kUnsupported) {
const TensorFlowUnsupportedOperator& unsupported_op =
static_cast<const TensorFlowUnsupportedOperator&>(op);
- custom_code = unsupported_op.tensorflow_op;
+
+ // TODO(b/113715895): When `allow_eager_ops` is on, for now there's no way
+ // to populate a regular custom op. We need to find a way to fix this.
+ if (allow_eager_ops) {
+ custom_code = string(::tflite::kEagerCustomCodePrefix) +
+ unsupported_op.tensorflow_op;
+ } else {
+ custom_code = unsupported_op.tensorflow_op;
+ }
}
int version = 1;
if (ops_by_type.count(op.type) != 0) {
@@ -61,6 +72,13 @@ details::OperatorKey GetOperatorKey(
return details::OperatorKey(op.type, custom_code, version);
}
+void WriteModelToString(const flatbuffers::FlatBufferBuilder& builder,
+ string* file_contents) {
+ const uint8_t* buffer = builder.GetBufferPointer();
+ int size = builder.GetSize();
+ *file_contents = string(reinterpret_cast<const char*>(buffer), size);
+}
+
} // Anonymous namespace.
namespace details {
@@ -82,11 +100,12 @@ void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) {
void LoadOperatorsMap(
const Model& model, OperatorsMap* operators_map,
- const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
+ bool allow_eager_ops) {
// First find a list of unique operator types.
std::set<OperatorKey> keys;
for (const auto& op : model.operators) {
- keys.insert(GetOperatorKey(*op, ops_by_type));
+ keys.insert(GetOperatorKey(*op, ops_by_type, allow_eager_ops));
}
// Now assign indices to them and fill in the map.
int index = 0;
@@ -180,7 +199,7 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
const Model& model,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
const details::OperatorsMap& operators_map, FlatBufferBuilder* builder,
- std::set<string>* error_summary) {
+ std::set<string>* error_summary, const ExportParams& params) {
// Map from operator name to TF Lite enum value, for all builtins.
std::map<string, BuiltinOperator> builtin_ops;
for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) {
@@ -196,7 +215,8 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
std::map<int, Offset<OperatorCode>> ordered_opcodes;
for (const auto& op : model.operators) {
- const details::OperatorKey operator_key = GetOperatorKey(*op, ops_by_type);
+ const details::OperatorKey operator_key =
+ GetOperatorKey(*op, ops_by_type, params.allow_eager_ops);
int op_index = operators_map.at(operator_key);
int op_version = operator_key.version;
@@ -243,7 +263,7 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
const details::OperatorsMap& operators_map,
const details::TensorsMap& tensors_map, FlatBufferBuilder* builder,
- std::set<int32_t>* variable_tensor_indices) {
+ std::set<int32_t>* variable_tensor_indices, const ExportParams& params) {
variable_tensor_indices->clear();
// The operators are in execution order, so we just follow tf.mini order.
@@ -260,7 +280,8 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
outputs.push_back(tensors_map.at(output));
}
- int op_index = operators_map.at(GetOperatorKey(*op, ops_by_type));
+ int op_index = operators_map.at(
+ GetOperatorKey(*op, ops_by_type, params.allow_eager_ops));
auto tflite_op_it = ops_by_type.find(op->type);
BaseOperator* tflite_op = tflite_op_it == ops_by_type.end()
@@ -311,14 +332,15 @@ Offset<Vector<Offset<Buffer>>> ExportBuffers(
return builder->CreateVector(buffer_vector);
}
-void Export(const Model& model, bool allow_custom_ops,
- string* output_file_contents) {
- const auto ops_by_type = BuildOperatorByTypeMap();
- Export(model, allow_custom_ops, output_file_contents, ops_by_type);
+void Export(const Model& model, string* output_file_contents,
+ const ExportParams& params) {
+ const auto ops_by_type = BuildOperatorByTypeMap(params.allow_eager_ops);
+ Export(model, output_file_contents, params, ops_by_type);
}
void Export(
- const Model& model, bool allow_custom_ops, string* output_file_contents,
+ const Model& model, string* output_file_contents,
+ const ExportParams& params,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
flatbuffers::FlatBufferBuilder builder(/*initial_size=*/10240);
@@ -326,7 +348,8 @@ void Export(
details::LoadTensorsMap(model, &tensors_map);
details::OperatorsMap operators_map;
- details::LoadOperatorsMap(model, &operators_map, ops_by_type);
+ details::LoadOperatorsMap(model, &operators_map, ops_by_type,
+ params.allow_eager_ops);
std::vector<const Array*> buffers_to_write;
Array empty_array;
@@ -334,7 +357,7 @@ void Export(
std::set<string> error_summary;
auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map,
- &builder, &error_summary);
+ &builder, &error_summary, params);
for (const auto& op : model.operators) {
if (op->type == OperatorType::kFakeQuant) {
@@ -344,7 +367,7 @@ void Export(
"for --std_values and --mean_values.";
}
}
- if (!allow_custom_ops && !error_summary.empty()) {
+ if (!params.allow_custom_ops && !error_summary.empty()) {
// Remove ExpandDims and ReorderAxes from unimplemented list unless they
// compose the list. Both ops are removed during graph transformations.
// However, if an op is unimplemented earlier in the model, the graph
@@ -365,14 +388,14 @@ void Export(
"the standard TensorFlow Lite runtime. If you have a custom "
"implementation for them you can disable this error with "
"--allow_custom_ops, or by setting allow_custom_ops=True "
- "when calling tf.contrib.lite.toco_convert(). Here is a list "
+ "when calling tf.contrib.lite.TocoConverter(). Here is a list "
"of operators for which you will need custom implementations: "
<< absl::StrJoin(error_summary_final, ", ") << ".";
}
std::set<int32_t> variable_tensor_indices;
auto ops = ExportOperators(model, ops_by_type, operators_map, tensors_map,
- &builder, &variable_tensor_indices);
+ &builder, &variable_tensor_indices, params);
auto tensors = ExportTensors(model, tensors_map, &builder, &buffers_to_write,
variable_tensor_indices);
@@ -390,9 +413,24 @@ void Export(
CreateModel(builder, TFLITE_SCHEMA_VERSION, op_codes,
builder.CreateVector(subgraphs), description, buffers);
::tflite::FinishModelBuffer(builder, new_model_location);
- const uint8_t* buffer = builder.GetBufferPointer();
- int size = builder.GetSize();
- *output_file_contents = string(reinterpret_cast<const char*>(buffer), size);
+
+ if (params.quantize_weights) {
+ // Call the quantize_weights tool.
+ LOG(INFO) << "Quantizing TFLite model after conversion to flatbuffer. "
+ "dump_graphviz will only output the model before this "
+ "transformation. To visualize the output graph use "
+ "lite/tools/optimize.py.";
+ flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240);
+ const uint8_t* buffer = builder.GetBufferPointer();
+ const ::tflite::Model* input_model = ::tflite::GetModel(buffer);
+ if (::tflite::optimize::QuantizeWeights(&q_builder, input_model) !=
+ kTfLiteOk) {
+ LOG(QFATAL) << "Quantize weights transformation failed.";
+ }
+ WriteModelToString(q_builder, output_file_contents);
+ } else {
+ WriteModelToString(builder, output_file_contents);
+ }
}
} // namespace tflite
diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h
index 58ea5c725c..b070a38768 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.h
+++ b/tensorflow/contrib/lite/toco/tflite/export.h
@@ -23,22 +23,55 @@ namespace toco {
namespace tflite {
+// The parameters for exporting a TFLite model.
+struct ExportParams {
+ bool allow_custom_ops = false;
+ bool allow_eager_ops = false;
+ bool quantize_weights = false;
+};
+
// Transform the given tf.mini model into a TF Lite flatbuffer and deposit the
// result in the given string.
-void Export(const Model& model, bool allow_custom_ops,
- string* output_file_contents);
-
-// This if backward-compatibility.
-// TODO(ycling): Remove the deprecated entry functions.
-inline void Export(const Model& model, string* output_file_contents) {
- Export(model, true, output_file_contents);
-}
+void Export(const Model& model, string* output_file_contents,
+ const ExportParams& params);
// Export API with custom TFLite operator mapping.
void Export(
- const Model& model, bool allow_custom_ops, string* output_file_contents,
+ const Model& model, string* output_file_contents,
+ const ExportParams& params,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type);
+// This is for backward-compatibility.
+// TODO(ycling): Remove the deprecated entry functions.
+inline void Export(const Model& model, bool allow_custom_ops,
+ bool quantize_weights, string* output_file_contents) {
+ ExportParams params;
+ params.allow_custom_ops = allow_custom_ops;
+ params.quantize_weights = quantize_weights;
+ Export(model, output_file_contents, params);
+}
+
+// This is for backward-compatibility.
+// TODO(ycling): Remove the deprecated entry functions.
+inline void Export(
+ const Model& model, bool allow_custom_ops, bool quantize_weights,
+ string* output_file_contents,
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
+ ExportParams params;
+ params.allow_custom_ops = allow_custom_ops;
+ params.quantize_weights = quantize_weights;
+ Export(model, output_file_contents, params, ops_by_type);
+}
+
+// This is for backward-compatibility.
+// TODO(ycling): Remove the deprecated entry functions.
+inline void Export(const Model& model, string* output_file_contents) {
+ ExportParams params;
+ params.allow_custom_ops = true;
+ Export(model, output_file_contents, params);
+ Export(model, true, false, output_file_contents);
+}
+
namespace details {
// A maps from tensor name to its final position in the TF Lite buffer.
@@ -87,7 +120,8 @@ using OperatorsMap = std::unordered_map<OperatorKey, int, OperatorKey::Hash>;
void LoadTensorsMap(const Model& model, TensorsMap* tensors_map);
void LoadOperatorsMap(
const Model& model, OperatorsMap* operators_map,
- const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type);
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
+ bool allow_eager_ops);
} // namespace details
} // namespace tflite
diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc
index a95937ba0f..8d4d197c46 100644
--- a/tensorflow/contrib/lite/toco/tflite/export_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc
@@ -52,6 +52,42 @@ class ExportTest : public ::testing::Test {
input_model_.operators.emplace_back(new SubOperator);
}
+ void BuildQuantizableTestModel() {
+ input_model_.GetOrCreateArray("inputs");
+ Array& weight_array = input_model_.GetOrCreateArray("weights");
+
+ // Make the buffer large enough for QuantizeWeights transformation to take
+ // effect.
+ int buf_size = 1296;
+ auto weight_buf = absl::make_unique<float[]>(buf_size);
+ for (int i = 0; i < buf_size; i++) {
+ // Fill the array with some garbage values.
+ weight_buf[i] = static_cast<float>(i % 128);
+ }
+
+ weight_array.data_type = ArrayDataType::kFloat;
+
+ // Initialize shape for the input array.
+ Shape* weight_array_shape = weight_array.mutable_shape();
+ std::vector<int>* weight_array_shape_dim =
+ weight_array_shape->mutable_dims();
+ weight_array_shape_dim->resize(4, 6);
+ auto& weight_array_buffer =
+ weight_array.GetMutableBuffer<ArrayDataType::kFloat>();
+ weight_array_buffer.data.resize(buf_size);
+ float* buf_ptr =
+ weight_array.GetMutableBuffer<ArrayDataType::kFloat>().data.data();
+ std::copy(weight_buf.get(), weight_buf.get() + buf_size, buf_ptr);
+
+ {
+ auto* op = new ConvOperator;
+ op->padding.type = PaddingType::kSame;
+ op->inputs = {"inputs", "weights"};
+ input_model_.operators.emplace_back(op);
+ }
+ input_model_.operators.emplace_back(new AddOperator);
+ }
+
Model input_model_;
};
@@ -69,7 +105,8 @@ TEST_F(ExportTest, LoadOperatorsMap) {
details::OperatorsMap operators;
const auto ops_by_type = BuildOperatorByTypeMap();
- details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+ // TODO(ycling): Add a test for allow_eager_ops.
+ details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "", 1)]);
EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "", 1)]);
EXPECT_EQ(2, operators[details::OperatorKey(OperatorType::kSub, "", 1)]);
@@ -81,7 +118,7 @@ TEST_F(ExportTest, Export) {
BuildTestModel();
string result;
- Export(input_model_, true, &result);
+ Export(input_model_, true, false, &result);
auto* model = ::tflite::GetModel(result.data());
@@ -108,6 +145,20 @@ TEST_F(ExportTest, Export) {
EXPECT_THAT(indices, ElementsAre(1, 0, 3, 2));
}
+TEST_F(ExportTest, QuantizeWeights) {
+ // Sanity check for quantize_weights parameter.
+ BuildQuantizableTestModel();
+ string unquantized_result;
+ Export(input_model_, true, /*quantize_weights*/ false, &unquantized_result);
+
+ BuildQuantizableTestModel();
+ string quantized_result;
+ Export(input_model_, true, /*quantize_weights*/ true, &quantized_result);
+
+ // The quantized models should be smaller.
+ EXPECT_LT(quantized_result.size(), unquantized_result.size());
+}
+
// This test is based on a hypothetical scenario that dilation is supported
// only in Conv version 2. So Toco populates version=1 when dialation
// parameters are all 1, and version=2 otehrwise.
@@ -203,7 +254,7 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV1) {
details::OperatorsMap operators;
const auto ops_by_type = BuildFakeOperatorByTypeMap();
- details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+ details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
EXPECT_EQ(1, operators.size());
EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1)));
@@ -214,7 +265,7 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV2) {
details::OperatorsMap operators;
const auto ops_by_type = BuildFakeOperatorByTypeMap();
- details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+ details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
EXPECT_EQ(1, operators.size());
EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 2)));
@@ -226,7 +277,7 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithBothVersions) {
details::OperatorsMap operators;
const auto ops_by_type = BuildFakeOperatorByTypeMap();
- details::LoadOperatorsMap(input_model_, &operators, ops_by_type);
+ details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
EXPECT_EQ(2, operators.size());
EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1)));
@@ -239,7 +290,7 @@ TEST_F(VersionedOpExportTest, Export) {
string result;
const auto ops_by_type = BuildFakeOperatorByTypeMap();
- Export(input_model_, true, &result, ops_by_type);
+ Export(input_model_, true, false, &result, ops_by_type);
auto* model = ::tflite::GetModel(result.data());
auto operator_codes = model->operator_codes();
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index dcb5fff39f..eb0f7c443a 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -769,7 +769,7 @@ class Sum
};
class ReduceMax
- : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions,
+ : public BuiltinOperator<TensorFlowMaxOperator, ::tflite::ReducerOptions,
::tflite::BuiltinOptions_ReducerOptions> {
public:
using BuiltinOperator::BuiltinOperator;
@@ -788,7 +788,7 @@ class ReduceMax
};
class ReduceMin
- : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions,
+ : public BuiltinOperator<TensorFlowMinOperator, ::tflite::ReducerOptions,
::tflite::BuiltinOptions_ReducerOptions> {
public:
using BuiltinOperator::BuiltinOperator;
@@ -807,7 +807,26 @@ class ReduceMin
};
class ReduceProd
- : public BuiltinOperator<TensorFlowSumOperator, ::tflite::ReducerOptions,
+ : public BuiltinOperator<TensorFlowProdOperator, ::tflite::ReducerOptions,
+ ::tflite::BuiltinOptions_ReducerOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateReducerOptions(*builder, op.keep_dims);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->keep_dims = options.keep_dims();
+ }
+
+ int GetVersion(const Operator& op) const override { return 1; }
+};
+
+class ReduceAny
+ : public BuiltinOperator<TensorFlowAnyOperator, ::tflite::ReducerOptions,
::tflite::BuiltinOptions_ReducerOptions> {
public:
using BuiltinOperator::BuiltinOperator;
@@ -1110,9 +1129,29 @@ class CTCBeamSearchDecoder
int GetVersion(const Operator& op) const override { return 1; }
};
+class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
+ ::tflite::BuiltinOptions_UnpackOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateUnpackOptions(*builder, op.num, op.axis);
+ }
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->num = options.num();
+ op->axis = options.axis();
+ }
+
+ int GetVersion(const Operator& op) const override { return 1; }
+};
+
class TensorFlowUnsupported : public BaseOperator {
public:
- using BaseOperator::BaseOperator;
+ TensorFlowUnsupported(const string& name, OperatorType type,
+ bool allow_eager_ops)
+ : BaseOperator(name, type), allow_eager_ops_(allow_eager_ops) {}
Options Serialize(const Operator& op,
flatbuffers::FlatBufferBuilder* builder) const override {
@@ -1128,6 +1167,9 @@ class TensorFlowUnsupported : public BaseOperator {
std::unique_ptr<Operator> Deserialize(
const BuiltinOptions* builtin_options,
const CustomOptions* custom_options) const override {
+ // Deserializing Eager ops doesn't work now.
+ // TODO(ycling): Revisit and decide if we should fix the flow for importing
+ // TFLite models with Eager ops.
auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
if (custom_options) {
auto flexbuffer_map =
@@ -1148,6 +1190,16 @@ class TensorFlowUnsupported : public BaseOperator {
return std::unique_ptr<flexbuffers::Builder>();
}
+ if (allow_eager_ops_) {
+ fbb->Vector([&]() {
+ fbb->String(node_def.op());
+ fbb->String(op.tensorflow_node_def);
+ });
+ fbb->Finish();
+ LOG(INFO) << "Writing eager op: " << node_def.op();
+ return std::unique_ptr<flexbuffers::Builder>(fbb.release());
+ }
+
bool has_valid_attr = false;
size_t map_start = fbb->StartMap();
for (const auto& pair : node_def.attr()) {
@@ -1248,11 +1300,15 @@ class TensorFlowUnsupported : public BaseOperator {
// custom ops.
return 1;
}
+
+ private:
+ const bool allow_eager_ops_;
};
namespace {
// Build a vector containing all the known operators.
-std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
+std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList(
+ bool allow_eager_ops = false) {
std::vector<std::unique_ptr<BaseOperator>> ops;
using tensorflow::MakeUnique;
// Builtin Operators.
@@ -1318,6 +1374,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
OperatorType::kReduceMax));
ops.push_back(MakeUnique<ReduceMin>(::tflite::BuiltinOperator_REDUCE_MIN,
OperatorType::kReduceMin));
+ ops.push_back(MakeUnique<ReduceAny>(::tflite::BuiltinOperator_REDUCE_ANY,
+ OperatorType::kAny));
ops.push_back(
MakeUnique<ResizeBilinear>(::tflite::BuiltinOperator_RESIZE_BILINEAR,
OperatorType::kResizeBilinear));
@@ -1353,14 +1411,16 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
MakeUnique<Pack>(::tflite::BuiltinOperator_PACK, OperatorType::kPack));
ops.push_back(MakeUnique<OneHot>(::tflite::BuiltinOperator_ONE_HOT,
OperatorType::kOneHot));
+ ops.push_back(MakeUnique<Unpack>(::tflite::BuiltinOperator_UNPACK,
+ OperatorType::kUnpack));
// Custom Operators.
ops.push_back(
MakeUnique<DepthToSpace>("DEPTH_TO_SPACE", OperatorType::kDepthToSpace));
ops.push_back(MakeUnique<CTCBeamSearchDecoder>(
"CTC_BEAM_SEARCH_DECODER", OperatorType::kCTCBeamSearchDecoder));
- ops.push_back(MakeUnique<TensorFlowUnsupported>("TENSORFLOW_UNSUPPORTED",
- OperatorType::kUnsupported));
+ ops.push_back(MakeUnique<TensorFlowUnsupported>(
+ "TENSORFLOW_UNSUPPORTED", OperatorType::kUnsupported, allow_eager_ops));
// There operators are supported by Toco, but not by TF Lite, and has no
// attributes.
@@ -1417,6 +1477,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
"LOGICAL_AND", OperatorType::kLogicalAnd));
ops.emplace_back(new SimpleOperator<LogicalNotOperator>(
"LOGICAL_NOT", OperatorType::kLogicalNot));
+ ops.emplace_back(new SimpleOperator<FloorDivOperator>(
+ "FLOOR_DIV", OperatorType::kFloorDiv));
// Element-wise operator
ops.push_back(
MakeUnique<SimpleOperator<SinOperator>>("SIN", OperatorType::kSin));
@@ -1431,10 +1493,12 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
}
} // namespace
-std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap() {
+std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
+ bool allow_eager_ops) {
std::map<OperatorType, std::unique_ptr<BaseOperator>> result;
- std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList();
+ std::vector<std::unique_ptr<BaseOperator>> ops =
+ BuildOperatorList(allow_eager_ops);
for (auto& op : ops) {
result[op->type()] = std::move(op);
}
@@ -1442,10 +1506,12 @@ std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap() {
return result;
}
-std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap() {
+std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
+ bool allow_eager_ops) {
std::map<string, std::unique_ptr<BaseOperator>> result;
- std::vector<std::unique_ptr<BaseOperator>> ops = BuildOperatorList();
+ std::vector<std::unique_ptr<BaseOperator>> ops =
+ BuildOperatorList(allow_eager_ops);
for (auto& op : ops) {
result[op->name()] = std::move(op);
}
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h
index d9ea23edf2..702fb28ea6 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.h
+++ b/tensorflow/contrib/lite/toco/tflite/operator.h
@@ -26,11 +26,15 @@ namespace tflite {
class BaseOperator;
// Return a map contained all know TF Lite Operators, keyed by their names.
-std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap();
+// TODO(ycling): The pattern to propagate parameters (e.g. allow_eager_ops)
+// is ugly here. Consider refactoring.
+std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
+ bool allow_eager_ops = false);
// Return a map contained all know TF Lite Operators, keyed by the type of
// their tf.mini counterparts.
-std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap();
+std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
+ bool allow_eager_ops = false);
// These are the flatbuffer types for custom and builtin options.
using CustomOptions = flatbuffers::Vector<uint8_t>;
diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
index fc854461b4..519a3a4e01 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc
@@ -97,6 +97,16 @@ class OperatorTest : public ::testing::Test {
ASSERT_NE(nullptr, output_toco_op.get());
}
+
+ template <typename T>
+ void CheckReducerOperator(const string& name, OperatorType type) {
+ T op;
+
+ op.keep_dims = false;
+
+ auto output_toco_op = SerializeAndDeserialize(GetOperator(name, type), op);
+ EXPECT_EQ(op.keep_dims, output_toco_op->keep_dims);
+ }
};
TEST_F(OperatorTest, SimpleOperators) {
@@ -133,6 +143,7 @@ TEST_F(OperatorTest, SimpleOperators) {
OperatorType::kLogicalAnd);
CheckSimpleOperator<LogicalNotOperator>("LOGICAL_NOT",
OperatorType::kLogicalNot);
+ CheckSimpleOperator<FloorDivOperator>("FLOOR_DIV", OperatorType::kFloorDiv);
}
TEST_F(OperatorTest, BuiltinAdd) {
@@ -144,13 +155,16 @@ TEST_F(OperatorTest, BuiltinAdd) {
output_toco_op->fused_activation_function);
}
-TEST_F(OperatorTest, BuiltinMean) {
- MeanOperator op;
- op.keep_dims = false;
-
- auto output_toco_op =
- SerializeAndDeserialize(GetOperator("MEAN", OperatorType::kMean), op);
- EXPECT_EQ(op.keep_dims, output_toco_op->keep_dims);
+TEST_F(OperatorTest, BuiltinReducerOps) {
+ CheckReducerOperator<MeanOperator>("MEAN", OperatorType::kMean);
+ CheckReducerOperator<TensorFlowSumOperator>("SUM", OperatorType::kSum);
+ CheckReducerOperator<TensorFlowProdOperator>("REDUCE_PROD",
+ OperatorType::kReduceProd);
+ CheckReducerOperator<TensorFlowMaxOperator>("REDUCE_MAX",
+ OperatorType::kReduceMax);
+ CheckReducerOperator<TensorFlowMinOperator>("REDUCE_MIN",
+ OperatorType::kReduceMin);
+ CheckReducerOperator<TensorFlowAnyOperator>("REDUCE_ANY", OperatorType::kAny);
}
TEST_F(OperatorTest, BuiltinCast) {
@@ -476,6 +490,16 @@ TEST_F(OperatorTest, BuiltinOneHot) {
EXPECT_EQ(op.axis, output_toco_op->axis);
}
+TEST_F(OperatorTest, BuiltinUnpack) {
+ UnpackOperator op;
+ op.num = 5;
+ op.axis = 2;
+ auto output_toco_op =
+ SerializeAndDeserialize(GetOperator("UNPACK", OperatorType::kUnpack), op);
+ EXPECT_EQ(op.num, output_toco_op->num);
+ EXPECT_EQ(op.axis, output_toco_op->axis);
+}
+
TEST_F(OperatorTest, CustomCTCBeamSearchDecoder) {
CTCBeamSearchDecoderOperator op;
op.beam_width = 3;
diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
index c6d0a03452..b6aebc0470 100644
--- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
@@ -160,10 +160,18 @@ bool ParseTocoFlagsFromCommandLineFlags(
"Ignored if the output format is not TFLite."),
Flag("quantize_weights", parsed_flags.quantize_weights.bind(),
parsed_flags.quantize_weights.default_value(),
- "Store weights as quantized weights followed by dequantize "
- "operations. Computation is still done in float, but reduces model "
- "size (at the cost of accuracy and latency)."),
- };
+ "Deprecated. Please use --post_training_quantize instead."),
+ Flag("post_training_quantize", parsed_flags.post_training_quantize.bind(),
+ parsed_flags.post_training_quantize.default_value(),
+ "Boolean indicating whether to quantize the weights of the "
+ "converted float model. Model size will be reduced and there will "
+ "be latency improvements (at the cost of accuracy)."),
+ // WARNING: Experimental interface, subject to change
+ Flag("allow_eager_ops", parsed_flags.allow_eager_ops.bind(),
+ parsed_flags.allow_eager_ops.default_value(), ""),
+ // WARNING: Experimental interface, subject to change
+ Flag("force_eager_ops", parsed_flags.force_eager_ops.bind(),
+ parsed_flags.force_eager_ops.default_value(), "")};
bool asked_for_help =
*argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
if (asked_for_help) {
@@ -257,6 +265,17 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
READ_TOCO_FLAG(dedupe_array_min_size_bytes, FlagRequirement::kNone);
READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone);
READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone);
+ READ_TOCO_FLAG(post_training_quantize, FlagRequirement::kNone);
+ READ_TOCO_FLAG(allow_eager_ops, FlagRequirement::kNone);
+ READ_TOCO_FLAG(force_eager_ops, FlagRequirement::kNone);
+
+ if (parsed_toco_flags.force_eager_ops.value() &&
+ !parsed_toco_flags.allow_eager_ops.value()) {
+ // TODO(ycling): Consider to enforce `allow_eager_ops` when
+ // `force_eager_ops` is true.
+ LOG(WARNING) << "--force_eager_ops should always be used with "
+ "--allow_eager_ops.";
+ }
// Deprecated flag handling.
if (parsed_toco_flags.input_type.specified()) {
@@ -291,9 +310,19 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
toco_flags->set_inference_input_type(input_type);
}
if (parsed_toco_flags.quantize_weights.value()) {
- QCHECK_NE(toco_flags->inference_type(), IODataType::QUANTIZED_UINT8)
- << "quantize_weights is not supported with inference_type "
- "QUANTIZED_UINT8.";
+ LOG(WARNING)
+ << "--quantize_weights is deprecated. Falling back to "
+ "--post_training_quantize. Please switch --post_training_quantize.";
+ toco_flags->set_post_training_quantize(
+ parsed_toco_flags.quantize_weights.value());
+ }
+ if (parsed_toco_flags.quantize_weights.value()) {
+ if (toco_flags->inference_type() == IODataType::QUANTIZED_UINT8) {
+ LOG(WARNING)
+ << "--post_training_quantize quantizes a graph of inference_type "
+ "FLOAT. Overriding inference type QUANTIZED_UINT8 to FLOAT.";
+ toco_flags->set_inference_type(IODataType::FLOAT);
+ }
}
#undef READ_TOCO_FLAG
diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto
index b4a9870d58..53d60fed05 100644
--- a/tensorflow/contrib/lite/toco/toco_flags.proto
+++ b/tensorflow/contrib/lite/toco/toco_flags.proto
@@ -37,7 +37,7 @@ enum FileFormat {
// of as properties of models, instead describing how models are to be
// processed in the context of the present tooling job.
//
-// Next ID to use: 26.
+// Next ID to use: 29.
message TocoFlags {
// Input file format
optional FileFormat input_format = 1;
@@ -173,6 +173,7 @@ message TocoFlags {
// Store weights as quantized weights followed by dequantize operations.
// Computation is still done in float, but reduces model size (at the cost of
// accuracy and latency).
+ // DEPRECATED: Please use post_training_quantize instead.
optional bool quantize_weights = 20 [default = false];
// Full filepath of folder to dump the graphs at various stages of processing
@@ -183,4 +184,22 @@ message TocoFlags {
// Boolean indicating whether to dump the graph after every graph
// transformation.
optional bool dump_graphviz_include_video = 25;
+
+ // Boolean indicating whether to quantize the weights of the converted float
+ // model. Model size will be reduced and there will be latency improvements
+ // (at the cost of accuracy).
+ optional bool post_training_quantize = 26 [default = false];
+
+ // When enabled, unsupported ops will be converted to TFLite Eager ops.
+ // TODO(ycling): Consider to rename the following 2 flags and don't call it
+ // "Eager".
+ // `allow_eager_ops` should always be used with `allow_custom_ops`.
+ // WARNING: Experimental interface, subject to change
+ optional bool allow_eager_ops = 27 [default = false];
+
+ // When enabled, all TensorFlow ops will be converted to TFLite Eager
+ // ops directly. This will force `allow_eager_ops` to true.
+ // `force_eager_ops` should always be used with `allow_eager_ops`.
+ // WARNING: Experimental interface, subject to change
+ optional bool force_eager_ops = 28 [default = false];
}
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index 34130a02b0..a7c17156b1 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -197,6 +197,10 @@ std::unique_ptr<Model> Import(const TocoFlags& toco_flags,
toco_flags.has_drop_control_dependency()
? toco_flags.drop_control_dependency()
: (toco_flags.output_format() != TENSORFLOW_GRAPHDEF);
+
+ tf_import_flags.import_all_ops_as_unsupported =
+ toco_flags.force_eager_ops();
+
model = ImportTensorFlowGraphDef(model_flags, tf_import_flags,
input_file_contents);
break;
@@ -281,12 +285,6 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
RunGraphTransformations(model, "general graph transformations",
transformations);
- if (toco_flags.quantize_weights()) {
- // Run the quantize weights transformation after batchnorms have been
- // folded into the weights.
- RunGraphTransformations(model, "quantize weights transformation",
- {new QuantizeWeights});
- }
if (quantize_output) {
if (toco_flags.propagate_fake_quant_num_bits()) {
RunGraphTransformations(model,
@@ -403,9 +401,21 @@ void Export(const TocoFlags& toco_flags, const Model& model,
case TENSORFLOW_GRAPHDEF:
ExportTensorFlowGraphDef(model, output_file_contents);
break;
- case TFLITE:
- toco::tflite::Export(model, allow_custom_ops, output_file_contents);
- break;
+ case TFLITE: {
+ toco::tflite::ExportParams params;
+
+ // Always allow custom ops when eager ops are allowed.
+ if (toco_flags.force_eager_ops() || toco_flags.allow_eager_ops()) {
+ params.allow_eager_ops = true;
+ params.allow_custom_ops = true;
+ } else if (allow_custom_ops) {
+ params.allow_custom_ops = true;
+ }
+
+ params.quantize_weights = toco_flags.post_training_quantize();
+
+ toco::tflite::Export(model, output_file_contents, params);
+ } break;
case GRAPHVIZ_DOT:
DumpGraphviz(model, output_file_contents);
break;
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 3a4542f522..6ab93d9316 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -405,6 +405,7 @@ const char* OperatorTypeName(OperatorType type) {
HANDLE_OPERATORTYPENAME_CASE(LogicalNot)
HANDLE_OPERATORTYPENAME_CASE(LogicalOr)
HANDLE_OPERATORTYPENAME_CASE(CTCBeamSearchDecoder)
+ HANDLE_OPERATORTYPENAME_CASE(Unpack)
default:
LOG(FATAL) << "Unhandled op type";
#undef HANDLE_OPERATORTYPENAME_CASE
diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h
index bdeb203024..5f4b8cb66a 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.h
+++ b/tensorflow/contrib/lite/toco/tooling_util.h
@@ -28,6 +28,7 @@ limitations under the License.
#if TOCO_SUPPORT_PORTABLE_PROTOS
#include "third_party/protobuf/include/google/protobuf/text_format.h"
#endif // TOCO_SUPPORT_PORTABLE_PROTOS
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
#include "tensorflow/contrib/lite/toco/runtime/types.h"
@@ -139,6 +140,10 @@ bool ShapesAgreeUpToBroadcasting(const Shape& shape0, const Shape& shape1);
// - For the remaining indices [0..i0), d0[i0] == 1.
bool ShapesAgreeUpToExtending(const Shape& shape0, const Shape& shape1);
+inline ::tflite::RuntimeShape ToRuntimeShape(const Shape& shape) {
+ return ::tflite::RuntimeShape(shape.dimensions_count(), shape.dims().data());
+}
+
bool IsArrayFullyConnectedWeights(const Model& model, const string& name);
// If there is a wildcard dimension (-1), this may return a negative value.
diff --git a/tensorflow/contrib/lite/tools/accuracy/BUILD b/tensorflow/contrib/lite/tools/accuracy/BUILD
new file mode 100644
index 0000000000..1b60d6a60d
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/BUILD
@@ -0,0 +1,328 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "tflite_linkopts")
+load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite")
+
+common_linkopts = tflite_linkopts() + select({
+ "//conditions:default": [],
+ "//tensorflow:android": [
+ "-pie",
+ "-llog",
+ ],
+})
+
+cc_library(
+ name = "utils",
+ srcs = ["utils.cc"],
+ hdrs = ["utils.h"],
+ copts = tflite_copts(),
+ deps = [
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework",
+ ],
+ },
+ ),
+)
+
+tf_cc_test(
+ name = "utils_test",
+ srcs = ["utils_test.cc"],
+ args = [
+ "--test_model_file=$(location //tensorflow/contrib/lite:testdata/multi_add.bin)",
+ ],
+ data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"],
+ linkopts = common_linkopts,
+ linkstatic = 1,
+ tags = [
+ "tflite_not_portable_android",
+ "tflite_not_portable_ios",
+ ],
+ deps = [
+ ":utils",
+ "@com_google_googletest//:gtest",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_tensorflow_test_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+ },
+ ),
+)
+
+cc_library(
+ name = "run_tflite_model_op",
+ srcs = ["run_tflite_model_op.cc"],
+ copts = tflite_copts(),
+ deps = [
+ ":utils",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:ops",
+ ],
+ },
+ ),
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "android_required_build_flags",
+ srcs = ["android_required_build_flags.cc"],
+ copts = tflite_copts(),
+)
+
+tf_cc_test(
+ name = "run_tflite_model_op_test",
+ srcs = ["run_tflite_model_op_test.cc"],
+ args = [
+ "--test_model_file=$(location //tensorflow/contrib/lite:testdata/multi_add.bin)",
+ ],
+ data = ["//tensorflow/contrib/lite:testdata/multi_add.bin"],
+ linkopts = common_linkopts,
+ linkstatic = 1,
+ tags = [
+ "tflite_not_portable_android",
+ "tflite_not_portable_ios",
+ ],
+ deps = [
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:scope",
+ ":run_tflite_model_op",
+ ":android_required_build_flags",
+ "@com_google_googletest//:gtest",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_tensorflow_test_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:ops",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow",
+ ],
+ },
+ ),
+)
+
+cc_library(
+ name = "stage",
+ hdrs = ["stage.h"],
+ copts = tflite_copts(),
+ deps = [
+ "//tensorflow/cc:scope",
+ ],
+)
+
+cc_library(
+ name = "file_reader_stage",
+ srcs = ["file_reader_stage.cc"],
+ hdrs = ["file_reader_stage.h"],
+ deps = [
+ ":stage",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:scope",
+ ],
+)
+
+tf_cc_test(
+ name = "file_reader_stage_test",
+ srcs = ["file_reader_stage_test.cc"],
+ linkopts = common_linkopts,
+ linkstatic = 1,
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":file_reader_stage",
+ "@com_google_googletest//:gtest",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core/kernels:android_whole_file_read_ops",
+ "//tensorflow/core:android_tensorflow_test_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:tensorflow",
+ ],
+ },
+ ),
+)
+
+cc_library(
+ name = "run_tflite_model_stage",
+ srcs = ["run_tflite_model_stage.cc"],
+ hdrs = ["run_tflite_model_stage.h"],
+ copts = tflite_copts(),
+ deps = [
+ ":run_tflite_model_op",
+ ":stage",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:scope",
+ ],
+)
+
+cc_library(
+ name = "accuracy_eval_stage",
+ hdrs = ["accuracy_eval_stage.h"],
+ copts = tflite_copts(),
+ deps = [
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework",
+ ],
+ },
+ ),
+)
+
+cc_library(
+ name = "eval_pipeline",
+ srcs = ["eval_pipeline.cc"],
+ hdrs = ["eval_pipeline.h"],
+ copts = tflite_copts(),
+ deps = [
+ ":accuracy_eval_stage",
+ ":stage",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:core_cpu",
+ ],
+ },
+ ),
+)
+
+tf_cc_test(
+ name = "eval_pipeline_test",
+ srcs = ["eval_pipeline_test.cc"],
+ linkopts = common_linkopts,
+ linkstatic = 1,
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":eval_pipeline",
+ "//tensorflow/cc:cc_ops",
+ "@com_google_googletest//:gtest",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_tensorflow_test_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:ops",
+ "//tensorflow/core:tensorflow",
+ ],
+ },
+ ),
+)
+
+cc_library(
+ name = "eval_pipeline_builder",
+ srcs = ["eval_pipeline_builder.cc"],
+ hdrs = ["eval_pipeline_builder.h"],
+ copts = tflite_copts(),
+ deps = [
+ ":eval_pipeline",
+ ":accuracy_eval_stage",
+ ":stage",
+ "@com_google_absl//absl/memory",
+ "//tensorflow/cc:cc_ops",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:ops",
+ "//tensorflow/core:tensorflow",
+ ],
+ },
+ ),
+)
+
+tf_cc_test(
+ name = "eval_pipeline_builder_test",
+ srcs = ["eval_pipeline_builder_test.cc"],
+ linkopts = common_linkopts,
+ linkstatic = 1,
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":eval_pipeline_builder",
+ "//tensorflow/cc:cc_ops",
+ "@com_google_googletest//:gtest",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_tensorflow_test_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:ops",
+ "//tensorflow/core:tensorflow",
+ ],
+ },
+ ),
+)
+
+cc_library(
+ name = "csv_writer",
+ hdrs = ["csv_writer.h"],
+ copts = tflite_copts(),
+ deps = select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:lib",
+ ],
+ },
+ ),
+)
+
+tflite_portable_test_suite()
diff --git a/tensorflow/contrib/lite/tools/accuracy/README.md b/tensorflow/contrib/lite/tools/accuracy/README.md
new file mode 100644
index 0000000000..8100cd1e8c
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/README.md
@@ -0,0 +1,38 @@
+## TFLite accuracy library.
+
+This library provides evaluation pipelines that can be used to evaluate
+accuracy and other metrics of a model. The resulting binary can be run on
+a desktop or on a mobile device.
+
+## Usage
+The tool provides an evaluation pipeline with different stages. Each
+stage outputs a Tensorflow graph.
+A sample usage is shown below.
+
+```C++
+// First build the pipeline.
+EvalPipelineBuilder builder;
+std::unique_ptr<EvalPipeline> eval_pipeline;
+auto status = builder.WithInput("pipeline_input", DT_FLOAT)
+ .WithInputStage(&input_stage)
+ .WithRunModelStage(&run_model_stage)
+ .WithPreprocessingStage(&preprocess_stage)
+ .WithAccuracyEval(&eval)
+ .Build(scope, &eval_pipeline);
+TF_CHECK_OK(status);
+
+// Now run the pipeline with inputs and outputs.
+std::unique_ptr<Session> session(NewSession(SessionOptions()));
+TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session)));
+Tensor input = ... read input for the model ...
+Tensor ground_truth = ... read ground truth for the model ...
+TF_CHECK_OK(eval_pipeline.Run(input1, ground_truth1));
+```
+For further examples, check the usage in [imagenet accuracy evaluation binary](ilsvrc/imagenet_model_evaluator.cc)
+
+## Measuring accuracy of published models.
+
+### ILSVRC (Imagenet Large Scale Visual Recognition Contest) classification task
+For measuring accuracy for [ILSVRC 2012 image classification task](http://www.image-net.org/challenges/LSVRC/2012/), the binary can be built
+using these
+[instructions.](ilsvrc/)
diff --git a/tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h b/tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h
new file mode 100644
index 0000000000..9cb843729a
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ACCURACY_EVAL_STAGE_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ACCURACY_EVAL_STAGE_H_
+
+#include <vector>
+
+#include "tensorflow/core/framework/tensor.h"
+
+namespace tensorflow {
+namespace metrics {
+
+// Base class for evaluation stage that evaluates the accuracy of the model.
+// This stage calculates the accuracy metrics given the model outputs and
+// expected ground truth.
+class AccuracyEval {
+ public:
+ AccuracyEval() = default;
+ AccuracyEval(const AccuracyEval&) = delete;
+ AccuracyEval& operator=(const AccuracyEval&) = delete;
+
+ AccuracyEval(const AccuracyEval&&) = delete;
+ AccuracyEval& operator=(const AccuracyEval&&) = delete;
+
+ virtual ~AccuracyEval() = default;
+
+ // Evaluates the accuracy of the model for given `model_outputs` and the
+ // `ground truth`.
+ // Derived classes can do additional book keeping, calculate aggregrate
+ // statistics etc for the given model.
+ virtual Status ComputeEval(const std::vector<Tensor>& model_outputs,
+ const Tensor& ground_truth) = 0;
+};
+} // namespace metrics
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ACCURACY_EVAL_STAGE_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/android_required_build_flags.cc b/tensorflow/contrib/lite/tools/accuracy/android_required_build_flags.cc
new file mode 100644
index 0000000000..7fa8986716
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/android_required_build_flags.cc
@@ -0,0 +1,27 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tensorflow on Android requires selective registration to be enabled in order
+// for certain types (e.g. DT_UINT8) to work.
+// Checks below ensure that for Android build, the right flags are passed to
+// the compiler.
+
+#if defined(__ANDROID__) && (!defined(__ANDROID_TYPES_FULL__) || \
+ !defined(SUPPORT_SELECTIVE_REGISTRATION))
+#error \
+ "Binary needs custom kernel support. For enabling custom kernels on " \
+ "Android, please pass -D__ANDROID_TYPES_FULL__ && " \
+ "-DSUPPORT_SELECTIVE_REGISTRATION for including the kernel in the binary."
+#endif
diff --git a/tensorflow/contrib/lite/tools/accuracy/csv_writer.h b/tensorflow/contrib/lite/tools/accuracy/csv_writer.h
new file mode 100644
index 0000000000..806b0d9418
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/csv_writer.h
@@ -0,0 +1,79 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_CSV_WRITER_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_CSV_WRITER_H_
+
+#include <fstream>
+#include <vector>
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace metrics {
+// A simple CSV writer that writes values of same type for fixed number of
+// columns. This supports a very limited set of CSV spec and doesn't do any
+// escaping.
+// Usage:
+// std::ofstream * output_stream = ...
+// CSVWriter writer({"column1", "column2"}, output_stream);
+// writer.WriteRow({4, 5});
+// writer.Flush(); // flush results immediately.
+class CSVWriter {
+ public:
+ CSVWriter(const std::vector<string>& columns, std::ofstream* output_stream)
+ : num_columns_(columns.size()), output_stream_(output_stream) {
+ TF_CHECK_OK(WriteRow(columns, output_stream_));
+ }
+
+ template <typename T>
+ Status WriteRow(const std::vector<T>& values) {
+ if (values.size() != num_columns_) {
+ return errors::InvalidArgument("Invalid size for row:", values.size(),
+ " expected: ", num_columns_);
+ }
+ return WriteRow(values, output_stream_);
+ }
+
+ void Flush() { output_stream_->flush(); }
+
+ ~CSVWriter() { output_stream_->flush(); }
+
+ private:
+ template <typename T>
+ static Status WriteRow(const std::vector<T>& values,
+ std::ofstream* output_stream) {
+ bool first = true;
+ for (const auto& v : values) {
+ if (!first) {
+ (*output_stream) << ", ";
+ } else {
+ first = false;
+ }
+ (*output_stream) << v;
+ }
+ (*output_stream) << "\n";
+ if (!output_stream->good()) {
+ return errors::Internal("Writing to stream failed.");
+ }
+ return Status::OK();
+ }
+ const size_t num_columns_;
+ std::ofstream* output_stream_;
+};
+} // namespace metrics
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_CSV_WRITER_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.cc b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.cc
new file mode 100644
index 0000000000..a03aba6a26
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.cc
@@ -0,0 +1,39 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h"
+
+namespace tensorflow {
+namespace metrics {
+
+Status EvalPipeline::AttachSession(std::unique_ptr<Session> session) {
+ session_ = std::move(session);
+ TF_RETURN_IF_ERROR(session_->Create(model_graph_));
+ return Status::OK();
+}
+
+Status EvalPipeline::Run(const Tensor& input, const Tensor& ground_truth) {
+ if (session_ == nullptr) {
+ return errors::Internal("No session is associated with the graph.");
+ }
+ std::vector<Tensor> outputs;
+ TF_RETURN_IF_ERROR(session_->Run({{params_.model_input_node_name, input}},
+ {params_.model_output_node_name}, {},
+ &outputs));
+ TF_RETURN_IF_ERROR(eval_->ComputeEval(outputs, ground_truth));
+ return Status::OK();
+}
+} // namespace metrics
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h
new file mode 100644
index 0000000000..c9cfc86613
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h
@@ -0,0 +1,87 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_H_
+
+#include <string>
+
+#include "tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h"
+#include "tensorflow/contrib/lite/tools/accuracy/stage.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+namespace metrics {
+
+// Pipeline for evaluating a model.
+// Runs the graph and passes the output of graph to
+// the provided instance of AccuracyEval.
+// Example usage:
+// AccuracyEval *eval;
+// GraphDef graph_def;
+// ... populate graph_def...
+//
+// EvalPipeline eval_pipeline(&graph_def,
+// {.model_input_node_name = "model_input",
+// .model_output_node_name = "model_output"},
+// eval);
+// std::unique_ptr<Session> session(NewSession(SessionOptions()));
+// TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session)));
+// Tensor input = ... read input for the model ...
+// Tensor ground_truth = ... read ground truth for the model ...
+// TF_CHECK_OK(eval_pipeline.Run(input, ground_truth));
+//
+class EvalPipeline {
+ public:
+ struct Params {
+ string model_input_node_name;
+ string model_output_node_name;
+ };
+
+ // Creates a new `EvalPipeline` object. The ownership of the `accuracy_eval`
+ // is retained by the caller. Lifetime of `accuracy_eval` instance should
+ // be longer than the lifetime of this instance of pipeline.
+ EvalPipeline(const GraphDef& graph, const Params& params,
+ AccuracyEval* accuracy_eval)
+ : model_graph_(graph),
+ params_(params),
+ eval_(accuracy_eval),
+ session_(nullptr) {}
+
+ EvalPipeline(const EvalPipeline&) = delete;
+ EvalPipeline& operator=(const EvalPipeline&) = delete;
+
+ EvalPipeline(const EvalPipeline&&) = delete;
+ EvalPipeline& operator=(const EvalPipeline&&) = delete;
+
+ // Attaches the given session to this instance of pipeline.
+ // The provided session object will be reused for subsequent calls to
+ // EvalPipeline::Run.
+ Status AttachSession(std::unique_ptr<Session> session);
+
+ // Runs the model by feeding `input` and then passes the output of the model
+ // along with provided `ground_truth` to the AccuracyEval instance by calling
+ // AccuracyEval::ComputeEval.
+ Status Run(const Tensor& input, const Tensor& ground_truth);
+
+ private:
+ GraphDef model_graph_;
+ Params params_;
+ AccuracyEval* eval_;
+ std::unique_ptr<Session> session_;
+};
+} // namespace metrics
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.cc b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.cc
new file mode 100644
index 0000000000..2e16437e15
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.cc
@@ -0,0 +1,100 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h"
+
+#include "absl/memory/memory.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+
+namespace tensorflow {
+namespace metrics {
+
+EvalPipelineBuilder& EvalPipelineBuilder::WithInputStage(Stage* input_stage) {
+ input_stage_ = input_stage;
+ return *this;
+}
+
+EvalPipelineBuilder& EvalPipelineBuilder::WithPreprocessingStage(
+ Stage* preprocessing_stage) {
+ preprocessing_stage_ = preprocessing_stage;
+ return *this;
+}
+
+EvalPipelineBuilder& EvalPipelineBuilder::WithRunModelStage(
+ Stage* run_model_stage) {
+ run_model_stage_ = run_model_stage;
+ return *this;
+}
+
+EvalPipelineBuilder& EvalPipelineBuilder::WithAccuracyEval(
+ AccuracyEval* accuracy_eval) {
+ accuracy_eval_ = accuracy_eval;
+ return *this;
+}
+
+EvalPipelineBuilder& EvalPipelineBuilder::WithInput(const string& input_name,
+ DataType input_type) {
+ input_name_ = input_name;
+ input_type_ = input_type;
+ return *this;
+}
+
+Status EvalPipelineBuilder::Build(
+ const Scope& scope, std::unique_ptr<EvalPipeline>* eval_pipeline) {
+ if (input_stage_ == nullptr) {
+ return errors::InvalidArgument("Input stage is null.");
+ }
+ if (preprocessing_stage_ == nullptr) {
+ return errors::InvalidArgument("Preprocessing stage is null.");
+ }
+ if (run_model_stage_ == nullptr) {
+ return errors::InvalidArgument("Run model stage is null.");
+ }
+ if (accuracy_eval_ == nullptr) {
+ return errors::InvalidArgument("accuracy_eval is null.");
+ }
+ if (input_name_.empty()) {
+ return errors::InvalidArgument("input name is not set.");
+ }
+ if (input_type_ == DT_INVALID) {
+ return errors::InvalidArgument("input type is not set.");
+ }
+
+ auto input_placeholder =
+ ops::Placeholder(scope.WithOpName(input_name_), input_type_);
+ TF_RETURN_IF_ERROR(scope.status());
+
+ input_stage_->AddToGraph(scope, input_placeholder);
+ TF_RETURN_IF_ERROR(scope.status());
+
+ preprocessing_stage_->AddToGraph(scope, input_stage_->Output());
+ TF_RETURN_IF_ERROR(scope.status());
+
+ run_model_stage_->AddToGraph(scope, preprocessing_stage_->Output());
+ TF_RETURN_IF_ERROR(scope.status());
+
+ GraphDef graph_def;
+ TF_RETURN_IF_ERROR(scope.ToGraphDef(&graph_def));
+ EvalPipeline::Params params;
+ params.model_input_node_name = input_name_;
+ params.model_output_node_name = run_model_stage_->output_name();
+ *eval_pipeline =
+ absl::make_unique<EvalPipeline>(graph_def, params, accuracy_eval_);
+
+ return Status::OK();
+}
+
+} // namespace metrics
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h
new file mode 100644
index 0000000000..692db022f8
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h
@@ -0,0 +1,99 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_BUILDER_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_BUILDER_H_
+
+#include <memory>
+#include <string>
+
+#include "tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h"
+#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h"
+#include "tensorflow/contrib/lite/tools/accuracy/stage.h"
+
+namespace tensorflow {
+namespace metrics {
+
+// A builder to simplify construction of an `EvalPipeline` instance.
+// The `Build` method creates an |EvalPipeline| with the following structure:
+// |input| -> |input_stage|
+// |--> |preprocessing_stage|
+// |--> |run_model_stage| -> |accuracy_eval_stage|.
+// The stages are chained in the order shown above. Any missing stage results in
+// an error. The ownership of the stage object is retained by the caller. Stage
+// objects need to exist until the |Build| method is called.
+//
+// Currently only single inputs are supported.
+//
+// Example Usage:
+// EvalPipelineBuilder builder;
+// std::unique_ptr<EvalPipeline> eval_pipeline;
+// auto status = builder.WithInput("pipeline_input", DT_FLOAT)
+// .WithInputStage(&input_stage)
+// .WithRunModelStage(&run_model_stage)
+// .WithPreprocessingStage(&preprocess_stage)
+// .WithAccuracyEval(&eval)
+// .Build(scope, &eval_pipeline);
+// TF_CHECK_OK(status);
+class EvalPipelineBuilder {
+ public:
+ EvalPipelineBuilder() = default;
+ EvalPipelineBuilder(const EvalPipelineBuilder&) = delete;
+ EvalPipeline& operator=(const EvalPipelineBuilder&) = delete;
+
+ EvalPipelineBuilder(const EvalPipelineBuilder&&) = delete;
+ EvalPipeline& operator=(const EvalPipelineBuilder&&) = delete;
+
+ // Sets the input stage for the pipeline.
+ // Input stage converts the input, say filename into appropriate format
+ // that can be consumed by the preprocessing stage.
+ EvalPipelineBuilder& WithInputStage(Stage* input_stage);
+
+ // Sets the preprocessing stage for the pipeline.
+ // Preprocessing stage converts the input into a format that can be used to
+ // run the model.
+ EvalPipelineBuilder& WithPreprocessingStage(Stage* preprocessing_stage);
+
+ // Sets the run model stage for the pipeline.
+ // This stage receives the preprocessing input and output of this stage is
+ // fed to the accuracy eval stage.
+ EvalPipelineBuilder& WithRunModelStage(Stage* run_model_stage);
+
+ // Sets the accuracy eval for the pipeline.
+ // Results of evaluating the pipeline are fed to the `accuracy_eval` instance.
+ EvalPipelineBuilder& WithAccuracyEval(AccuracyEval* accuracy_eval);
+
+ // Sets the name and type of input for the pipeline.
+ // TODO(shashishekhar): Support multiple inputs for the pipeline, use a vector
+ // here.
+ EvalPipelineBuilder& WithInput(const string& input_name, DataType input_type);
+
+ // Builds the pipeline and assigns the pipeline to `eval_pipeline`.
+ // If the pipeline creation fails `eval_pipeline` is untouched.
+ Status Build(const Scope& scope,
+ std::unique_ptr<EvalPipeline>* eval_pipeline);
+
+ private:
+ Stage* input_stage_ = nullptr;
+ Stage* preprocessing_stage_ = nullptr;
+ Stage* run_model_stage_ = nullptr;
+ AccuracyEval* accuracy_eval_ = nullptr;
+ string input_name_;
+ DataType input_type_ = DT_INVALID;
+};
+
+} // namespace metrics
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_EVAL_PIPELINE_BUILDER_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder_test.cc b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder_test.cc
new file mode 100644
index 0000000000..2d41929b79
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder_test.cc
@@ -0,0 +1,229 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h"
+#include <gtest/gtest.h>
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+namespace metrics {
+namespace {
+
+class IdentityStage : public Stage {
+ public:
+ IdentityStage(const string& name, const string& output)
+ : name_(name), output_(output) {}
+
+ void AddToGraph(const Scope& scope, const Input& input) override {
+ called_count_++;
+ inputs_.push_back(input.node()->name());
+ stage_output_ = ops::Identity(scope.WithOpName(output_), input);
+ }
+
+ string name() const override { return name_; }
+ string output_name() const override { return output_; }
+
+ int times_called() const { return called_count_; }
+
+ const std::vector<string> input_params() { return inputs_; }
+
+ private:
+ string name_;
+ string output_;
+ int called_count_ = 0;
+ std::vector<string> inputs_;
+};
+
+class FailingStage : public Stage {
+ public:
+ FailingStage(const string& name, const string& output)
+ : name_(name), output_(output) {}
+
+ void AddToGraph(const Scope& scope, const Input& input) override {
+ called_count_++;
+ scope.UpdateStatus(errors::Internal("Stage failed:", name_));
+ }
+
+ string name() const override { return name_; }
+ string output_name() const override { return output_; }
+
+ int times_called() const { return called_count_; }
+
+ private:
+ string name_;
+ string output_;
+ int called_count_ = 0;
+};
+
+class SimpleAccuracyEval : public AccuracyEval {
+ public:
+ SimpleAccuracyEval() {}
+
+ Status ComputeEval(const std::vector<Tensor>& model_outputs,
+ const Tensor& ground_truth) override {
+ return Status::OK();
+ }
+};
+
+TEST(EvalPipelineBuilder, MissingPipelineStages) {
+ IdentityStage input_stage("input_stage", "input_stage_out");
+ IdentityStage run_model_stage("run_model", "run_model_out");
+ IdentityStage preprocess_stage("preprocess_stage", "preprocess_stage_out");
+ const string pipeline_input = "pipeline_input";
+
+ SimpleAccuracyEval eval;
+
+ Scope scope = Scope::NewRootScope();
+ std::unique_ptr<EvalPipeline> eval_pipeline;
+ EvalPipelineBuilder builder;
+ auto status =
+ builder.WithInputStage(&input_stage).Build(scope, &eval_pipeline);
+ EXPECT_FALSE(status.ok());
+ EXPECT_FALSE(eval_pipeline);
+
+ status =
+ builder.WithRunModelStage(&run_model_stage).Build(scope, &eval_pipeline);
+ EXPECT_FALSE(status.ok());
+ EXPECT_FALSE(eval_pipeline);
+
+ status = builder.WithPreprocessingStage(&preprocess_stage)
+ .Build(scope, &eval_pipeline);
+ EXPECT_FALSE(status.ok());
+ EXPECT_FALSE(eval_pipeline);
+
+ status =
+ builder.WithInput(pipeline_input, DT_FLOAT).Build(scope, &eval_pipeline);
+ EXPECT_FALSE(status.ok());
+ EXPECT_FALSE(eval_pipeline);
+
+ status = builder.WithAccuracyEval(&eval).Build(scope, &eval_pipeline);
+ TF_CHECK_OK(status);
+ EXPECT_TRUE(eval_pipeline);
+}
+
+TEST(EvalPipeline, InputStageFailure) {
+ FailingStage input_stage("input_stage", "input_stage_out");
+ IdentityStage run_model_stage("run_model", "run_model_out");
+ IdentityStage preprocess_stage("preprocess_stage", "preprocess_stage_out");
+ const string pipeline_input = "pipeline_input";
+
+ SimpleAccuracyEval eval;
+
+ Scope scope = Scope::NewRootScope();
+ std::unique_ptr<EvalPipeline> eval_pipeline;
+ EvalPipelineBuilder builder;
+ auto status = builder.WithInputStage(&input_stage)
+ .WithRunModelStage(&run_model_stage)
+ .WithPreprocessingStage(&preprocess_stage)
+ .WithInput(pipeline_input, DT_FLOAT)
+ .WithAccuracyEval(&eval)
+ .Build(scope, &eval_pipeline);
+
+ EXPECT_FALSE(scope.status().ok());
+ // None of the other stages would have been called.
+ EXPECT_EQ(1, input_stage.times_called());
+ EXPECT_EQ(0, preprocess_stage.times_called());
+ EXPECT_EQ(0, run_model_stage.times_called());
+}
+
+TEST(EvalPipeline, PreprocessingFailure) {
+ IdentityStage input_stage("input_stage", "input_stage_out");
+ FailingStage preprocess_stage("preprocess_stage", "preprocess_stage_out");
+ IdentityStage run_model_stage("run_model", "run_model_out");
+ const string pipeline_input = "pipeline_input";
+
+ SimpleAccuracyEval eval;
+
+ Scope scope = Scope::NewRootScope();
+ std::unique_ptr<EvalPipeline> eval_pipeline;
+ EvalPipelineBuilder builder;
+ auto status = builder.WithInputStage(&input_stage)
+ .WithRunModelStage(&run_model_stage)
+ .WithPreprocessingStage(&preprocess_stage)
+ .WithInput(pipeline_input, DT_FLOAT)
+ .WithAccuracyEval(&eval)
+ .Build(scope, &eval_pipeline);
+
+ EXPECT_FALSE(status.ok());
+ // None of the other stages would have been called.
+ EXPECT_EQ(1, input_stage.times_called());
+ EXPECT_EQ(1, preprocess_stage.times_called());
+ EXPECT_EQ(0, run_model_stage.times_called());
+}
+
+TEST(EvalPipeline, GraphEvalFailure) {
+ IdentityStage input_stage("input_stage", "input_stage_out");
+ IdentityStage preprocess_stage("preprocess_stage", "preprocess_stage_out");
+ FailingStage run_model_stage("run_model", "run_model_out");
+ const string pipeline_input = "pipeline_input";
+
+ SimpleAccuracyEval eval;
+
+ Scope scope = Scope::NewRootScope();
+ std::unique_ptr<EvalPipeline> eval_pipeline;
+ EvalPipelineBuilder builder;
+ auto status = builder.WithInputStage(&input_stage)
+ .WithRunModelStage(&run_model_stage)
+ .WithPreprocessingStage(&preprocess_stage)
+ .WithInput(pipeline_input, DT_FLOAT)
+ .WithAccuracyEval(&eval)
+ .Build(scope, &eval_pipeline);
+
+ EXPECT_FALSE(status.ok());
+ // None of the other stages would have been called.
+ EXPECT_EQ(1, input_stage.times_called());
+ EXPECT_EQ(1, preprocess_stage.times_called());
+ EXPECT_EQ(1, run_model_stage.times_called());
+}
+
+TEST(EvalPipeline, PipelineHasCorrectSequence) {
+ IdentityStage input_stage("input_stage", "input_stage_out");
+ IdentityStage preprocess_stage("preprocess_stage", "preprocess_stage_out");
+ IdentityStage run_model_stage("run_model", "run_model_out");
+ const string pipeline_input = "pipeline_input";
+
+ SimpleAccuracyEval eval;
+
+ Scope scope = Scope::NewRootScope();
+ std::unique_ptr<EvalPipeline> eval_pipeline;
+ EvalPipelineBuilder builder;
+ auto status = builder.WithInputStage(&input_stage)
+ .WithRunModelStage(&run_model_stage)
+ .WithPreprocessingStage(&preprocess_stage)
+ .WithInput(pipeline_input, DT_FLOAT)
+ .WithAccuracyEval(&eval)
+ .Build(scope, &eval_pipeline);
+ TF_CHECK_OK(status);
+
+ ASSERT_EQ(1, input_stage.times_called());
+ ASSERT_EQ(1, run_model_stage.times_called());
+ ASSERT_EQ(1, preprocess_stage.times_called());
+
+ EXPECT_EQ(pipeline_input, input_stage.input_params()[0]);
+ EXPECT_EQ(input_stage.output_name(), preprocess_stage.input_params()[0]);
+ EXPECT_EQ(preprocess_stage.output_name(), run_model_stage.input_params()[0]);
+}
+
+} // namespace
+
+} // namespace metrics
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_test.cc b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_test.cc
new file mode 100644
index 0000000000..ea0f6e19df
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/eval_pipeline_test.cc
@@ -0,0 +1,133 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h"
+#include <gtest/gtest.h>
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+namespace metrics {
+namespace {
+
+Tensor CreateFloatTensor(float value) {
+ Tensor tensor(DT_FLOAT, TensorShape({}));
+ tensor.scalar<float>()() = value;
+ return tensor;
+}
+
+class NoOpAccuracyEval : public AccuracyEval {
+ public:
+ explicit NoOpAccuracyEval(const Status& status_to_return)
+ : status_to_return_(status_to_return) {}
+
+ Status ComputeEval(const std::vector<Tensor>& model_outputs,
+ const Tensor& ground_truth) override {
+ model_outputs_ = model_outputs;
+ ground_truth_ = ground_truth;
+ was_called_ = true;
+ return status_to_return_;
+ }
+
+ bool WasCalled() { return was_called_; }
+ std::vector<Tensor> model_outputs() { return model_outputs_; }
+ Tensor ground_truth() { return ground_truth_; }
+
+ private:
+ std::vector<Tensor> model_outputs_;
+ Tensor ground_truth_;
+ Status status_to_return_;
+ bool was_called_ = false;
+};
+
+TEST(EvalPipeline, AccuracyEvalIsCalled) {
+ Scope scope = Scope::NewRootScope();
+ // A graph that adds 1 to input.
+ auto input = ops::Placeholder(scope.WithOpName("input"), DT_FLOAT);
+ auto add_node = ops::Add(scope.WithOpName("output"), input, 1.0f);
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+ EvalPipeline::Params params;
+ params.model_input_node_name = "input";
+ params.model_output_node_name = "output";
+ NoOpAccuracyEval accuracy_eval(Status::OK());
+
+ EvalPipeline eval_pipeline(graph_def, params, &accuracy_eval);
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session)));
+ TF_CHECK_OK(eval_pipeline.Run(CreateFloatTensor(5), CreateFloatTensor(27)));
+
+ EXPECT_TRUE(accuracy_eval.WasCalled());
+ auto outputs = accuracy_eval.model_outputs();
+ ASSERT_EQ(1, outputs.size());
+ EXPECT_EQ(6.0f, outputs[0].scalar<float>()());
+ // Ground truth is unchanged.
+ EXPECT_EQ(27, accuracy_eval.ground_truth().scalar<float>()());
+}
+
+TEST(EvalPipeline, EvalIsNotCalledOnGraphRunFailure) {
+ Scope scope = Scope::NewRootScope();
+ // A graph that adds 1 to input.
+ auto input = ops::Placeholder(scope.WithOpName("input"), DT_FLOAT);
+ auto add_node = ops::Add(scope.WithOpName("output"), input, 1.0f);
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+ EvalPipeline::Params params;
+ params.model_input_node_name = "input";
+ params.model_output_node_name = "output";
+ NoOpAccuracyEval accuracy_eval(Status::OK());
+
+ EvalPipeline eval_pipeline(graph_def, params, &accuracy_eval);
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session)));
+
+ // Pass a string tensor instead of a float tensor.
+ Tensor string_tensor(DT_STRING, TensorShape{});
+ auto status = eval_pipeline.Run(string_tensor, CreateFloatTensor(27));
+ EXPECT_FALSE(accuracy_eval.WasCalled());
+ EXPECT_FALSE(status.ok());
+}
+
+TEST(EvalPipeline, AccuracyEvalFailureResultsInFailure) {
+ Scope scope = Scope::NewRootScope();
+ // A graph that adds 1 to input.
+ auto input = ops::Placeholder(scope.WithOpName("input"), DT_FLOAT);
+ auto add_node = ops::Add(scope.WithOpName("output"), input, 1.0f);
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+ EvalPipeline::Params params;
+ params.model_input_node_name = "input";
+ params.model_output_node_name = "output";
+ NoOpAccuracyEval accuracy_eval(errors::Internal("accuracy_fail"));
+
+ EvalPipeline eval_pipeline(graph_def, params, &accuracy_eval);
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(eval_pipeline.AttachSession(std::move(session)));
+ auto status = eval_pipeline.Run(CreateFloatTensor(5), CreateFloatTensor(27));
+
+ EXPECT_TRUE(accuracy_eval.WasCalled());
+ EXPECT_FALSE(status.ok());
+}
+
+} // namespace
+
+} // namespace metrics
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.cc b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.cc
new file mode 100644
index 0000000000..61bed369f8
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.cc
@@ -0,0 +1,29 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h"
+
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+
+namespace tensorflow {
+namespace metrics {
+void FileReaderStage::AddToGraph(const Scope& scope, const Input& input) {
+ if (!scope.ok()) return;
+ Scope s = scope.WithOpName(name());
+ this->stage_output_ = ops::ReadFile(s.WithOpName(output_name()), input);
+}
+} // namespace metrics
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h
new file mode 100644
index 0000000000..18db5837c1
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h
@@ -0,0 +1,37 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_FILE_READER_STAGE_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_FILE_READER_STAGE_H_
+
+#include <string>
+
+#include "tensorflow/contrib/lite/tools/accuracy/stage.h"
+
+namespace tensorflow {
+namespace metrics {
+// A stage for reading a file into |string|.
+// Inputs: a string tensor: |file_name|.
+// Outputs: a string tensor: contents of |file_name|.
+class FileReaderStage : public Stage {
+ public:
+ string name() const override { return "stage_filereader"; }
+ string output_name() const override { return "stage_filereader_output"; }
+
+ void AddToGraph(const Scope& scope, const Input& input) override;
+};
+} // namespace metrics
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_FILE_READER_STAGE_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/file_reader_stage_test.cc b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage_test.cc
new file mode 100644
index 0000000000..a75f99187d
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/file_reader_stage_test.cc
@@ -0,0 +1,110 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <cstdio>
+#include <fstream>
+#include <memory>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+namespace metrics {
+namespace {
+
+class TempFile {
+ public:
+ TempFile() {
+ string file_path;
+ if (Env::Default()->LocalTempFilename(&file_path)) {
+ file_path_ = file_path;
+ created_ = true;
+ }
+ }
+
+ string filepath() { return file_path_; }
+ bool CreateFileWithContents(const std::string& contents) {
+ if (!created_) {
+ return false;
+ }
+ std::fstream file(file_path_, std::ios_base::out);
+ if (file) {
+ file << contents;
+ }
+ return file.good();
+ }
+
+ ~TempFile() {
+ if (created_) {
+ std::remove(file_path_.c_str());
+ }
+ }
+
+ private:
+ bool created_ = false;
+ string file_path_;
+};
+
+TEST(FileReaderStageTest, FileIsRead) {
+ TempFile file;
+ const string kFileContents = "Hello world.";
+ ASSERT_TRUE(file.CreateFileWithContents(kFileContents));
+ Scope scope = Scope::NewRootScope();
+ FileReaderStage reader_stage;
+ reader_stage.AddToGraph(scope, file.filepath());
+ TF_CHECK_OK(scope.status());
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(session->Create(graph_def));
+ std::vector<Tensor> outputs;
+ auto run_status =
+ session->Run({}, /*inputs*/
+ {reader_stage.output_name()}, {}, /*target node names */
+ &outputs);
+ TF_CHECK_OK(run_status);
+ EXPECT_EQ(1, outputs.size());
+ string contents = outputs[0].scalar<string>()();
+ EXPECT_EQ(kFileContents, contents);
+}
+
+TEST(FileReaderStageTest, InvalidFile) {
+ Scope scope = Scope::NewRootScope();
+ FileReaderStage reader_stage;
+ reader_stage.AddToGraph(scope, string("non_existent_file"));
+ TF_CHECK_OK(scope.status());
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(session->Create(graph_def));
+ std::vector<Tensor> outputs;
+ auto run_status =
+ session->Run({}, /*inputs*/
+ {reader_stage.output_name()}, {}, /*target node names */
+ &outputs);
+ EXPECT_FALSE(run_status.ok());
+}
+
+} // namespace
+
+} // namespace metrics
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD
new file mode 100644
index 0000000000..98e2835b2e
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/BUILD
@@ -0,0 +1,182 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "tflite_linkopts")
+load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite")
+
+common_linkopts = tflite_linkopts() + select({
+ "//conditions:default": [],
+ "//tensorflow:android": [
+ "-pie",
+ "-llog",
+ ],
+})
+
+cc_library(
+ name = "inception_preprocessing",
+ srcs = ["inception_preprocessing.cc"],
+ hdrs = ["inception_preprocessing.h"],
+ copts = tflite_copts(),
+ deps = [
+ "//tensorflow/contrib/lite/tools/accuracy:android_required_build_flags",
+ "//tensorflow/contrib/lite/tools/accuracy:stage",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:scope",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core/kernels:android_tensorflow_image_op",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:ops",
+ ],
+ },
+ ),
+)
+
+tf_cc_test(
+ name = "inception_preprocessing_test",
+ srcs = ["inception_preprocessing_test.cc"],
+ args = [
+ "--test_image=$(location :testdata/grace_hopper.jpg)",
+ ],
+ data = [":testdata/grace_hopper.jpg"],
+ linkopts = common_linkopts,
+ linkstatic = 1,
+ tags = [
+ "no_oss", # b/114307765
+ "tflite_not_portable_android",
+ "tflite_not_portable_ios",
+ ],
+ deps = [
+ ":inception_preprocessing",
+ "//tensorflow/contrib/lite/tools/accuracy:android_required_build_flags",
+ "@com_google_googletest//:gtest",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_tensorflow_test_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+ },
+ ),
+)
+
+cc_library(
+ name = "imagenet_topk_eval",
+ srcs = ["imagenet_topk_eval.cc"],
+ hdrs = ["imagenet_topk_eval.h"],
+ copts = tflite_copts(),
+ deps = [
+ "//tensorflow/contrib/lite/tools/accuracy:accuracy_eval_stage",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ ],
+ },
+ ),
+)
+
+tf_cc_test(
+ name = "imagenet_topk_eval_test",
+ srcs = ["imagenet_topk_eval_test.cc"],
+ linkopts = common_linkopts,
+ linkstatic = 1,
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":imagenet_topk_eval",
+ "@com_google_googletest//:gtest",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core:android_tensorflow_test_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:framework",
+ ],
+ },
+ ),
+)
+
+cc_library(
+ name = "imagenet_model_evaluator",
+ srcs = ["imagenet_model_evaluator.cc"],
+ hdrs = ["imagenet_model_evaluator.h"],
+ copts = tflite_copts(),
+ deps = [
+ ":imagenet_topk_eval",
+ ":inception_preprocessing",
+ "//tensorflow/contrib/lite/tools/accuracy:android_required_build_flags",
+ "//tensorflow/contrib/lite/tools/accuracy:eval_pipeline",
+ "//tensorflow/contrib/lite/tools/accuracy:eval_pipeline_builder",
+ "//tensorflow/contrib/lite/tools/accuracy:file_reader_stage",
+ "//tensorflow/contrib/lite/tools/accuracy:run_tflite_model_stage",
+ "//tensorflow/contrib/lite/tools/accuracy:utils",
+ "@com_google_absl//absl/memory",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:scope",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ "//tensorflow/core/kernels:android_whole_file_read_ops",
+ "//tensorflow/core/kernels:android_tensorflow_image_op",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:core_cpu",
+ ],
+ },
+ ),
+)
+
+tf_cc_binary(
+ name = "imagenet_accuracy_eval",
+ srcs = ["imagenet_accuracy_eval.cc"],
+ copts = tflite_copts(),
+ linkopts = common_linkopts,
+ deps = [
+ ":imagenet_model_evaluator",
+ ":imagenet_topk_eval",
+ "@com_google_absl//absl/memory",
+ "//tensorflow/contrib/lite/tools/accuracy:android_required_build_flags",
+ "//tensorflow/contrib/lite/tools/accuracy:csv_writer",
+ ] + select(
+ {
+ "//tensorflow:android": [
+ "//tensorflow/core:android_tensorflow_lib",
+ ],
+ "//conditions:default": [
+ "//tensorflow/core:lib",
+ "//tensorflow/core:framework_internal",
+ ],
+ },
+ ),
+)
+
+tflite_portable_test_suite()
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md
new file mode 100644
index 0000000000..362ea3ac34
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/README.md
@@ -0,0 +1,146 @@
+## Accuracy evaluation for ILSVRC 2012 (Imagenet Large Scale Visual Recognition Challenge) image classification task
+
+This binary can evaluate the accuracy of TFLite models trained for the [ILSVRC 2012 image classification task]
+(http://www.image-net.org/challenges/LSVRC/2012/).
+The binary takes the path to validation images and labels as inputs. It outputs the accuracy after running the TFLite model on the validation sets.
+
+To run the binary download the ILSVRC 2012 devkit [see instructions](#downloading-ilsvrc) and run the [`generate_validation_ground_truth` script](#ground-truth-label-generation) to generate the ground truth labels.
+
+## Parameters
+The binary takes the following parameters:
+
+* `model_file` : `string` \
+ Path to the TFlite model file.
+
+* `ground_truth_images_path`: `string` \
+ The path to the directory containing ground truth images.
+
+* `ground_truth_labels`: `string` \
+ Path to ground truth labels file. This file should contain the same number of labels as the number images in the ground truth directory. The labels are assumed to be in the
+ same order as the sorted filename of images. See [ground truth label generation](#ground-truth-label-generation)
+ section for more information about how to generate labels for images.
+
+* `model_output_labels`: `string` \
+ Path to the file containing labels, that is used to interpret the output of
+ the model. E.g. in case of mobilenets, this is the path to
+ `mobilenet_labels.txt` where each label is in the same order as the output
+ 1001 dimension tensor.
+
+* `output_path`: `string` \
+ This is the path to the output file. The output is a CSV file that has top-10 accuracies in each row. Each line of output file is the cumulative accuracy after processing images in a sorted order. So first line is accuracy after processing the first image, second line is accuracy after procesing first two images. The last line of the file is accuracy after processing the entire validation set.
+
+and the following optional parameters:
+
+* `blacklist_file_path`: `string` \
+ Path to blacklist file. This file contains the indices of images that are blacklisted for evaluation. 1762 images are blacklisted in ILSVRC dataset. For details please refer to readme.txt of ILSVRC2014 devkit.
+
+* `num_images`: `int` (default=0) \
+ The number of images to process, if 0, all images in the directory are processed otherwise only num_images will be processed.
+
+* `num_threads`: `int` (default=4) \
+ The number of threads to use for evaluation.
+
+
+## Downloading ILSVRC
+In order to use this tool to run evaluation on the full 50K ImageNet dataset,
+download the data set from http://image-net.org/request.
+
+## Ground truth label generation
+The ILSVRC 2012 devkit `validation_ground_truth.txt` contains IDs that correspond to synset of the image.
+The accuracy binary however expects the ground truth labels to contain the actual name of
+category instead of synset ids. A conversion script has been provided to convert the validation ground truth to
+category labels. The `validation_ground_truth.txt` can be converted by the following steps:
+
+```
+ILSVRC_2012_DEVKIT_DIR=[set to path to ILSVRC 2012 devkit]
+VALIDATION_LABELS=[set to path to output]
+
+python generate_validation_labels.py -- \
+--ilsvrc_devkit_dir=${ILSVRC_2012_DEVKIT_DIR} \
+--validation_labels_output=${VALIDATION_LABELS}
+```
+
+## Running the binary
+
+### On Android
+
+(0) Refer to https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android for configuring NDK and SDK.
+
+(1) Build using the following command:
+
+```
+bazel build -c opt \
+ --config=android_arm \
+ --config=monolithic \
+ --cxxopt='--std=c++11' \
+ --copt=-D__ANDROID_TYPES_FULL__ \
+ --copt=-DSUPPORT_SELECTIVE_REGISTRATION \
+ //tensorflow/contrib/lite/tools/accuracy/ilsvrc:imagenet_accuracy_eval
+```
+
+(2) Connect your phone. Push the binary to your phone with adb push
+ (make the directory if required):
+
+```
+adb push bazel-bin/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval /data/local/tmp
+```
+
+(3) Make the binary executable.
+
+```
+adb shell chmod +x /data/local/tmp/imagenet_accuracy_eval
+```
+
+(4) Push the TFLite model that you need to test. For example:
+
+```
+adb push mobilenet_quant_v1_224.tflite /data/local/tmp
+```
+
+(5) Push the imagenet images to device, make sure device has sufficient storage available before pushing the dataset:
+
+```
+adb shell mkdir /data/local/tmp/ilsvrc_images && \
+adb push ${IMAGENET_IMAGES_DIR} /data/local/tmp/ilsvrc_images
+```
+
+(6) Push the generated validation ground labels to device.
+
+```
+adb push ${VALIDATION_LABELS} /data/local/tmp/ilsvrc_validation_labels.txt
+```
+
+(7) Push the model labels text file to device.
+
+```
+adb push ${MODEL_LABELS_TXT} /data/local/tmp/model_output_labels.txt
+```
+
+(8) Run the binary.
+
+```
+adb shell /data/local/tmp/imagenet_accuracy_eval \
+ --model_file=/data/local/tmp/mobilenet_quant_v1_224.tflite \
+ --ground_truth_images_path=/data/local/tmp/ilsvrc_images \
+ --ground_truth_labels=/data/local/tmp/ilsvrc_validation_labels.txt \
+ --model_output_labels=/data/local/tmp/model_output_labels.txt \
+ --output_file_path=/data/local/tmp/accuracy_output.txt \
+ --num_images=0 # Run on all images.
+```
+
+### On Desktop
+
+(1) Build and run using the following command:
+
+```
+bazel run -c opt \
+ --cxxopt='--std=c++11' \
+ -- \
+ //tensorflow/contrib/lite/tools/accuracy/ilsvrc:imagenet_accuracy_eval \
+ --model_file=mobilenet_quant_v1_224.tflite \
+ --ground_truth_images_path=${IMAGENET_IMAGES_DIR} \
+ --ground_truth_labels=${VALIDATION_LABELS} \
+ --model_output_labels=${MODEL_LABELS_TXT} \
+ --output_file_path=/tmp/accuracy_output.txt \
+ --num_images=0 # Run on all images.
+```
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/clsloc_validation_blacklist.txt b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/clsloc_validation_blacklist.txt
new file mode 100644
index 0000000000..b2f00e034e
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/clsloc_validation_blacklist.txt
@@ -0,0 +1,1762 @@
+36
+50
+56
+103
+127
+195
+199
+226
+230
+235
+251
+254
+288
+397
+485
+543
+556
+601
+605
+652
+653
+663
+666
+697
+699
+705
+745
+774
+815
+816
+845
+848
+951
+977
+1006
+1008
+1018
+1056
+1066
+1079
+1102
+1128
+1133
+1188
+1193
+1194
+1266
+1271
+1372
+1382
+1405
+1426
+1430
+1441
+1477
+1502
+1518
+1606
+1621
+1642
+1658
+1716
+1722
+1734
+1750
+1807
+1880
+1882
+1936
+1951
+1970
+1977
+1983
+2086
+2112
+2146
+2152
+2217
+2304
+2321
+2404
+2526
+2554
+2563
+2647
+2675
+2732
+2733
+2827
+2839
+2854
+2865
+2872
+2880
+2886
+2893
+2915
+2973
+2993
+3019
+3020
+3044
+3047
+3049
+3117
+3167
+3197
+3201
+3282
+3311
+3315
+3344
+3345
+3378
+3425
+3477
+3497
+3514
+3525
+3531
+3587
+3637
+3650
+3657
+3686
+3720
+3732
+3798
+3802
+3823
+3847
+3971
+4007
+4059
+4072
+4087
+4099
+4124
+4126
+4156
+4195
+4197
+4241
+4275
+4321
+4333
+4352
+4356
+4368
+4377
+4428
+4440
+4497
+4509
+4513
+4526
+4528
+4565
+4570
+4596
+4633
+4677
+4696
+4743
+4759
+4778
+4835
+4976
+5032
+5058
+5061
+5066
+5140
+5145
+5177
+5197
+5219
+5226
+5228
+5240
+5289
+5292
+5385
+5433
+5445
+5448
+5465
+5488
+5549
+5553
+5609
+5638
+5666
+5683
+5711
+5729
+5760
+5793
+5819
+5837
+5855
+5858
+5961
+5966
+6048
+6197
+6199
+6201
+6206
+6215
+6220
+6264
+6278
+6280
+6305
+6388
+6411
+6466
+6490
+6509
+6523
+6529
+6625
+6754
+6818
+6886
+6890
+6893
+6902
+6912
+6942
+7067
+7141
+7144
+7214
+7217
+7278
+7312
+7320
+7329
+7342
+7345
+7369
+7408
+7428
+7463
+7556
+7557
+7582
+7613
+7621
+7624
+7647
+7671
+7679
+7734
+7736
+7747
+7750
+7777
+7851
+7854
+7883
+7889
+7902
+7985
+7999
+8070
+8087
+8096
+8100
+8128
+8180
+8195
+8367
+8377
+8465
+8497
+8508
+8528
+8538
+8581
+8657
+8692
+8742
+8784
+8839
+8861
+8912
+8970
+8982
+8987
+9103
+9155
+9180
+9248
+9284
+9300
+9357
+9382
+9414
+9450
+9463
+9493
+9522
+9543
+9563
+9630
+9643
+9653
+9693
+9747
+9787
+9847
+9851
+9892
+9913
+9929
+9965
+10026
+10027
+10055
+10154
+10189
+10243
+10297
+10337
+10346
+10347
+10377
+10403
+10483
+10518
+10540
+10559
+10567
+10568
+10580
+10606
+10615
+10618
+10645
+10685
+10707
+10710
+10807
+10837
+10856
+10873
+10989
+11046
+11054
+11132
+11163
+11218
+11243
+11255
+11265
+11292
+11306
+11307
+11310
+11343
+11349
+11407
+11411
+11422
+11427
+11431
+11439
+11496
+11644
+11662
+11690
+11692
+11725
+11743
+11767
+11812
+11867
+11871
+11897
+11975
+12001
+12046
+12076
+12119
+12158
+12216
+12252
+12261
+12264
+12293
+12296
+12306
+12357
+12358
+12371
+12415
+12422
+12472
+12497
+12499
+12538
+12540
+12544
+12569
+12645
+12647
+12652
+12699
+12727
+12750
+12832
+12849
+12873
+12889
+12902
+12996
+13029
+13065
+13073
+13075
+13079
+13268
+13338
+13372
+13529
+13530
+13537
+13623
+13626
+13637
+13644
+13646
+13681
+13778
+13782
+13805
+13846
+13853
+13881
+13914
+13961
+13975
+13979
+14011
+14135
+14143
+14144
+14161
+14170
+14207
+14212
+14215
+14260
+14311
+14368
+14373
+14400
+14509
+14523
+14566
+14594
+14628
+14629
+14633
+14649
+14652
+14705
+14709
+14732
+14734
+14802
+14834
+14865
+14883
+14933
+14965
+15003
+15100
+15159
+15178
+15272
+15289
+15308
+15319
+15327
+15353
+15357
+15363
+15408
+15429
+15438
+15469
+15485
+15495
+15501
+15524
+15530
+15551
+15598
+15613
+15614
+15631
+15646
+15647
+15661
+15679
+15684
+15758
+15775
+15826
+15838
+15840
+15931
+15940
+15969
+15976
+16003
+16037
+16045
+16116
+16200
+16233
+16247
+16339
+16340
+16345
+16361
+16400
+16408
+16430
+16468
+16474
+16500
+16521
+16565
+16569
+16584
+16613
+16645
+16662
+16671
+16719
+16724
+16760
+16764
+16805
+16849
+16893
+16896
+16954
+16979
+17023
+17026
+17034
+17038
+17049
+17054
+17061
+17073
+17074
+17133
+17163
+17176
+17177
+17217
+17237
+17246
+17298
+17312
+17324
+17337
+17365
+17415
+17442
+17449
+17576
+17578
+17581
+17588
+17589
+17591
+17593
+17605
+17661
+17688
+17689
+17695
+17697
+17703
+17736
+17746
+17758
+17788
+17798
+17828
+17841
+17884
+17898
+17924
+17956
+17960
+18001
+18013
+18025
+18052
+18097
+18106
+18158
+18211
+18223
+18240
+18261
+18266
+18297
+18325
+18329
+18335
+18340
+18351
+18433
+18462
+18466
+18524
+18569
+18581
+18631
+18696
+18748
+18766
+18787
+18793
+18950
+18961
+19001
+19008
+19011
+19154
+19177
+19217
+19255
+19286
+19320
+19333
+19360
+19403
+19407
+19419
+19464
+19499
+19510
+19519
+19555
+19564
+19605
+19610
+19689
+19699
+19705
+19707
+19725
+19732
+19741
+19774
+19799
+19838
+19877
+19903
+19940
+19945
+19952
+19973
+19987
+20024
+20086
+20111
+20114
+20174
+20193
+20201
+20245
+20299
+20329
+20439
+20485
+20534
+20562
+20575
+20578
+20601
+20604
+20605
+20648
+20658
+20665
+20677
+20693
+20697
+20699
+20791
+20794
+20808
+20876
+20890
+20906
+20914
+20990
+21065
+21128
+21144
+21151
+21156
+21175
+21199
+21204
+21207
+21225
+21236
+21241
+21342
+21351
+21429
+21533
+21550
+21622
+21676
+21727
+21764
+21785
+21822
+21830
+21845
+21853
+21867
+21909
+21910
+21923
+21924
+21937
+21948
+21955
+21962
+22008
+22017
+22026
+22037
+22072
+22075
+22135
+22138
+22160
+22167
+22190
+22287
+22375
+22440
+22457
+22460
+22471
+22481
+22484
+22488
+22515
+22553
+22679
+22703
+22714
+22730
+22735
+22752
+22768
+22809
+22813
+22817
+22846
+22902
+22910
+22944
+22986
+23026
+23053
+23065
+23088
+23117
+23124
+23126
+23132
+23142
+23165
+23172
+23223
+23264
+23280
+23322
+23335
+23439
+23453
+23455
+23474
+23501
+23518
+23580
+23589
+23608
+23614
+23641
+23649
+23660
+23698
+23728
+23766
+23809
+23859
+23874
+23902
+23946
+24040
+24105
+24132
+24137
+24151
+24153
+24157
+24171
+24271
+24281
+24296
+24303
+24308
+24328
+24332
+24338
+24402
+24440
+24453
+24466
+24504
+24531
+24543
+24547
+24556
+24562
+24610
+24649
+24660
+24693
+24706
+24745
+24834
+24948
+24963
+25056
+25057
+25083
+25093
+25120
+25150
+25161
+25197
+25219
+25220
+25253
+25257
+25290
+25327
+25332
+25344
+25387
+25390
+25422
+25453
+25481
+25489
+25587
+25599
+25600
+25622
+25681
+25686
+25702
+25708
+25740
+25776
+25870
+25918
+25973
+25978
+25986
+25987
+26033
+26038
+26041
+26087
+26113
+26155
+26162
+26184
+26235
+26299
+26301
+26318
+26364
+26383
+26430
+26511
+26528
+26561
+26618
+26653
+26688
+26697
+26778
+26940
+26951
+27023
+27029
+27037
+27046
+27051
+27118
+27244
+27252
+27258
+27272
+27283
+27303
+27381
+27392
+27403
+27422
+27437
+27440
+27476
+27493
+27494
+27501
+27506
+27550
+27559
+27571
+27581
+27596
+27604
+27612
+27665
+27687
+27701
+27711
+27732
+27759
+27766
+27772
+27797
+27813
+27854
+27864
+27865
+27879
+27894
+27907
+27958
+27963
+27969
+28003
+28027
+28032
+28051
+28058
+28079
+28093
+28120
+28132
+28194
+28227
+28324
+28328
+28331
+28360
+28373
+28419
+28431
+28436
+28451
+28467
+28471
+28527
+28541
+28588
+28640
+28649
+28662
+28670
+28678
+28722
+28768
+28780
+28835
+28863
+28879
+28885
+28928
+28948
+28954
+28963
+28969
+29020
+29065
+29077
+29105
+29117
+29143
+29166
+29172
+29299
+29302
+29342
+29357
+29378
+29410
+29411
+29414
+29415
+29447
+29473
+29488
+29499
+29505
+29533
+29537
+29601
+29637
+29650
+29667
+29671
+29681
+29686
+29708
+29721
+29749
+29755
+29771
+29853
+29886
+29894
+29919
+29928
+29990
+30008
+30064
+30067
+30107
+30150
+30160
+30164
+30186
+30195
+30219
+30243
+30282
+30314
+30324
+30389
+30418
+30497
+30550
+30592
+30615
+30624
+30640
+30650
+30695
+30720
+30741
+30750
+30751
+30767
+30830
+30856
+30885
+30901
+30907
+30953
+30985
+31005
+31027
+31034
+31045
+31057
+31071
+31109
+31119
+31227
+31230
+31250
+31303
+31320
+31371
+31401
+31440
+31447
+31464
+31478
+31487
+31494
+31525
+31553
+31554
+31558
+31572
+31588
+31639
+31641
+31683
+31698
+31704
+31708
+31717
+31722
+31781
+31786
+31788
+31791
+31803
+31850
+31853
+31862
+31886
+31901
+31944
+32020
+32048
+32052
+32073
+32094
+32116
+32147
+32180
+32212
+32218
+32256
+32270
+32305
+32411
+32414
+32430
+32465
+32484
+32534
+32584
+32589
+32608
+32612
+32613
+32615
+32641
+32674
+32697
+32708
+32757
+32763
+32796
+32824
+32861
+32877
+32944
+32945
+32946
+32984
+33004
+33012
+33029
+33050
+33090
+33096
+33097
+33124
+33139
+33161
+33170
+33173
+33179
+33191
+33293
+33367
+33370
+33371
+33373
+33399
+33415
+33436
+33440
+33443
+33488
+33551
+33563
+33564
+33629
+33643
+33664
+33685
+33696
+33714
+33722
+33728
+33764
+33809
+33868
+33883
+33913
+33942
+33956
+33994
+34081
+34089
+34091
+34098
+34178
+34207
+34269
+34287
+34348
+34392
+34445
+34447
+34455
+34529
+34579
+34591
+34643
+34659
+34692
+34729
+34758
+34836
+34857
+34862
+34883
+34930
+34942
+34957
+34963
+35003
+35089
+35180
+35187
+35209
+35220
+35239
+35247
+35253
+35263
+35380
+35393
+35394
+35408
+35452
+35485
+35486
+35557
+35578
+35639
+35663
+35688
+35746
+35832
+35862
+35890
+35903
+35917
+35929
+35946
+35984
+36060
+36084
+36090
+36124
+36135
+36151
+36197
+36249
+36269
+36303
+36364
+36377
+36398
+36402
+36418
+36421
+36435
+36499
+36511
+36521
+36544
+36556
+36601
+36627
+36640
+36660
+36673
+36676
+36787
+36790
+36797
+36821
+36840
+36901
+36921
+36934
+37006
+37041
+37051
+37112
+37160
+37167
+37213
+37231
+37242
+37274
+37313
+37332
+37391
+37416
+37522
+37594
+37621
+37664
+37699
+37731
+37915
+37968
+38030
+38070
+38117
+38128
+38135
+38172
+38184
+38224
+38277
+38295
+38311
+38428
+38464
+38529
+38549
+38599
+38623
+38673
+38681
+38713
+38722
+38726
+38762
+38867
+38872
+38944
+38947
+39015
+39023
+39028
+39043
+39068
+39080
+39097
+39118
+39171
+39197
+39236
+39254
+39271
+39277
+39280
+39336
+39338
+39340
+39341
+39358
+39364
+39497
+39503
+39537
+39541
+39559
+39560
+39562
+39596
+39600
+39613
+39623
+39656
+39670
+39781
+39810
+39832
+39861
+39875
+39892
+39918
+39919
+40008
+40016
+40082
+40091
+40095
+40164
+40213
+40234
+40274
+40279
+40324
+40332
+40341
+40349
+40365
+40438
+40446
+40482
+40501
+40510
+40516
+40541
+40544
+40545
+40574
+40617
+40659
+40668
+40742
+40754
+40758
+40764
+40765
+40795
+40858
+40901
+40985
+40986
+41080
+41112
+41121
+41136
+41196
+41199
+41219
+41233
+41246
+41278
+41376
+41401
+41409
+41434
+41470
+41492
+41502
+41517
+41571
+41572
+41608
+41648
+41699
+41773
+41779
+41801
+41837
+41843
+41849
+41855
+41873
+41881
+41901
+41924
+41926
+41935
+41962
+42008
+42062
+42069
+42072
+42094
+42097
+42104
+42112
+42117
+42137
+42147
+42170
+42185
+42224
+42237
+42250
+42254
+42257
+42276
+42282
+42298
+42321
+42351
+42372
+42378
+42420
+42446
+42453
+42466
+42470
+42502
+42514
+42518
+42527
+42662
+42721
+42727
+42743
+42794
+42840
+42843
+42871
+42872
+42897
+42950
+42956
+42967
+42969
+42975
+42995
+43005
+43008
+43046
+43052
+43091
+43103
+43124
+43198
+43225
+43228
+43385
+43394
+43402
+43405
+43408
+43423
+43503
+43529
+43557
+43647
+43656
+43704
+43706
+43714
+43745
+43748
+43759
+43812
+43927
+43950
+43997
+43998
+44016
+44018
+44025
+44060
+44066
+44099
+44128
+44149
+44150
+44169
+44184
+44198
+44254
+44272
+44293
+44310
+44352
+44389
+44399
+44400
+44442
+44451
+44470
+44474
+44522
+44569
+44590
+44713
+44738
+44787
+44823
+44829
+44845
+44895
+44918
+44975
+45024
+45121
+45148
+45154
+45179
+45208
+45210
+45215
+45218
+45220
+45235
+45265
+45282
+45283
+45285
+45286
+45303
+45351
+45359
+45396
+45407
+45414
+45472
+45519
+45522
+45564
+45621
+45641
+45660
+45678
+45695
+45696
+45710
+45780
+45800
+45823
+45828
+45862
+45947
+45964
+46001
+46050
+46084
+46113
+46132
+46146
+46198
+46221
+46234
+46236
+46256
+46272
+46298
+46325
+46337
+46347
+46374
+46386
+46388
+46437
+46491
+46560
+46561
+46589
+46600
+46656
+46660
+46664
+46673
+46690
+46700
+46808
+46809
+46828
+46918
+46963
+46979
+46984
+47005
+47088
+47097
+47100
+47143
+47147
+47261
+47320
+47369
+47450
+47503
+47533
+47538
+47576
+47601
+47608
+47618
+47621
+47624
+47659
+47681
+47698
+47708
+47745
+47817
+47826
+47879
+47883
+47917
+47937
+47957
+48000
+48023
+48076
+48099
+48130
+48133
+48281
+48298
+48321
+48349
+48351
+48353
+48358
+48371
+48426
+48455
+48522
+48526
+48544
+48573
+48606
+48609
+48646
+48667
+48699
+48701
+48740
+48773
+48777
+48785
+48847
+48886
+48940
+48986
+49029
+49054
+49100
+49121
+49137
+49157
+49191
+49222
+49291
+49315
+49347
+49374
+49376
+49381
+49407
+49427
+49481
+49497
+49624
+49785
+49791
+49835
+49875
+49877
+49981
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/generate_validation_labels.py b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/generate_validation_labels.py
new file mode 100644
index 0000000000..c32a41e50d
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/generate_validation_labels.py
@@ -0,0 +1,105 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tool to convert ILSVRC devkit validation ground truth to synset labels."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+from os import path
+import sys
+import scipy.io
+
+_SYNSET_ARRAYS_RELATIVE_PATH = 'data/meta.mat'
+_VALIDATION_FILE_RELATIVE_PATH = 'data/ILSVRC2012_validation_ground_truth.txt'
+
+
+def _synset_to_word(filepath):
+ """Returns synset to word dictionary by reading sysnset arrays."""
+ mat = scipy.io.loadmat(filepath)
+ entries = mat['synsets']
+ # These fields are listed in devkit readme.txt
+ fields = [
+ 'synset_id', 'WNID', 'words', 'gloss', 'num_children', 'children',
+ 'wordnet_height', 'num_train_images'
+ ]
+ synset_index = fields.index('synset_id')
+ words_index = fields.index('words')
+ synset_to_word = {}
+ for entry in entries:
+ entry = entry[0]
+ synset_id = int(entry[synset_index][0])
+ first_word = entry[words_index][0].split(',')[0]
+ synset_to_word[synset_id] = first_word
+ return synset_to_word
+
+
+def _validation_file_path(ilsvrc_dir):
+ return path.join(ilsvrc_dir, _VALIDATION_FILE_RELATIVE_PATH)
+
+
+def _synset_array_path(ilsvrc_dir):
+ return path.join(ilsvrc_dir, _SYNSET_ARRAYS_RELATIVE_PATH)
+
+
+def _generate_validation_labels(ilsvrc_dir, output_file):
+ synset_to_word = _synset_to_word(_synset_array_path(ilsvrc_dir))
+ with open(_validation_file_path(ilsvrc_dir), 'r') as synset_id_file, open(
+ output_file, 'w') as output:
+ for synset_id in synset_id_file:
+ synset_id = int(synset_id)
+ output.write('%s\n' % synset_to_word[synset_id])
+
+
+def _check_arguments(args):
+ if not args.validation_labels_output:
+ raise ValueError('Invalid path to output file.')
+ ilsvrc_dir = args.ilsvrc_devkit_dir
+ if not ilsvrc_dir or not path.isdir(ilsvrc_dir):
+ raise ValueError('Invalid path to ilsvrc_dir')
+ if not path.exists(_validation_file_path(ilsvrc_dir)):
+ raise ValueError('Invalid path to ilsvrc_dir, cannot find validation file.')
+ if not path.exists(_synset_array_path(ilsvrc_dir)):
+ raise ValueError(
+ 'Invalid path to ilsvrc_dir, cannot find synset arrays file.')
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description='Converts ILSVRC devkit validation_ground_truth.txt to synset'
+ ' labels file that can be used by the accuracy script.')
+ parser.add_argument(
+ '--validation_labels_output',
+ type=str,
+ help='Full path for outputting validation labels.')
+ parser.add_argument(
+ '--ilsvrc_devkit_dir',
+ type=str,
+ help='Full path to ILSVRC 2012 devikit directory.')
+ args = parser.parse_args()
+ try:
+ _check_arguments(args)
+ except ValueError as e:
+ parser.print_usage()
+ file_name = path.basename(sys.argv[0])
+ sys.stderr.write('{0}: error: {1}\n'.format(file_name, str(e)))
+ sys.exit(1)
+ _generate_validation_labels(args.ilsvrc_devkit_dir,
+ args.validation_labels_output)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc
new file mode 100644
index 0000000000..2a8a2b9b59
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_accuracy_eval.cc
@@ -0,0 +1,165 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <iomanip>
+#include <memory>
+
+#include "absl/memory/memory.h"
+#include "tensorflow/contrib/lite/tools/accuracy/csv_writer.h"
+#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h"
+#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace tensorflow {
+namespace metrics {
+
+namespace {
+
+std::vector<double> GetAccuracies(
+ const ImagenetTopKAccuracy::AccuracyStats& accuracy_stats) {
+ std::vector<double> results;
+ results.reserve(accuracy_stats.number_of_images);
+ if (accuracy_stats.number_of_images > 0) {
+ for (int n : accuracy_stats.topk_counts) {
+ double accuracy = 0;
+ if (accuracy_stats.number_of_images > 0) {
+ accuracy = (n * 100.0) / accuracy_stats.number_of_images;
+ }
+ results.push_back(accuracy);
+ }
+ }
+ return results;
+}
+
+} // namespace
+
+// Writes results to a CSV file.
+class ResultsWriter : public ImagenetModelEvaluator::Observer {
+ public:
+ explicit ResultsWriter(std::unique_ptr<CSVWriter> writer)
+ : writer_(std::move(writer)) {}
+
+ void OnEvaluationStart(const std::unordered_map<uint64_t, int>&
+ shard_id_image_count_map) override {}
+
+ void OnSingleImageEvaluationComplete(
+ uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
+ const string& image) override;
+
+ private:
+ std::unique_ptr<CSVWriter> writer_ GUARDED_BY(mu_);
+ mutex mu_;
+};
+
+void ResultsWriter::OnSingleImageEvaluationComplete(
+ uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
+ const string& image) {
+ mutex_lock lock(mu_);
+ TF_CHECK_OK(writer_->WriteRow(GetAccuracies(stats)));
+ writer_->Flush();
+}
+
+// Logs results to standard output with `kLogDelayUs` microseconds.
+class ResultsLogger : public ImagenetModelEvaluator::Observer {
+ public:
+ void OnEvaluationStart(const std::unordered_map<uint64_t, int>&
+ shard_id_image_count_map) override;
+
+ void OnSingleImageEvaluationComplete(
+ uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
+ const string& image) override;
+
+ private:
+ uint64_t last_logged_time_us_ GUARDED_BY(mu_) = 0;
+ int total_num_images_ GUARDED_BY(mu_);
+ static constexpr int kLogDelayUs = 500 * 1000;
+ mutex mu_;
+};
+
+void ResultsLogger::OnEvaluationStart(
+ const std::unordered_map<uint64_t, int>& shard_id_image_count_map) {
+ int total_num_images = 0;
+ for (const auto& kv : shard_id_image_count_map) {
+ total_num_images += kv.second;
+ }
+ LOG(ERROR) << "Starting model evaluation: " << total_num_images;
+ mutex_lock lock(mu_);
+ total_num_images_ = total_num_images;
+}
+
+void ResultsLogger::OnSingleImageEvaluationComplete(
+ uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
+ const string& image) {
+ auto now_us = Env::Default()->NowMicros();
+ int num_evaluated = stats.number_of_images;
+ mutex_lock lock(mu_);
+ if ((now_us - last_logged_time_us_) >= kLogDelayUs) {
+ last_logged_time_us_ = now_us;
+ double current_percent = num_evaluated * 100.0 / total_num_images_;
+ LOG(ERROR) << "Evaluated " << num_evaluated << "/" << total_num_images_
+ << " images, " << std::setprecision(2) << std::fixed
+ << current_percent << "%";
+ }
+}
+
+int Main(int argc, char* argv[]) {
+ // TODO(shashishekhar): Make this binary configurable and model
+ // agnostic.
+ string output_file_path;
+ int num_threads = 4;
+ std::vector<Flag> flag_list = {
+ Flag("output_file_path", &output_file_path, "Path to output file."),
+ Flag("num_threads", &num_threads, "Number of threads."),
+ };
+ Flags::Parse(&argc, argv, flag_list);
+
+ std::unique_ptr<ImagenetModelEvaluator> evaluator;
+ CHECK(!output_file_path.empty()) << "Invalid output file path.";
+
+ CHECK(num_threads > 0) << "Invalid number of threads.";
+
+ TF_CHECK_OK(
+ ImagenetModelEvaluator::Create(argc, argv, num_threads, &evaluator));
+
+ std::ofstream output_stream(output_file_path, std::ios::out);
+ CHECK(output_stream) << "Unable to open output file path: '"
+ << output_file_path << "'";
+
+ output_stream << std::setprecision(3) << std::fixed;
+ std::vector<string> columns;
+ columns.reserve(evaluator->params().num_ranks);
+ for (int i = 0; i < evaluator->params().num_ranks; i++) {
+ string column_name = "Top ";
+ tensorflow::strings::StrAppend(&column_name, i + 1);
+ columns.push_back(column_name);
+ }
+
+ ResultsWriter results_writer(
+ absl::make_unique<CSVWriter>(columns, &output_stream));
+ ResultsLogger logger;
+ evaluator->AddObserver(&results_writer);
+ evaluator->AddObserver(&logger);
+ LOG(ERROR) << "Starting evaluation with: " << num_threads << " threads.";
+ TF_CHECK_OK(evaluator->EvaluateModel());
+ return 0;
+}
+
+} // namespace metrics
+} // namespace tensorflow
+
+int main(int argc, char* argv[]) {
+ return tensorflow::metrics::Main(argc, argv);
+}
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc
new file mode 100644
index 0000000000..63616fc3b4
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.cc
@@ -0,0 +1,351 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h"
+
+#include <fstream>
+#include <iomanip>
+#include <string>
+#include <vector>
+
+#include "absl/memory/memory.h"
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline.h"
+#include "tensorflow/contrib/lite/tools/accuracy/eval_pipeline_builder.h"
+#include "tensorflow/contrib/lite/tools/accuracy/file_reader_stage.h"
+#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h"
+#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h"
+#include "tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h"
+#include "tensorflow/contrib/lite/tools/accuracy/utils.h"
+#include "tensorflow/core/lib/core/blocking_counter.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace {
+using tensorflow::string;
+
+string StripTrailingSlashes(const string& path) {
+ int end = path.size();
+ while (end > 0 && path[end - 1] == '/') {
+ end--;
+ }
+ return path.substr(0, end);
+}
+
+tensorflow::Tensor CreateStringTensor(const string& value) {
+ tensorflow::Tensor tensor(tensorflow::DT_STRING, tensorflow::TensorShape({}));
+ tensor.scalar<string>()() = value;
+ return tensor;
+}
+
+template <typename T>
+std::vector<T> GetFirstN(const std::vector<T>& v, int n) {
+ if (n >= v.size()) return v;
+ std::vector<T> result(v.begin(), v.begin() + n);
+ return result;
+}
+
+template <typename T>
+std::vector<std::vector<T>> Split(const std::vector<T>& v, int n) {
+ CHECK_GT(n, 0);
+ std::vector<std::vector<T>> vecs(n);
+ int input_index = 0;
+ int vec_index = 0;
+ while (input_index < v.size()) {
+ vecs[vec_index].push_back(v[input_index]);
+ vec_index = (vec_index + 1) % n;
+ input_index++;
+ }
+ CHECK_EQ(vecs.size(), n);
+ return vecs;
+}
+
+// File pattern for imagenet files.
+const char* const kImagenetFilePattern = "*.[jJ][pP][eE][gG]";
+
+} // namespace
+
+namespace tensorflow {
+namespace metrics {
+
+class CompositeObserver : public ImagenetModelEvaluator::Observer {
+ public:
+ explicit CompositeObserver(const std::vector<Observer*>& observers)
+ : observers_(observers) {}
+
+ void OnEvaluationStart(const std::unordered_map<uint64_t, int>&
+ shard_id_image_count_map) override {
+ mutex_lock lock(mu_);
+ for (auto observer : observers_) {
+ observer->OnEvaluationStart(shard_id_image_count_map);
+ }
+ }
+
+ void OnSingleImageEvaluationComplete(
+ uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
+ const string& image) override {
+ mutex_lock lock(mu_);
+ for (auto observer : observers_) {
+ observer->OnSingleImageEvaluationComplete(shard_id, stats, image);
+ }
+ }
+
+ private:
+ const std::vector<ImagenetModelEvaluator::Observer*>& observers_
+ GUARDED_BY(mu_);
+ mutex mu_;
+};
+
+/*static*/ Status ImagenetModelEvaluator::Create(
+ int argc, char* argv[], int num_threads,
+ std::unique_ptr<ImagenetModelEvaluator>* model_evaluator) {
+ Params params;
+ const std::vector<Flag> flag_list = {
+ Flag("model_output_labels", &params.model_output_labels_path,
+ "Path to labels that correspond to output of model."
+ " E.g. in case of mobilenet, this is the path to label "
+ "file where each label is in the same order as the output"
+ " of the model."),
+ Flag("ground_truth_images_path", &params.ground_truth_images_path,
+ "Path to ground truth images."),
+ Flag("ground_truth_labels", &params.ground_truth_labels_path,
+ "Path to ground truth labels."),
+ Flag("num_images", &params.number_of_images,
+ "Number of examples to evaluate, pass 0 for all "
+ "examples. Default: 100"),
+ Flag("blacklist_file_path", &params.blacklist_file_path,
+ "Path to blacklist file (optional)."
+ "Path to blacklist file where each line is a single integer that is "
+ "equal to number of blacklisted image."),
+ Flag("model_file", &params.model_file_path,
+ "Path to test tflite model file."),
+ };
+ const bool parse_result = Flags::Parse(&argc, argv, flag_list);
+ if (!parse_result)
+ return errors::InvalidArgument("Invalid command line flags");
+ ::tensorflow::port::InitMain(argv[0], &argc, &argv);
+
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ Env::Default()->IsDirectory(params.ground_truth_images_path),
+ "Invalid ground truth data path.");
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ Env::Default()->FileExists(params.ground_truth_labels_path),
+ "Invalid ground truth labels path.");
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ Env::Default()->FileExists(params.model_output_labels_path),
+ "Invalid model output labels path.");
+
+ if (!params.blacklist_file_path.empty()) {
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ Env::Default()->FileExists(params.blacklist_file_path),
+ "Invalid blacklist path.");
+ }
+
+ if (params.number_of_images < 0) {
+ return errors::InvalidArgument("Invalid: num_examples");
+ }
+
+ utils::ModelInfo model_info;
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ utils::GetTFliteModelInfo(params.model_file_path, &model_info),
+ "Invalid TFLite model.");
+
+ *model_evaluator = absl::make_unique<ImagenetModelEvaluator>(
+ model_info, params, num_threads);
+ return Status::OK();
+}
+
+struct ImageLabel {
+ string image;
+ string label;
+};
+
+Status EvaluateModelForShard(const uint64_t shard_id,
+ const std::vector<ImageLabel>& image_labels,
+ const std::vector<string>& model_labels,
+ const utils::ModelInfo& model_info,
+ const ImagenetModelEvaluator::Params& params,
+ ImagenetModelEvaluator::Observer* observer,
+ ImagenetTopKAccuracy* eval) {
+ const TensorShape& input_shape = model_info.input_shapes[0];
+ const int image_height = input_shape.dim_size(1);
+ const int image_width = input_shape.dim_size(2);
+ const bool is_quantized = (model_info.input_types[0] == DT_UINT8);
+
+ RunTFLiteModelStage::Params tfl_model_params;
+ tfl_model_params.model_file_path = params.model_file_path;
+ if (is_quantized) {
+ tfl_model_params.input_type = {DT_UINT8};
+ tfl_model_params.output_type = {DT_UINT8};
+ } else {
+ tfl_model_params.input_type = {DT_FLOAT};
+ tfl_model_params.output_type = {DT_FLOAT};
+ }
+
+ Scope root = Scope::NewRootScope();
+ FileReaderStage reader;
+ InceptionPreprocessingStage inc(image_height, image_width, is_quantized);
+ RunTFLiteModelStage tfl_model_stage(tfl_model_params);
+ EvalPipelineBuilder builder;
+
+ std::unique_ptr<EvalPipeline> eval_pipeline;
+
+ auto build_status = builder.WithInputStage(&reader)
+ .WithPreprocessingStage(&inc)
+ .WithRunModelStage(&tfl_model_stage)
+ .WithAccuracyEval(eval)
+ .WithInput("input_file", DT_STRING)
+ .Build(root, &eval_pipeline);
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(build_status,
+ "Failure while building eval pipeline.");
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+
+ TF_RETURN_IF_ERROR(eval_pipeline->AttachSession(std::move(session)));
+
+ for (const auto& image_label : image_labels) {
+ TF_CHECK_OK(eval_pipeline->Run(CreateStringTensor(image_label.image),
+ CreateStringTensor(image_label.label)));
+ observer->OnSingleImageEvaluationComplete(
+ shard_id, eval->GetTopKAccuracySoFar(), image_label.image);
+ }
+ return Status::OK();
+}
+
+Status FilterBlackListedImages(const string& blacklist_file_path,
+ std::vector<ImageLabel>* image_labels) {
+ if (!blacklist_file_path.empty()) {
+ std::vector<string> lines;
+ TF_RETURN_IF_ERROR(utils::ReadFileLines(blacklist_file_path, &lines));
+ std::vector<int> blacklist_ids;
+ blacklist_ids.reserve(lines.size());
+ // Populate blacklist_ids with indices of images.
+ std::transform(lines.begin(), lines.end(),
+ std::back_inserter(blacklist_ids),
+ [](const string& val) { return std::stoi(val) - 1; });
+
+ std::vector<ImageLabel> filtered_images;
+ std::sort(blacklist_ids.begin(), blacklist_ids.end());
+ const size_t size_post_filtering =
+ image_labels->size() - blacklist_ids.size();
+ filtered_images.reserve(size_post_filtering);
+ int blacklist_index = 0;
+ for (int image_index = 0; image_index < image_labels->size();
+ image_index++) {
+ if (blacklist_index < blacklist_ids.size() &&
+ blacklist_ids[blacklist_index] == image_index) {
+ blacklist_index++;
+ continue;
+ }
+ filtered_images.push_back((*image_labels)[image_index]);
+ }
+
+ if (filtered_images.size() != size_post_filtering) {
+ return errors::Internal("Invalid number of filtered images");
+ }
+ *image_labels = filtered_images;
+ }
+ return Status::OK();
+}
+
+Status ImagenetModelEvaluator::EvaluateModel() const {
+ if (model_info_.input_shapes.size() != 1) {
+ return errors::InvalidArgument("Invalid input shape");
+ }
+
+ const TensorShape& input_shape = model_info_.input_shapes[0];
+ // Input should be of the shape {1, height, width, 3}
+ if (input_shape.dims() != 4 || input_shape.dim_size(3) != 3) {
+ return errors::InvalidArgument("Invalid input shape for the model.");
+ }
+
+ string data_path =
+ StripTrailingSlashes(params_.ground_truth_images_path) + "/";
+
+ const string imagenet_file_pattern = data_path + kImagenetFilePattern;
+ std::vector<string> image_files;
+ TF_CHECK_OK(
+ Env::Default()->GetMatchingPaths(imagenet_file_pattern, &image_files));
+ std::vector<string> ground_truth_image_labels;
+ TF_CHECK_OK(utils::ReadFileLines(params_.ground_truth_labels_path,
+ &ground_truth_image_labels));
+ CHECK_EQ(image_files.size(), ground_truth_image_labels.size());
+
+ // Process files in filename sorted order.
+ std::sort(image_files.begin(), image_files.end());
+
+ std::vector<ImageLabel> image_labels;
+ image_labels.reserve(image_files.size());
+ for (int i = 0; i < image_files.size(); i++) {
+ image_labels.push_back({image_files[i], ground_truth_image_labels[i]});
+ }
+
+ // Filter any blacklisted images.
+ TF_CHECK_OK(
+ FilterBlackListedImages(params_.blacklist_file_path, &image_labels));
+
+ if (params_.number_of_images > 0) {
+ image_labels = GetFirstN(image_labels, params_.number_of_images);
+ }
+
+ std::vector<string> model_labels;
+ TF_RETURN_IF_ERROR(
+ utils::ReadFileLines(params_.model_output_labels_path, &model_labels));
+ if (model_labels.size() != 1001) {
+ return errors::InvalidArgument("Invalid number of labels: ",
+ model_labels.size());
+ }
+
+ ImagenetTopKAccuracy eval(model_labels, params_.num_ranks);
+
+ auto img_labels = Split(image_labels, num_threads_);
+
+ BlockingCounter counter(num_threads_);
+
+ CompositeObserver observer(observers_);
+
+ ::tensorflow::thread::ThreadPool pool(Env::Default(), "evaluation_pool",
+ num_threads_);
+ std::unordered_map<uint64_t, int> shard_id_image_count_map;
+ std::vector<std::function<void()>> thread_funcs;
+ thread_funcs.reserve(num_threads_);
+ for (int i = 0; i < num_threads_; i++) {
+ const auto& image_label = img_labels[i];
+ const uint64_t shard_id = i + 1;
+ shard_id_image_count_map[shard_id] = image_label.size();
+ auto func = [shard_id, &image_label, &model_labels, this, &observer, &eval,
+ &counter]() {
+ TF_CHECK_OK(EvaluateModelForShard(shard_id, image_label, model_labels,
+ model_info_, params_, &observer,
+ &eval));
+ counter.DecrementCount();
+ };
+ thread_funcs.push_back(func);
+ }
+
+ observer.OnEvaluationStart(shard_id_image_count_map);
+ for (const auto& func : thread_funcs) {
+ pool.Schedule(func);
+ }
+
+ counter.Wait();
+
+ return Status::OK();
+}
+
+} // namespace metrics
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h
new file mode 100644
index 0000000000..97e4232b35
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_model_evaluator.h
@@ -0,0 +1,124 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_MODEL_EVALUATOR_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_MODEL_EVALUATOR_H_
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h"
+#include "tensorflow/contrib/lite/tools/accuracy/utils.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace metrics {
+
+// Evaluates models accuracy for ILSVRC dataset.
+//
+// Generates the top-1, top-k accuracy counts where k is
+// controlled by |num_ranks|.
+// Usage:
+// ModelInfo model_info = ..
+// ImagenetModelEvaluator::Params params;
+// .. set params to image, label, output label and model file path..
+// SomeObserver observer;
+// ImagenetModelEvaluator evaluator(model_info, params);
+// evaluator.AddObserver(&observer);
+// TF_CHECK_OK(evaluator.EvaluateModel());
+class ImagenetModelEvaluator {
+ public:
+ struct Params {
+ // Path to ground truth images.
+ string ground_truth_images_path;
+
+ // Path to labels file for ground truth image.
+ // This file should be generated with the scripts.
+ string ground_truth_labels_path;
+
+ // This is word labels generated by the model. The category
+ // indices of output probabilities generated by the model maybe different
+ // from the indices in the imagenet dataset.
+ string model_output_labels_path;
+
+ // Path to the model file.
+ string model_file_path;
+
+ // Path to black list file. 1762 images were blacklisted from
+ // original ILSVRC dataset. This black list file is present in
+ // ILSVRC2014 devkit. Please refer to readme.txt of the ILSVRC2014
+ // devkit for details.
+ // This file is a list of image indices in a sorted order.
+ string blacklist_file_path;
+
+ // The maximum number of images to calculate accuracy.
+ // 0 means all images, a positive number means only the specified
+ // number of images.
+ int number_of_images = 0;
+
+ // Number of ranks, top K.
+ int num_ranks = 10;
+ };
+
+ // An evaluation observer.
+ // Observers can be called from multiple threads and need to be thread safe.
+ class Observer {
+ public:
+ Observer() = default;
+ Observer(const Observer&) = delete;
+ Observer& operator=(const Observer&) = delete;
+
+ Observer(const Observer&&) = delete;
+ Observer& operator=(const Observer&&) = delete;
+
+ // Called on start of evaluation.
+ // `shard_id_image_count_map` map from shard id to image count.
+ virtual void OnEvaluationStart(
+ const std::unordered_map<uint64_t, int>& shard_id_image_count_map) = 0;
+
+ // Called when evaluation was complete for `image`.
+ virtual void OnSingleImageEvaluationComplete(
+ uint64_t shard_id, const ImagenetTopKAccuracy::AccuracyStats& stats,
+ const string& image) = 0;
+
+ virtual ~Observer() = default;
+ };
+
+ ImagenetModelEvaluator(const utils::ModelInfo& model_info,
+ const Params& params, const int num_threads)
+ : model_info_(model_info), params_(params), num_threads_(num_threads) {}
+
+ // Factory method to create the evaluator by parsing command line arguments.
+ static Status Create(int argc, char* argv[], int num_threads,
+ std::unique_ptr<ImagenetModelEvaluator>* evaluator);
+
+ // Adds an observer that can observe evaluation events..
+ void AddObserver(Observer* observer) { observers_.push_back(observer); }
+
+ const Params& params() const { return params_; }
+
+ // Evaluates the provided model over the dataset.
+ Status EvaluateModel() const;
+
+ private:
+ const utils::ModelInfo model_info_;
+ const Params params_;
+ const int num_threads_;
+ std::vector<Observer*> observers_;
+};
+
+} // namespace metrics
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ILSVRC_IMAGENET_MODEL_EVALUATOR_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.cc b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.cc
new file mode 100644
index 0000000000..c75baa82b1
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.cc
@@ -0,0 +1,114 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h"
+
+#include <numeric>
+
+namespace {
+constexpr int kNumCategories = 1001;
+std::vector<int> GetTopK(const std::vector<float>& values, int k) {
+ CHECK_LE(k, values.size());
+ std::vector<int> indices(values.size());
+
+ std::iota(indices.begin(), indices.end(), 0);
+ std::sort(indices.begin(), indices.end(),
+ [&values](int a, int b) { return values[a] > values[b]; });
+
+ indices.resize(k);
+ return indices;
+}
+} // namespace
+
+namespace tensorflow {
+namespace metrics {
+ImagenetTopKAccuracy::ImagenetTopKAccuracy(
+ const std::vector<string>& ground_truth_labels, int k)
+ : ground_truth_labels_(ground_truth_labels),
+ k_(k),
+ accuracy_counts_(k_, 0),
+ num_samples_(0) {
+ CHECK_EQ(kNumCategories, ground_truth_labels.size());
+}
+
+Status ImagenetTopKAccuracy::ComputeEval(
+ const std::vector<Tensor>& model_outputs, const Tensor& ground_truth) {
+ if (model_outputs.size() != 1) {
+ return errors::InvalidArgument("Invalid model output: ",
+ model_outputs.size());
+ }
+ const Tensor& output = model_outputs[0];
+ if (!output.shape().IsSameSize({1, kNumCategories})) {
+ return errors::InvalidArgument("Invalid shape of model output: ",
+ output.shape().DebugString());
+ }
+ if (ground_truth.dtype() != DT_STRING && ground_truth.dims() != 0) {
+ return errors::InvalidArgument("Invalid ground truth type: ",
+ ground_truth.DebugString());
+ }
+ string ground_truth_label = ground_truth.scalar<string>()();
+
+ std::vector<float> probabilities;
+ probabilities.reserve(kNumCategories);
+ if (output.dtype() == DT_FLOAT) {
+ auto probs = output.flat<float>();
+ for (size_t i = 0; i < probs.size(); i++) {
+ probabilities.push_back(probs(i));
+ }
+ } else {
+ auto probs = output.flat<uint8>();
+ for (size_t i = 0; i < probs.size(); i++) {
+ probabilities.push_back(probs(i));
+ }
+ }
+
+ CHECK_EQ(kNumCategories, probabilities.size());
+ std::vector<int> topK = GetTopK(probabilities, k_);
+ int ground_truth_index = GroundTruthIndex(ground_truth_label);
+ UpdateSamples(topK, ground_truth_index);
+ return Status::OK();
+}
+
+const ImagenetTopKAccuracy::AccuracyStats
+ImagenetTopKAccuracy::GetTopKAccuracySoFar() const {
+ mutex_lock lock(mu_);
+ AccuracyStats stats;
+ stats.number_of_images = num_samples_;
+ stats.topk_counts = accuracy_counts_;
+ return stats;
+}
+
+void ImagenetTopKAccuracy::UpdateSamples(const std::vector<int>& counts,
+ int ground_truth_index) {
+ mutex_lock lock(mu_);
+ for (size_t i = 0; i < counts.size(); ++i) {
+ if (ground_truth_index == counts[i]) {
+ for (size_t j = i; j < counts.size(); j++) {
+ accuracy_counts_[j] += 1;
+ }
+ break;
+ }
+ }
+ num_samples_++;
+}
+
+int ImagenetTopKAccuracy::GroundTruthIndex(const string& label) const {
+ auto index = std::find(ground_truth_labels_.cbegin(),
+ ground_truth_labels_.cend(), label);
+ CHECK(index != ground_truth_labels_.end()) << "Invalid label: " << label;
+ return std::distance(ground_truth_labels_.cbegin(), index);
+}
+} // namespace metrics
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h
new file mode 100644
index 0000000000..cad646a30c
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h
@@ -0,0 +1,83 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_TOPK_EVAL_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_IMAGENET_TOPK_EVAL_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/tools/accuracy/accuracy_eval_stage.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+namespace metrics {
+// An |AccuracyEval| stage that calculates the top K error rate for model
+// evaluations on imagenet like datasets.
+// Inputs: A {1, 1001} shaped tensor that contains the probabilities for objects
+// predicted by the model.
+// Ground truth: A |string| label for the image.
+// From the input object probabilities, the stage computes the predicted labels
+// and finds the top K error rates by comparing the predictions with ground
+// truths.
+class ImagenetTopKAccuracy : public AccuracyEval {
+ public:
+ // Accuracy statistics.
+ struct AccuracyStats {
+ // Number of images evaluated.
+ int number_of_images;
+ // A vector of size |k| that contains the number of images
+ // that have correct labels in top K.
+ // E.g. topk_counts[0] contains number of images for which
+ // model returned the correct label as the first result.
+ // Similarly topk_counts[4] contains the number of images for which
+ // model returned the correct label in top 5 results.
+ // This can be used to compute the top K error-rate for the model.
+ std::vector<int> topk_counts;
+ };
+
+ // Creates a new instance of |ImagenetTopKAccuracy| with the given
+ // |ground_truth_labels| and |k|.
+ // Args:
+ // |ground_truth_labels| : an ordered vector of labels for images. This is
+ // used to compute the index for the predicted labels and ground_truth label.
+ ImagenetTopKAccuracy(const std::vector<string>& ground_truth_labels, int k);
+
+ // Computes accuracy for a given image. The |model_outputs| should
+ // be a vector containing exactly one Tensor of shape: {1, 1001} where each
+ // item is a probability of the predicted object representing the image as
+ // output by the model.
+ // Uses |ground_truth_labels| to compute the index of |model_outputs| and
+ // |ground_truth| and computes the top K error rate.
+ Status ComputeEval(const std::vector<Tensor>& model_outputs,
+ const Tensor& ground_truth) override;
+
+ // Gets the topK accuracy for images that have been evaluated till now.
+ const AccuracyStats GetTopKAccuracySoFar() const;
+
+ private:
+ int GroundTruthIndex(const string& label) const;
+ void UpdateSamples(const std::vector<int>& counts, int ground_truth_index);
+ const std::vector<string> ground_truth_labels_;
+ const int k_;
+ std::vector<int> accuracy_counts_ GUARDED_BY(mu_);
+ int num_samples_ GUARDED_BY(mu_);
+ mutable mutex mu_;
+};
+} // namespace metrics
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_ILSVRC_IMAGENET_TOPK_EVAL_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval_test.cc b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval_test.cc
new file mode 100644
index 0000000000..ff332af5c5
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval_test.cc
@@ -0,0 +1,151 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/imagenet_topk_eval.h"
+#include <gtest/gtest.h>
+
+namespace tensorflow {
+namespace metrics {
+namespace {
+
+const int kNumCategories = 1001;
+
+Tensor CreateStringTensor(const string& value) {
+ Tensor tensor(DT_STRING, TensorShape({}));
+ tensor.scalar<string>()() = value;
+ return tensor;
+}
+
+Tensor CreateOutputTensor() {
+ Tensor tensor(DT_FLOAT, TensorShape({1, kNumCategories}));
+ for (int i = 0; i < kNumCategories; i++) {
+ tensor.flat<float>()(i) = 0;
+ }
+ return tensor;
+}
+
+std::vector<string> CreateGroundTruth() {
+ std::vector<string> ground_truth;
+ ground_truth.reserve(kNumCategories);
+ for (int i = 0; i < kNumCategories; i++) {
+ string category;
+ strings::StrAppend(&category, i);
+ ground_truth.push_back(category);
+ }
+ return ground_truth;
+}
+
+TEST(ImagenetTopKAccuracy, AllCorrect) {
+ ImagenetTopKAccuracy acc_top_5(CreateGroundTruth(), 5);
+ auto accuracies = acc_top_5.GetTopKAccuracySoFar();
+ EXPECT_EQ(0, accuracies.number_of_images);
+ EXPECT_EQ(5, accuracies.topk_counts.size());
+
+ for (int i : accuracies.topk_counts) {
+ EXPECT_EQ(0, i);
+ }
+ // First image was correctly identified as "0".
+ Tensor tensor = CreateOutputTensor();
+ tensor.flat<float>()(0) = 0.8;
+
+ TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("0")));
+ accuracies = acc_top_5.GetTopKAccuracySoFar();
+ EXPECT_EQ(1, accuracies.number_of_images);
+
+ for (int i : accuracies.topk_counts) {
+ EXPECT_EQ(1, i);
+ }
+ tensor.flat<float>()(1) = 0.9;
+ TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("1")));
+ accuracies = acc_top_5.GetTopKAccuracySoFar();
+ EXPECT_EQ(2, accuracies.number_of_images);
+
+ for (int i : accuracies.topk_counts) {
+ EXPECT_EQ(2, i);
+ }
+}
+
+TEST(ImagenetTopKAccuracy, Top5) {
+ ImagenetTopKAccuracy acc_top_5(CreateGroundTruth(), 5);
+ auto accuracies = acc_top_5.GetTopKAccuracySoFar();
+ EXPECT_EQ(0, accuracies.number_of_images);
+ EXPECT_EQ(5, accuracies.topk_counts.size());
+
+ // For first image, with ground truth "0" probabilities were
+ // 0.5 for "0",
+ // "0.6" for 1,
+ // "0.7" for 2,
+ // "0.8" for 3,
+ // "0.9" for 4.
+ // remaining all zeroes.
+
+ // First image was correctly identified as "0".
+ Tensor tensor = CreateOutputTensor();
+ tensor.flat<float>()(0) = 0.5;
+ tensor.flat<float>()(1) = 0.6;
+ tensor.flat<float>()(2) = 0.7;
+ tensor.flat<float>()(3) = 0.8;
+ tensor.flat<float>()(4) = 0.9;
+
+ TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("0")));
+ accuracies = acc_top_5.GetTopKAccuracySoFar();
+ EXPECT_EQ(1, accuracies.number_of_images);
+ EXPECT_EQ(1, accuracies.topk_counts[4]);
+
+ for (int i = 0; i < 4; i++) {
+ EXPECT_EQ(0, accuracies.topk_counts[i]);
+ }
+
+ // Now for "1" only last two buckets are going to be affected.
+ TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("1")));
+ accuracies = acc_top_5.GetTopKAccuracySoFar();
+ EXPECT_EQ(2, accuracies.number_of_images);
+ EXPECT_EQ(1, accuracies.topk_counts[3]);
+ EXPECT_EQ(2, accuracies.topk_counts[4]);
+ for (int i = 0; i < 3; i++) {
+ EXPECT_EQ(0, accuracies.topk_counts[i]);
+ }
+
+ // All buckets will be affected.
+ TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("4")));
+ accuracies = acc_top_5.GetTopKAccuracySoFar();
+ EXPECT_EQ(3, accuracies.number_of_images);
+ EXPECT_EQ(1, accuracies.topk_counts[0]);
+ EXPECT_EQ(1, accuracies.topk_counts[1]);
+ EXPECT_EQ(1, accuracies.topk_counts[2]);
+ EXPECT_EQ(2, accuracies.topk_counts[3]);
+ EXPECT_EQ(3, accuracies.topk_counts[4]);
+
+ // No buckets will be affected
+ TF_CHECK_OK(acc_top_5.ComputeEval({tensor}, CreateStringTensor("10")));
+ accuracies = acc_top_5.GetTopKAccuracySoFar();
+ EXPECT_EQ(4, accuracies.number_of_images);
+ EXPECT_EQ(1, accuracies.topk_counts[0]);
+ EXPECT_EQ(1, accuracies.topk_counts[1]);
+ EXPECT_EQ(1, accuracies.topk_counts[2]);
+ EXPECT_EQ(2, accuracies.topk_counts[3]);
+ EXPECT_EQ(3, accuracies.topk_counts[4]);
+}
+
+} // namespace
+
+} // namespace metrics
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.cc b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.cc
new file mode 100644
index 0000000000..7512b39c32
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.cc
@@ -0,0 +1,80 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h"
+
+#include <memory>
+
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+namespace metrics {
+
+namespace {
+void CentralCropImage(const Scope& s, const tensorflow::Output& decoded_image,
+ double crop_fraction, tensorflow::Output* cropped_image) {
+ auto image_dims = ops::Slice(s, ops::Shape(s, decoded_image), {0}, {2});
+ auto height_width = ops::Cast(s, image_dims, DT_DOUBLE);
+ auto cropped_begin = ops::Div(
+ s, ops::Sub(s, height_width, ops::Mul(s, height_width, crop_fraction)),
+ 2.0);
+ auto bbox_begin = ops::Cast(s, cropped_begin, DT_INT32);
+ auto bbox_size = ops::Sub(s, image_dims, ops::Mul(s, bbox_begin, 2));
+ auto slice_begin = ops::Concat(s, {bbox_begin, Input({0})}, 0);
+ auto slice_size = ops::Concat(s, {bbox_size, {-1}}, 0);
+ *cropped_image = ops::Slice(s, decoded_image, slice_begin, slice_size);
+}
+
+} // namespace
+
+void InceptionPreprocessingStage::AddToGraph(const Scope& scope,
+ const Input& input) {
+ if (!scope.ok()) return;
+ Scope s = scope.WithOpName(name());
+ ops::DecodeJpeg::Attrs attrs;
+ attrs.channels_ = 3;
+ auto decoded_jpeg = ops::DecodeJpeg(s, input, attrs);
+ tensorflow::Output cropped_image;
+ CentralCropImage(s, decoded_jpeg, params_.cropping_fraction, &cropped_image);
+ auto dims_expander = ops::ExpandDims(s, cropped_image, 0);
+ auto resized_image = ops::ResizeBilinear(
+ s, dims_expander,
+ ops::Const(s.WithOpName("size"), {image_height_, image_width_}));
+ if (is_quantized_) {
+ this->stage_output_ =
+ ops::Cast(s.WithOpName(output_name()), resized_image, DT_UINT8);
+ } else {
+ auto squeezed_image = ops::Squeeze(s, resized_image);
+ auto normalized_image =
+ ops::Div(s,
+ ops::Sub(s, squeezed_image,
+ {params_.input_means[0], params_.input_means[1],
+ params_.input_means[2]}),
+ {params_.scale});
+ this->stage_output_ =
+ ops::ExpandDims(s.WithOpName(output_name()), normalized_image, {0});
+ }
+}
+
+} // namespace metrics
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h
new file mode 100644
index 0000000000..15df719817
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h
@@ -0,0 +1,75 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_INCEPTION_PREPROCESSING_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_INCEPTION_PREPROCESSING_H_
+
+#include <utility>
+
+#include "tensorflow/contrib/lite/tools/accuracy/stage.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace metrics {
+
+// A stage that does inception preprocessing.
+// Inputs: A tensor containing bytes of a JPEG image.
+// Outputs: A tensor containing rescaled and preprocessed image that has
+// shape {1, image_height, image_width, 3}, where 3 is the number of channels.
+class InceptionPreprocessingStage : public Stage {
+ public:
+ struct Params {
+ std::vector<float> input_means;
+ float scale;
+ double cropping_fraction;
+ };
+
+ static Params DefaultParams() {
+ return {.input_means = {127.5, 127.5, 127.5},
+ .scale = 127.5,
+ .cropping_fraction = 0.875};
+ }
+
+ // Creates a new preprocessing stage object with provided |image_width|
+ // |image_height| as the size of output image.
+ // If |is_quantized| is set to true then |params| is ignored since quantized
+ // images don't go through any preprocessing.
+ InceptionPreprocessingStage(int image_width, int image_height,
+ bool is_quantized,
+ Params params = DefaultParams())
+ : image_width_(image_width),
+ image_height_(image_height),
+ is_quantized_(is_quantized),
+ params_(std::move(params)) {}
+
+ string name() const override { return "stage_inception_preprocess"; }
+ string output_name() const override {
+ return "stage_inception_preprocess_output";
+ }
+
+ void AddToGraph(const Scope& scope, const Input& input) override;
+
+ private:
+ int image_width_;
+ int image_height_;
+ bool is_quantized_;
+ Params params_;
+};
+
+} // namespace metrics
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_INCEPTION_PREPROCESSING_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing_test.cc b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing_test.cc
new file mode 100644
index 0000000000..3587878ba3
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing_test.cc
@@ -0,0 +1,123 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <fstream>
+#include <string>
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/tools/accuracy/ilsvrc/inception_preprocessing.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace {
+tensorflow::string* g_test_image_file = nullptr;
+} // namespace
+
+namespace tensorflow {
+namespace metrics {
+
+namespace {
+
+using tensorflow::Status;
+using tensorflow::Tensor;
+
+Status GetContents(const string& filename, string* output) {
+ std::ifstream input(filename, std::ios::binary);
+ const int kBufferSize = 2048;
+ char buffer[kBufferSize];
+ while (true) {
+ input.read(buffer, kBufferSize);
+ output->append(buffer, input.gcount());
+ if (!input.good()) {
+ if (input.eof()) return Status::OK();
+ return Status(tensorflow::error::ABORTED, "Failed to read file.");
+ }
+ }
+}
+
+TEST(InceptionPreprocessingTest, TestImagePreprocessQuantized) {
+ ASSERT_TRUE(g_test_image_file != nullptr);
+ string image_contents;
+ string image_path = *g_test_image_file;
+ auto status = GetContents(image_path, &image_contents);
+ ASSERT_TRUE(status.ok()) << status.error_message();
+ const int width = 224;
+ const int height = 224;
+ const bool is_quantized = true;
+ InceptionPreprocessingStage preprocess_stage(width, height, is_quantized);
+ Scope scope = Scope::NewRootScope();
+ preprocess_stage.AddToGraph(scope, image_contents);
+ TF_CHECK_OK(scope.status());
+
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(session->Create(graph_def));
+ std::vector<Tensor> outputs;
+ auto run_status =
+ session->Run({}, /*inputs*/
+ {preprocess_stage.output_name()}, {}, /*target node names */
+ &outputs);
+ TF_CHECK_OK(run_status);
+ EXPECT_EQ(1, outputs.size());
+ EXPECT_EQ(DT_UINT8, outputs[0].dtype());
+ EXPECT_TRUE(outputs[0].shape().IsSameSize({1, 224, 224, 3}));
+}
+
+TEST(InceptionPreprocessingTest, TestImagePreprocessFloat) {
+ ASSERT_TRUE(g_test_image_file != nullptr);
+ string image_contents;
+ string image_path = *g_test_image_file;
+ auto status = GetContents(image_path, &image_contents);
+ ASSERT_TRUE(status.ok()) << status.error_message();
+ const int width = 224;
+ const int height = 224;
+ const bool is_quantized = false;
+ InceptionPreprocessingStage preprocess_stage(width, height, is_quantized);
+ Scope scope = Scope::NewRootScope();
+ preprocess_stage.AddToGraph(scope, image_contents);
+ TF_CHECK_OK(scope.status());
+
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(session->Create(graph_def));
+ std::vector<Tensor> outputs;
+ auto run_status =
+ session->Run({}, /*inputs*/
+ {preprocess_stage.output_name()}, {}, /*target node names */
+ &outputs);
+ TF_CHECK_OK(run_status);
+ EXPECT_EQ(1, outputs.size());
+ EXPECT_EQ(DT_FLOAT, outputs[0].dtype());
+ EXPECT_TRUE(outputs[0].shape().IsSameSize({1, 224, 224, 3}));
+}
+
+} // namespace
+} // namespace metrics
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ g_test_image_file = new tensorflow::string();
+ const std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag("test_image", g_test_image_file,
+ "Path to image file for test."),
+ };
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ CHECK(parse_result) << "Required test_model_file";
+ ::tensorflow::port::InitMain(argv[0], &argc, &argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/tools/accuracy/ilsvrc/testdata/grace_hopper.jpg b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/testdata/grace_hopper.jpg
new file mode 100644
index 0000000000..d2a427810f
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/ilsvrc/testdata/grace_hopper.jpg
Binary files differ
diff --git a/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op.cc b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op.cc
new file mode 100644
index 0000000000..da4258f1c1
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op.cc
@@ -0,0 +1,158 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/op_resolver.h"
+#include "tensorflow/contrib/lite/tools/accuracy/utils.h"
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+namespace {
+Status ValidateInputsMatch(const OpInputList& input_tensors,
+ const tflite::Interpreter& interpreter) {
+ std::vector<int> tflite_tensor_indices = interpreter.inputs();
+ if (tflite_tensor_indices.size() != input_tensors.size()) {
+ return errors::InvalidArgument(
+ "size mismatch, interpreter size: ", tflite_tensor_indices.size(),
+ " actual: ", input_tensors.size());
+ }
+
+ for (int i = 0; i < input_tensors.size(); i++) {
+ const TfLiteTensor* tflite_tensor =
+ interpreter.tensor(tflite_tensor_indices[i]);
+ if (tflite_tensor == nullptr) {
+ return errors::InvalidArgument("Tensor is null at index: ", i);
+ }
+
+ const Tensor& tensor = input_tensors[i];
+ auto i_type = metrics::utils::GetTFDataType(tflite_tensor->type);
+ auto i_shape = metrics::utils::GetTFLiteTensorShape(*tflite_tensor);
+ if (i_type != tensor.dtype()) {
+ return errors::InvalidArgument("Data types mismatch for tensors: ", i,
+ " expected: ", i_type,
+ " got: ", tensor.dtype());
+ }
+
+ if (i_shape != tensor.shape()) {
+ return errors::InvalidArgument("Data shapes mismatch for tensors: ", i,
+ " expected: ", i_shape,
+ " got: ", tensor.shape());
+ }
+ }
+
+ return Status::OK();
+}
+
+} // namespace
+
+class RunTFLiteModelOp : public OpKernel {
+ public:
+ explicit RunTFLiteModelOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string model_file_path;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("model_file_path", &model_file_path));
+ model_ = tflite::FlatBufferModel::BuildFromFile(model_file_path.data());
+ OP_REQUIRES(ctx, model_,
+ errors::InvalidArgument(
+ "Model loading failed. Invalid model file path: ",
+ model_file_path));
+ tflite::ops::builtin::BuiltinOpResolver resolver;
+
+ tflite::InterpreterBuilder(*model_, resolver)(&interpreter_);
+ OP_REQUIRES(ctx, interpreter_,
+ errors::Internal("Interpreter creation failed."));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ OpInputList input_tensors;
+ OP_REQUIRES_OK(context, context->input_list("model_input", &input_tensors));
+
+ OP_REQUIRES_OK(context, ValidateInputsMatch(input_tensors, *interpreter_));
+ OpOutputList output_tensors;
+ OP_REQUIRES_OK(context,
+ context->output_list("model_output", &output_tensors));
+ auto tfl_outputs = interpreter_->outputs();
+ OP_REQUIRES(context, output_tensors.size() == tfl_outputs.size(),
+ errors::InvalidArgument(
+ "Invalid output size, expected: ", tfl_outputs.size(),
+ " got: ", output_tensors.size()));
+ for (int i = 0; i < output_tensors.size(); i++) {
+ DataType tfl_type = metrics::utils::GetTFDataType(
+ interpreter_->tensor(tfl_outputs[i])->type);
+ DataType otype = output_tensors.expected_output_dtype(i);
+ OP_REQUIRES(
+ context, tfl_type == otype,
+ errors::InvalidArgument("Invalid data type for output at index: ", i,
+ " expected: ", tfl_type, " got: ", otype));
+ }
+
+ auto allocation_status = interpreter_->AllocateTensors();
+ OP_REQUIRES(context, allocation_status == kTfLiteOk,
+ errors::Internal("Unable to allocate tensors."));
+ for (int i = 0; i < input_tensors.size(); i++) {
+ const int tfl_index = interpreter_->inputs()[i];
+ TfLiteTensor* tflite_tensor = interpreter_->tensor(tfl_index);
+ auto tensor_bytes = input_tensors[i].tensor_data();
+ OP_REQUIRES(context, tflite_tensor->bytes == tensor_bytes.size(),
+ errors::InvalidArgument(
+ "Size mismatch, expected: ", tflite_tensor->bytes,
+ " got: ", tensor_bytes.size()));
+ std::memcpy(tflite_tensor->data.raw, tensor_bytes.data(),
+ tensor_bytes.size());
+ }
+ auto invocation_status = interpreter_->Invoke();
+ OP_REQUIRES(context, invocation_status == kTfLiteOk,
+ errors::Internal("Interpreter invocation failed."));
+ for (int i = 0; i < output_tensors.size(); i++) {
+ auto tfl_tensor = interpreter_->tensor(tfl_outputs[i]);
+ TensorShape shape = metrics::utils::GetTFLiteTensorShape(*tfl_tensor);
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, output_tensors.allocate(i, shape, &output));
+ auto tensor_bytes = output->tensor_data();
+ OP_REQUIRES(context, tensor_bytes.size() == tfl_tensor->bytes,
+ errors::Internal("Invalid size"));
+ std::memcpy(const_cast<char*>(tensor_bytes.data()), tfl_tensor->data.raw,
+ tfl_tensor->bytes);
+ }
+ }
+
+ private:
+ std::unique_ptr<tflite::FlatBufferModel> model_;
+ std::unique_ptr<tflite::Interpreter> interpreter_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("RunTFLiteModel").Device(DEVICE_CPU),
+ RunTFLiteModelOp);
+
+REGISTER_OP("RunTFLiteModel")
+ .Input("model_input: input_type")
+ .Output("model_output: output_type")
+ .Attr("model_file_path: string")
+ .Attr("input_type : list(type)")
+ .Attr("output_type: list(type)")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ // TODO(shashishekhar): Infer the correct shape based on output_type and
+ // maybe another attribute.
+ return shape_inference::UnknownShape(c);
+ });
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op_test.cc b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op_test.cc
new file mode 100644
index 0000000000..88175984a0
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_op_test.cc
@@ -0,0 +1,200 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace {
+tensorflow::string* g_test_model_file = nullptr;
+}
+
+namespace tensorflow {
+namespace {
+
+TEST(RunTfliteModelOpTest, ModelIsRun) {
+ ASSERT_TRUE(g_test_model_file != nullptr);
+ string test_model_file = *g_test_model_file;
+ ASSERT_FALSE(test_model_file.empty());
+
+ Scope scope = Scope::NewRootScope();
+ TF_CHECK_OK(scope.status());
+ // Passed graph has 4 inputs : a,b,c,d and 2 outputs x,y
+ // x = a+b+c, y=b+c+d
+
+ std::vector<Input> graph_inputs = {
+ ops::Const(scope, 1.0f, {1, 8, 8, 3}), // a
+ ops::Const(scope, 2.1f, {1, 8, 8, 3}), // b
+ ops::Const(scope, 3.2f, {1, 8, 8, 3}), // c
+ ops::Const(scope, 4.3f, {1, 8, 8, 3}), // d
+ };
+
+ std::vector<NodeBuilder::NodeOut> input_data;
+ std::transform(graph_inputs.begin(), graph_inputs.end(),
+ std::back_inserter(input_data), [&scope](Input model_input) {
+ return ops::AsNodeOut(scope, model_input);
+ });
+
+ std::vector<DataType> model_input_type = {DT_FLOAT, DT_FLOAT, DT_FLOAT,
+ DT_FLOAT};
+ ::tensorflow::Node* ret;
+ auto builder = ::tensorflow::NodeBuilder("run_model_op", "RunTFLiteModel")
+ .Input(input_data)
+ .Attr("model_file_path", test_model_file)
+ .Attr("input_type", model_input_type)
+ .Attr("output_type", {DT_FLOAT, DT_FLOAT});
+
+ scope.UpdateBuilder(&builder);
+ scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
+ TF_CHECK_OK(scope.status());
+
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(session->Create(graph_def));
+
+ std::vector<Tensor> outputs;
+ TF_CHECK_OK(
+ session->Run({}, {"run_model_op:0", "run_model_op:1"}, {}, &outputs));
+ EXPECT_EQ(2, outputs.size());
+
+ for (const auto& tensor : outputs) {
+ EXPECT_TRUE(tensor.shape().IsSameSize({1, 8, 8, 3}));
+ }
+ auto output_x = outputs[0].flat<float>();
+ auto output_y = outputs[1].flat<float>();
+ EXPECT_EQ(1 * 8 * 8 * 3, output_x.size());
+ EXPECT_EQ(1 * 8 * 8 * 3, output_y.size());
+ for (int i = 0; i < output_x.size(); i++) {
+ EXPECT_NEAR(6.3f, output_x(i), 1e-6f); // a+b+c
+ EXPECT_NEAR(9.6f, output_y(i), 1e-6f); // b+c+d
+ }
+}
+
+TEST(RunTfliteModelOpTest, NumInputsMismatch) {
+ ASSERT_TRUE(g_test_model_file != nullptr);
+ string test_model_file = *g_test_model_file;
+ ASSERT_FALSE(test_model_file.empty());
+
+ Scope scope = Scope::NewRootScope();
+ TF_CHECK_OK(scope.status());
+ // Passed graph has 4 inputs : a,b,c,d and 2 outputs x,y
+ // x = a+b+c, y=b+c+d
+ // Remove a from input.
+
+ std::vector<Input> graph_inputs = {
+ ops::Const(scope, 2.1f, {1, 8, 8, 3}), // b
+ ops::Const(scope, 3.2f, {1, 8, 8, 3}), // c
+ ops::Const(scope, 4.3f, {1, 8, 8, 3}), // d
+ };
+
+ std::vector<NodeBuilder::NodeOut> input_data;
+ std::transform(graph_inputs.begin(), graph_inputs.end(),
+ std::back_inserter(input_data), [&scope](Input model_input) {
+ return ops::AsNodeOut(scope, model_input);
+ });
+
+ std::vector<DataType> model_input_type = {DT_FLOAT, DT_FLOAT, DT_FLOAT};
+
+ ::tensorflow::Node* ret;
+ auto builder = ::tensorflow::NodeBuilder("run_model_op", "RunTFLiteModel")
+ .Input(input_data)
+ .Attr("model_file_path", test_model_file)
+ .Attr("input_type", model_input_type)
+ .Attr("output_type", {DT_FLOAT, DT_FLOAT});
+
+ scope.UpdateBuilder(&builder);
+ scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
+ TF_CHECK_OK(scope.status());
+
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(session->Create(graph_def));
+
+ std::vector<Tensor> outputs;
+ auto status =
+ (session->Run({}, {"run_model_op:0", "run_model_op:1"}, {}, &outputs));
+ EXPECT_FALSE(status.ok());
+}
+
+TEST(RunTfliteModelOpTest, InputSizesMismatch) {
+ ASSERT_TRUE(g_test_model_file != nullptr);
+ string test_model_file = *g_test_model_file;
+ ASSERT_FALSE(test_model_file.empty());
+
+ Scope scope = Scope::NewRootScope();
+ TF_CHECK_OK(scope.status());
+ // Passed graph has 4 inputs : a,b,c,d and 2 outputs x,y
+ // x = a+b+c, y=b+c+d
+ // Set a to be invalid size.
+ std::vector<Input> graph_inputs = {
+ ops::Const(scope, 1.0f, {1, 8, 8, 4}), // a invalid size,
+ ops::Const(scope, 2.1f, {1, 8, 8, 3}), // b
+ ops::Const(scope, 3.2f, {1, 8, 8, 3}), // c
+ ops::Const(scope, 4.3f, {1, 8, 8, 3}), // d
+ };
+
+ std::vector<NodeBuilder::NodeOut> input_data;
+ std::transform(graph_inputs.begin(), graph_inputs.end(),
+ std::back_inserter(input_data), [&scope](Input model_input) {
+ return ops::AsNodeOut(scope, model_input);
+ });
+
+ std::vector<DataType> model_input_type = {DT_FLOAT, DT_FLOAT, DT_FLOAT,
+ DT_FLOAT};
+ ::tensorflow::Node* ret;
+ auto builder = ::tensorflow::NodeBuilder("run_model_op", "RunTFLiteModel")
+ .Input(input_data)
+ .Attr("model_file_path", test_model_file)
+ .Attr("input_type", model_input_type)
+ .Attr("output_type", {DT_FLOAT, DT_FLOAT});
+
+ scope.UpdateBuilder(&builder);
+ scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
+ TF_CHECK_OK(scope.status());
+
+ GraphDef graph_def;
+ TF_CHECK_OK(scope.ToGraphDef(&graph_def));
+ std::unique_ptr<Session> session(NewSession(SessionOptions()));
+ TF_CHECK_OK(session->Create(graph_def));
+
+ std::vector<Tensor> outputs;
+ auto status =
+ (session->Run({}, {"run_model_op:0", "run_model_op:1"}, {}, &outputs));
+ EXPECT_FALSE(status.ok());
+}
+
+} // namespace
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ g_test_model_file = new tensorflow::string();
+ const std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag("test_model_file", g_test_model_file,
+ "Path to test tflite model file."),
+ };
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ CHECK(parse_result) << "Required test_model_file";
+ ::tensorflow::port::InitMain(argv[0], &argc, &argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.cc b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.cc
new file mode 100644
index 0000000000..c96795d499
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.cc
@@ -0,0 +1,45 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h"
+
+#include <vector>
+
+#include "tensorflow/cc/framework/scope.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+
+namespace tensorflow {
+namespace metrics {
+void RunTFLiteModelStage::AddToGraph(const Scope& scope, const Input& input) {
+ if (!scope.ok()) return;
+ Scope s = scope.WithOpName(name());
+
+ std::vector<NodeBuilder::NodeOut> _data = {ops::AsNodeOut(s, input)};
+ ::tensorflow::Node* ret;
+ auto builder = NodeBuilder(output_name(), "RunTFLiteModel")
+ .Input(_data)
+ .Attr("model_file_path", params_.model_file_path)
+ .Attr("input_type", params_.input_type)
+ .Attr("output_type", params_.output_type);
+
+ s.UpdateBuilder(&builder);
+ s.UpdateStatus(builder.Finalize(s.graph(), &ret));
+ if (!s.ok()) return;
+ s.UpdateStatus(s.DoShapeInference(ret));
+ this->stage_output_ = ::tensorflow::Output(ret, 0);
+}
+
+} // namespace metrics
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h
new file mode 100644
index 0000000000..90d12d6f42
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/run_tflite_model_stage.h
@@ -0,0 +1,53 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_RUN_TFLITE_MODEL_STAGE_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_RUN_TFLITE_MODEL_STAGE_H_
+
+#include <string>
+
+#include "tensorflow/contrib/lite/tools/accuracy/stage.h"
+
+namespace tensorflow {
+namespace metrics {
+// Stage that loads and runs a TFLite model.
+// Inputs: The input to TFLite model.
+// Outputs: The output of running the TFLite model.
+class RunTFLiteModelStage : public Stage {
+ public:
+ // The parameters for the stage.
+ struct Params {
+ string model_file_path;
+ std::vector<TensorShape> output_shape;
+ std::vector<DataType> input_type;
+ std::vector<DataType> output_type;
+ };
+
+ explicit RunTFLiteModelStage(const Params& params) : params_(params) {}
+
+ string name() const override { return "stage_run_tfl_model"; }
+ // TODO(shashishekhar): This stage can have multiple inputs and
+ // outputs, perhaps change the definition of stage.
+ string output_name() const override { return "stage_run_tfl_model_output"; }
+
+ void AddToGraph(const Scope& scope, const Input& input) override;
+
+ private:
+ Params params_;
+};
+
+} // namespace metrics
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_RUN_TFLITE_MODEL_STAGE_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/stage.h b/tensorflow/contrib/lite/tools/accuracy/stage.h
new file mode 100644
index 0000000000..8292ea2ec7
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/stage.h
@@ -0,0 +1,56 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_STAGE_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_STAGE_H_
+
+#include "tensorflow/cc/framework/scope.h"
+
+namespace tensorflow {
+namespace metrics {
+
+// A stage in an evaluation pipeline.
+// Each stage adds a subgraph to the pipeline. Stages can be chained
+// together.
+class Stage {
+ public:
+ Stage() = default;
+ Stage(const Stage&) = delete;
+ Stage& operator=(const Stage&) = delete;
+
+ Stage(const Stage&&) = delete;
+ Stage& operator=(const Stage&&) = delete;
+
+ // Adds a subgraph to given scope that takes in `input` as a parameter.
+ virtual void AddToGraph(const Scope& scope, const Input& input) = 0;
+ virtual ~Stage() {}
+
+ // The name of the stage.
+ // Can be used by derived classes for naming the subscope for the stage
+ // graph.
+ virtual string name() const = 0;
+
+ // The name of the output for the stage.
+ virtual string output_name() const = 0;
+
+ const ::tensorflow::Output& Output() const { return stage_output_; }
+
+ protected:
+ ::tensorflow::Output stage_output_;
+};
+} // namespace metrics
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_STAGE_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/utils.cc b/tensorflow/contrib/lite/tools/accuracy/utils.cc
new file mode 100644
index 0000000000..f5493301fc
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/utils.cc
@@ -0,0 +1,102 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/tools/accuracy/utils.h"
+
+#include <sys/stat.h>
+
+#include <cstring>
+#include <fstream>
+#include <memory>
+#include <string>
+
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/op_resolver.h"
+
+namespace tensorflow {
+namespace metrics {
+
+namespace utils {
+
+DataType GetTFDataType(TfLiteType tflite_type) {
+ switch (tflite_type) {
+ case kTfLiteFloat32:
+ return DT_FLOAT;
+ case kTfLiteUInt8:
+ return DT_UINT8;
+ default:
+ return DT_INVALID;
+ }
+}
+
+TensorShape GetTFLiteTensorShape(const TfLiteTensor& tflite_tensor) {
+ TensorShape shape;
+ for (int i = 0; i < tflite_tensor.dims->size; i++) {
+ shape.AddDim(tflite_tensor.dims->data[i]);
+ }
+ return shape;
+}
+
+Status ReadFileLines(const string& file_path,
+ std::vector<string>* lines_output) {
+ if (!lines_output) {
+ return errors::InvalidArgument("Invalid output");
+ }
+ std::vector<string> lines;
+ std::ifstream stream(file_path, std::ios_base::in);
+ if (!stream) {
+ return errors::InvalidArgument("Unable to open file: ", file_path);
+ }
+ std::string line;
+ while (std::getline(stream, line)) {
+ lines_output->push_back(line);
+ }
+ return Status::OK();
+}
+
+Status GetTFliteModelInfo(const string& model_file_path,
+ ModelInfo* model_info) {
+ if (model_file_path.empty()) {
+ return errors::InvalidArgument("Invalid model file.");
+ }
+ struct stat stat_buf;
+ if (stat(model_file_path.c_str(), &stat_buf) != 0) {
+ int error_num = errno;
+ return errors::InvalidArgument("Invalid model file: ", model_file_path,
+ std::strerror(error_num));
+ }
+
+ std::unique_ptr<tflite::FlatBufferModel> model;
+ std::unique_ptr<tflite::Interpreter> interpreter;
+ model = tflite::FlatBufferModel::BuildFromFile(model_file_path.data());
+ tflite::ops::builtin::BuiltinOpResolver resolver;
+
+ tflite::InterpreterBuilder(*model, resolver)(&interpreter);
+ if (!interpreter) {
+ return errors::InvalidArgument("Invalid model", model_file_path);
+ }
+ for (int i : interpreter->inputs()) {
+ TfLiteTensor* tensor = interpreter->tensor(i);
+ model_info->input_shapes.push_back(utils::GetTFLiteTensorShape(*tensor));
+ model_info->input_types.push_back(utils::GetTFDataType(tensor->type));
+ }
+ return Status::OK();
+}
+
+} // namespace utils
+} // namespace metrics
+} // namespace tensorflow
diff --git a/tensorflow/contrib/lite/tools/accuracy/utils.h b/tensorflow/contrib/lite/tools/accuracy/utils.h
new file mode 100644
index 0000000000..37cbad4d51
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/utils.h
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_UTILS_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_UTILS_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+
+namespace tensorflow {
+namespace metrics {
+
+namespace utils {
+
+struct ModelInfo {
+ std::vector<TensorShape> input_shapes;
+ std::vector<DataType> input_types;
+};
+
+Status GetTFliteModelInfo(const string& model_file_path, ModelInfo* model_info);
+
+DataType GetTFDataType(TfLiteType tflite_type);
+
+TensorShape GetTFLiteTensorShape(const TfLiteTensor& tflite_tensor);
+
+Status ReadFileLines(const string& file_path,
+ std::vector<string>* lines_output);
+} // namespace utils
+} // namespace metrics
+} // namespace tensorflow
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_ACCURACY_UTILS_H_
diff --git a/tensorflow/contrib/lite/tools/accuracy/utils_test.cc b/tensorflow/contrib/lite/tools/accuracy/utils_test.cc
new file mode 100644
index 0000000000..727eba21b6
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/accuracy/utils_test.cc
@@ -0,0 +1,76 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/tools/accuracy/utils.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace {
+tensorflow::string* g_test_model_file = nullptr;
+}
+
+namespace tensorflow {
+namespace metrics {
+namespace utils {
+namespace {
+
+TEST(UtilsTest, GetTFLiteModelInfoReturnsCorrectly) {
+ ASSERT_TRUE(g_test_model_file != nullptr);
+ string test_model_file = *g_test_model_file;
+ ASSERT_FALSE(test_model_file.empty());
+ // Passed graph has 4 inputs : a,b,c,d and 2 outputs x,y
+ // x = a+b+c, y=b+c+d
+ // Input and outputs have shape : {1,8,8,3}
+ ModelInfo model_info;
+ auto status = GetTFliteModelInfo(test_model_file, &model_info);
+ TF_CHECK_OK(status);
+ ASSERT_EQ(4, model_info.input_shapes.size());
+ ASSERT_EQ(4, model_info.input_types.size());
+
+ for (int i = 0; i < 4; i++) {
+ const TensorShape& shape = model_info.input_shapes[i];
+ DataType dataType = model_info.input_types[i];
+ EXPECT_TRUE(shape.IsSameSize({1, 8, 8, 3}));
+ EXPECT_EQ(DT_FLOAT, dataType);
+ }
+}
+
+TEST(UtilsTest, GetTFliteModelInfoIncorrectFile) {
+ ModelInfo model_info;
+ auto status = GetTFliteModelInfo("non_existent_file", &model_info);
+ EXPECT_FALSE(status.ok());
+}
+
+} // namespace
+} // namespace utils
+} // namespace metrics
+} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ g_test_model_file = new tensorflow::string();
+ const std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag("test_model_file", g_test_model_file,
+ "Path to test tflite model file."),
+ };
+ const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ CHECK(parse_result) << "Required test_model_file";
+ ::tensorflow::port::InitMain(argv[0], &argc, &argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/tools/benchmark/README.md b/tensorflow/contrib/lite/tools/benchmark/README.md
index f1e257ad10..8d997639fb 100644
--- a/tensorflow/contrib/lite/tools/benchmark/README.md
+++ b/tensorflow/contrib/lite/tools/benchmark/README.md
@@ -9,7 +9,7 @@ of runs. Aggregrate latency statistics are reported after running the benchmark.
The instructions below are for running the binary on Desktop and Android,
for iOS please use the
-[iOS benchmark app] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios).
+[iOS benchmark app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark/ios).
## Parameters
@@ -17,11 +17,6 @@ The binary takes the following required parameters:
* `graph`: `string` \
The path to the TFLite model file.
-* `input_layer`: `string` \
- The name of the input layer, this is typically the first layer of the model.
-* `input_layer_shape`: `string` \
- The shape of the input layer. This is a comma separated string of the shape
- of tensor of input layer.
and the following optional parameters:
@@ -29,11 +24,13 @@ and the following optional parameters:
The number of threads to use for running TFLite interpreter.
* `warmup_runs`: `int` (default=1) \
The number of warmup runs to do before starting the benchmark.
+* `num_runs`: `int` (default=50) \
+ The number of runs. Increase this to reduce variance.
* `run_delay`: `float` (default=-1.0) \
The delay in seconds between subsequent benchmark runs. Non-positive values
mean use no delay.
* `use_nnapi`: `bool` (default=false) \
- Whether to use [Android NNAPI] (https://developer.android.com/ndk/guides/neuralnetworks/).
+ Whether to use [Android NNAPI](https://developer.android.com/ndk/guides/neuralnetworks/).
This API is available on recent Android devices.
## To build/install/run
@@ -75,8 +72,6 @@ adb push mobilenet_quant_v1_224.tflite /data/local/tmp
```
adb shell /data/local/tmp/benchmark_model \
--graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \
- --input_layer="input" \
- --input_layer_shape="1,224,224,3" \
--num_threads=4
```
@@ -93,13 +88,10 @@ For example:
```
bazel-bin/tensorflow/contrib/lite/tools/benchmark/benchmark_model \
--graph=mobilenet_quant_v1_224.tflite \
- --input_layer="Placeholder" \
- --input_layer_shape="1,224,224,3" \
--num_threads=4
```
-The MobileNet graph used as an example here may be downloaded from
-https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip
+The MobileNet graph used as an example here may be downloaded from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip).
## Reducing variance between runs on Android.
@@ -117,8 +109,6 @@ can use the following command:
```
adb shell taskset f0 /data/local/tmp/benchmark_model \
--graph=/data/local/tmp/mobilenet_quant_v1_224.tflite \
- --input_layer="input" \
- --input_layer_shape="1,224,224,3" \
--num_threads=1
```
@@ -205,5 +195,3 @@ Memory (bytes): count=0
Average inference timings in us: Warmup: 83235, Init: 38467, no stats: 79760.9
```
-
-
diff --git a/tensorflow/contrib/lite/tools/benchmark/ios/README.md b/tensorflow/contrib/lite/tools/benchmark/ios/README.md
index c8d3307e29..46144f7bf8 100644
--- a/tensorflow/contrib/lite/tools/benchmark/ios/README.md
+++ b/tensorflow/contrib/lite/tools/benchmark/ios/README.md
@@ -17,8 +17,8 @@ Mobilenet_1.0_224 model
## To build/install/run
-- Follow instructions at [iOS build for TFLite]
-(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/ios.md)
+- Follow instructions at
+[iOS build for TFLite](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/ios.md)
to build TFLite.
Running
diff --git a/tensorflow/contrib/lite/tools/make/Makefile b/tensorflow/contrib/lite/tools/make/Makefile
index e30cc1d70e..59bdb10811 100644
--- a/tensorflow/contrib/lite/tools/make/Makefile
+++ b/tensorflow/contrib/lite/tools/make/Makefile
@@ -24,6 +24,21 @@ HOST_ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32
TARGET := $(HOST_OS)
TARGET_ARCH := $(HOST_ARCH)
+INCLUDES := \
+-I. \
+-I$(MAKEFILE_DIR)/../../../../../ \
+-I$(MAKEFILE_DIR)/../../../../../../ \
+-I$(MAKEFILE_DIR)/downloads/ \
+-I$(MAKEFILE_DIR)/downloads/eigen \
+-I$(MAKEFILE_DIR)/downloads/gemmlowp \
+-I$(MAKEFILE_DIR)/downloads/neon_2_sse \
+-I$(MAKEFILE_DIR)/downloads/farmhash/src \
+-I$(MAKEFILE_DIR)/downloads/flatbuffers/include \
+-I$(OBJDIR)
+# This is at the end so any globally-installed frameworks like protobuf don't
+# override local versions in the source tree.
+INCLUDES += -I/usr/local/include
+
# These are the default libraries needed, but they can be added to or
# overridden by the platform-specific settings in target makefiles.
LIBS := \
@@ -44,55 +59,17 @@ ARFLAGS := -r
TARGET_TOOLCHAIN_PREFIX :=
CC_PREFIX :=
-# These target-specific makefiles should modify or replace options like
-# CXXFLAGS or LIBS to work for a specific targetted architecture. All logic
-# based on platforms or architectures should happen within these files, to
-# keep this main makefile focused on the sources and dependencies.
-include $(wildcard $(MAKEFILE_DIR)/targets/*_makefile.inc)
-
-# Where compiled objects are stored.
-GENDIR := $(MAKEFILE_DIR)/gen/$(TARGET)_$(TARGET_ARCH)/
-OBJDIR := $(GENDIR)obj/
-BINDIR := $(GENDIR)bin/
-LIBDIR := $(GENDIR)lib/
-
-INCLUDES := \
--I. \
--I$(MAKEFILE_DIR)/../../../../../ \
--I$(MAKEFILE_DIR)/../../../../../../ \
--I$(MAKEFILE_DIR)/downloads/ \
--I$(MAKEFILE_DIR)/downloads/eigen \
--I$(MAKEFILE_DIR)/downloads/gemmlowp \
--I$(MAKEFILE_DIR)/downloads/neon_2_sse \
--I$(MAKEFILE_DIR)/downloads/farmhash/src \
--I$(MAKEFILE_DIR)/downloads/flatbuffers/include \
--I$(OBJDIR)
-# This is at the end so any globally-installed frameworks like protobuf don't
-# override local versions in the source tree.
-INCLUDES += -I/usr/local/include
-
-CXX := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}g++
-CC := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}gcc
-AR := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}ar
-
# This library is the main target for this makefile. It will contain a minimal
# runtime that can be linked in to other programs.
LIB_NAME := libtensorflow-lite.a
-LIB_PATH := $(LIBDIR)$(LIB_NAME)
-
-# A small example program that shows how to link against the library.
-MINIMAL_PATH := $(BINDIR)minimal
# Benchmark static library and binary
BENCHMARK_LIB_NAME := benchmark-lib.a
BENCHMARK_BINARY_NAME := benchmark_model
-BENCHMARK_LIB := $(LIBDIR)$(BENCHMARK_LIB_NAME)
-BENCHMARK_BINARY := $(BINDIR)$(BENCHMARK_BINARY_NAME)
+# A small example program that shows how to link against the library.
MINIMAL_SRCS := \
tensorflow/contrib/lite/examples/minimal/minimal.cc
-MINIMAL_OBJS := $(addprefix $(OBJDIR), \
-$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MINIMAL_SRCS))))
# What sources we want to compile, must be kept in sync with the main Bazel
# build files.
@@ -105,7 +82,9 @@ PROFILE_SUMMARIZER_SRCS := \
CORE_CC_ALL_SRCS := \
$(wildcard tensorflow/contrib/lite/*.cc) \
-$(wildcard tensorflow/contrib/lite/*.c)
+$(wildcard tensorflow/contrib/lite/*.c) \
+$(wildcard tensorflow/contrib/lite/c/*.c) \
+$(wildcard tensorflow/contrib/lite/core/api/*.cc)
ifneq ($(BUILD_TYPE),micro)
CORE_CC_ALL_SRCS += \
$(wildcard tensorflow/contrib/lite/kernels/*.cc) \
@@ -136,10 +115,6 @@ tensorflow/contrib/lite/nnapi_delegate.cc
endif
# Filter out all the excluded files.
TF_LITE_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS))
-# File names of the intermediate files target compilation generates.
-TF_LITE_CC_OBJS := $(addprefix $(OBJDIR), \
-$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_LITE_CC_SRCS))))
-LIB_OBJS := $(TF_LITE_CC_OBJS)
# Benchmark sources
BENCHMARK_SRCS_DIR := tensorflow/contrib/lite/tools/benchmark
@@ -151,6 +126,40 @@ BENCHMARK_SRCS := $(filter-out \
$(wildcard $(BENCHMARK_SRCS_DIR)/*_test.cc), \
$(BENCHMARK_ALL_SRCS))
+# These target-specific makefiles should modify or replace options like
+# CXXFLAGS or LIBS to work for a specific targetted architecture. All logic
+# based on platforms or architectures should happen within these files, to
+# keep this main makefile focused on the sources and dependencies.
+include $(wildcard $(MAKEFILE_DIR)/targets/*_makefile.inc)
+
+ALL_SRCS := \
+ $(MINIMAL_SRCS) \
+ $(PROFILER_SRCS) \
+ $(PROFILER_SUMMARY_SRCS) \
+ $(TF_LITE_CC_SRCS) \
+ $(BENCHMARK_SRCS)
+
+# Where compiled objects are stored.
+GENDIR := $(MAKEFILE_DIR)/gen/$(TARGET)_$(TARGET_ARCH)/
+OBJDIR := $(GENDIR)obj/
+BINDIR := $(GENDIR)bin/
+LIBDIR := $(GENDIR)lib/
+
+LIB_PATH := $(LIBDIR)$(LIB_NAME)
+BENCHMARK_LIB := $(LIBDIR)$(BENCHMARK_LIB_NAME)
+BENCHMARK_BINARY := $(BINDIR)$(BENCHMARK_BINARY_NAME)
+MINIMAL_BINARY := $(BINDIR)minimal
+
+CXX := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}g++
+CC := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}gcc
+AR := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}ar
+
+MINIMAL_OBJS := $(addprefix $(OBJDIR), \
+$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MINIMAL_SRCS))))
+
+LIB_OBJS := $(addprefix $(OBJDIR), \
+$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_LITE_CC_SRCS))))
+
BENCHMARK_OBJS := $(addprefix $(OBJDIR), \
$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(BENCHMARK_SRCS))))
@@ -164,7 +173,7 @@ $(OBJDIR)%.o: %.c
$(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@
# The target that's compiled if there's no command-line arguments.
-all: $(LIB_PATH) $(MINIMAL_PATH) $(BENCHMARK_BINARY)
+all: $(LIB_PATH) $(MINIMAL_BINARY) $(BENCHMARK_BINARY)
# The target that's compiled for micro-controllers
micro: $(LIB_PATH)
@@ -178,19 +187,18 @@ $(LIB_PATH): tensorflow/contrib/lite/schema/schema_generated.h $(LIB_OBJS)
@mkdir -p $(dir $@)
$(AR) $(ARFLAGS) $(LIB_PATH) $(LIB_OBJS)
-$(MINIMAL_PATH): $(MINIMAL_OBJS) $(LIB_PATH)
+$(MINIMAL_BINARY): $(MINIMAL_OBJS) $(LIB_PATH)
@mkdir -p $(dir $@)
$(CXX) $(CXXFLAGS) $(INCLUDES) \
- -o $(MINIMAL_PATH) $(MINIMAL_OBJS) \
+ -o $(MINIMAL_BINARY) $(MINIMAL_OBJS) \
$(LIBFLAGS) $(LIB_PATH) $(LDFLAGS) $(LIBS)
-
$(BENCHMARK_LIB) : $(LIB_PATH) $(BENCHMARK_OBJS)
@mkdir -p $(dir $@)
$(AR) $(ARFLAGS) $(BENCHMARK_LIB) $(LIB_OBJS) $(BENCHMARK_OBJS)
benchmark_lib: $(BENCHMARK_LIB)
-$(info $(BENCHMARK_BINARY))
+
$(BENCHMARK_BINARY) : $(BENCHMARK_LIB)
@mkdir -p $(dir $@)
$(CXX) $(CXXFLAGS) $(INCLUDES) \
@@ -213,4 +221,4 @@ cleantarget:
$(DEPDIR)/%.d: ;
.PRECIOUS: $(DEPDIR)/%.d
--include $(patsubst %,$(DEPDIR)/%.d,$(basename $(TF_CC_SRCS)))
+-include $(patsubst %,$(DEPDIR)/%.d,$(basename $(ALL_SRCS)))
diff --git a/tensorflow/contrib/lite/tools/optimize/BUILD b/tensorflow/contrib/lite/tools/optimize/BUILD
new file mode 100644
index 0000000000..51ccaedc23
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/optimize/BUILD
@@ -0,0 +1,25 @@
+# TODO(suharshs): Write quantize_weights tests that use small exportable files.
+# Then we can remove this file.
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
+
+cc_library(
+ name = "quantize_weights",
+ srcs = ["quantize_weights.cc"],
+ hdrs = ["quantize_weights.h"],
+ deps = [
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels/internal:tensor_utils",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ "//tensorflow/core:tflite_portable_logging",
+ "@com_google_absl//absl/memory",
+ "@flatbuffers",
+ ],
+)
diff --git a/tensorflow/contrib/lite/tools/optimize/g3doc/quantize_weights.md b/tensorflow/contrib/lite/tools/optimize/g3doc/quantize_weights.md
new file mode 100644
index 0000000000..93fe576583
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/optimize/g3doc/quantize_weights.md
@@ -0,0 +1,70 @@
+# TFLite Quantize Weights Tool
+
+## Recommended usage
+
+The Quantize Weights transformation is integrated with
+[tflite_convert](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/toco/g3doc/cmdline_reference.md#transformation-flags).
+
+The recommended way of invoking this tool is by simply adding the
+`--post_training_quantize` flag to your original tflite_convert invocation. For
+example,
+
+```
+tflite_convert \
+ --output_file=/tmp/foo.tflite \
+ --saved_model_dir=/tmp/saved_model \
+ --post_training_quantize
+```
+
+## Overview
+
+The Quantize Weights tool provides a simple way to quantize the weights for a
+float TFLite model.
+
+TODO(raghuramank): Add link to weight quantization tutorial.
+
+### Size reduction
+
+float32 weights will be converted to 8 bit integers. This results in a model
+that is around 1/4th the size of the original model.
+
+### Latency reduction
+
+TFLite also has "hybrid" kernels implemented for many operations. These "hybrid"
+kernels take 8 bit integer weights and float inputs, dynamically quantize the
+inputs tensor (based on the input tensor's min and max elements), and does
+computations using the 8 bit integer values. This results in a 2-4x reduction in
+latency for "hybrid" kernels. In this mode the inference type is still FLOAT
+since the inputs and output to each operation is still float.
+
+For operations that do not yet have "hybrid" kernels implemented, we introduce a
+Dequantize operation after 8 bit integer weights. These convert weights back to
+float32 during inference to allow original float32 kernels to run. Since we
+cache dequantized results, the result of each of this dequantized path will be
+on-par with the original float model.
+
+TODO(yunluli): Fill in latency results from latency experiments.
+
+### Accuracy
+
+Since this technique quantizes weights after the model has already been trained,
+there can be accuracy drops depending on the model. For common CNN networks, the
+observed accuracy drops are small and can be seen below.
+
+TODO(yunluli): Fill in accuracy results from accuracy experiments.
+
+## Direct usage
+
+One can also invoke the Quantize Weights directly via C++ if they have a float
+`::tflite::Model` that they want to convert. They must provide a
+`flatbuffers::FlatBufferBuilder` which owns the underlying buffer of the created
+model. Here is an example invocation:
+
+```
+::tflite::Model* input_model = ...;
+flatbuffers::FlatBufferBuilder builder;
+TfLiteStatus status = ::tflite::optimize::QuantizeWeights(&builder, input_model);
+CHECK(status, kTfLiteStatusOk);
+const uint8_t* buffer = builder->GetBufferPointer();
+tflite::Model* output_model = ::tflite::GetModel(buffer);
+```
diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
new file mode 100644
index 0000000000..b863108aa4
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc
@@ -0,0 +1,432 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/tools/optimize/quantize_weights.h"
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "flatbuffers/flexbuffers.h"
+#include "absl/memory/memory.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tflite {
+namespace optimize {
+
+namespace {
+
+typedef struct {
+ TensorT* tensor;
+ // The index of the tensor to quantize in subgraph->tensors.
+ int32_t tensor_idx;
+ // The index of the tensor of the weight tensor to be quantize in op->inputs.
+ int32_t op_input_idx;
+ // True if the tensor supports hybrid evaluation.
+ bool eval_hybrid;
+} TensorInfo;
+
+// The default minimum number of elements a weights array must have to be
+// quantized by this transformation.
+const int kWeightsMinNumElementsDefault = 1024;
+
+// Nudge min and max so that floating point 0 falls exactly on a quantized
+// value, returning the nudges scale and zero_point.
+//
+// Although this code originates from FakeQuantization in quantized training,
+// we may deviate from that implementation as we please since we do not fine
+// tune the weights with quantized training.
+void GetAsymmetricQuantizationParams(
+ const float min, const float max, const int quant_min, const int quant_max,
+ QuantizationParametersT* quantization_params) {
+ // Adjust the boundaries to guarantee 0 is included.
+ const float quant_min_float = std::min(static_cast<float>(quant_min), 0.0f);
+ const float quant_max_float = std::max(static_cast<float>(quant_max), 0.0f);
+ const float scale = (max - min) / (quant_max_float - quant_min_float);
+ const float zero_point_from_min = quant_min_float - min / scale;
+ int64_t zero_point;
+ if (zero_point_from_min < quant_min_float) {
+ zero_point = static_cast<int64_t>(quant_min);
+ } else if (zero_point_from_min > quant_max_float) {
+ zero_point = static_cast<int64_t>(quant_max);
+ } else {
+ zero_point = static_cast<int64_t>(std::round(zero_point_from_min));
+ }
+ quantization_params->scale = std::vector<float>(1, scale);
+ quantization_params->zero_point = std::vector<int64_t>(1, zero_point);
+}
+
+// Returns the number of elements in tensor.
+uint64_t NumElements(const TensorT* tensor) {
+ if (tensor->shape.empty()) {
+ LOG(FATAL) << "Tensor has no shape information.";
+ }
+ uint64_t num_elements = 1;
+ for (const uint64_t dim : tensor->shape) {
+ num_elements *= dim;
+ }
+ return num_elements;
+}
+
+uint64_t CountTensorConsumers(const ModelT* model, const SubGraphT* subgraph,
+ int32_t tensor_idx) {
+ uint64_t count = 0;
+ for (int op_idx = 0; op_idx < subgraph->operators.size(); ++op_idx) {
+ const OperatorT* op = subgraph->operators[op_idx].get();
+ if (op == nullptr) {
+ continue;
+ }
+ for (int i = 0; i < op->inputs.size(); ++i) {
+ if (op->inputs[i] == tensor_idx) {
+ count++;
+ }
+ }
+ }
+ return count;
+}
+
+// Gets the list of op->inputs indices of the weights inputs to be quantized for
+// the provided op.
+std::vector<int32_t> GetWeightInputIndices(const BuiltinOperator& op_code) {
+ if (op_code == BuiltinOperator_CONV_2D ||
+ op_code == BuiltinOperator_DEPTHWISE_CONV_2D ||
+ op_code == BuiltinOperator_FULLY_CONNECTED ||
+ op_code == BuiltinOperator_EMBEDDING_LOOKUP) {
+ return {1};
+ } else if (op_code == BuiltinOperator_SVDF) {
+ // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/svdf.cc
+ return {1, 2};
+ } else if (op_code == BuiltinOperator_LSTM ||
+ op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM) {
+ // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/lstm.cc
+ // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
+ return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16};
+ } else if (op_code == BuiltinOperator_RNN ||
+ op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) {
+ // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/basic_rnn.cc
+ // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc
+ return {1, 2};
+ } else if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM) {
+ // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
+ return {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16,
+ 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 33};
+ } else if (op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN) {
+ // https://www.tensorflow.org/code/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
+ return {1, 2, 4, 5};
+ }
+ return {};
+}
+
+// Returns true if the operator supports hybrid evaluation.
+bool IsHybridEvaluationOp(const OperatorT* op, const BuiltinOperator& op_code) {
+ // Operations that support hybrid evaluation.
+ bool eval_hybrid = false;
+ if (op_code == BuiltinOperator_FULLY_CONNECTED ||
+ op_code == BuiltinOperator_CONV_2D || op_code == BuiltinOperator_SVDF ||
+ op_code == BuiltinOperator_EMBEDDING_LOOKUP ||
+ op_code == BuiltinOperator_RNN ||
+ op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM ||
+ op_code == BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN ||
+ op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM ||
+ op_code == BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN) {
+ eval_hybrid = true;
+ } else if (op_code == BuiltinOperator_LSTM) {
+ const LSTMOptionsT* options = op->builtin_options.AsLSTMOptions();
+ // Only lstm kernel_type full supports hybrid evaluation.
+ if (options->kernel_type == LSTMKernelType_FULL) {
+ eval_hybrid = true;
+ }
+ }
+ return eval_hybrid;
+}
+
+// Returns a vector of TensorInfos for each input tensor of op that should be
+// quantized.
+std::vector<TensorInfo> GetQuantizableTensorsFromOperator(
+ const ModelT* model, const OperatorT* op, uint64_t weights_min_num_elements,
+ bool use_hybrid_evaluation) {
+ SubGraphT* subgraph = model->subgraphs.at(0).get();
+ const BuiltinOperator op_code =
+ model->operator_codes[op->opcode_index]->builtin_code;
+
+ std::vector<TensorInfo> tensor_infos;
+
+ bool eval_hybrid = use_hybrid_evaluation && IsHybridEvaluationOp(op, op_code);
+
+ std::vector<int32_t> op_input_indices = GetWeightInputIndices(op_code);
+ for (const int32_t op_input_idx : op_input_indices) {
+ int32_t tensor_idx = op->inputs[op_input_idx];
+
+ if (tensor_idx == -1) {
+ LOG(INFO) << "Skipping optional tensor input " << op_input_idx
+ << " of operation " << EnumNameBuiltinOperator(op_code);
+ continue;
+ }
+
+ TensorT* tensor = subgraph->tensors[tensor_idx].get();
+ // TODO(suharshs): Support shared weights, i.e. If two tensors share the
+ // same weight array, things may break. (i.e. SSD object detection)
+ if (!eval_hybrid &&
+ CountTensorConsumers(model, subgraph, tensor_idx) != 1) {
+ LOG(INFO) << "Skipping quantization of tensor " << tensor->name
+ << " that is shared between multiple multiple operations.";
+ continue;
+ }
+
+ if (tensor->type != TensorType_FLOAT32) {
+ LOG(INFO) << "Skipping quantization of tensor " << tensor->name
+ << " that is not type float.";
+ continue;
+ }
+
+ const uint64_t num_elements = NumElements(tensor);
+ if (num_elements < weights_min_num_elements) {
+ LOG(INFO) << "Skipping quantization of tensor " << tensor->name
+ << " because it has fewer than " << weights_min_num_elements
+ << " elements (" << num_elements << ").";
+ // If one of the weights isn't quantized, then we cannot use the hybrid
+ // kernel for this operation, since it expects everything to be quantized.
+ eval_hybrid = false;
+ continue;
+ }
+
+ TensorInfo tensor_info;
+ tensor_info.eval_hybrid = eval_hybrid;
+ tensor_info.op_input_idx = op_input_idx;
+ tensor_info.tensor_idx = tensor_idx;
+ tensor_info.tensor = tensor;
+
+ tensor_infos.push_back(tensor_info);
+ }
+
+ return tensor_infos;
+}
+
+// Quantizes tensor using asymmetric quantization with the min and max elements
+// of the tensor. This is needed to pass to Dequantize operations.
+TfLiteStatus AsymmetricQuantizeTensor(ModelT* model, TensorT* tensor) {
+ BufferT* buffer = model->buffers[tensor->buffer].get();
+ float* float_data = reinterpret_cast<float*>(buffer->data.data());
+ const uint64_t num_elements = NumElements(tensor);
+ LOG(INFO) << "Quantizing tensor " << tensor->name << " with " << num_elements
+ << " elements for float evaluation.";
+
+ // Compute the quantization params.
+ float min_value = *std::min_element(float_data, float_data + num_elements);
+ float max_value = *std::max_element(float_data, float_data + num_elements);
+
+ if (tensor->quantization == nullptr) {
+ tensor->quantization = absl::make_unique<QuantizationParametersT>();
+ }
+ GetAsymmetricQuantizationParams(min_value, max_value, 0, 255,
+ tensor->quantization.get());
+
+ // Quantize the buffer.
+ std::vector<uint8_t> quantized_buffer;
+ quantized_buffer.resize(num_elements);
+ const double inverse_scale = 1. / tensor->quantization->scale[0];
+ for (std::size_t i = 0; i < num_elements; i++) {
+ const float src_val = float_data[i];
+ double scaled_val;
+ if (tensor->quantization->scale[0] == 0) {
+ scaled_val = tensor->quantization->zero_point[0];
+ } else {
+ scaled_val =
+ tensor->quantization->zero_point[0] + inverse_scale * src_val;
+ }
+ uint8_t integer_val = static_cast<uint8_t>(std::round(scaled_val));
+ quantized_buffer[i] = integer_val;
+ }
+ model->buffers[tensor->buffer]->data = quantized_buffer;
+
+ // Update the tensor type.
+ tensor->type = TensorType_UINT8;
+
+ return kTfLiteOk;
+}
+
+// Quantizes tensor using symmetric quantization with the min and max elements
+// of the tensor. This is need for operations with hybrid evaluation
+// implemented.
+TfLiteStatus SymmetricQuantizeTensor(ModelT* model, TensorT* tensor) {
+ BufferT* buffer = model->buffers[tensor->buffer].get();
+ float* float_data = reinterpret_cast<float*>(buffer->data.data());
+ const uint64_t num_elements = NumElements(tensor);
+ LOG(INFO) << "Quantizing tensor " << tensor->name << " with " << num_elements
+ << " elements for hybrid evaluation.";
+
+ std::vector<int8_t> quantized_buffer;
+ quantized_buffer.resize(num_elements);
+
+ float min_value, max_value, scaling_factor;
+ tensor_utils::SymmetricQuantizeFloats(float_data, num_elements,
+ quantized_buffer.data(), &min_value,
+ &max_value, &scaling_factor);
+
+ if (tensor->quantization == nullptr) {
+ tensor->quantization = absl::make_unique<QuantizationParametersT>();
+ }
+ tensor->quantization->scale = std::vector<float>(1, scaling_factor);
+ tensor->quantization->zero_point = std::vector<int64_t>(1, 0);
+
+ uint8_t* uint8_buffer = reinterpret_cast<uint8_t*>(quantized_buffer.data());
+ model->buffers[tensor->buffer]->data.assign(uint8_buffer,
+ uint8_buffer + num_elements);
+
+ // Update the tensor type.
+ tensor->type = TensorType_UINT8;
+
+ return kTfLiteOk;
+}
+
+// Returns the index of the Dequantize op_code.
+// If a Dequantize op_code doesn't exist, adds it and returns its index.
+int32_t GetOrInsertDequantizeOpCodeIndex(ModelT* model) {
+ for (int i = 0; i < model->operator_codes.size(); ++i) {
+ if (model->operator_codes[i]->builtin_code == BuiltinOperator_DEQUANTIZE) {
+ return i;
+ }
+ }
+ model->operator_codes.push_back(absl::make_unique<OperatorCodeT>());
+ int op_code_idx = model->operator_codes.size() - 1;
+ model->operator_codes[op_code_idx]->builtin_code = BuiltinOperator_DEQUANTIZE;
+ // TODO(suharshs): How should the version be set in this op_code?
+
+ // Return the index of the newly placed OperatorCodeT.
+ return op_code_idx;
+}
+
+// Creates a Dequantize OperatorT object.
+void MakeDequantizeOperator(ModelT* model, std::unique_ptr<OperatorT>* op,
+ int32_t input, int32_t output) {
+ OperatorT* op_raw = new OperatorT;
+ op_raw->opcode_index = GetOrInsertDequantizeOpCodeIndex(model);
+ op_raw->inputs = {input};
+ op_raw->outputs = {output};
+
+ op->reset(op_raw);
+}
+
+// Create a new TensorT object.
+void MakeTensor(const string& name, const std::vector<int32_t>& shape,
+ std::unique_ptr<TensorT>* tensor) {
+ TensorT* tensor_raw = new TensorT;
+ tensor_raw->name = name;
+ tensor_raw->shape = shape;
+
+ tensor->reset(tensor_raw);
+}
+
+TfLiteStatus QuantizeWeightsInternal(flatbuffers::FlatBufferBuilder* builder,
+ const Model* input_model,
+ bool use_hybrid_evaluation,
+ uint64_t weights_min_num_elements) {
+ std::unique_ptr<ModelT> model;
+ model.reset(input_model->UnPack());
+
+ // TODO(suharshs): When models support multiple subgraphs, add support.
+ if (model->subgraphs.size() != 1) {
+ LOG(ERROR) << "Quantize weights tool only supports tflite models with one "
+ "subgraph.";
+ return kTfLiteError;
+ }
+
+ SubGraphT* subgraph = model->subgraphs.at(0).get();
+
+ std::vector<std::unique_ptr<OperatorT>> new_operators;
+ for (int i = 0; i < subgraph->operators.size(); ++i) {
+ OperatorT* op = subgraph->operators[i].get();
+
+ std::vector<TensorInfo> tensor_infos = GetQuantizableTensorsFromOperator(
+ model.get(), op, weights_min_num_elements, use_hybrid_evaluation);
+
+ for (const TensorInfo& tensor_info : tensor_infos) {
+ if (tensor_info.eval_hybrid) {
+ // Quantize the tensor.
+ TF_LITE_ENSURE_STATUS(
+ SymmetricQuantizeTensor(model.get(), tensor_info.tensor));
+ } else {
+ // Quantize the tensor.
+ TF_LITE_ENSURE_STATUS(
+ AsymmetricQuantizeTensor(model.get(), tensor_info.tensor));
+
+ // Create a new tensor to be the output of the dequantize op.
+ std::unique_ptr<TensorT> dequantize_output;
+ MakeTensor(tensor_info.tensor->name + "_dequantize",
+ tensor_info.tensor->shape, &dequantize_output);
+ const int32_t dequantize_output_idx = subgraph->tensors.size();
+ subgraph->tensors.push_back(std::move(dequantize_output));
+
+ // Create the Dequantize operation.
+ std::unique_ptr<OperatorT> dequantize_op;
+ MakeDequantizeOperator(model.get(), &dequantize_op,
+ tensor_info.tensor_idx, dequantize_output_idx);
+
+ // Update the op_input of tensor_idx to dequantize_output_idx.
+ op->inputs[tensor_info.op_input_idx] = dequantize_output_idx;
+
+ // Insert the newly created Dequantize operation.
+ new_operators.push_back(std::move(dequantize_op));
+ }
+ }
+ // After (maybe) quantizing inputs, we copy the operator into the new list.
+ new_operators.push_back(std::move(subgraph->operators[i]));
+ }
+
+ // At this point all unique_ptrs in the original operators are invalid, and
+ // we need to replace it with the new_operators vector.
+ subgraph->operators = std::move(new_operators);
+
+ flatbuffers::Offset<Model> output_model_location =
+ Model::Pack(*builder, model.get());
+ FinishModelBuffer(*builder, output_model_location);
+
+ return kTfLiteOk;
+}
+
+} // namespace
+
+namespace internal {
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+ const Model* input_model,
+ bool use_hybrid_evaluation) {
+ // By default we require that only weights with more than
+ // kWeightsMinSizeDefault elements are quantized.
+ return QuantizeWeightsInternal(builder, input_model, use_hybrid_evaluation,
+ kWeightsMinNumElementsDefault);
+}
+} // namespace internal
+
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+ const Model* input_model,
+ uint64_t weights_min_num_elements) {
+ return QuantizeWeightsInternal(builder, input_model, true,
+ weights_min_num_elements);
+}
+
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+ const Model* input_model) {
+ // By default we require that only weights with more than
+ // kWeightsMinSizeDefault elements are quantized.
+ return QuantizeWeightsInternal(builder, input_model, true,
+ kWeightsMinNumElementsDefault);
+}
+
+} // namespace optimize
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.h b/tensorflow/contrib/lite/tools/optimize/quantize_weights.h
new file mode 100644
index 0000000000..706f10b87b
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.h
@@ -0,0 +1,57 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_TOOLS_OPTIMIZE_QUANTIZE_WEIGHTS_H_
+#define TENSORFLOW_CONTRIB_LITE_TOOLS_OPTIMIZE_QUANTIZE_WEIGHTS_H_
+
+#include <memory>
+#include "flatbuffers/flexbuffers.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+
+namespace tflite {
+namespace optimize {
+
+// Quantizes input_model and populates the provided builder with the new model.
+// By default only weights tensors weight more than 1024 elements will be
+// quantized.
+//
+// A tflite::Model can be obtained from the builder with:
+// const uint8_t* buffer = builder->GetBufferPointer();
+// tflite::Model* model = GetModel(buffer);
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+ const Model* input_model);
+
+// Same as above, but only weights with greater than or equal
+// weights_min_num_elements elements will be quantized.
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+ const Model* input_model,
+ uint64_t weights_min_num_elements);
+
+namespace internal {
+// If use_hybrid_evaluation is false, will disable using hybrid eval for
+// operations that support it.
+//
+// We use this internal QuantizeWeights call to test models with hybrid
+// evaluation disabled.
+TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
+ const Model* input_model,
+ bool use_hybrid_evaluation);
+} // namespace internal
+
+} // namespace optimize
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_TOOLS_OPTIMIZE_QUANTIZE_WEIGHTS_H_
diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc
new file mode 100644
index 0000000000..387b3471c2
--- /dev/null
+++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights_test.cc
@@ -0,0 +1,226 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/tools/optimize/quantize_weights.h"
+
+#include <memory>
+
+#include "flatbuffers/flexbuffers.h"
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+
+namespace tflite {
+namespace optimize {
+namespace {
+
+class QuantizeWeightsTest : public ::testing::Test {
+ protected:
+ int GetElementsNum(const TensorT* tensor) {
+ int tensor_size = 1;
+ for (const int dim : tensor->shape) {
+ tensor_size *= dim;
+ }
+ return tensor_size;
+ }
+
+ const OperatorT* GetOpWithOutput(const SubGraphT* subgraph,
+ int32_t output_tensor_idx) {
+ for (int i = 0; i < subgraph->operators.size(); ++i) {
+ OperatorT* op = subgraph->operators[i].get();
+ if (std::find(op->outputs.begin(), op->outputs.end(),
+ output_tensor_idx) != op->outputs.end()) {
+ return op;
+ }
+ }
+ return nullptr;
+ }
+
+ void SymmetricDequantizeAndCompare(const BufferT* input_buffer,
+ const BufferT* output_buffer,
+ float scale) {
+ const float* input_buffer_data =
+ reinterpret_cast<const float*>(input_buffer->data.data());
+ const int8_t* output_buffer_data =
+ reinterpret_cast<const int8_t*>(output_buffer->data.data());
+ for (int i = 0; i < output_buffer->data.size(); i++) {
+ float diff = input_buffer_data[i] - (output_buffer_data[i] * scale);
+ ASSERT_TRUE(std::abs(diff) <= scale);
+ }
+ }
+
+ void AsymmetricDequantizeAndCompare(const BufferT* input_buffer,
+ const BufferT* output_buffer, float scale,
+ int64_t zero_point) {
+ const float* input_buffer_data =
+ reinterpret_cast<const float*>(input_buffer->data.data());
+ const uint8_t* output_buffer_data = output_buffer->data.data();
+ for (int i = 0; i < output_buffer->data.size(); i++) {
+ float diff =
+ input_buffer_data[i] - ((output_buffer_data[i] - zero_point) * scale);
+ ASSERT_TRUE(std::abs(diff) <= scale);
+ }
+ }
+
+ void CheckWeights(const Model* input_model_packed,
+ const Model* output_model_packed,
+ bool use_hybrid_evaluation,
+ uint64_t weights_min_num_elements = 1024) {
+ std::unique_ptr<ModelT> input_model;
+ input_model.reset(input_model_packed->UnPack());
+
+ std::unique_ptr<ModelT> output_model;
+ output_model.reset(output_model_packed->UnPack());
+
+ SubGraphT* subgraph = output_model->subgraphs.at(0).get();
+
+ for (int i = 0; i < subgraph->operators.size(); ++i) {
+ OperatorT* op = subgraph->operators[i].get();
+ const BuiltinOperator op_code =
+ output_model->operator_codes[op->opcode_index]->builtin_code;
+
+ // These are the operations that should be quantized.
+ // TODO(suharshs): Right now this test only checks the relevant operations
+ // for the mobilenet v1 model used in the tests below.
+ int32_t tensor_idx;
+ if (op_code == BuiltinOperator_CONV_2D ||
+ op_code == BuiltinOperator_DEPTHWISE_CONV_2D ||
+ op_code == BuiltinOperator_FULLY_CONNECTED) {
+ tensor_idx = op->inputs[1];
+ } else {
+ continue;
+ }
+
+ bool eval_hybrid = false;
+ // These are the ops that support hybrid evaluation.
+ if (op_code == BuiltinOperator_FULLY_CONNECTED ||
+ op_code == BuiltinOperator_CONV_2D) {
+ eval_hybrid = true;
+ }
+
+ const TensorT* tensor = subgraph->tensors[tensor_idx].get();
+ int tensor_size = GetElementsNum(tensor);
+ // If the tensor_size is less than 1024 we expect the tensor to remain
+ // unquantized.
+ if (tensor_size < weights_min_num_elements) {
+ ASSERT_TRUE(tensor->type == TensorType_FLOAT32)
+ << tensor->name << " of type " << tensor->type;
+ const OperatorT* preceding_op = GetOpWithOutput(subgraph, tensor_idx);
+ // The weight tensor should not come from a dequantize op.
+ ASSERT_TRUE(preceding_op == nullptr);
+ } else if (use_hybrid_evaluation && eval_hybrid) {
+ // The input to the op should still be uint8.
+ ASSERT_TRUE(tensor->type == TensorType_UINT8) << tensor->name;
+ // The weight tensor should not come from a dequantize op.
+ const OperatorT* preceding_op = GetOpWithOutput(subgraph, tensor_idx);
+ ASSERT_TRUE(preceding_op == nullptr);
+
+ // Test symmetric quantization.
+ SymmetricDequantizeAndCompare(
+ input_model->buffers[tensor->buffer].get(),
+ output_model->buffers[tensor->buffer].get(),
+ tensor->quantization->scale[0]);
+
+ } else {
+ // The input to the op should still be float.
+ ASSERT_TRUE(tensor->type == TensorType_FLOAT32) << tensor->name;
+ const OperatorT* preceding_op = GetOpWithOutput(subgraph, tensor_idx);
+ ASSERT_TRUE(preceding_op != nullptr);
+ // The float input should be the dequantize output.
+ ASSERT_TRUE(output_model->operator_codes[preceding_op->opcode_index]
+ ->builtin_code == BuiltinOperator_DEQUANTIZE);
+ // Finally, ensure that the input to the dequantize operation is
+ // quantized.
+ const TensorT* quantized_tensor =
+ subgraph->tensors[preceding_op->inputs[0]].get();
+ ASSERT_TRUE(quantized_tensor->type == TensorType_UINT8);
+
+ // Test the assymetric quantization.
+ AsymmetricDequantizeAndCompare(
+ input_model->buffers[quantized_tensor->buffer].get(),
+ output_model->buffers[quantized_tensor->buffer].get(),
+ quantized_tensor->quantization->scale[0],
+ quantized_tensor->quantization->zero_point[0]);
+ }
+ }
+ }
+};
+
+TEST_F(QuantizeWeightsTest, SimpleTestWithHybrid) {
+ string model_path =
+ "third_party/tensorflow/contrib/lite/tools/optimize/testdata/"
+ "mobilenet_v1_0.25_128.tflite";
+ std::unique_ptr<FlatBufferModel> input_fb =
+ FlatBufferModel::BuildFromFile(model_path.data());
+ const Model* input_model = input_fb->GetModel();
+
+ flatbuffers::FlatBufferBuilder builder;
+ EXPECT_EQ(QuantizeWeights(&builder, input_model), kTfLiteOk);
+
+ const uint8_t* buffer = builder.GetBufferPointer();
+ const Model* output_model = GetModel(buffer);
+
+ CheckWeights(input_model, output_model, true);
+}
+
+TEST_F(QuantizeWeightsTest, SimpleTestWithoutHybrid) {
+ string model_path =
+ "third_party/tensorflow/contrib/lite/tools/optimize/testdata/"
+ "mobilenet_v1_0.25_128.tflite";
+ std::unique_ptr<FlatBufferModel> input_fb =
+ FlatBufferModel::BuildFromFile(model_path.data());
+ const Model* input_model = input_fb->GetModel();
+
+ flatbuffers::FlatBufferBuilder builder;
+ // Disable hybrid evaluation.
+ EXPECT_EQ(internal::QuantizeWeights(&builder, input_model, false), kTfLiteOk);
+
+ const uint8_t* buffer = builder.GetBufferPointer();
+ const Model* output_model = GetModel(buffer);
+
+ CheckWeights(input_model, output_model, false);
+}
+
+TEST_F(QuantizeWeightsTest, SimpleTestWithWeightsMinNumElements) {
+ string model_path =
+ "third_party/tensorflow/contrib/lite/tools/optimize/testdata/"
+ "mobilenet_v1_0.25_128.tflite";
+ std::unique_ptr<FlatBufferModel> input_fb =
+ FlatBufferModel::BuildFromFile(model_path.data());
+ const Model* input_model = input_fb->GetModel();
+
+ flatbuffers::FlatBufferBuilder builder;
+ // Make weights_min_size sufficiently large such that no quantization should
+ // happen, i.e. the original model is the same size as the old one.
+ const uint64_t kWeightsMinNumElements = 1000000;
+ EXPECT_EQ(QuantizeWeights(&builder, input_model, kWeightsMinNumElements),
+ kTfLiteOk);
+
+ const uint8_t* buffer = builder.GetBufferPointer();
+ const Model* output_model = GetModel(buffer);
+ CheckWeights(input_model, output_model, true, kWeightsMinNumElements);
+}
+
+// TODO(suharshs): Add tests that run the resulting model.
+
+} // namespace
+} // namespace optimize
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ // On Linux, add: FLAGS_logtostderr = true;
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/tutorials/BUILD b/tensorflow/contrib/lite/tutorials/BUILD
new file mode 100644
index 0000000000..67ff1ea124
--- /dev/null
+++ b/tensorflow/contrib/lite/tutorials/BUILD
@@ -0,0 +1,20 @@
+# Example Estimator model
+
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_binary(
+ name = "mnist_tflite",
+ srcs = [
+ "dataset.py",
+ "mnist_tflite.py",
+ ],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
diff --git a/tensorflow/contrib/lite/tutorials/dataset.py b/tensorflow/contrib/lite/tutorials/dataset.py
new file mode 100644
index 0000000000..ba49dfcc9b
--- /dev/null
+++ b/tensorflow/contrib/lite/tutorials/dataset.py
@@ -0,0 +1,122 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""tf.data.Dataset interface to the MNIST dataset.
+
+ This is cloned from
+ https://github.com/tensorflow/models/blob/master/official/mnist/dataset.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+import os
+import shutil
+import tempfile
+
+import numpy as np
+from six.moves import urllib
+import tensorflow as tf
+
+
+def read32(bytestream):
+ """Read 4 bytes from bytestream as an unsigned 32-bit integer."""
+ dt = np.dtype(np.uint32).newbyteorder('>')
+ return np.frombuffer(bytestream.read(4), dtype=dt)[0]
+
+
+def check_image_file_header(filename):
+ """Validate that filename corresponds to images for the MNIST dataset."""
+ with tf.gfile.Open(filename, 'rb') as f:
+ magic = read32(f)
+ read32(f) # num_images, unused
+ rows = read32(f)
+ cols = read32(f)
+ if magic != 2051:
+ raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
+ f.name))
+ if rows != 28 or cols != 28:
+ raise ValueError(
+ 'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' %
+ (f.name, rows, cols))
+
+
+def check_labels_file_header(filename):
+ """Validate that filename corresponds to labels for the MNIST dataset."""
+ with tf.gfile.Open(filename, 'rb') as f:
+ magic = read32(f)
+ read32(f) # num_items, unused
+ if magic != 2049:
+ raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
+ f.name))
+
+
+def download(directory, filename):
+ """Download (and unzip) a file from the MNIST dataset if not already done."""
+ filepath = os.path.join(directory, filename)
+ if tf.gfile.Exists(filepath):
+ return filepath
+ if not tf.gfile.Exists(directory):
+ tf.gfile.MakeDirs(directory)
+ # CVDF mirror of http://yann.lecun.com/exdb/mnist/
+ url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'
+ _, zipped_filepath = tempfile.mkstemp(suffix='.gz')
+ print('Downloading %s to %s' % (url, zipped_filepath))
+ urllib.request.urlretrieve(url, zipped_filepath)
+ with gzip.open(zipped_filepath, 'rb') as f_in, \
+ tf.gfile.Open(filepath, 'wb') as f_out:
+ shutil.copyfileobj(f_in, f_out)
+ os.remove(zipped_filepath)
+ return filepath
+
+
+def dataset(directory, images_file, labels_file):
+ """Download and parse MNIST dataset."""
+
+ images_file = download(directory, images_file)
+ labels_file = download(directory, labels_file)
+
+ check_image_file_header(images_file)
+ check_labels_file_header(labels_file)
+
+ def decode_image(image):
+ # Normalize from [0, 255] to [0.0, 1.0]
+ image = tf.decode_raw(image, tf.uint8)
+ image = tf.cast(image, tf.float32)
+ image = tf.reshape(image, [784])
+ return image / 255.0
+
+ def decode_label(label):
+ label = tf.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8]
+ label = tf.reshape(label, []) # label is a scalar
+ return tf.to_int32(label)
+
+ images = tf.data.FixedLengthRecordDataset(
+ images_file, 28 * 28, header_bytes=16).map(decode_image)
+ labels = tf.data.FixedLengthRecordDataset(
+ labels_file, 1, header_bytes=8).map(decode_label)
+ return tf.data.Dataset.zip((images, labels))
+
+
+def train(directory):
+ """tf.data.Dataset object for MNIST training data."""
+ return dataset(directory, 'train-images-idx3-ubyte',
+ 'train-labels-idx1-ubyte')
+
+
+def test(directory):
+ """tf.data.Dataset object for MNIST test data."""
+ return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')
diff --git a/tensorflow/contrib/lite/tutorials/mnist_tflite.py b/tensorflow/contrib/lite/tutorials/mnist_tflite.py
new file mode 100644
index 0000000000..7b8bf5b5db
--- /dev/null
+++ b/tensorflow/contrib/lite/tutorials/mnist_tflite.py
@@ -0,0 +1,87 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Script to evaluate accuracy of TFLite flatbuffer model on mnist dataset."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import numpy as np
+import tensorflow as tf # pylint: disable=g-bad-import-order
+from tensorflow.contrib.lite.tutorials import dataset
+flags = tf.app.flags
+
+flags.DEFINE_string('data_dir', '/tmp/data_dir',
+ 'Directory where data is stored.')
+flags.DEFINE_string('model_file', '',
+ 'The path to the TFLite flatbuffer model file.')
+
+
+flags = flags.FLAGS
+
+
+def test_image_generator():
+ # Generates an iterator over images
+ with tf.Session() as sess:
+ input_data = dataset.test(
+ flags.data_dir).make_one_shot_iterator().get_next()
+ try:
+ while True:
+ yield sess.run(input_data)
+ except tf.errors.OutOfRangeError:
+ pass
+
+
+def run_eval(interpreter, input_image):
+ """Performs evaluation for input image over specified model.
+
+ Args:
+ interpreter: TFLite interpreter initialized with model to execute.
+ input_image: Image input to the model.
+
+ Returns:
+ output: output tensor of model being executed.
+ """
+
+ # Get input and output tensors.
+ input_details = interpreter.get_input_details()
+ output_details = interpreter.get_output_details()
+
+ # Test model on the input images.
+ input_image = np.reshape(input_image, input_details[0]['shape'])
+ interpreter.set_tensor(input_details[0]['index'], input_image)
+
+ interpreter.invoke()
+ output_data = interpreter.get_tensor(output_details[0]['index'])
+ output = np.squeeze(output_data)
+ return output
+
+
+def main(_):
+ interpreter = tf.contrib.lite.Interpreter(model_path=flags.model_file)
+ interpreter.allocate_tensors()
+ num_correct, total = 0, 0
+ for input_data in test_image_generator():
+ output = run_eval(interpreter, input_data[0])
+ total += 1
+ if output == input_data[1]:
+ num_correct += 1
+ if total % 500 == 0:
+ print('Accuracy after %i images: %f' %
+ (total, float(num_correct) / float(total)))
+
+
+if __name__ == '__main__':
+ tf.logging.set_verbosity(tf.logging.INFO)
+ tf.app.run(main)
diff --git a/tensorflow/contrib/lite/util.h b/tensorflow/contrib/lite/util.h
index f5b208afbb..6d81f844f8 100644
--- a/tensorflow/contrib/lite/util.h
+++ b/tensorflow/contrib/lite/util.h
@@ -22,7 +22,7 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_UTIL_H_
#include <vector>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
namespace tflite {
diff --git a/tensorflow/contrib/lite/util_test.cc b/tensorflow/contrib/lite/util_test.cc
index 32bf917a59..c5c1709f1d 100644
--- a/tensorflow/contrib/lite/util_test.cc
+++ b/tensorflow/contrib/lite/util_test.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
-#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
#include "tensorflow/contrib/lite/util.h"
namespace tflite {
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index 0a54bb1f5e..89b538d1ba 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -44,7 +44,7 @@ from tensorflow.python.training.checkpointable import util as checkpointable
class HashTableOpTest(test.TestCase):
def testHashTable(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -68,7 +68,7 @@ class HashTableOpTest(test.TestCase):
self.assertItemsEqual([0, 1, 2], exported_values_tensor.eval())
def testHashTableFindHighRank(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -86,7 +86,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([[0, 1], [-1, -1]], result)
def testHashTableInitWithPythonArrays(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = ["brain", "salad", "surgery"]
values = [0, 1, 2]
@@ -105,7 +105,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testHashTableInitWithNumPyArrays(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = np.array(["brain", "salad", "surgery"], dtype=np.str)
values = np.array([0, 1, 2], dtype=np.int64)
@@ -122,7 +122,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testMultipleHashTables(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -150,7 +150,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], out3)
def testHashTableWithTensorDefault(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant(-1, dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -165,7 +165,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testHashTableWithSparseTensorInput(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_val = constant_op.constant(-1, dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -188,7 +188,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual(sp_shape, out_shape)
def testSignatureMismatch(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -210,7 +210,7 @@ class HashTableOpTest(test.TestCase):
lookup.KeyValueTensorInitializer(keys, values), "UNK")
def testDTypes(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
with self.assertRaises(TypeError):
lookup.HashTable(
@@ -218,7 +218,7 @@ class HashTableOpTest(test.TestCase):
dtypes.int64), default_val)
def testNotInitialized(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
table = lookup.HashTable(
lookup.KeyValueTensorInitializer(
@@ -232,7 +232,7 @@ class HashTableOpTest(test.TestCase):
output.eval()
def testInitializeTwice(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -244,7 +244,7 @@ class HashTableOpTest(test.TestCase):
table.init.run()
def testInitializationWithInvalidDimensions(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2, 3, 4], dtypes.int64)
@@ -283,7 +283,7 @@ class HashTableOpTest(test.TestCase):
self.assertAllEqual(3, table.size().eval())
def testHashTableInt32String(self):
- with self.test_session():
+ with self.cached_session():
default_val = "n/a"
keys = constant_op.constant([0, 1, 2], dtypes.int32)
values = constant_op.constant(["brain", "salad", "surgery"])
@@ -301,7 +301,7 @@ class HashTableOpTest(test.TestCase):
class MutableHashTableOpTest(test.TestCase):
def testMutableHashTable(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -470,7 +470,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([b"-", b"a", b"b"], output.eval())
def testMutableHashTableOfTensors(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant([-1, -1], dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
@@ -500,7 +500,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([[4, 5], [2, 3], [0, 1]], sorted_values)
def testMutableHashTableExportInsert(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant([-1, -1], dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([[0, 1], [2, 3], [4, 5]], dtypes.int64)
@@ -531,7 +531,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual(expected_output, output2.eval())
def testMutableHashTableOfTensorsInvalidShape(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant([-1, -1], dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
@@ -563,7 +563,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual(3, table.size().eval())
def testMutableHashTableInvalidDefaultValue(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant([[-1, -1]], dtypes.int64)
table = lookup.MutableHashTable(dtypes.string, dtypes.int64,
default_val)
@@ -571,7 +571,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual(0, table.size().eval())
def testMutableHashTableDuplicateInsert(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery", "brain"])
values = constant_op.constant([0, 1, 2, 3], dtypes.int64)
@@ -589,7 +589,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([3, 1, -1], result)
def testMutableHashTableFindHighRank(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -608,7 +608,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([[0, 1], [-1, -1]], result)
def testMutableHashTableInsertHighRank(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant([["brain", "salad"], ["surgery", "tank"]])
values = constant_op.constant([[0, 1], [2, 3]], dtypes.int64)
@@ -625,7 +625,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, 3, -1], result)
def testMutableHashTableOfTensorsFindHighRank(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant([-1, -1, -1], dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([[0, 1, 2], [2, 3, 4], [4, 5, 6]],
@@ -646,7 +646,7 @@ class MutableHashTableOpTest(test.TestCase):
[[[0, 1, 2], [2, 3, 4]], [[-1, -1, -1], [-1, -1, -1]]], result)
def testMultipleMutableHashTables(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -676,7 +676,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], out3)
def testMutableHashTableWithTensorDefault(self):
- with self.test_session():
+ with self.cached_session():
default_val = constant_op.constant(-1, dtypes.int64)
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -693,7 +693,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testSignatureMismatch(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1, 2], dtypes.int64)
@@ -734,7 +734,7 @@ class MutableHashTableOpTest(test.TestCase):
lookup.MutableHashTable(dtypes.string, dtypes.int64, "UNK")
def testMutableHashTableStringFloat(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1.5
keys = constant_op.constant(["brain", "salad", "surgery"])
values = constant_op.constant([0, 1.1, 2.2], dtypes.float32)
@@ -752,7 +752,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllClose([0, 1.1, default_val], result)
def testMutableHashTableIntFloat(self):
- with self.test_session():
+ with self.cached_session():
default_val = -1.0
keys = constant_op.constant([3, 7, 0], dtypes.int64)
values = constant_op.constant([7.5, -1.2, 9.9], dtypes.float32)
@@ -770,7 +770,7 @@ class MutableHashTableOpTest(test.TestCase):
self.assertAllClose([-1.2, 9.9, default_val], result)
def testMutableHashTableInt64String(self):
- with self.test_session():
+ with self.cached_session():
default_val = "n/a"
keys = constant_op.constant([0, 1, 2], dtypes.int64)
values = constant_op.constant(["brain", "salad", "surgery"])
@@ -791,7 +791,7 @@ class MutableHashTableOpTest(test.TestCase):
class MutableDenseHashTableOpTest(test.TestCase):
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup.MutableDenseHashTable(
@@ -809,7 +809,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testBasicBool(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([True, True, True], dtypes.bool)
table = lookup.MutableDenseHashTable(
@@ -827,7 +827,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([True, True, False], result)
def testLookupUnknownShape(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup.MutableDenseHashTable(
@@ -843,7 +843,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testMapStringToFloat(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant(["a", "b", "c"], dtypes.string)
values = constant_op.constant([0.0, 1.1, 2.2], dtypes.float32)
default_value = constant_op.constant(-1.5, dtypes.float32)
@@ -866,7 +866,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
def testMapInt64ToFloat(self):
for float_dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0.0, 1.1, 2.2], float_dtype)
default_value = constant_op.constant(-1.5, float_dtype)
@@ -885,7 +885,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllClose([0, 1.1, -1.5], result)
def testVectorValues(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([[0, 1, 2, 3], [3, 4, 5, 6], [6, 7, 8, 9]],
dtypes.int64)
@@ -918,7 +918,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
result)
def testVectorKeys(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([[0, 1], [1, 2], [1, 3]], dtypes.int64)
values = constant_op.constant([10, 11, 12], dtypes.int64)
empty_key = constant_op.constant([0, 3], dtypes.int64)
@@ -949,7 +949,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([10, 11, -1], result)
def testResize(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup.MutableDenseHashTable(
@@ -977,7 +977,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([-1, 0, 1, 3, 4, 5, 6, 7, -1], output.eval())
def testExport(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 12, 13], dtypes.int64)
values = constant_op.constant([1, 2, 3], dtypes.int64)
table = lookup.MutableDenseHashTable(
@@ -1238,7 +1238,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1, 2, -1], output.eval())
def testReprobe(self):
- with self.test_session():
+ with self.cached_session():
# Insert 6 keys into a table with 8 buckets.
# The values are chosen to make sure collisions occur when using GCC STL
keys = constant_op.constant([11, 12, 13, 19, 20, 21], dtypes.int64)
@@ -1263,7 +1263,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([-1, 51, 52, 53, -1, 54, 55, 56, -1], result)
def testCustomEmptyKey(self):
- with self.test_session():
+ with self.cached_session():
keys = constant_op.constant([11, 0, 13], dtypes.int64)
values = constant_op.constant([0, 1, 2], dtypes.int64)
table = lookup.MutableDenseHashTable(
@@ -1281,7 +1281,7 @@ class MutableDenseHashTableOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], result)
def testErrors(self):
- with self.test_session():
+ with self.cached_session():
table = lookup.MutableDenseHashTable(
dtypes.int64, dtypes.int64, default_value=-1, empty_key=0)
@@ -1328,7 +1328,7 @@ class IndexTableFromFile(test.TestCase):
def test_string_index_table_from_file(self):
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1339,7 +1339,7 @@ class IndexTableFromFile(test.TestCase):
def test_string_index_table_from_file_tensor_filename(self):
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
- with self.test_session():
+ with self.cached_session():
vocabulary_file = constant_op.constant(vocabulary_file)
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1)
@@ -1353,7 +1353,7 @@ class IndexTableFromFile(test.TestCase):
def test_string_index_table_from_file_placeholder_filename(self):
vocabulary_file = self._createVocabFile("f2i_vocab1.txt")
- with self.test_session():
+ with self.cached_session():
vocabulary_placeholder = array_ops.placeholder(dtypes.string, [])
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_placeholder, num_oov_buckets=1)
@@ -1370,7 +1370,7 @@ class IndexTableFromFile(test.TestCase):
def test_int32_index_table_from_file(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab2.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1,
key_dtype=dtypes.int32)
@@ -1384,7 +1384,7 @@ class IndexTableFromFile(test.TestCase):
def test_int64_index_table_from_file(self):
vocabulary_file = self._createVocabFile(
"f2i_vocab3.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1,
key_dtype=dtypes.int64)
@@ -1398,7 +1398,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_default_value(self):
default_value = -42
vocabulary_file = self._createVocabFile("f2i_vocab4.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, default_value=default_value)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1409,7 +1409,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_oov_buckets(self):
vocabulary_file = self._createVocabFile("f2i_vocab5.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, num_oov_buckets=1000)
ids = table.lookup(
@@ -1439,7 +1439,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_vocab_size_too_small(self):
vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=2)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1451,7 +1451,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_vocab_size_too_large(self):
vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=4)
self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
@@ -1466,7 +1466,7 @@ class IndexTableFromFile(test.TestCase):
vocabulary_file=vocabulary_file,
vocab_size=0)
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=3)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1478,7 +1478,7 @@ class IndexTableFromFile(test.TestCase):
def test_index_table_from_file_with_invalid_hashers(self):
vocabulary_file = self._createVocabFile("invalid_hasher.txt")
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
lookup.index_table_from_file(
vocabulary_file=vocabulary_file,
@@ -1499,21 +1499,21 @@ class IndexTableFromFile(test.TestCase):
class KeyValueTensorInitializerTest(test.TestCase):
def test_string(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
init = lookup.KeyValueTensorInitializer(
("brain", "salad", "surgery"), (0, 1, 2), dtypes.string, dtypes.int64)
table = lookup.HashTable(init, default_value=-1)
table.init.run()
def test_int64(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
init = lookup.KeyValueTensorInitializer(
(42, 1, -1000), (0, 1, 2), dtypes.int64, dtypes.int64)
table = lookup.HashTable(init, default_value=-1)
table.init.run()
def test_int32(self):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
init = lookup.KeyValueTensorInitializer(
(42, 1, -1000), (0, 1, 2), dtypes.int32, dtypes.int64)
table = lookup.HashTable(init, default_value=-1)
@@ -1542,7 +1542,7 @@ class IndexTableFromTensor(test.TestCase):
self.assertAllEqual((1, 2, 3), self.evaluate(ids))
def test_int32_index_table_from_tensor_with_tensor_init(self):
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_tensor(
mapping=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int32)
ids = table.lookup(
@@ -1553,7 +1553,7 @@ class IndexTableFromTensor(test.TestCase):
self.assertAllEqual((1, 2, 3), ids.eval())
def test_int64_index_table_from_tensor_with_tensor_init(self):
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_tensor(
mapping=(42, 1, -1000), num_oov_buckets=1, dtype=dtypes.int64)
ids = table.lookup(
@@ -1565,7 +1565,7 @@ class IndexTableFromTensor(test.TestCase):
def test_index_table_from_tensor_with_default_value(self):
default_value = -42
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_tensor(
mapping=["brain", "salad", "surgery"], default_value=default_value)
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
@@ -1575,12 +1575,12 @@ class IndexTableFromTensor(test.TestCase):
self.assertAllEqual((1, 2, default_value), ids.eval())
def test_index_table_from_tensor_missing_mapping(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(ValueError, "mapping must be specified"):
lookup.index_table_from_tensor(mapping=None, num_oov_buckets=1)
def test_index_table_from_tensor_empty_mapping(self):
- with self.test_session():
+ with self.cached_session():
table = lookup.index_table_from_tensor(
mapping=np.array([], dtype=np.str_), num_oov_buckets=1)
ids = table.lookup(constant_op.constant(["salad", "surgery", "brain"]))
@@ -1590,7 +1590,7 @@ class IndexTableFromTensor(test.TestCase):
lookup_ops.tables_initializer().run()
def test_index_table_from_tensor_with_invalid_hashers(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(TypeError):
lookup.index_table_from_tensor(
mapping=["brain", "salad", "surgery"],
@@ -1609,7 +1609,7 @@ class IndexTableFromTensor(test.TestCase):
class StringToIndexTest(test.TestCase):
def test_string_to_index(self):
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
feats = constant_op.constant(["salad", "surgery", "tarkus"])
indices = lookup.string_to_index(feats, mapping=mapping_strings)
@@ -1620,7 +1620,7 @@ class StringToIndexTest(test.TestCase):
self.assertAllEqual((1, 2, -1), indices.eval())
def test_duplicate_entries(self):
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["hello", "hello"])
feats = constant_op.constant(["hello", "hola"])
_ = lookup.string_to_index(feats, mapping=mapping_strings)
@@ -1630,7 +1630,7 @@ class StringToIndexTest(test.TestCase):
def test_string_to_index_with_default_value(self):
default_value = -42
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
feats = constant_op.constant(["salad", "surgery", "tarkus"])
indices = lookup.string_to_index(
@@ -1651,7 +1651,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table(self):
vocabulary_file = self._createVocabFile("i2f_vocab1.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_to_string_table_from_file(
vocabulary_file=vocabulary_file)
features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64))
@@ -1663,7 +1663,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_default_value(self):
default_value = b"NONE"
vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_to_string_table_from_file(
vocabulary_file=vocabulary_file, default_value=default_value)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
@@ -1675,7 +1675,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_vocab_size_too_small(self):
default_value = b"NONE"
vocabulary_file = self._createVocabFile("f2i_vocab2.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_to_string_table_from_file(
vocabulary_file=vocabulary_file,
vocab_size=2,
@@ -1688,7 +1688,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_vocab_size_too_large(self):
vocabulary_file = self._createVocabFile("f2i_vocab6.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_to_string_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=4)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
@@ -1700,7 +1700,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
def test_index_to_string_table_with_vocab_size(self):
vocabulary_file = self._createVocabFile("f2i_vocab7.txt")
- with self.test_session():
+ with self.cached_session():
table = lookup.index_to_string_table_from_file(
vocabulary_file=vocabulary_file, vocab_size=3)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
@@ -1713,7 +1713,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
class IndexToStringTableFromTensorTest(test.TestCase):
def test_index_to_string_table_from_tensor(self):
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
table = lookup.index_to_string_table_from_tensor(
mapping=mapping_strings)
@@ -1727,7 +1727,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
features.eval())
def test_duplicate_entries(self):
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["hello", "hello"])
table = lookup.index_to_string_table_from_tensor(
mapping=mapping_strings)
@@ -1738,7 +1738,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
def test_index_to_string_with_default_value(self):
default_value = b"NONE"
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
table = lookup.index_to_string_table_from_tensor(
mapping=mapping_strings, default_value=default_value)
@@ -1754,7 +1754,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
class IndexToStringTest(test.TestCase):
def test_index_to_string(self):
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
indices = constant_op.constant([0, 1, 2, 3], dtypes.int64)
feats = lookup.index_to_string(indices, mapping=mapping_strings)
@@ -1766,7 +1766,7 @@ class IndexToStringTest(test.TestCase):
feats.eval())
def test_duplicate_entries(self):
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["hello", "hello"])
indices = constant_op.constant([0, 1, 4], dtypes.int64)
feats = lookup.index_to_string(indices, mapping=mapping_strings)
@@ -1778,7 +1778,7 @@ class IndexToStringTest(test.TestCase):
def test_index_to_string_with_default_value(self):
default_value = b"NONE"
- with self.test_session():
+ with self.cached_session():
mapping_strings = constant_op.constant(["brain", "salad", "surgery"])
indices = constant_op.constant([1, 2, 4], dtypes.int64)
feats = lookup.index_to_string(
@@ -1818,7 +1818,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
vocabulary_file = self._createVocabFile(
"one_column_int64.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
table = lookup.HashTable(
lookup.TextFileInitializer(vocabulary_file, dtypes.int64,
@@ -1837,7 +1837,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInitializeIndexTable(self):
vocabulary_file = self._createVocabFile("one_column_2.txt")
- with self.test_session():
+ with self.cached_session():
default_value = "UNK"
key_index = lookup.TextFileIndex.LINE_NUMBER
value_index = lookup.TextFileIndex.WHOLE_LINE
@@ -1858,7 +1858,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
with open(vocabulary_file, "w") as f:
f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n")
- with self.test_session():
+ with self.cached_session():
default_value = -1
key_index = 1
value_index = 2
@@ -1880,7 +1880,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
with open(vocabulary_file, "w") as f:
f.write("\n".join(["0\tbrain\t1", "1\tsalad\t5", "2\tsurgery\t6"]) + "\n")
- with self.test_session():
+ with self.cached_session():
default_value = -1
key_index = 2
value_index = 1
@@ -1894,7 +1894,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInvalidDataType(self):
vocabulary_file = self._createVocabFile("one_column_3.txt")
- with self.test_session():
+ with self.cached_session():
default_value = "UNK"
key_index = lookup.TextFileIndex.WHOLE_LINE
value_index = lookup.TextFileIndex.LINE_NUMBER
@@ -1907,7 +1907,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInvalidIndex(self):
vocabulary_file = self._createVocabFile("one_column_4.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
key_index = 1 # second column of the line
value_index = lookup.TextFileIndex.LINE_NUMBER
@@ -1922,7 +1922,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInitializeSameTableWithMultipleNodes(self):
vocabulary_file = self._createVocabFile("one_column_5.txt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shared_name = "shared-one-columm"
default_value = -1
table1 = lookup.HashTable(
@@ -1961,7 +1961,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
self.assertAllEqual([0, 1, -1], out3)
def testInitializeTableWithNoFilename(self):
- with self.test_session():
+ with self.cached_session():
default_value = -1
with self.assertRaises(ValueError):
lookup.HashTable(
@@ -1971,7 +1971,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
default_value)
def testInitializeWithVocabSize(self):
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
vocabulary_file1 = self._createVocabFile("one_column6.txt")
@@ -2022,7 +2022,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testFeedVocabularyName(self):
vocabulary_file = self._createVocabFile("feed_vocabulary.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
table = lookup.HashTable(
lookup.TextFileInitializer("old_file.txt", dtypes.string,
@@ -2049,7 +2049,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInvalidFilenames(self):
vocabulary_file = self._createVocabFile("filename_shape.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
# Invalid data type
@@ -2072,7 +2072,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testIdToStringTable(self):
vocab_file = self._createVocabFile("feat_to_id_1.txt")
- with self.test_session():
+ with self.cached_session():
default_value = "UNK"
vocab_size = 3
table = lookup.HashTable(
@@ -2090,7 +2090,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testStringToIdTable(self):
vocab_file = self._createVocabFile("feat_to_id_2.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
table = lookup.HashTable(
@@ -2108,7 +2108,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
def testInt64ToIdTable(self):
vocab_file = self._createVocabFile(
"feat_to_id_3.txt", values=("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
table = lookup.HashTable(
@@ -2133,7 +2133,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testStringIdTableWithHashBuckets(self):
vocab_file = self._createVocabFile("feat_to_id_1.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -2154,7 +2154,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt32IdTableWithHashBuckets(self):
vocab_file = self._createVocabFile("feat_to_id_2.txt", ("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -2176,7 +2176,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt64IdTableWithHashBuckets(self):
vocab_file = self._createVocabFile("feat_to_id_3.txt", ("42", "1", "-1000"))
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -2196,7 +2196,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertEquals(vocab_size + oov_buckets, table.size().eval())
def testStringIdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
oov_buckets = 5
# Set a table that only uses hash buckets, for each input value returns
@@ -2217,7 +2217,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertEquals(oov_buckets, table.size().eval())
def testInt32IdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
oov_buckets = 5
# Set a table that only uses hash buckets, for each input value returns
@@ -2239,20 +2239,20 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertEquals(oov_buckets, table.size().eval())
def testFloat64IdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"):
lookup.IdTableWithHashBuckets(
None, num_oov_buckets=5, key_dtype=dtypes.float64)
def testBoolIdTableWithOnlyHashBucket(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesRegexp(TypeError, "Invalid key_dtype"):
lookup.IdTableWithHashBuckets(
None, num_oov_buckets=5, key_dtype=dtypes.bool)
def testIdTableWithHashBucketsWithMultipleInitializers(self):
vocab_file = self._createVocabFile("feat_to_id_4.txt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_value = -1
vocab_size = 3
oov_buckets = 3
@@ -2294,7 +2294,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testIdTableWithHashBucketsInitializationAcrossSessions(self):
vocab_file = self._createVocabFile("feat_to_id_5.txt")
shared_name = "across-sessions"
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -2316,7 +2316,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
self.assertAllEqual([0, 1, 2, 3], out1.eval())
self.assertEquals(vocab_size + oov_buckets, table1.size().eval())
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
@@ -2340,7 +2340,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testIdTableWithHashBucketsWithMultipleInitializersDifferentDefault(self):
vocab_file = self._createVocabFile("feat_to_id_6.txt")
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_value1 = -1
vocab_size = 3
oov_buckets = 0
@@ -2378,7 +2378,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
vocab_file = self._createVocabFile("feat_to_id_7.txt")
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
input_shape = [4, 4]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sp_features = sparse_tensor.SparseTensor(
constant_op.constant(input_indices, dtypes.int64),
constant_op.constant(["brain", "salad", "brain", "surgery", "tarkus"],
@@ -2407,7 +2407,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt32SparseTensor(self):
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
input_shape = [4, 4]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sp_features = sparse_tensor.SparseTensor(
constant_op.constant(input_indices, dtypes.int64),
constant_op.constant([42, 1, 42, -1000, 11], dtypes.int32),
@@ -2436,7 +2436,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testInt64SparseTensor(self):
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
input_shape = [4, 4]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sp_features = sparse_tensor.SparseTensor(
constant_op.constant(input_indices, dtypes.int64),
constant_op.constant([42, 1, 42, -1000, 11], dtypes.int64),
@@ -2464,7 +2464,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
def testIdTableWithHashBucketsWithInvalidHashers(self):
vocab_file = self._createVocabFile("feat_to_id_4.txt")
- with self.test_session():
+ with self.cached_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
diff --git a/tensorflow/contrib/losses/python/losses/loss_ops_test.py b/tensorflow/contrib/losses/python/losses/loss_ops_test.py
index 2a442a8fc8..c0aec09778 100644
--- a/tensorflow/contrib/losses/python/losses/loss_ops_test.py
+++ b/tensorflow/contrib/losses/python/losses/loss_ops_test.py
@@ -43,68 +43,68 @@ class AbsoluteDifferenceLossTest(test.TestCase):
self._labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.absolute_difference(
self._predictions, self._predictions, weights=None)
def testAllCorrectNoLossWeight(self):
loss = loss_ops.absolute_difference(self._predictions, self._predictions)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testNonZeroLoss(self):
loss = loss_ops.absolute_difference(self._predictions, self._labels)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.5, loss.eval(), 3)
def testNonZeroLossWithPythonScalarWeight(self):
weights = 2.3
loss = loss_ops.absolute_difference(self._predictions, self._labels,
weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.5 * weights, loss.eval(), 3)
def testNonZeroLossWithScalarTensorWeight(self):
weights = 2.3
loss = loss_ops.absolute_difference(self._predictions, self._labels,
constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.5 * weights, loss.eval(), 3)
def testNonZeroLossWithOneDimBatchSpecificWeights(self):
weights = constant_op.constant([1.2, 0.0], shape=[2,])
loss = loss_ops.absolute_difference(self._predictions, self._labels,
weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.6, loss.eval(), 3)
def testNonZeroLossWithTwoDimBatchSpecificWeights(self):
weights = constant_op.constant([1.2, 0.0], shape=[2, 1])
loss = loss_ops.absolute_difference(self._predictions, self._labels,
weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(5.6, loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeights(self):
weights = constant_op.constant([3, 6, 5, 0, 4, 2], shape=[2, 3])
loss = loss_ops.absolute_difference(self._predictions, self._labels,
weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(16.6, loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
weights = constant_op.constant([0, 0, 0, 0, 0, 2], shape=[2, 3])
loss = loss_ops.absolute_difference(self._predictions, self._labels,
weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(6.0, loss.eval(), 3)
def testLossWithSampleSpecificWeightsAllZero(self):
weights = array_ops.zeros((2, 3))
loss = loss_ops.absolute_difference(self._predictions, self._labels,
weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
@@ -117,12 +117,12 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
labels = constant_op.constant([[1, 0, 0],
[0, 1, 0],
[0, 0, 1]])
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.softmax_cross_entropy(logits, labels, weights=None)
def testAllCorrect(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
@@ -141,7 +141,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[1, 0, 0],
[0, 1, 0]])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.softmax_cross_entropy(logits, labels)
self.assertEquals(loss.op.name, 'softmax_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -154,7 +154,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[1, 0, 0],
[0, 1, 0]])
weights = 2.3
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -166,7 +166,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[1, 0, 0],
[0, 1, 0]])
weights = 2.3
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.softmax_cross_entropy(logits, labels,
constant_op.constant(weights))
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -179,7 +179,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[1, 0, 0],
[0, 1, 0]])
weights = constant_op.constant([1.2, 3.4, 5.6], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3)
@@ -191,7 +191,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[1, 0, 0],
[0, 1, 0]])
weights = constant_op.constant([0, 0, 0], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual(0.0, loss.eval(), 3)
@@ -203,12 +203,12 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
[1, 0, 0],
[0, 1, 0]])
weights = constant_op.constant([1.2, 0, 0], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual(12.0, loss.eval(), 3)
def testSoftmaxWithMeasurementSpecificWeightsRaisesException(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -223,7 +223,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
loss_ops.softmax_cross_entropy(logits, labels, weights=weights).eval()
def testSoftmaxLabelSmoothing(self):
- with self.test_session():
+ with self.cached_session():
# Softmax Cross Entropy Loss is:
# -\sum_i p_i \log q_i
# where for a softmax activation
@@ -253,7 +253,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
weights = [2.3, 2.4, 2.5]
weights_placeholder = array_ops.placeholder(dtypes.float32, shape=[None])
loss = loss_ops.softmax_cross_entropy(logits, labels, weights_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, {weights_placeholder: weights})
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
@@ -268,7 +268,7 @@ class SoftmaxCrossEntropyLossTest(test.TestCase):
weights_placeholder = array_ops.placeholder(
dtypes.float32, shape=[None, None])
loss = loss_ops.softmax_cross_entropy(logits, labels, weights_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, {weights_placeholder: weights})
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
@@ -280,12 +280,12 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[0], [1], [2]])
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.sparse_softmax_cross_entropy(logits, labels, weights=None)
def testAllCorrectInt32Labels(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
@@ -295,7 +295,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(loss.eval(), 0.0, 3)
def testAllCorrectInt64Labels(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
@@ -305,7 +305,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(loss.eval(), 0.0, 3)
def testAllCorrectNonColumnLabels(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]])
@@ -320,7 +320,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]], dtype=dtypes.int32)
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels)
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -331,7 +331,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]], dtype=dtypes.int64)
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels)
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -342,7 +342,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([2, 0, 1])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels)
self.assertEquals(loss.op.name, 'sparse_softmax_cross_entropy_loss/value')
self.assertAlmostEqual(loss.eval(), 10.0, 3)
@@ -353,7 +353,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = 2.3
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -363,7 +363,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = 2.3
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(
logits, labels, constant_op.constant(weights))
self.assertAlmostEqual(weights * 10.0, loss.eval(), 3)
@@ -374,7 +374,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = constant_op.constant([1.2, 3.4, 5.6], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3)
@@ -384,7 +384,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = constant_op.constant([[1.2], [3.4], [5.6]])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual((1.2 + 3.4 + 5.6) * 10.0 / 3.0, loss.eval(), 3)
@@ -394,7 +394,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = constant_op.constant([0, 0, 0], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual(0.0, loss.eval(), 3)
@@ -404,12 +404,12 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
[0.0, 0.0, 10.0]])
labels = constant_op.constant([[2], [0], [1]])
weights = constant_op.constant([1.2, 0, 0], shape=[3])
- with self.test_session():
+ with self.cached_session():
loss = loss_ops.sparse_softmax_cross_entropy(logits, labels, weights)
self.assertAlmostEqual(12.0, loss.eval(), 3)
def testMeasurementSpecificWeightsRaisesException(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -422,7 +422,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
def testInconsistentWeightSizeRaisesException(self):
"""The weight tensor has incorrect number of elements."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -435,7 +435,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
def testInconsistentLabelSizeRaisesException(self):
"""The label tensor has incorrect number of elements."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -448,7 +448,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
def testInconsistentWeightShapeRaisesException(self):
"""The weight tensor has incorrect shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0, -100.0],
[-100.0, -100.0, 100.0, -100.0],
@@ -462,7 +462,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
def testInconsistentLabelShapeRaisesException(self):
"""The label tensor has incorrect shape."""
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0, -100.0],
[-100.0, -100.0, 100.0, -100.0],
@@ -484,7 +484,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
dtypes.float32, shape=[None])
loss = loss_ops.sparse_softmax_cross_entropy(
logits, labels, weights_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, {weights_placeholder: weights})
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
@@ -498,7 +498,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
dtypes.float32, shape=[None, None])
loss = loss_ops.sparse_softmax_cross_entropy(
logits, labels, weights_placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, {weights_placeholder: weights})
self.assertAlmostEqual(np.average(weights) * 10.0, loss, 3)
@@ -506,7 +506,7 @@ class SparseSoftmaxCrossEntropyLossTest(test.TestCase):
class SigmoidCrossEntropyLossTest(test.TestCase):
def testAllCorrectSigmoid(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -522,7 +522,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
loss = loss_ops.sigmoid_cross_entropy(logits, labels, weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
logits: np.ones((32, 1)),
@@ -537,7 +537,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
loss = loss_ops.sigmoid_cross_entropy(logits, labels, weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
logits: np.ones((32, 2)),
@@ -546,7 +546,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(0.313, loss, 3)
def testAllWrongSigmoid(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -558,7 +558,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(loss.eval(), 600.0 / 9.0, 3)
def testAllWrongSigmoidWithMeasurementSpecificWeights(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0],
[-100.0, 100.0, -100.0],
[-100.0, -100.0, 100.0]])
@@ -582,11 +582,11 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
loss = loss_ops.sigmoid_cross_entropy(logits, labels)
self.assertEquals(loss.op.name, 'sigmoid_cross_entropy_loss/value')
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(loss.eval(), 0.0, 3)
def testSigmoidLabelSmoothingCorrect(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[100.0, -100.0, -100.0]])
labels = constant_op.constant([[1, 0, 1]])
# Sigmoid cross entropy loss is:
@@ -608,7 +608,7 @@ class SigmoidCrossEntropyLossTest(test.TestCase):
self.assertAlmostEqual(loss.eval(), expected_value, 3)
def testSigmoidLabelSmoothingEqualsSoftmaxTwoLabel(self):
- with self.test_session():
+ with self.cached_session():
label_smoothing = 0.1
sigmoid_logits = constant_op.constant([[100.0, -100.0, -100.0]])
sigmoid_labels = constant_op.constant([[1, 0, 1]])
@@ -641,33 +641,33 @@ class LogLossTest(test.TestCase):
self._labels = constant_op.constant(labels)
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.log_loss(self._labels, self._labels, weights=None)
def testAllCorrectNoLossWeight(self):
loss = loss_ops.log_loss(self._labels, self._labels)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testAllCorrectNoLossWeightWithPlaceholder(self):
tf_predictions = array_ops.placeholder(
dtypes.float32, shape=self._np_labels.shape)
loss = loss_ops.log_loss(tf_predictions, self._labels)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(
0.0, loss.eval(feed_dict={tf_predictions: self._np_labels}), 3)
def testNonZeroLoss(self):
loss = loss_ops.log_loss(self._predictions, self._labels)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(self._expected_losses) / 6.0,
loss.eval(), 3)
def testNonZeroLossWithPythonScalarWeight(self):
weights = 2.3
loss = loss_ops.log_loss(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
loss.eval(), 3)
@@ -675,7 +675,7 @@ class LogLossTest(test.TestCase):
weights = 2.3
loss = loss_ops.log_loss(self._predictions, self._labels,
constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
loss.eval(), 3)
@@ -685,7 +685,7 @@ class LogLossTest(test.TestCase):
weights = 2.3
loss = loss_ops.log_loss(tf_predictions, self._labels,
constant_op.constant(weights))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
loss, 3)
@@ -695,7 +695,7 @@ class LogLossTest(test.TestCase):
weights = 2.3
loss = loss_ops.log_loss(tf_predictions, self._labels,
constant_op.constant(weights))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(weights * -np.sum(self._expected_losses) / 6.0,
loss, 3)
@@ -706,7 +706,7 @@ class LogLossTest(test.TestCase):
self._expected_losses,
np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3)))
loss = loss_ops.log_loss(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses) / 6.0, loss.eval(), 3)
def testNonZeroLossWithOneDimBatchSpecificWeightsSomeZero(self):
@@ -715,7 +715,7 @@ class LogLossTest(test.TestCase):
np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape(
(2, 3)))
loss = loss_ops.log_loss(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses) / 3.0, loss.eval(), 3)
def testNonZeroLossWithTwoDimBatchSpecificWeightsSomeZero(self):
@@ -724,12 +724,12 @@ class LogLossTest(test.TestCase):
np.asarray([1.2, 1.2, 1.2, 0, 0, 0]).reshape(
(2, 3)))
loss = loss_ops.log_loss(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses) / 3.0, loss.eval(), 3)
def testWeightsWithSameNumDimsButWrongShapeThrowsException(self):
weights = constant_op.constant(np.random.normal(size=(2, 4)), shape=[2, 4])
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.log_loss(self._predictions, self._labels, weights)
@@ -742,7 +742,7 @@ class LogLossTest(test.TestCase):
self._labels,
constant_op.constant(
weights, shape=(2, 3)))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, loss.eval(), 3)
def testNonZeroLossWithMeasurementSpecificWeightsWithPlaceholder(self):
@@ -756,7 +756,7 @@ class LogLossTest(test.TestCase):
constant_op.constant(
weights, shape=(2, 3)))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(-np.sum(expected_losses) / 5.0, loss, 3)
@@ -769,7 +769,7 @@ class LogLossTest(test.TestCase):
self._labels,
constant_op.constant(
weights, shape=(2, 3)))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(-np.sum(expected_losses), loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeightsMostZeroWithPlaceholder(self):
@@ -780,35 +780,35 @@ class LogLossTest(test.TestCase):
tf_weights = constant_op.constant(weights, shape=(2, 3))
loss = loss_ops.log_loss(tf_predictions, self._labels, tf_weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._np_predictions})
self.assertAlmostEqual(-np.sum(expected_losses), loss, 3)
def testLossWithSampleSpecificWeightsAllZero(self):
tf_weights = array_ops.zeros(shape=(2, 3))
loss = loss_ops.log_loss(self._predictions, self._labels, tf_weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
class HingeLossTest(test.TestCase):
def testIncompatibleShapes(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[-1.0], [2.1]])
labels = constant_op.constant([0.0, 1.0])
with self.assertRaises(ValueError):
_ = loss_ops.hinge_loss(logits, labels).eval()
def testAllOutsideMargin(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([1.2, -1.4, -1.0, 2.1])
labels = constant_op.constant([1.0, 0.0, 0.0, 1.0])
loss = loss_ops.hinge_loss(logits, labels)
self.assertAllClose(loss.eval(), [0.0, 0.0, 0.0, 0.0], atol=1e-3)
def testSomeInsideMargin(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[-0.7], [-1.4], [1.4], [0.6]])
labels = constant_op.constant([[0.0], [0.0], [1.0], [1.0]])
loss = loss_ops.hinge_loss(logits, labels)
@@ -817,7 +817,7 @@ class HingeLossTest(test.TestCase):
self.assertAllClose(loss.eval(), [[0.3], [0.0], [0.0], [0.4]], atol=1e-3)
def testSomeMisclassified(self):
- with self.test_session():
+ with self.cached_session():
logits = constant_op.constant([[[1.2], [0.4], [-1.0], [-1.1]]])
labels = constant_op.constant([[[1.0], [0.0], [0.0], [1.0]]])
loss = loss_ops.hinge_loss(logits, labels)
@@ -834,62 +834,62 @@ class MeanSquaredErrorTest(test.TestCase):
self._labels = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.mean_squared_error(
self._predictions, self._predictions, weights=None)
def testAllCorrectNoLossWeight(self):
loss = loss_ops.mean_squared_error(self._predictions, self._predictions)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testNonZeroLoss(self):
loss = loss_ops.mean_squared_error(self._predictions, self._labels)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(49.5, loss.eval(), 3)
def testNonZeroLossWithPythonScalarWeight(self):
weights = 2.3
loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(49.5 * weights, loss.eval(), 3)
def testNonZeroLossWithScalarTensorWeight(self):
weights = 2.3
loss = loss_ops.mean_squared_error(self._predictions, self._labels,
constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(49.5 * weights, loss.eval(), 3)
def testNonZeroLossWithOneDimBatchSpecificWeights(self):
weights = constant_op.constant([1.2, 3.4], shape=[2,])
loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3)
def testNonZeroLossWithTwoDimBatchSpecificWeights(self):
weights = constant_op.constant([1.2, 3.4], shape=[2, 1])
loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(767.8 / 6.0, loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeights(self):
weights = constant_op.constant([3, 6, 5, 0, 4, 2], shape=[2, 3])
loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(587 / 5.0, loss.eval(), 3)
def testNonZeroLossWithSampleSpecificWeightsMostZero(self):
weights = constant_op.constant([0, 0, 0, 0, 0, 2], shape=[2, 3])
loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(18.0, loss.eval(), 3)
def testLossWithSampleSpecificWeightsAllZero(self):
weights = array_ops.zeros((2, 3))
loss = loss_ops.mean_squared_error(self._predictions, self._labels, weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
@@ -914,7 +914,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
self._expected_losses = np.divide(total, 9.0)
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.mean_pairwise_squared_error(
predictions=constant_op.constant(self._labels),
@@ -925,14 +925,14 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
loss = loss_ops.mean_pairwise_squared_error(
predictions=constant_op.constant(self._labels),
labels=constant_op.constant(self._labels))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testNonZeroLoss(self):
loss = loss_ops.mean_pairwise_squared_error(
predictions=constant_op.constant(self._predictions),
labels=constant_op.constant(self._labels))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(np.sum(self._expected_losses), loss.eval(), 3)
def testGradientWithZeroWeight(self):
@@ -954,7 +954,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
init_op = variables.global_variables_initializer()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(init_op)
for grad, _ in gradients_to_variables:
np_grad = sess.run(grad)
@@ -966,7 +966,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
predictions=constant_op.constant(self._predictions),
labels=constant_op.constant(self._labels),
weights=weights)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(weights * np.sum(self._expected_losses),
loss.eval(), 3)
@@ -976,7 +976,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
predictions=constant_op.constant(self._predictions),
labels=constant_op.constant(self._labels),
weights=constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(weights * np.sum(self._expected_losses),
loss.eval(), 3)
@@ -986,7 +986,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
predictions=constant_op.constant(self._predictions),
labels=constant_op.constant(self._labels),
weights=constant_op.constant(weights))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0, loss.eval(), 3)
def testNonZeroLossWithScalarTensorWeightWithPlaceholder(self):
@@ -998,7 +998,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
predictions=tf_predictions,
labels=tf_labels,
weights=constant_op.constant(weights))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
tf_predictions: self._predictions,
@@ -1015,7 +1015,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
labels=constant_op.constant(self._labels),
weights=constant_op.constant(
weights, shape=[2]))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(np.sum(expected_losses), loss.eval(), 3)
def testZeroLossWithOneDimBatchZeroWeights(self):
@@ -1025,7 +1025,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
labels=constant_op.constant(self._labels),
weights=constant_op.constant(
weights, shape=[2]))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0, loss.eval(), 3)
def testNonZeroLossWithOneDimBatchSpecificWeightsAndPlaceholders(self):
@@ -1041,7 +1041,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
weights=constant_op.constant(
weights, shape=[2]))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss,
feed_dict={
tf_predictions: self._predictions,
@@ -1056,7 +1056,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
labels=constant_op.constant(self._labels),
weights=constant_op.constant(
weights, shape=[2]))
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0.0, loss.eval(), 3)
def testLossIsAssociativeAcrossBatchElements(self):
@@ -1087,7 +1087,7 @@ class MeanPairwiseSquaresErrorTest(test.TestCase):
predictions=array_ops.concat([predictions0, predictions1], 0),
labels=array_ops.concat([labels0, labels1], 0))
- with self.test_session() as session:
+ with self.cached_session() as session:
loss0, loss1, loss0_1 = session.run([loss0, loss1, loss0_1])
self.assertTrue(loss0 > 0)
@@ -1115,7 +1115,7 @@ class CosineDistanceLossTest(test.TestCase):
[0, 1, 0]]).reshape((3, 2, 3))
def testValueErrorThrownWhenWeightIsNone(self):
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.cosine_distance(
predictions=constant_op.constant(self._labels),
@@ -1128,7 +1128,7 @@ class CosineDistanceLossTest(test.TestCase):
predictions=constant_op.constant(self._labels),
labels=constant_op.constant(self._labels),
dim=2)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(0, loss.eval(), 5)
def testPartiallyCorrectWithIntegerValues(self):
@@ -1136,7 +1136,7 @@ class CosineDistanceLossTest(test.TestCase):
predictions=constant_op.constant(self._predictions),
labels=constant_op.constant(self._labels),
dim=2)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(1, loss.eval(), 5)
def testPartiallyCorrectFloatingPointValues(self):
@@ -1154,7 +1154,7 @@ class CosineDistanceLossTest(test.TestCase):
labels, shape=(3, 1, 3), dtype=dtypes.float32)
loss = loss_ops.cosine_distance(tf_preds, tf_labels, dim=2)
- with self.test_session():
+ with self.cached_session():
self.assertAlmostEqual(1.0, loss.eval(), 5)
def testSampleSpecificWeights(self):
@@ -1163,7 +1163,7 @@ class CosineDistanceLossTest(test.TestCase):
labels=constant_op.constant(self._labels),
dim=2,
weights=constant_op.constant([1, 0, 0]))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(1.0, loss.eval())
def testMeasurementSpecificWeights(self):
@@ -1173,12 +1173,12 @@ class CosineDistanceLossTest(test.TestCase):
dim=2,
weights=constant_op.constant(
[1, 0, 0, 1, 1, 1], shape=(3, 2)))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(3.0 / 4.0, loss.eval())
def testValueErrorThrownWithShapelessPlaceholder(self):
tf_predictions = array_ops.placeholder(dtypes.float32)
- with self.test_session():
+ with self.cached_session():
with self.assertRaises(ValueError):
loss_ops.cosine_distance(
predictions=tf_predictions,
@@ -1196,7 +1196,7 @@ class CosineDistanceLossTest(test.TestCase):
dim=2,
weights=constant_op.constant(
[1, 0, 0, 1, 1, 1], shape=(3, 2)))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
loss = sess.run(loss, feed_dict={tf_predictions: self._predictions})
self.assertEqual(3.0 / 4.0, loss)
@@ -1206,7 +1206,7 @@ class CosineDistanceLossTest(test.TestCase):
labels=constant_op.constant(self._labels),
dim=2,
weights=array_ops.zeros((3,)))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(0, loss.eval())
def testZeroLossWhenAllMeasurementSpecificWeightsAreZero(self):
@@ -1215,7 +1215,7 @@ class CosineDistanceLossTest(test.TestCase):
labels=constant_op.constant(self._labels),
dim=2,
weights=array_ops.zeros((3, 2)))
- with self.test_session():
+ with self.cached_session():
self.assertEqual(0, loss.eval())
@@ -1228,7 +1228,7 @@ class ComputeWeightedLossTest(test.TestCase):
self.assertFalse(loss_ops.get_losses())
loss = loss_ops.compute_weighted_loss(losses)
self.assertTrue(loss_ops.get_losses())
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(losses.eval(), [0.0, 1.4, 0.0, 2.1], atol=1e-3)
self.assertAllClose(loss.eval(), 3.5 / 4.0, atol=1e-3)
@@ -1243,7 +1243,7 @@ class AddLossTest(test.TestCase):
loss_ops.add_loss(math_ops.reduce_mean(losses))
self.assertTrue(loss_ops.get_losses())
total_loss = loss_ops.get_total_loss()
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(losses.eval(), [[0.0, 1.4, 0.0, 2.1]], atol=1e-3)
self.assertAllClose(total_loss.eval(), 3.5 / 4.0, atol=1e-3)
@@ -1254,7 +1254,7 @@ class AddLossTest(test.TestCase):
self.assertFalse(loss_ops.get_losses())
loss_ops.add_loss(math_ops.reduce_mean(losses), loss_collection=None)
self.assertFalse(loss_ops.get_losses())
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(losses.eval(), [[0.0, 1.4, 0.0, 2.1]], atol=1e-3)
def testNoCollectLosses(self):
diff --git a/tensorflow/contrib/makefile/proto_text_cc_files.txt b/tensorflow/contrib/makefile/proto_text_cc_files.txt
index 7d26429f9c..9ea94c7433 100644
--- a/tensorflow/contrib/makefile/proto_text_cc_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_cc_files.txt
@@ -1,62 +1,61 @@
-tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc
-tensorflow/tools/proto_text/gen_proto_text_functions.cc
tensorflow/core/framework/resource_handle.cc
+tensorflow/core/lib/core/arena.cc
+tensorflow/core/lib/core/coding.cc
+tensorflow/core/lib/core/status.cc
+tensorflow/core/lib/core/threadpool.cc
+tensorflow/core/lib/hash/crc32c.cc
+tensorflow/core/lib/hash/crc32c_accelerate.cc
+tensorflow/core/lib/hash/hash.cc
+tensorflow/core/lib/histogram/histogram.cc
+tensorflow/core/lib/io/block.cc
+tensorflow/core/lib/io/block_builder.cc
+tensorflow/core/lib/io/buffered_inputstream.cc
+tensorflow/core/lib/io/compression.cc
+tensorflow/core/lib/io/format.cc
+tensorflow/core/lib/io/inputbuffer.cc
+tensorflow/core/lib/io/inputstream_interface.cc
+tensorflow/core/lib/io/iterator.cc
+tensorflow/core/lib/io/path.cc
+tensorflow/core/lib/io/random_inputstream.cc
+tensorflow/core/lib/io/record_reader.cc
+tensorflow/core/lib/io/record_writer.cc
+tensorflow/core/lib/io/table.cc
+tensorflow/core/lib/io/table_builder.cc
+tensorflow/core/lib/io/two_level_iterator.cc
+tensorflow/core/lib/io/zlib_compression_options.cc
+tensorflow/core/lib/io/zlib_inputstream.cc
+tensorflow/core/lib/io/zlib_outputbuffer.cc
+tensorflow/core/lib/random/distribution_sampler.cc
+tensorflow/core/lib/random/random.cc
+tensorflow/core/lib/random/simple_philox.cc
+tensorflow/core/lib/random/weighted_picker.cc
+tensorflow/core/lib/strings/numbers.cc
+tensorflow/core/lib/strings/ordered_code.cc
+tensorflow/core/lib/strings/proto_text_util.cc
+tensorflow/core/lib/strings/scanner.cc
+tensorflow/core/lib/strings/str_util.cc
+tensorflow/core/lib/strings/strcat.cc
+tensorflow/core/lib/strings/stringprintf.cc
+tensorflow/core/lib/wav/wav_io.cc
+tensorflow/core/platform/cpu_info.cc
+tensorflow/core/platform/default/logging.cc
+tensorflow/core/platform/default/mutex.cc
tensorflow/core/platform/default/protobuf.cc
-tensorflow/core/platform/tracing.cc
-tensorflow/core/platform/tensor_coding.cc
-tensorflow/core/platform/protobuf_util.cc
-tensorflow/core/platform/posix/posix_file_system.cc
-tensorflow/core/platform/posix/port.cc
-tensorflow/core/platform/posix/error.cc
-tensorflow/core/platform/posix/env.cc
-tensorflow/core/platform/posix/load_library.cc
-tensorflow/core/platform/posix/env_time.cc
-tensorflow/core/platform/file_system.cc
-tensorflow/core/platform/file_system_helper.cc
+tensorflow/core/platform/default/tracing.cc
+tensorflow/core/platform/denormal.cc
tensorflow/core/platform/env.cc
tensorflow/core/platform/env_time.cc
+tensorflow/core/platform/file_system.cc
+tensorflow/core/platform/file_system_helper.cc
+tensorflow/core/platform/posix/env.cc
+tensorflow/core/platform/posix/env_time.cc
+tensorflow/core/platform/posix/error.cc
+tensorflow/core/platform/posix/load_library.cc
+tensorflow/core/platform/posix/port.cc
+tensorflow/core/platform/posix/posix_file_system.cc
+tensorflow/core/platform/protobuf_util.cc
tensorflow/core/platform/setround.cc
-tensorflow/core/platform/denormal.cc
-tensorflow/core/platform/default/tracing.cc
-tensorflow/core/platform/default/mutex.cc
-tensorflow/core/platform/default/logging.cc
-tensorflow/core/platform/cpu_info.cc
-tensorflow/core/lib/wav/wav_io.cc
-tensorflow/core/lib/strings/stringprintf.cc
-tensorflow/core/lib/strings/strcat.cc
-tensorflow/core/lib/strings/str_util.cc
-tensorflow/core/lib/strings/scanner.cc
-tensorflow/core/lib/strings/proto_text_util.cc
-tensorflow/core/lib/strings/ordered_code.cc
-tensorflow/core/lib/strings/numbers.cc
-tensorflow/core/lib/random/weighted_picker.cc
-tensorflow/core/lib/random/simple_philox.cc
-tensorflow/core/lib/random/random.cc
-tensorflow/core/lib/random/distribution_sampler.cc
-tensorflow/core/lib/io/zlib_outputbuffer.cc
-tensorflow/core/lib/io/zlib_inputstream.cc
-tensorflow/core/lib/io/zlib_compression_options.cc
-tensorflow/core/lib/io/two_level_iterator.cc
-tensorflow/core/lib/io/table_builder.cc
-tensorflow/core/lib/io/table.cc
-tensorflow/core/lib/io/record_writer.cc
-tensorflow/core/lib/io/record_reader.cc
-tensorflow/core/lib/io/random_inputstream.cc
-tensorflow/core/lib/io/path.cc
-tensorflow/core/lib/io/iterator.cc
-tensorflow/core/lib/io/inputstream_interface.cc
-tensorflow/core/lib/io/inputbuffer.cc
-tensorflow/core/lib/io/format.cc
-tensorflow/core/lib/io/compression.cc
-tensorflow/core/lib/io/buffered_inputstream.cc
-tensorflow/core/lib/io/block_builder.cc
-tensorflow/core/lib/io/block.cc
-tensorflow/core/lib/histogram/histogram.cc
-tensorflow/core/lib/hash/hash.cc
-tensorflow/core/lib/hash/crc32c.cc
-tensorflow/core/lib/hash/crc32c_accelerate.cc
-tensorflow/core/lib/core/threadpool.cc
-tensorflow/core/lib/core/stringpiece.cc
-tensorflow/core/lib/core/status.cc
-tensorflow/core/lib/core/coding.cc
-tensorflow/core/lib/core/arena.cc
+tensorflow/core/platform/tensor_coding.cc
+tensorflow/core/platform/tracing.cc
+tensorflow/tools/proto_text/gen_proto_text_functions.cc
+tensorflow/tools/proto_text/gen_proto_text_functions_lib.cc
diff --git a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
index 938c4a53ab..0d8df93d11 100644
--- a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt
@@ -1,41 +1,41 @@
-tensorflow/core/util/test_log.pb.cc
-tensorflow/core/util/saved_tensor_slice.pb.cc
-tensorflow/core/util/memmapped_file_system.pb.cc
-tensorflow/core/util/event.pb.cc
-tensorflow/core/protobuf/tensorflow_server.pb.cc
-tensorflow/core/protobuf/saver.pb.cc
-tensorflow/core/protobuf/queue_runner.pb.cc
-tensorflow/core/protobuf/named_tensor.pb.cc
-tensorflow/core/protobuf/meta_graph.pb.cc
+tensorflow/core/example/example.pb.cc
+tensorflow/core/example/feature.pb.cc
+tensorflow/core/framework/allocation_description.pb.cc
+tensorflow/core/framework/api_def.pb.cc
+tensorflow/core/framework/attr_value.pb.cc
+tensorflow/core/framework/cost_graph.pb.cc
+tensorflow/core/framework/device_attributes.pb.cc
+tensorflow/core/framework/function.pb.cc
+tensorflow/core/framework/graph.pb.cc
+tensorflow/core/framework/graph_transfer_info.pb.cc
+tensorflow/core/framework/kernel_def.pb.cc
+tensorflow/core/framework/log_memory.pb.cc
+tensorflow/core/framework/node_def.pb.cc
+tensorflow/core/framework/op_def.pb.cc
+tensorflow/core/framework/remote_fused_graph_execute_info.pb.cc
+tensorflow/core/framework/resource_handle.pb.cc
+tensorflow/core/framework/step_stats.pb.cc
+tensorflow/core/framework/summary.pb.cc
+tensorflow/core/framework/tensor.pb.cc
+tensorflow/core/framework/tensor_description.pb.cc
+tensorflow/core/framework/tensor_shape.pb.cc
+tensorflow/core/framework/tensor_slice.pb.cc
+tensorflow/core/framework/types.pb.cc
+tensorflow/core/framework/variable.pb.cc
+tensorflow/core/framework/versions.pb.cc
+tensorflow/core/grappler/costs/op_performance_data.pb.cc
+tensorflow/core/lib/core/error_codes.pb.cc
tensorflow/core/protobuf/cluster.pb.cc
tensorflow/core/protobuf/config.pb.cc
-tensorflow/core/protobuf/rewriter_config.pb.cc
tensorflow/core/protobuf/debug.pb.cc
tensorflow/core/protobuf/device_properties.pb.cc
-tensorflow/core/lib/core/error_codes.pb.cc
-tensorflow/core/framework/versions.pb.cc
-tensorflow/core/framework/variable.pb.cc
-tensorflow/core/framework/types.pb.cc
-tensorflow/core/framework/tensor_slice.pb.cc
-tensorflow/core/framework/tensor_shape.pb.cc
-tensorflow/core/framework/tensor_description.pb.cc
-tensorflow/core/framework/tensor.pb.cc
-tensorflow/core/framework/summary.pb.cc
-tensorflow/core/framework/step_stats.pb.cc
-tensorflow/core/framework/resource_handle.pb.cc
-tensorflow/core/framework/remote_fused_graph_execute_info.pb.cc
-tensorflow/core/framework/api_def.pb.cc
-tensorflow/core/framework/op_def.pb.cc
-tensorflow/core/framework/node_def.pb.cc
-tensorflow/core/framework/log_memory.pb.cc
-tensorflow/core/framework/kernel_def.pb.cc
-tensorflow/core/framework/graph_transfer_info.pb.cc
-tensorflow/core/framework/graph.pb.cc
-tensorflow/core/framework/function.pb.cc
-tensorflow/core/framework/device_attributes.pb.cc
-tensorflow/core/framework/cost_graph.pb.cc
-tensorflow/core/framework/attr_value.pb.cc
-tensorflow/core/framework/allocation_description.pb.cc
-tensorflow/core/example/feature.pb.cc
-tensorflow/core/example/example.pb.cc
-tensorflow/core/grappler/costs/op_performance_data.pb.cc
+tensorflow/core/protobuf/meta_graph.pb.cc
+tensorflow/core/protobuf/named_tensor.pb.cc
+tensorflow/core/protobuf/queue_runner.pb.cc
+tensorflow/core/protobuf/rewriter_config.pb.cc
+tensorflow/core/protobuf/saver.pb.cc
+tensorflow/core/protobuf/tensorflow_server.pb.cc
+tensorflow/core/util/event.pb.cc
+tensorflow/core/util/memmapped_file_system.pb.cc
+tensorflow/core/util/saved_tensor_slice.pb.cc
+tensorflow/core/util/test_log.pb.cc
diff --git a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
index aa91b2f954..d982df9319 100644
--- a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
+++ b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt
@@ -1,42 +1,43 @@
-tensorflow/core/util/test_log.pb.h
-tensorflow/core/util/saved_tensor_slice.pb.h
-tensorflow/core/util/memmapped_file_system.pb.h
-tensorflow/core/util/event.pb.h
-tensorflow/core/protobuf/tensorflow_server.pb.h
-tensorflow/core/protobuf/saver.pb.h
-tensorflow/core/protobuf/queue_runner.pb.h
-tensorflow/core/protobuf/named_tensor.pb.h
-tensorflow/core/protobuf/meta_graph.pb.h
+tensorflow/core/example/example.pb.h
+tensorflow/core/example/feature.pb.h
+tensorflow/core/framework/allocation_description.pb.h
+tensorflow/core/framework/api_def.pb.h
+tensorflow/core/framework/attr_value.pb.h
+tensorflow/core/framework/cost_graph.pb.h
+tensorflow/core/framework/device_attributes.pb.h
+tensorflow/core/framework/function.pb.h
+tensorflow/core/framework/graph.pb.h
+tensorflow/core/framework/graph_transfer_info.pb.h
+tensorflow/core/framework/kernel_def.pb.h
+tensorflow/core/framework/log_memory.pb.h
+tensorflow/core/framework/node_def.pb.h
+tensorflow/core/framework/op_def.pb.h
+tensorflow/core/framework/remote_fused_graph_execute_info.pb.h
+tensorflow/core/framework/resource_handle.pb.h
+tensorflow/core/framework/step_stats.pb.h
+tensorflow/core/framework/summary.pb.h
+tensorflow/core/framework/tensor.pb.h
+tensorflow/core/framework/tensor_description.pb.h
+tensorflow/core/framework/tensor_shape.pb.h
+tensorflow/core/framework/tensor_slice.pb.h
+tensorflow/core/framework/types.pb.h
+tensorflow/core/framework/variable.pb.h
+tensorflow/core/framework/versions.pb.h
+tensorflow/core/grappler/costs/op_performance_data.pb.h
+tensorflow/core/lib/core/error_codes.pb.h
tensorflow/core/protobuf/cluster.pb.h
tensorflow/core/protobuf/config.pb.h
tensorflow/core/protobuf/debug.pb.h
tensorflow/core/protobuf/device_properties.pb.h
+tensorflow/core/protobuf/meta_graph.pb.h
+tensorflow/core/protobuf/named_tensor.pb.h
+tensorflow/core/protobuf/queue_runner.pb.h
tensorflow/core/protobuf/rewriter_config.pb.h
+tensorflow/core/protobuf/saver.pb.h
tensorflow/core/protobuf/tensor_bundle.pb.h
-tensorflow/core/lib/core/error_codes.pb.h
-tensorflow/core/framework/versions.pb.h
-tensorflow/core/framework/variable.pb.h
-tensorflow/core/framework/types.pb.h
-tensorflow/core/framework/tensor_slice.pb.h
-tensorflow/core/framework/tensor_shape.pb.h
-tensorflow/core/framework/tensor_description.pb.h
-tensorflow/core/framework/tensor.pb.h
-tensorflow/core/framework/summary.pb.h
-tensorflow/core/framework/step_stats.pb.h
-tensorflow/core/framework/resource_handle.pb.h
-tensorflow/core/framework/remote_fused_graph_execute_info.pb.h
-tensorflow/core/framework/api_def.pb.h
-tensorflow/core/framework/op_def.pb.h
-tensorflow/core/framework/node_def.pb.h
-tensorflow/core/framework/log_memory.pb.h
-tensorflow/core/framework/kernel_def.pb.h
-tensorflow/core/framework/graph_transfer_info.pb.h
-tensorflow/core/framework/graph.pb.h
-tensorflow/core/framework/function.pb.h
-tensorflow/core/framework/device_attributes.pb.h
-tensorflow/core/framework/cost_graph.pb.h
-tensorflow/core/framework/attr_value.pb.h
-tensorflow/core/framework/allocation_description.pb.h
-tensorflow/core/example/feature.pb.h
-tensorflow/core/example/example.pb.h
-tensorflow/core/grappler/costs/op_performance_data.pb.h
+tensorflow/core/protobuf/tensorflow_server.pb.h
+tensorflow/core/util/event.pb.h
+tensorflow/core/util/memmapped_file_system.pb.h
+tensorflow/core/util/saved_tensor_slice.pb.h
+tensorflow/core/util/test_log.pb.h
+
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index 66a3315700..08de54b8e1 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -4,218 +4,19 @@ tensorflow/contrib/boosted_trees/ops/quantile_ops.cc
tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc
tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc
tensorflow/contrib/boosted_trees/ops/training_ops.cc
-tensorflow/core/kernels/xent_op.cc
-tensorflow/core/kernels/where_op.cc
-tensorflow/core/kernels/variable_ops.cc
-tensorflow/core/kernels/unpack_op.cc
-tensorflow/core/kernels/unique_op.cc
-tensorflow/core/kernels/transpose_op.cc
-tensorflow/core/kernels/transpose_functor_cpu.cc
-tensorflow/core/kernels/training_op_helpers.cc
-tensorflow/core/kernels/training_ops.cc
-tensorflow/core/kernels/topk_op.cc
-tensorflow/core/kernels/tile_functor_cpu.cc
-tensorflow/core/kernels/tile_ops.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_1.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_2.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_3.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_4.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_5.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_6.cc
-tensorflow/core/kernels/tile_ops_cpu_impl_7.cc
-tensorflow/core/kernels/tensor_array_ops.cc
-tensorflow/core/kernels/tensor_array.cc
-tensorflow/core/kernels/strided_slice_op_inst_7.cc
-tensorflow/core/kernels/strided_slice_op_inst_6.cc
-tensorflow/core/kernels/strided_slice_op_inst_5.cc
-tensorflow/core/kernels/strided_slice_op_inst_4.cc
-tensorflow/core/kernels/strided_slice_op_inst_3.cc
-tensorflow/core/kernels/strided_slice_op_inst_2.cc
-tensorflow/core/kernels/strided_slice_op_inst_1.cc
-tensorflow/core/kernels/strided_slice_op_inst_0.cc
-tensorflow/core/kernels/strided_slice_op.cc
-tensorflow/core/kernels/stack_ops.cc
-tensorflow/core/kernels/split_op.cc
-tensorflow/core/kernels/split_v_op.cc
-tensorflow/core/kernels/split_lib_cpu.cc
-tensorflow/core/kernels/spectrogram_op.cc
-tensorflow/core/kernels/spectrogram.cc
-tensorflow/core/kernels/sparse_to_dense_op.cc
-tensorflow/core/kernels/sparse_matmul_op.cc
-tensorflow/core/kernels/sparse_fill_empty_rows_op.cc
-tensorflow/core/kernels/sparse_reshape_op.c
-tensorflow/core/kernels/segment_reduction_ops.cc
-tensorflow/core/kernels/softsign_op.cc
-tensorflow/core/kernels/softplus_op.cc
-tensorflow/core/kernels/softmax_op.cc
-tensorflow/core/kernels/slice_op_cpu_impl_1.cc
-tensorflow/core/kernels/slice_op_cpu_impl_2.cc
-tensorflow/core/kernels/slice_op_cpu_impl_3.cc
-tensorflow/core/kernels/slice_op_cpu_impl_4.cc
-tensorflow/core/kernels/slice_op_cpu_impl_5.cc
-tensorflow/core/kernels/slice_op_cpu_impl_6.cc
-tensorflow/core/kernels/slice_op_cpu_impl_7.cc
-tensorflow/core/kernels/slice_op.cc
-tensorflow/core/kernels/shape_ops.cc
-tensorflow/core/kernels/session_ops.cc
-tensorflow/core/kernels/sequence_ops.cc
-tensorflow/core/kernels/sendrecv_ops.cc
-tensorflow/core/kernels/scatter_op.cc
-tensorflow/core/kernels/scatter_functor.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_0.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_1.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_2.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_3.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_4.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_5.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_6.cc
-tensorflow/core/kernels/scatter_nd_op_cpu_impl_7.cc
-tensorflow/core/kernels/scatter_nd_op.cc
-tensorflow/core/kernels/save_restore_tensor.cc
-tensorflow/core/kernels/save_restore_v2_ops.cc
-tensorflow/core/kernels/save_op.cc
-tensorflow/core/kernels/string_join_op.cc
-tensorflow/core/kernels/reverse_sequence_op.cc
-tensorflow/core/kernels/reverse_op.cc
-tensorflow/core/kernels/restore_op.cc
-tensorflow/core/kernels/resize_nearest_neighbor_op.cc
-tensorflow/core/kernels/resize_bilinear_op.cc
-tensorflow/core/kernels/reshape_util.cc
-tensorflow/core/kernels/reshape_op.cc
-tensorflow/core/kernels/relu_op.cc
-tensorflow/core/kernels/reduction_ops_sum.cc
-tensorflow/core/kernels/reduction_ops_prod.cc
-tensorflow/core/kernels/reduction_ops_min.cc
-tensorflow/core/kernels/reduction_ops_mean.cc
-tensorflow/core/kernels/reduction_ops_max.cc
-tensorflow/core/kernels/reduction_ops_common.cc
-tensorflow/core/kernels/reduction_ops_any.cc
-tensorflow/core/kernels/reduction_ops_all.cc
-tensorflow/core/kernels/roll_op.cc
-tensorflow/core/kernels/queue_op.cc
-tensorflow/core/kernels/queue_ops.cc
-tensorflow/core/kernels/queue_base.cc
-tensorflow/core/kernels/pooling_ops_common.cc
-tensorflow/core/kernels/padding_fifo_queue_op.cc
-tensorflow/core/kernels/padding_fifo_queue.cc
-tensorflow/core/kernels/pad_op.cc
-tensorflow/core/kernels/pack_op.cc
-tensorflow/core/kernels/ops_util.cc
-tensorflow/core/kernels/one_hot_op.cc
-tensorflow/core/kernels/non_max_suppression_op.cc
-tensorflow/core/kernels/no_op.cc
-tensorflow/core/kernels/mirror_pad_op.cc
-tensorflow/core/kernels/mirror_pad_op_cpu_impl_1.cc
-tensorflow/core/kernels/mirror_pad_op_cpu_impl_2.cc
-tensorflow/core/kernels/mirror_pad_op_cpu_impl_3.cc
-tensorflow/core/kernels/mirror_pad_op_cpu_impl_4.cc
-tensorflow/core/kernels/mirror_pad_op_cpu_impl_5.cc
-tensorflow/core/kernels/mfcc_op.cc
-tensorflow/core/kernels/mfcc_mel_filterbank.cc
-tensorflow/core/kernels/mfcc_dct.cc
-tensorflow/core/kernels/mfcc.cc
-tensorflow/core/kernels/maxpooling_op.cc
-tensorflow/core/kernels/matmul_op.cc
-tensorflow/core/kernels/lrn_op.cc
-tensorflow/core/kernels/logging_ops.cc
-tensorflow/core/kernels/initializable_lookup_table.c
-tensorflow/core/kernels/lookup_table_init_op.cc
-tensorflow/core/kernels/lookup_table_op.cc
-tensorflow/core/kernels/lookup_util.cc
-tensorflow/core/kernels/inplace_ops.cc
-tensorflow/core/kernels/in_topk_op.cc
-tensorflow/core/kernels/immutable_constant_op.cc
-tensorflow/core/kernels/identity_op.cc
-tensorflow/core/kernels/identity_n_op.cc
-tensorflow/core/kernels/gather_op.cc
-tensorflow/core/kernels/gather_functor.cc
-tensorflow/core/kernels/gather_nd_op.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_0.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_1.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_2.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_3.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_4.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_5.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_6.cc
-tensorflow/core/kernels/gather_nd_op_cpu_impl_7.cc
-tensorflow/core/kernels/fused_batch_norm_op.cc
-tensorflow/core/kernels/function_ops.cc
-tensorflow/core/kernels/fill_functor.cc
-tensorflow/core/kernels/fifo_queue.cc
-tensorflow/core/kernels/fifo_queue_op.cc
-tensorflow/core/kernels/fake_quant_ops.cc
-tensorflow/core/kernels/example_parsing_ops.cc
-tensorflow/core/kernels/encode_wav_op.cc
-tensorflow/core/kernels/dynamic_stitch_op.cc
-tensorflow/core/kernels/dynamic_partition_op.cc
-tensorflow/core/kernels/decode_bmp_op.cc
-tensorflow/core/kernels/depthtospace_op.cc
-tensorflow/core/kernels/data_format_ops.cc
-tensorflow/core/kernels/spacetodepth_op.cc
-tensorflow/core/kernels/dense_update_functor.cc
-tensorflow/core/kernels/dense_update_ops.cc
-tensorflow/core/kernels/deep_conv2d.cc
-tensorflow/core/kernels/decode_wav_op.cc
-tensorflow/core/kernels/xsmm_conv2d.cc
-tensorflow/core/kernels/cwise_ops_common.cc
-tensorflow/core/kernels/cwise_op_tanh.cc
-tensorflow/core/kernels/cwise_op_pow.cc
-tensorflow/core/kernels/cwise_op_sub.cc
-tensorflow/core/kernels/cwise_op_squared_difference.cc
-tensorflow/core/kernels/cwise_op_square.cc
-tensorflow/core/kernels/cwise_op_sqrt.cc
-tensorflow/core/kernels/cwise_op_sigmoid.cc
-tensorflow/core/kernels/cwise_op_sign.cc
-tensorflow/core/kernels/cwise_op_select.cc
-tensorflow/core/kernels/cwise_op_round.cc
-tensorflow/core/kernels/cwise_op_rsqrt.cc
-tensorflow/core/kernels/cwise_op_reciprocal.cc
-tensorflow/core/kernels/cwise_op_neg.cc
-tensorflow/core/kernels/cwise_op_mul_2.cc
-tensorflow/core/kernels/cwise_op_mul_1.cc
-tensorflow/core/kernels/cwise_op_minimum.cc
-tensorflow/core/kernels/cwise_op_maximum.cc
-tensorflow/core/kernels/cwise_op_logical_not.cc
-tensorflow/core/kernels/cwise_op_logical_and.cc
-tensorflow/core/kernels/cwise_op_logical_or.cc
-tensorflow/core/kernels/cwise_op_log.cc
-tensorflow/core/kernels/cwise_op_less.cc
-tensorflow/core/kernels/cwise_op_less_equal.cc
-tensorflow/core/kernels/cwise_op_isnan.cc
-tensorflow/core/kernels/cwise_op_isfinite.cc
-tensorflow/core/kernels/cwise_op_invert.cc
-tensorflow/core/kernels/cwise_op_greater_equal.cc
-tensorflow/core/kernels/cwise_op_greater.cc
-tensorflow/core/kernels/cwise_op_floor_div.cc
-tensorflow/core/kernels/cwise_op_floor_mod.cc
-tensorflow/core/kernels/cwise_op_floor.cc
-tensorflow/core/kernels/cwise_op_exp.cc
-tensorflow/core/kernels/cwise_op_equal_to_2.cc
-tensorflow/core/kernels/cwise_op_equal_to_1.cc
-tensorflow/core/kernels/cwise_op_not_equal_to_2.cc
-tensorflow/core/kernels/cwise_op_not_equal_to_1.cc
-tensorflow/core/kernels/cwise_op_div.cc
-tensorflow/core/kernels/cwise_op_bitwise_xor.cc
-tensorflow/core/kernels/cwise_op_bitwise_or.cc
-tensorflow/core/kernels/cwise_op_bitwise_and.cc
-tensorflow/core/kernels/cwise_op_left_shift.cc
-tensorflow/core/kernels/cwise_op_right_shift.cc
-tensorflow/core/kernels/cwise_op_add_2.cc
-tensorflow/core/kernels/cwise_op_add_1.cc
-tensorflow/core/kernels/cwise_op_abs.cc
-tensorflow/core/kernels/ctc_decoder_ops.cc
-tensorflow/core/kernels/crop_and_resize_op.cc
-tensorflow/core/kernels/conv_ops_using_gemm.cc
-tensorflow/core/kernels/conv_ops_fused.cc
-tensorflow/core/kernels/conv_ops.cc
-tensorflow/core/kernels/conv_grad_filter_ops.cc
-tensorflow/core/kernels/conv_grad_input_ops.cc
-tensorflow/core/kernels/conv_grad_ops.cc
-tensorflow/core/kernels/control_flow_ops.cc
-tensorflow/core/kernels/constant_op.cc
-tensorflow/core/kernels/concat_op.cc
-tensorflow/core/kernels/concat_lib_cpu.cc
-tensorflow/core/kernels/check_numerics_op.cc
+tensorflow/core/kernels/aggregate_ops.cc
+tensorflow/core/kernels/argmax_op.cc
+tensorflow/core/kernels/avgpooling_op.cc
+tensorflow/core/kernels/batch_matmul_op_real.cc
+tensorflow/core/kernels/batch_norm_op.cc
+tensorflow/core/kernels/batchtospace_op.cc
+tensorflow/core/kernels/bcast_ops.cc
+tensorflow/core/kernels/bias_op.cc
+tensorflow/core/kernels/boosted_trees/prediction_ops.cc
+tensorflow/core/kernels/boosted_trees/resource_ops.cc
+tensorflow/core/kernels/boosted_trees/resources.cc
+tensorflow/core/kernels/boosted_trees/stats_ops.cc
+tensorflow/core/kernels/boosted_trees/training_ops.cc
tensorflow/core/kernels/cast_op.cc
tensorflow/core/kernels/cast_op_impl_bfloat.cc
tensorflow/core/kernels/cast_op_impl_bool.cc
@@ -232,20 +33,131 @@ tensorflow/core/kernels/cast_op_impl_uint16.cc
tensorflow/core/kernels/cast_op_impl_uint32.cc
tensorflow/core/kernels/cast_op_impl_uint64.cc
tensorflow/core/kernels/cast_op_impl_uint8.cc
-tensorflow/core/kernels/boosted_trees/prediction_ops.cc
-tensorflow/core/kernels/boosted_trees/resource_ops.cc
-tensorflow/core/kernels/boosted_trees/resources.cc
-tensorflow/core/kernels/boosted_trees/stats_ops.cc
-tensorflow/core/kernels/boosted_trees/training_ops.cc
-tensorflow/core/kernels/bias_op.cc
-tensorflow/core/kernels/bcast_ops.cc
-tensorflow/core/kernels/batch_norm_op.cc
-tensorflow/core/kernels/avgpooling_op.cc
-tensorflow/core/kernels/argmax_op.cc
-tensorflow/core/kernels/aggregate_ops.cc
+tensorflow/core/kernels/check_numerics_op.cc
+tensorflow/core/kernels/concat_lib_cpu.cc
+tensorflow/core/kernels/concat_op.cc
+tensorflow/core/kernels/constant_op.cc
+tensorflow/core/kernels/control_flow_ops.cc
+tensorflow/core/kernels/conv_grad_filter_ops.cc
+tensorflow/core/kernels/conv_grad_input_ops.cc
+tensorflow/core/kernels/conv_grad_ops.cc
+tensorflow/core/kernels/conv_ops.cc
+tensorflow/core/kernels/conv_ops_fused.cc
+tensorflow/core/kernels/conv_ops_using_gemm.cc
+tensorflow/core/kernels/crop_and_resize_op.cc
+tensorflow/core/kernels/ctc_decoder_ops.cc
+tensorflow/core/kernels/cwise_op_abs.cc
+tensorflow/core/kernels/cwise_op_add_1.cc
+tensorflow/core/kernels/cwise_op_add_2.cc
+tensorflow/core/kernels/cwise_op_bitwise_and.cc
+tensorflow/core/kernels/cwise_op_bitwise_or.cc
+tensorflow/core/kernels/cwise_op_bitwise_xor.cc
+tensorflow/core/kernels/cwise_op_div.cc
+tensorflow/core/kernels/cwise_op_equal_to_1.cc
+tensorflow/core/kernels/cwise_op_equal_to_2.cc
+tensorflow/core/kernels/cwise_op_exp.cc
+tensorflow/core/kernels/cwise_op_floor.cc
+tensorflow/core/kernels/cwise_op_floor_div.cc
+tensorflow/core/kernels/cwise_op_floor_mod.cc
+tensorflow/core/kernels/cwise_op_greater.cc
+tensorflow/core/kernels/cwise_op_greater_equal.cc
+tensorflow/core/kernels/cwise_op_invert.cc
+tensorflow/core/kernels/cwise_op_isfinite.cc
+tensorflow/core/kernels/cwise_op_isnan.cc
+tensorflow/core/kernels/cwise_op_left_shift.cc
+tensorflow/core/kernels/cwise_op_less.cc
+tensorflow/core/kernels/cwise_op_less_equal.cc
+tensorflow/core/kernels/cwise_op_log.cc
+tensorflow/core/kernels/cwise_op_logical_and.cc
+tensorflow/core/kernels/cwise_op_logical_not.cc
+tensorflow/core/kernels/cwise_op_logical_or.cc
+tensorflow/core/kernels/cwise_op_maximum.cc
+tensorflow/core/kernels/cwise_op_minimum.cc
+tensorflow/core/kernels/cwise_op_mul_1.cc
+tensorflow/core/kernels/cwise_op_mul_2.cc
+tensorflow/core/kernels/cwise_op_neg.cc
+tensorflow/core/kernels/cwise_op_not_equal_to_1.cc
+tensorflow/core/kernels/cwise_op_not_equal_to_2.cc
+tensorflow/core/kernels/cwise_op_pow.cc
+tensorflow/core/kernels/cwise_op_reciprocal.cc
+tensorflow/core/kernels/cwise_op_right_shift.cc
+tensorflow/core/kernels/cwise_op_round.cc
+tensorflow/core/kernels/cwise_op_rsqrt.cc
+tensorflow/core/kernels/cwise_op_select.cc
+tensorflow/core/kernels/cwise_op_sigmoid.cc
+tensorflow/core/kernels/cwise_op_sign.cc
+tensorflow/core/kernels/cwise_op_sqrt.cc
+tensorflow/core/kernels/cwise_op_square.cc
+tensorflow/core/kernels/cwise_op_squared_difference.cc
+tensorflow/core/kernels/cwise_op_sub.cc
+tensorflow/core/kernels/cwise_op_tanh.cc
+tensorflow/core/kernels/cwise_ops_common.cc
+tensorflow/core/kernels/data_format_ops.cc
+tensorflow/core/kernels/decode_bmp_op.cc
+tensorflow/core/kernels/decode_proto_op.cc
+tensorflow/core/kernels/decode_wav_op.cc
+tensorflow/core/kernels/deep_conv2d.cc
+tensorflow/core/kernels/dense_update_functor.cc
+tensorflow/core/kernels/dense_update_ops.cc
+tensorflow/core/kernels/depthtospace_op.cc
tensorflow/core/kernels/depthwise_conv_op.cc
tensorflow/core/kernels/dequantize_op.cc
+tensorflow/core/kernels/dynamic_partition_op.cc
+tensorflow/core/kernels/dynamic_stitch_op.cc
+tensorflow/core/kernels/encode_proto_op.cc
+tensorflow/core/kernels/encode_wav_op.cc
+tensorflow/core/kernels/example_parsing_ops.cc
+tensorflow/core/kernels/fake_quant_ops.cc
+tensorflow/core/kernels/fifo_queue.cc
+tensorflow/core/kernels/fifo_queue_op.cc
+tensorflow/core/kernels/fill_functor.cc
+tensorflow/core/kernels/function_ops.cc
+tensorflow/core/kernels/fused_batch_norm_op.cc
+tensorflow/core/kernels/gather_functor.cc
+tensorflow/core/kernels/gather_nd_op.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_0.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_1.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_2.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_3.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_4.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_5.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_6.cc
+tensorflow/core/kernels/gather_nd_op_cpu_impl_7.cc
+tensorflow/core/kernels/gather_op.cc
+tensorflow/core/kernels/identity_n_op.cc
+tensorflow/core/kernels/identity_op.cc
+tensorflow/core/kernels/immutable_constant_op.cc
+tensorflow/core/kernels/in_topk_op.cc
+tensorflow/core/kernels/initializable_lookup_table.c
+tensorflow/core/kernels/inplace_ops.cc
+tensorflow/core/kernels/listdiff_op.cc
+tensorflow/core/kernels/logging_ops.cc
+tensorflow/core/kernels/lookup_table_init_op.cc
+tensorflow/core/kernels/lookup_table_op.cc
+tensorflow/core/kernels/lookup_util.cc
+tensorflow/core/kernels/lrn_op.cc
+tensorflow/core/kernels/matmul_op.cc
+tensorflow/core/kernels/maxpooling_op.cc
tensorflow/core/kernels/meta_support.cc
+tensorflow/core/kernels/mfcc.cc
+tensorflow/core/kernels/mfcc_dct.cc
+tensorflow/core/kernels/mfcc_mel_filterbank.cc
+tensorflow/core/kernels/mfcc_op.cc
+tensorflow/core/kernels/mirror_pad_op.cc
+tensorflow/core/kernels/mirror_pad_op_cpu_impl_1.cc
+tensorflow/core/kernels/mirror_pad_op_cpu_impl_2.cc
+tensorflow/core/kernels/mirror_pad_op_cpu_impl_3.cc
+tensorflow/core/kernels/mirror_pad_op_cpu_impl_4.cc
+tensorflow/core/kernels/mirror_pad_op_cpu_impl_5.cc
+tensorflow/core/kernels/no_op.cc
+tensorflow/core/kernels/non_max_suppression_op.cc
+tensorflow/core/kernels/one_hot_op.cc
+tensorflow/core/kernels/ops_util.cc
+tensorflow/core/kernels/pack_op.cc
+tensorflow/core/kernels/pad_op.cc
+tensorflow/core/kernels/padding_fifo_queue.cc
+tensorflow/core/kernels/padding_fifo_queue_op.cc
+tensorflow/core/kernels/pooling_ops_common.cc
tensorflow/core/kernels/population_count_op.cc
tensorflow/core/kernels/quantization_utils.cc
tensorflow/core/kernels/quantize_down_and_shrink_range.cc
@@ -262,46 +174,135 @@ tensorflow/core/kernels/quantized_mul_op.cc
tensorflow/core/kernels/quantized_pooling_ops.cc
tensorflow/core/kernels/quantized_reshape_op.cc
tensorflow/core/kernels/quantized_resize_bilinear_op.cc
-tensorflow/core/kernels/requantization_range_op.cc
-tensorflow/core/kernels/requantize.cc
+tensorflow/core/kernels/queue_base.cc
+tensorflow/core/kernels/queue_op.cc
+tensorflow/core/kernels/queue_ops.cc
+tensorflow/core/kernels/random_op.cc
+tensorflow/core/kernels/reduction_ops_all.cc
+tensorflow/core/kernels/reduction_ops_any.cc
+tensorflow/core/kernels/reduction_ops_common.cc
+tensorflow/core/kernels/reduction_ops_max.cc
+tensorflow/core/kernels/reduction_ops_mean.cc
+tensorflow/core/kernels/reduction_ops_min.cc
+tensorflow/core/kernels/reduction_ops_prod.cc
+tensorflow/core/kernels/reduction_ops_sum.cc
+tensorflow/core/kernels/relu_op.cc
tensorflow/core/kernels/remote_fused_graph_execute_op.cc
tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
-tensorflow/core/kernels/batch_matmul_op_real.cc
-tensorflow/core/kernels/random_op.cc
-tensorflow/core/ops/training_ops.cc
-tensorflow/core/ops/string_ops.cc
-tensorflow/core/ops/state_ops.cc
-tensorflow/core/ops/sparse_ops.cc
-tensorflow/core/ops/sendrecv_ops.cc
-tensorflow/core/ops/script_ops.cc
-tensorflow/core/ops/remote_fused_graph_ops.cc
-tensorflow/core/ops/random_ops.cc
-tensorflow/core/ops/random_grad.cc
-tensorflow/core/ops/parsing_ops.cc
-tensorflow/core/ops/no_op.cc
-tensorflow/core/ops/nn_ops.cc
-tensorflow/core/ops/nn_grad.cc
-tensorflow/core/ops/manip_ops.cc
-tensorflow/core/ops/math_ops.cc
-tensorflow/core/ops/math_grad.cc
-tensorflow/core/ops/logging_ops.cc
-tensorflow/core/ops/linalg_ops.cc
-tensorflow/core/ops/io_ops.cc
-tensorflow/core/ops/image_ops.cc
-tensorflow/core/ops/functional_ops.cc
-tensorflow/core/ops/functional_grad.cc
-tensorflow/core/ops/function_ops.cc
-tensorflow/core/ops/data_flow_ops.cc
-tensorflow/core/ops/ctc_ops.cc
-tensorflow/core/ops/control_flow_ops.cc
-tensorflow/core/ops/candidate_sampling_ops.cc
-tensorflow/core/ops/boosted_trees_ops.cc
-tensorflow/core/ops/array_ops.cc
-tensorflow/core/ops/array_grad.cc
+tensorflow/core/kernels/requantization_range_op.cc
+tensorflow/core/kernels/requantize.cc
+tensorflow/core/kernels/reshape_op.cc
+tensorflow/core/kernels/reshape_util.cc
+tensorflow/core/kernels/resize_bilinear_op.cc
+tensorflow/core/kernels/resize_nearest_neighbor_op.cc
+tensorflow/core/kernels/restore_op.cc
+tensorflow/core/kernels/reverse_op.cc
+tensorflow/core/kernels/reverse_sequence_op.cc
+tensorflow/core/kernels/roll_op.cc
+tensorflow/core/kernels/save_op.cc
+tensorflow/core/kernels/save_restore_tensor.cc
+tensorflow/core/kernels/save_restore_v2_ops.cc
+tensorflow/core/kernels/scatter_functor.cc
+tensorflow/core/kernels/scatter_nd_op.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_0.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_1.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_2.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_3.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_4.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_5.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_6.cc
+tensorflow/core/kernels/scatter_nd_op_cpu_impl_7.cc
+tensorflow/core/kernels/scatter_op.cc
+tensorflow/core/kernels/segment_reduction_ops.cc
+tensorflow/core/kernels/segment_reduction_ops.cc
+tensorflow/core/kernels/sendrecv_ops.cc
+tensorflow/core/kernels/sequence_ops.cc
+tensorflow/core/kernels/session_ops.cc
+tensorflow/core/kernels/shape_ops.cc
+tensorflow/core/kernels/slice_op.cc
+tensorflow/core/kernels/slice_op_cpu_impl_1.cc
+tensorflow/core/kernels/slice_op_cpu_impl_2.cc
+tensorflow/core/kernels/slice_op_cpu_impl_3.cc
+tensorflow/core/kernels/slice_op_cpu_impl_4.cc
+tensorflow/core/kernels/slice_op_cpu_impl_5.cc
+tensorflow/core/kernels/slice_op_cpu_impl_6.cc
+tensorflow/core/kernels/slice_op_cpu_impl_7.cc
+tensorflow/core/kernels/softmax_op.cc
+tensorflow/core/kernels/softplus_op.cc
+tensorflow/core/kernels/softsign_op.cc
tensorflow/core/kernels/spacetobatch_functor.cc
tensorflow/core/kernels/spacetobatch_op.cc
-tensorflow/core/kernels/batchtospace_op.cc
-tensorflow/core/kernels/segment_reduction_ops.cc
+tensorflow/core/kernels/spacetodepth_op.cc
+tensorflow/core/kernels/sparse_fill_empty_rows_op.cc
+tensorflow/core/kernels/sparse_matmul_op.cc
+tensorflow/core/kernels/sparse_reshape_op.c
+tensorflow/core/kernels/sparse_to_dense_op.cc
+tensorflow/core/kernels/spectrogram.cc
+tensorflow/core/kernels/spectrogram_op.cc
+tensorflow/core/kernels/split_lib_cpu.cc
+tensorflow/core/kernels/split_op.cc
+tensorflow/core/kernels/split_v_op.cc
+tensorflow/core/kernels/stack_ops.cc
+tensorflow/core/kernels/strided_slice_op.cc
+tensorflow/core/kernels/strided_slice_op_inst_0.cc
+tensorflow/core/kernels/strided_slice_op_inst_1.cc
+tensorflow/core/kernels/strided_slice_op_inst_2.cc
+tensorflow/core/kernels/strided_slice_op_inst_3.cc
+tensorflow/core/kernels/strided_slice_op_inst_4.cc
+tensorflow/core/kernels/strided_slice_op_inst_5.cc
+tensorflow/core/kernels/strided_slice_op_inst_6.cc
+tensorflow/core/kernels/strided_slice_op_inst_7.cc
+tensorflow/core/kernels/string_join_op.cc
+tensorflow/core/kernels/tensor_array.cc
+tensorflow/core/kernels/tensor_array_ops.cc
+tensorflow/core/kernels/tile_functor_cpu.cc
+tensorflow/core/kernels/tile_ops.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_1.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_2.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_3.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_4.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_5.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_6.cc
+tensorflow/core/kernels/tile_ops_cpu_impl_7.cc
+tensorflow/core/kernels/topk_op.cc
+tensorflow/core/kernels/training_op_helpers.cc
+tensorflow/core/kernels/training_ops.cc
+tensorflow/core/kernels/transpose_functor_cpu.cc
+tensorflow/core/kernels/transpose_op.cc
+tensorflow/core/kernels/unique_op.cc
+tensorflow/core/kernels/unpack_op.cc
+tensorflow/core/kernels/variable_ops.cc
+tensorflow/core/kernels/where_op.cc
+tensorflow/core/kernels/xent_op.cc
+tensorflow/core/kernels/xsmm_conv2d.cc
+tensorflow/core/ops/array_grad.cc
+tensorflow/core/ops/array_ops.cc
tensorflow/core/ops/audio_ops.cc
-tensorflow/core/kernels/decode_proto_op.cc
-tensorflow/core/kernels/encode_proto_op.cc
+tensorflow/core/ops/boosted_trees_ops.cc
+tensorflow/core/ops/candidate_sampling_ops.cc
+tensorflow/core/ops/control_flow_ops.cc
+tensorflow/core/ops/ctc_ops.cc
+tensorflow/core/ops/data_flow_ops.cc
+tensorflow/core/ops/function_ops.cc
+tensorflow/core/ops/functional_grad.cc
+tensorflow/core/ops/functional_ops.cc
+tensorflow/core/ops/image_ops.cc
+tensorflow/core/ops/io_ops.cc
+tensorflow/core/ops/linalg_ops.cc
+tensorflow/core/ops/logging_ops.cc
+tensorflow/core/ops/manip_ops.cc
+tensorflow/core/ops/math_grad.cc
+tensorflow/core/ops/math_ops.cc
+tensorflow/core/ops/nn_grad.cc
+tensorflow/core/ops/nn_ops.cc
+tensorflow/core/ops/no_op.cc
+tensorflow/core/ops/parsing_ops.cc
+tensorflow/core/ops/random_grad.cc
+tensorflow/core/ops/random_ops.cc
+tensorflow/core/ops/remote_fused_graph_ops.cc
+tensorflow/core/ops/script_ops.cc
+tensorflow/core/ops/sendrecv_ops.cc
+tensorflow/core/ops/sparse_ops.cc
+tensorflow/core/ops/state_ops.cc
+tensorflow/core/ops/string_ops.cc
+tensorflow/core/ops/training_ops.cc
diff --git a/tensorflow/contrib/makefile/tf_pb_text_files.txt b/tensorflow/contrib/makefile/tf_pb_text_files.txt
index b5431df2eb..f94d70db90 100644
--- a/tensorflow/contrib/makefile/tf_pb_text_files.txt
+++ b/tensorflow/contrib/makefile/tf_pb_text_files.txt
@@ -1,33 +1,33 @@
-tensorflow/core/util/saved_tensor_slice.pb_text.cc
-tensorflow/core/util/memmapped_file_system.pb_text.cc
-tensorflow/core/protobuf/saver.pb_text.cc
+tensorflow/core/example/example.pb_text.cc
+tensorflow/core/example/feature.pb_text.cc
+tensorflow/core/framework/allocation_description.pb_text.cc
+tensorflow/core/framework/api_def.pb_text.cc
+tensorflow/core/framework/attr_value.pb_text.cc
+tensorflow/core/framework/cost_graph.pb_text.cc
+tensorflow/core/framework/device_attributes.pb_text.cc
+tensorflow/core/framework/function.pb_text.cc
+tensorflow/core/framework/graph.pb_text.cc
+tensorflow/core/framework/graph_transfer_info.pb_text.cc
+tensorflow/core/framework/kernel_def.pb_text.cc
+tensorflow/core/framework/log_memory.pb_text.cc
+tensorflow/core/framework/node_def.pb_text.cc
+tensorflow/core/framework/op_def.pb_text.cc
+tensorflow/core/framework/remote_fused_graph_execute_info.pb_text.cc
+tensorflow/core/framework/resource_handle.pb_text.cc
+tensorflow/core/framework/step_stats.pb_text.cc
+tensorflow/core/framework/summary.pb_text.cc
+tensorflow/core/framework/tensor.pb_text.cc
+tensorflow/core/framework/tensor_description.pb_text.cc
+tensorflow/core/framework/tensor_shape.pb_text.cc
+tensorflow/core/framework/tensor_slice.pb_text.cc
+tensorflow/core/framework/types.pb_text.cc
+tensorflow/core/framework/versions.pb_text.cc
+tensorflow/core/lib/core/error_codes.pb_text.cc
tensorflow/core/protobuf/cluster.pb_text.cc
tensorflow/core/protobuf/config.pb_text.cc
tensorflow/core/protobuf/debug.pb_text.cc
tensorflow/core/protobuf/rewriter_config.pb_text.cc
+tensorflow/core/protobuf/saver.pb_text.cc
tensorflow/core/protobuf/tensor_bundle.pb_text.cc
-tensorflow/core/lib/core/error_codes.pb_text.cc
-tensorflow/core/framework/versions.pb_text.cc
-tensorflow/core/framework/types.pb_text.cc
-tensorflow/core/framework/tensor_slice.pb_text.cc
-tensorflow/core/framework/tensor_shape.pb_text.cc
-tensorflow/core/framework/tensor_description.pb_text.cc
-tensorflow/core/framework/tensor.pb_text.cc
-tensorflow/core/framework/summary.pb_text.cc
-tensorflow/core/framework/step_stats.pb_text.cc
-tensorflow/core/framework/resource_handle.pb_text.cc
-tensorflow/core/framework/remote_fused_graph_execute_info.pb_text.cc
-tensorflow/core/framework/api_def.pb_text.cc
-tensorflow/core/framework/op_def.pb_text.cc
-tensorflow/core/framework/node_def.pb_text.cc
-tensorflow/core/framework/log_memory.pb_text.cc
-tensorflow/core/framework/kernel_def.pb_text.cc
-tensorflow/core/framework/graph_transfer_info.pb_text.cc
-tensorflow/core/framework/graph.pb_text.cc
-tensorflow/core/framework/function.pb_text.cc
-tensorflow/core/framework/device_attributes.pb_text.cc
-tensorflow/core/framework/cost_graph.pb_text.cc
-tensorflow/core/framework/attr_value.pb_text.cc
-tensorflow/core/framework/allocation_description.pb_text.cc
-tensorflow/core/example/feature.pb_text.cc
-tensorflow/core/example/example.pb_text.cc
+tensorflow/core/util/memmapped_file_system.pb_text.cc
+tensorflow/core/util/saved_tensor_slice.pb_text.cc
diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt
index 1f254692d7..8bec3e3e01 100644
--- a/tensorflow/contrib/makefile/tf_proto_files.txt
+++ b/tensorflow/contrib/makefile/tf_proto_files.txt
@@ -2,47 +2,47 @@ tensorflow/contrib/boosted_trees/proto/learner.proto
tensorflow/contrib/boosted_trees/proto/quantiles.proto
tensorflow/contrib/boosted_trees/proto/split_info.proto
tensorflow/contrib/boosted_trees/proto/tree_config.proto
-tensorflow/core/util/test_log.proto
-tensorflow/core/util/saved_tensor_slice.proto
-tensorflow/core/util/memmapped_file_system.proto
-tensorflow/core/util/event.proto
-tensorflow/core/protobuf/tensorflow_server.proto
-tensorflow/core/protobuf/saver.proto
-tensorflow/core/protobuf/queue_runner.proto
-tensorflow/core/protobuf/named_tensor.proto
-tensorflow/core/protobuf/meta_graph.proto
+tensorflow/core/example/example.proto
+tensorflow/core/example/feature.proto
+tensorflow/core/framework/allocation_description.proto
+tensorflow/core/framework/api_def.proto
+tensorflow/core/framework/attr_value.proto
+tensorflow/core/framework/cost_graph.proto
+tensorflow/core/framework/device_attributes.proto
+tensorflow/core/framework/function.proto
+tensorflow/core/framework/graph.proto
+tensorflow/core/framework/graph_transfer_info.proto
+tensorflow/core/framework/kernel_def.proto
+tensorflow/core/framework/log_memory.proto
+tensorflow/core/framework/node_def.proto
+tensorflow/core/framework/op_def.proto
+tensorflow/core/framework/reader_base.proto
+tensorflow/core/framework/remote_fused_graph_execute_info.proto
+tensorflow/core/framework/resource_handle.proto
+tensorflow/core/framework/step_stats.proto
+tensorflow/core/framework/summary.proto
+tensorflow/core/framework/tensor.proto
+tensorflow/core/framework/tensor_description.proto
+tensorflow/core/framework/tensor_shape.proto
+tensorflow/core/framework/tensor_slice.proto
+tensorflow/core/framework/types.proto
+tensorflow/core/framework/variable.proto
+tensorflow/core/framework/versions.proto
+tensorflow/core/grappler/costs/op_performance_data.proto
+tensorflow/core/kernels/boosted_trees/boosted_trees.proto
+tensorflow/core/lib/core/error_codes.proto
tensorflow/core/protobuf/cluster.proto
tensorflow/core/protobuf/config.proto
tensorflow/core/protobuf/debug.proto
tensorflow/core/protobuf/device_properties.proto
+tensorflow/core/protobuf/meta_graph.proto
+tensorflow/core/protobuf/named_tensor.proto
+tensorflow/core/protobuf/queue_runner.proto
tensorflow/core/protobuf/rewriter_config.proto
+tensorflow/core/protobuf/saver.proto
tensorflow/core/protobuf/tensor_bundle.proto
-tensorflow/core/lib/core/error_codes.proto
-tensorflow/core/kernels/boosted_trees/boosted_trees.proto
-tensorflow/core/framework/versions.proto
-tensorflow/core/framework/variable.proto
-tensorflow/core/framework/types.proto
-tensorflow/core/framework/tensor_slice.proto
-tensorflow/core/framework/tensor_shape.proto
-tensorflow/core/framework/tensor_description.proto
-tensorflow/core/framework/tensor.proto
-tensorflow/core/framework/summary.proto
-tensorflow/core/framework/step_stats.proto
-tensorflow/core/framework/resource_handle.proto
-tensorflow/core/framework/remote_fused_graph_execute_info.proto
-tensorflow/core/framework/reader_base.proto
-tensorflow/core/framework/api_def.proto
-tensorflow/core/framework/op_def.proto
-tensorflow/core/framework/node_def.proto
-tensorflow/core/framework/log_memory.proto
-tensorflow/core/framework/kernel_def.proto
-tensorflow/core/framework/graph_transfer_info.proto
-tensorflow/core/framework/graph.proto
-tensorflow/core/framework/function.proto
-tensorflow/core/framework/device_attributes.proto
-tensorflow/core/framework/cost_graph.proto
-tensorflow/core/framework/attr_value.proto
-tensorflow/core/framework/allocation_description.proto
-tensorflow/core/example/feature.proto
-tensorflow/core/example/example.proto
-tensorflow/core/grappler/costs/op_performance_data.proto
+tensorflow/core/protobuf/tensorflow_server.proto
+tensorflow/core/util/event.proto
+tensorflow/core/util/memmapped_file_system.proto
+tensorflow/core/util/saved_tensor_slice.proto
+tensorflow/core/util/test_log.proto
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index bfef0816aa..1ddd7e521b 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -2514,7 +2514,8 @@ def sparse_recall_at_top_k(labels,
name=name_scope)
-def _compute_recall_at_precision(tp, fp, fn, precision, name):
+def _compute_recall_at_precision(tp, fp, fn, precision, name,
+ strict_mode=False):
"""Helper function to compute recall at a given `precision`.
Args:
@@ -2523,17 +2524,42 @@ def _compute_recall_at_precision(tp, fp, fn, precision, name):
fn: The number of false negatives.
precision: The precision for which the recall will be calculated.
name: An optional variable_scope name.
+ strict_mode: If true and there exists a threshold where the precision is
+ no smaller than the target precision, return the corresponding recall at
+ the threshold. Otherwise, return 0. If false, find the threshold where the
+ precision is closest to the target precision and return the recall at the
+ threshold.
Returns:
The recall at a given `precision`.
"""
precisions = math_ops.div(tp, tp + fp + _EPSILON)
- tf_index = math_ops.argmin(
- math_ops.abs(precisions - precision), 0, output_type=dtypes.int32)
+ if not strict_mode:
+ tf_index = math_ops.argmin(
+ math_ops.abs(precisions - precision), 0, output_type=dtypes.int32)
+ # Now, we have the implicit threshold, so compute the recall:
+ return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON,
+ name)
+ else:
+ # We aim to find the threshold where the precision is minimum but no smaller
+ # than the target precision.
+ # The rationale:
+ # 1. Compute the difference between precisions (by different thresholds) and
+ # the target precision.
+ # 2. Take the reciprocal of the values by the above step. The intention is
+ # to make the positive values rank before negative values and also the
+ # smaller positives rank before larger positives.
+ tf_index = math_ops.argmax(
+ math_ops.div(1.0, precisions - precision + _EPSILON),
+ 0,
+ output_type=dtypes.int32)
+
+ def _return_good_recall():
+ return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON,
+ name)
- # Now, we have the implicit threshold, so compute the recall:
- return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON,
- name)
+ return control_flow_ops.cond(precisions[tf_index] >= precision,
+ _return_good_recall, lambda: .0)
def recall_at_precision(labels,
@@ -2543,7 +2569,8 @@ def recall_at_precision(labels,
num_thresholds=200,
metrics_collections=None,
updates_collections=None,
- name=None):
+ name=None,
+ strict_mode=False):
"""Computes `recall` at `precision`.
The `recall_at_precision` function creates four local variables,
@@ -2575,6 +2602,11 @@ def recall_at_precision(labels,
updates_collections: An optional list of collections that `update_op` should
be added to.
name: An optional variable_scope name.
+ strict_mode: If true and there exists a threshold where the precision is
+ above the target precision, return the corresponding recall at the
+ threshold. Otherwise, return 0. If false, find the threshold where the
+ precision is closest to the target precision and return the recall at the
+ threshold.
Returns:
recall: A scalar `Tensor` representing the recall at the given
@@ -2603,10 +2635,11 @@ def recall_at_precision(labels,
predictions, labels, thresholds, weights)
recall = _compute_recall_at_precision(values['tp'], values['fp'],
- values['fn'], precision, 'value')
+ values['fn'], precision, 'value',
+ strict_mode)
update_op = _compute_recall_at_precision(update_ops['tp'], update_ops['fp'],
update_ops['fn'], precision,
- 'update_op')
+ 'update_op', strict_mode)
if metrics_collections:
ops.add_to_collections(metrics_collections, recall)
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py
index 7acfc383eb..5777e64c29 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_large_test.py
@@ -47,7 +47,7 @@ class StreamingPrecisionRecallAtEqualThresholdsLargeTest(test.TestCase):
# code used float32 for accumulation.
num_updates = 71
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
for _ in xrange(num_updates):
sess.run(update_op)
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
index 1c2c17960a..955b83b44d 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py
@@ -178,7 +178,7 @@ class StreamingMeanTest(test.TestCase):
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -195,7 +195,7 @@ class StreamingMeanTest(test.TestCase):
self.assertAlmostEqual(1.65, sess.run(mean), 5)
def testUpdateOpsReturnsCurrentValue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -216,7 +216,7 @@ class StreamingMeanTest(test.TestCase):
self.assertAlmostEqual(1.65, sess.run(mean), 5)
def test1dWeightedValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -243,7 +243,7 @@ class StreamingMeanTest(test.TestCase):
self.assertAlmostEqual((0 + 1 - 3.2 + 4.0) / 4.0, mean.eval(), 5)
def test1dWeightedValues_placeholders(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0))
values = array_ops.placeholder(dtype=dtypes_lib.float32)
@@ -265,7 +265,7 @@ class StreamingMeanTest(test.TestCase):
self.assertAlmostEqual((0 + 1 - 3.2 + 4.0) / 4.0, mean.eval(), 5)
def test2dWeightedValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -292,7 +292,7 @@ class StreamingMeanTest(test.TestCase):
self.assertAlmostEqual((0 + 1 - 4.2 + 0) / 4.0, mean.eval(), 5)
def test2dWeightedValues_placeholders(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0))
values = array_ops.placeholder(dtype=dtypes_lib.float32)
@@ -337,7 +337,7 @@ class StreamingMeanTensorTest(test.TestCase):
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -354,7 +354,7 @@ class StreamingMeanTensorTest(test.TestCase):
self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(mean))
def testMultiDimensional(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
2, dtypes=dtypes_lib.float32, shapes=(2, 2, 2))
_enqueue_vector(
@@ -375,7 +375,7 @@ class StreamingMeanTensorTest(test.TestCase):
self.assertAllClose([[[1, 2], [1, 2]], [[2, 3], [5, 6]]], sess.run(mean))
def testUpdateOpsReturnsCurrentValue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -396,7 +396,7 @@ class StreamingMeanTensorTest(test.TestCase):
self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(mean), 5)
def testWeighted1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -423,7 +423,7 @@ class StreamingMeanTensorTest(test.TestCase):
self.assertAllClose([[3.25, 0.5]], sess.run(mean), 5)
def testWeighted2d_1(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -450,7 +450,7 @@ class StreamingMeanTensorTest(test.TestCase):
self.assertAllClose([[-2.1, 0.5]], sess.run(mean), 5)
def testWeighted2d_2(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -526,7 +526,7 @@ class StreamingAccuracyTest(test.TestCase):
(10, 3), maxval=3, dtype=dtypes_lib.int64, seed=2)
accuracy, update_op = metrics.streaming_accuracy(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -539,7 +539,7 @@ class StreamingAccuracyTest(test.TestCase):
self.assertEqual(initial_accuracy, accuracy.eval())
def testMultipleUpdates(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 1))
@@ -569,7 +569,7 @@ class StreamingAccuracyTest(test.TestCase):
def testEffectivelyEquivalentSizes(self):
predictions = array_ops.ones((40, 1))
labels = array_ops.ones((40,))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accuracy, update_op = metrics.streaming_accuracy(predictions, labels)
sess.run(variables.local_variables_initializer())
@@ -583,7 +583,7 @@ class StreamingAccuracyTest(test.TestCase):
weights = array_ops.expand_dims(ops.convert_to_tensor([100, 1, 1]),
1) # shape 3, 1
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accuracy, update_op = metrics.streaming_accuracy(predictions, labels,
weights)
@@ -604,7 +604,7 @@ class StreamingAccuracyTest(test.TestCase):
dtype=dtypes_lib.int32, name='weights')
feed_dict = {weights_placeholder: weights}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
accuracy, update_op = metrics.streaming_accuracy(predictions, labels,
weights_placeholder)
@@ -616,7 +616,7 @@ class StreamingAccuracyTest(test.TestCase):
self.assertGreater(accuracy.eval(feed_dict=feed_dict), .95)
def testMultipleUpdatesWithWeightedValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 1))
@@ -681,7 +681,7 @@ class StreamingTruePositivesTest(test.TestCase):
tp, tp_update_op = metrics.streaming_true_positives(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, tp.eval())
self.assertEqual(1, tp_update_op.eval())
@@ -698,7 +698,7 @@ class StreamingTruePositivesTest(test.TestCase):
tp, tp_update_op = metrics.streaming_true_positives(
predictions, labels, weights=37.0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, tp.eval())
self.assertEqual(37.0, tp_update_op.eval())
@@ -732,7 +732,7 @@ class StreamingFalseNegativesTest(test.TestCase):
fn, fn_update_op = metrics.streaming_false_negatives(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, fn.eval())
self.assertEqual(2, fn_update_op.eval())
@@ -749,7 +749,7 @@ class StreamingFalseNegativesTest(test.TestCase):
fn, fn_update_op = metrics.streaming_false_negatives(
predictions, labels, weights=((3.0,), (5.0,), (7.0,)))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, fn.eval())
self.assertEqual(8.0, fn_update_op.eval())
@@ -783,7 +783,7 @@ class StreamingFalsePositivesTest(test.TestCase):
fp, fp_update_op = metrics.streaming_false_positives(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, fp.eval())
self.assertEqual(4, fp_update_op.eval())
@@ -803,7 +803,7 @@ class StreamingFalsePositivesTest(test.TestCase):
weights=((1.0, 2.0, 3.0, 5.0), (7.0, 11.0, 13.0, 17.0), (19.0, 23.0,
29.0, 31.0)))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, fp.eval())
self.assertEqual(42.0, fp_update_op.eval())
@@ -837,7 +837,7 @@ class StreamingTrueNegativesTest(test.TestCase):
tn, tn_update_op = metrics.streaming_true_negatives(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, tn.eval())
self.assertEqual(5, tn_update_op.eval())
@@ -854,7 +854,7 @@ class StreamingTrueNegativesTest(test.TestCase):
tn, tn_update_op = metrics.streaming_true_negatives(
predictions, labels, weights=((0.0, 2.0, 3.0, 5.0),))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, tn.eval())
self.assertEqual(15.0, tn_update_op.eval())
@@ -879,7 +879,7 @@ class StreamingTruePositivesAtThresholdsTest(test.TestCase):
tp, tp_update_op = metrics.streaming_true_positives_at_thresholds(
predictions, labels, thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0, 0, 0), tp.eval())
self.assertAllEqual((3, 1, 0), tp_update_op.eval())
@@ -892,7 +892,7 @@ class StreamingTruePositivesAtThresholdsTest(test.TestCase):
tp, tp_update_op = metrics.streaming_true_positives_at_thresholds(
predictions, labels, weights=37.0, thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0.0, 0.0, 0.0), tp.eval())
self.assertAllEqual((111.0, 37.0, 0.0), tp_update_op.eval())
@@ -921,7 +921,7 @@ class StreamingFalseNegativesAtThresholdsTest(test.TestCase):
fn, fn_update_op = metrics.streaming_false_negatives_at_thresholds(
predictions, labels, thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0, 0, 0), fn.eval())
self.assertAllEqual((0, 2, 3), fn_update_op.eval())
@@ -937,7 +937,7 @@ class StreamingFalseNegativesAtThresholdsTest(test.TestCase):
weights=((3.0,), (5.0,), (7.0,)),
thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0.0, 0.0, 0.0), fn.eval())
self.assertAllEqual((0.0, 8.0, 11.0), fn_update_op.eval())
@@ -962,7 +962,7 @@ class StreamingFalsePositivesAtThresholdsTest(test.TestCase):
fp, fp_update_op = metrics.streaming_false_positives_at_thresholds(
predictions, labels, thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0, 0, 0), fp.eval())
self.assertAllEqual((7, 4, 2), fp_update_op.eval())
@@ -979,7 +979,7 @@ class StreamingFalsePositivesAtThresholdsTest(test.TestCase):
29.0, 31.0)),
thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0.0, 0.0, 0.0), fp.eval())
self.assertAllEqual((125.0, 42.0, 12.0), fp_update_op.eval())
@@ -1004,7 +1004,7 @@ class StreamingTrueNegativesAtThresholdsTest(test.TestCase):
tn, tn_update_op = metrics.streaming_true_negatives_at_thresholds(
predictions, labels, thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0, 0, 0), tn.eval())
self.assertAllEqual((2, 5, 7), tn_update_op.eval())
@@ -1020,7 +1020,7 @@ class StreamingTrueNegativesAtThresholdsTest(test.TestCase):
weights=((0.0, 2.0, 3.0, 5.0),),
thresholds=(0.15, 0.5, 0.85))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAllEqual((0.0, 0.0, 0.0), tn.eval())
self.assertAllEqual((5.0, 15.0, 23.0), tn_update_op.eval())
@@ -1062,7 +1062,7 @@ class StreamingPrecisionTest(test.TestCase):
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
precision, update_op = metrics.streaming_precision(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -1081,7 +1081,7 @@ class StreamingPrecisionTest(test.TestCase):
labels = constant_op.constant(inputs)
precision, update_op = metrics.streaming_precision(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(1, sess.run(update_op))
self.assertAlmostEqual(1, precision.eval())
@@ -1091,7 +1091,7 @@ class StreamingPrecisionTest(test.TestCase):
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
precision, update_op = metrics.streaming_precision(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.5, update_op.eval())
self.assertAlmostEqual(0.5, precision.eval())
@@ -1102,7 +1102,7 @@ class StreamingPrecisionTest(test.TestCase):
precision, update_op = metrics.streaming_precision(
predictions, labels, weights=constant_op.constant([[2], [5]]))
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
weighted_tp = 2.0 + 5.0
weighted_positives = (2.0 + 2.0) + (5.0 + 5.0)
@@ -1120,7 +1120,7 @@ class StreamingPrecisionTest(test.TestCase):
precision, update_op = metrics.streaming_precision(
predictions, labels, weights=constant_op.constant([[2], [5]]))
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
weighted_tp = 2.0 + 5.0
weighted_positives = (2.0 + 2.0) + (5.0 + 5.0)
@@ -1138,7 +1138,7 @@ class StreamingPrecisionTest(test.TestCase):
labels,
weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]))
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
weighted_tp = 3.0 + 4.0
weighted_positives = (1.0 + 3.0) + (4.0 + 2.0)
@@ -1158,7 +1158,7 @@ class StreamingPrecisionTest(test.TestCase):
labels,
weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]))
- with self.test_session():
+ with self.cached_session():
variables.local_variables_initializer().run()
weighted_tp = 3.0 + 4.0
weighted_positives = (1.0 + 3.0) + (4.0 + 2.0)
@@ -1175,7 +1175,7 @@ class StreamingPrecisionTest(test.TestCase):
labels = constant_op.constant(1 - inputs)
precision, update_op = metrics.streaming_precision(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertAlmostEqual(0, precision.eval())
@@ -1185,7 +1185,7 @@ class StreamingPrecisionTest(test.TestCase):
labels = constant_op.constant([0, 0, 0, 0])
precision, update_op = metrics.streaming_precision(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0.0, precision.eval())
@@ -1227,7 +1227,7 @@ class StreamingRecallTest(test.TestCase):
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
recall, update_op = metrics.streaming_recall(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -1246,7 +1246,7 @@ class StreamingRecallTest(test.TestCase):
labels = constant_op.constant(np_inputs)
recall, update_op = metrics.streaming_recall(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(1, recall.eval())
@@ -1256,7 +1256,7 @@ class StreamingRecallTest(test.TestCase):
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
recall, update_op = metrics.streaming_recall(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.5, update_op.eval())
self.assertAlmostEqual(0.5, recall.eval())
@@ -1268,7 +1268,7 @@ class StreamingRecallTest(test.TestCase):
recall, update_op = metrics.streaming_recall(
predictions, labels, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_tp = 2.0 + 5.0
weighted_t = (2.0 + 2.0) + (5.0 + 5.0)
@@ -1283,7 +1283,7 @@ class StreamingRecallTest(test.TestCase):
recall, update_op = metrics.streaming_recall(
predictions, labels, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_tp = 3.0 + 1.0
weighted_t = (2.0 + 3.0) + (4.0 + 1.0)
@@ -1298,7 +1298,7 @@ class StreamingRecallTest(test.TestCase):
labels = constant_op.constant(1 - np_inputs)
recall, update_op = metrics.streaming_recall(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, recall.eval())
@@ -1308,7 +1308,7 @@ class StreamingRecallTest(test.TestCase):
labels = array_ops.zeros((1, 4))
recall, update_op = metrics.streaming_recall(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, recall.eval())
@@ -1350,7 +1350,7 @@ class StreamingFPRTest(test.TestCase):
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -1369,7 +1369,7 @@ class StreamingFPRTest(test.TestCase):
labels = constant_op.constant(np_inputs)
fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, fpr.eval())
@@ -1379,7 +1379,7 @@ class StreamingFPRTest(test.TestCase):
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.5, update_op.eval())
self.assertAlmostEqual(0.5, fpr.eval())
@@ -1391,7 +1391,7 @@ class StreamingFPRTest(test.TestCase):
fpr, update_op = metrics.streaming_false_positive_rate(
predictions, labels, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_fp = 2.0 + 5.0
weighted_f = (2.0 + 2.0) + (5.0 + 5.0)
@@ -1406,7 +1406,7 @@ class StreamingFPRTest(test.TestCase):
fpr, update_op = metrics.streaming_false_positive_rate(
predictions, labels, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_fp = 1.0 + 3.0
weighted_f = (1.0 + 4.0) + (2.0 + 3.0)
@@ -1421,7 +1421,7 @@ class StreamingFPRTest(test.TestCase):
labels = constant_op.constant(1 - np_inputs)
fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(1, fpr.eval())
@@ -1431,7 +1431,7 @@ class StreamingFPRTest(test.TestCase):
labels = array_ops.ones((1, 4))
fpr, update_op = metrics.streaming_false_positive_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, fpr.eval())
@@ -1473,7 +1473,7 @@ class StreamingFNRTest(test.TestCase):
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -1492,7 +1492,7 @@ class StreamingFNRTest(test.TestCase):
labels = constant_op.constant(np_inputs)
fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, fnr.eval())
@@ -1502,7 +1502,7 @@ class StreamingFNRTest(test.TestCase):
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.5, update_op.eval())
self.assertAlmostEqual(0.5, fnr.eval())
@@ -1514,7 +1514,7 @@ class StreamingFNRTest(test.TestCase):
fnr, update_op = metrics.streaming_false_negative_rate(
predictions, labels, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_fn = 2.0 + 5.0
weighted_t = (2.0 + 2.0) + (5.0 + 5.0)
@@ -1529,7 +1529,7 @@ class StreamingFNRTest(test.TestCase):
fnr, update_op = metrics.streaming_false_negative_rate(
predictions, labels, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
weighted_fn = 2.0 + 4.0
weighted_t = (2.0 + 3.0) + (1.0 + 4.0)
@@ -1544,7 +1544,7 @@ class StreamingFNRTest(test.TestCase):
labels = constant_op.constant(1 - np_inputs)
fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(1, fnr.eval())
@@ -1554,7 +1554,7 @@ class StreamingFNRTest(test.TestCase):
labels = array_ops.zeros((1, 4))
fnr, update_op = metrics.streaming_false_negative_rate(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
self.assertEqual(0, fnr.eval())
@@ -1599,7 +1599,7 @@ class StreamingCurvePointsTest(test.TestCase):
points, update_op = metric_ops.streaming_curve_points(
labels, predictions=predictions, curve=curve)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
sess.run(update_op)
@@ -1615,7 +1615,7 @@ class StreamingCurvePointsTest(test.TestCase):
self._testValueTensorIsIdempotent(curve='PR')
def _testCase(self, labels, predictions, curve, expected_points):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions_tensor = constant_op.constant(
predictions, dtype=dtypes_lib.float32)
labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.float32)
@@ -1717,7 +1717,7 @@ class StreamingAUCTest(test.TestCase):
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
auc, update_op = metrics.streaming_auc(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -1730,7 +1730,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(initial_auc, auc.eval(), 5)
def testPredictionsOutOfRange(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, -1, 1, -1], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -1744,7 +1744,7 @@ class StreamingAUCTest(test.TestCase):
def allCorrectAsExpected(self, curve):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(inputs)
auc, update_op = metrics.streaming_auc(predictions, labels, curve=curve)
@@ -1755,7 +1755,7 @@ class StreamingAUCTest(test.TestCase):
self.assertEqual(1, auc.eval())
def testSomeCorrect(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -1767,7 +1767,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(0.5, auc.eval())
def testWeighted1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -1781,7 +1781,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(0.5, auc.eval(), 5)
def testWeighted2d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -1795,7 +1795,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(0.7, auc.eval(), 5)
def testAUCPRSpecialCase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4))
@@ -1807,7 +1807,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-3)
def testAnotherAUCPRSpecialCase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81],
shape=(1, 7),
@@ -1821,7 +1821,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-3)
def testThirdAUCPRSpecialCase(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
shape=(1, 7),
@@ -1837,7 +1837,7 @@ class StreamingAUCTest(test.TestCase):
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
auc, update_op = metrics.streaming_auc(predictions, labels)
@@ -1848,7 +1848,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(0, auc.eval())
def testZeroTruePositivesAndFalseNegativesGivesOneAUC(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
labels = array_ops.zeros([4])
auc, update_op = metrics.streaming_auc(predictions, labels)
@@ -1859,7 +1859,7 @@ class StreamingAUCTest(test.TestCase):
self.assertAlmostEqual(1, auc.eval(), 6)
def testRecallOneAndPrecisionOneGivesOnePRAUC(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.ones([4], dtype=dtypes_lib.float32)
labels = array_ops.ones([4])
auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR')
@@ -1893,7 +1893,7 @@ class StreamingAUCTest(test.TestCase):
np.random.exponential(scale=1.0, size=num_samples)):
expected_auc = _np_auc(predictions, labels, weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
enqueue_ops = [[] for i in range(num_batches)]
tf_predictions = _enqueue_as_batches(predictions, enqueue_ops)
tf_labels = _enqueue_as_batches(labels, enqueue_ops)
@@ -1966,7 +1966,7 @@ class StreamingDynamicAUCTest(test.TestCase):
labels = random_ops.random_uniform(
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
for _ in xrange(10):
@@ -1977,7 +1977,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertAlmostEqual(initial_auc, auc.eval(), 5)
def testAllLabelsOnes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1., 1., 1.])
labels = constant_op.constant([1, 1, 1])
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -1986,7 +1986,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertEqual(0, auc.eval())
def testAllLabelsZeros(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1., 1., 1.])
labels = constant_op.constant([0, 0, 0])
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -1995,7 +1995,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertEqual(0, auc.eval())
def testNonZeroOnePredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2.5, -2.5, 2.5, -2.5], dtype=dtypes_lib.float32)
labels = constant_op.constant([1, 0, 1, 0])
@@ -2006,7 +2006,7 @@ class StreamingDynamicAUCTest(test.TestCase):
def testAllCorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs)
labels = constant_op.constant(inputs)
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -2015,7 +2015,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertEqual(1, auc.eval())
def testSomeCorrect(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1, 0, 1, 0])
labels = constant_op.constant([0, 1, 1, 0])
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -2025,7 +2025,7 @@ class StreamingDynamicAUCTest(test.TestCase):
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
auc, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -2034,7 +2034,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertAlmostEqual(0, auc.eval())
def testExceptionOnIncompatibleShapes(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.ones([5])
labels = array_ops.zeros([6])
with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'):
@@ -2043,7 +2043,7 @@ class StreamingDynamicAUCTest(test.TestCase):
sess.run(update_op)
def testExceptionOnGreaterThanOneLabel(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32)
labels = constant_op.constant([2, 1, 0])
_, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -2054,7 +2054,7 @@ class StreamingDynamicAUCTest(test.TestCase):
sess.run(update_op)
def testExceptionOnNegativeLabel(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1, 0.5, 0], dtypes_lib.float32)
labels = constant_op.constant([1, 0, -1])
_, update_op = metrics.streaming_dynamic_auc(labels, predictions)
@@ -2078,7 +2078,7 @@ class StreamingDynamicAUCTest(test.TestCase):
collections=[ops.GraphKeys.LOCAL_VARIABLES],
dtype=dtypes_lib.float32)
auc, update_op = metrics.streaming_dynamic_auc(tf_labels, tf_predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
for _ in xrange(num_batches):
new_labels = np.random.randint(0, 2, size=batch_size)
@@ -2093,7 +2093,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertAlmostEqual(expected_auc, auc.eval())
def testAUCPRReverseIncreasingPredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8], dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 0, 1, 1])
@@ -2104,7 +2104,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-5)
def testAUCPRJumbledPredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81], dtypes_lib.float32)
labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1])
@@ -2115,7 +2115,7 @@ class StreamingDynamicAUCTest(test.TestCase):
self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-6)
def testAUCPRPredictionsLessThanHalf(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
shape=(1, 7),
@@ -2148,7 +2148,7 @@ class StreamingDynamicAUCTest(test.TestCase):
auc, update_op = metrics.streaming_dynamic_auc(tf_labels,
tf_predictions,
weights=tf_weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
for _ in xrange(num_batches):
new_labels = np.random.randint(0, 2, size=batch_size)
@@ -2196,7 +2196,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase):
expected_result: The expected result (dict) that maps to tensors.
weights: Optional weights tensor.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions_tensor = constant_op.constant(
predictions, dtype=dtypes_lib.float32)
labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.int64)
@@ -2320,7 +2320,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase):
dtype=dtypes_lib.float32)
auc, update_op = metrics.auc_with_confidence_intervals(tf_labels,
tf_predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
for _ in xrange(num_batches):
new_labels = np.random.randint(0, 2, size=batch_size)
@@ -2335,7 +2335,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase):
self.assertAllClose(expected_auc, auc.auc.eval())
def testExceptionOnFloatLabels(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32)
labels = constant_op.constant([0.7, 0, 1, 0, 1])
_, update_op = metrics.auc_with_confidence_intervals(labels, predictions)
@@ -2343,7 +2343,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase):
self.assertRaises(TypeError, sess.run(update_op))
def testExceptionOnGreaterThanOneLabel(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32)
labels = constant_op.constant([2, 1, 0, 1, 0])
_, update_op = metrics.auc_with_confidence_intervals(labels, predictions)
@@ -2354,7 +2354,7 @@ class AucWithConfidenceIntervalsTest(test.TestCase):
sess.run(update_op)
def testExceptionOnNegativeLabel(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32)
labels = constant_op.constant([1, 0, -1, 1, 0])
_, update_op = metrics.auc_with_confidence_intervals(labels, predictions)
@@ -2415,7 +2415,7 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
result, update_op = metric_ops.precision_recall_at_equal_thresholds(
labels=labels, predictions=predictions)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Run several updates.
sess.run(variables.local_variables_initializer())
for _ in range(3):
@@ -2448,7 +2448,7 @@ class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase):
default from assertAllClose.
weights: Optional weights tensor.
"""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions_tensor = constant_op.constant(predictions, dtype=dtype)
labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.bool)
weights_tensor = None
@@ -2621,7 +2621,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.streaming_specificity_at_sensitivity(
predictions, labels, sensitivity=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -2641,7 +2641,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.streaming_specificity_at_sensitivity(
predictions, labels, sensitivity=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1, sess.run(update_op))
self.assertEqual(1, specificity.eval())
@@ -2656,7 +2656,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.streaming_specificity_at_sensitivity(
predictions, labels, sensitivity=0.8)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(1.0, sess.run(update_op))
self.assertAlmostEqual(1.0, specificity.eval())
@@ -2671,7 +2671,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.streaming_specificity_at_sensitivity(
predictions, labels, sensitivity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.6, sess.run(update_op))
@@ -2689,7 +2689,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.streaming_specificity_at_sensitivity(
predictions, labels, weights=weights, sensitivity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.6, sess.run(update_op))
@@ -2707,7 +2707,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
specificity, update_op = metrics.streaming_specificity_at_sensitivity(
predictions, labels, weights=weights, sensitivity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(8.0 / 15.0, sess.run(update_op))
@@ -2757,7 +2757,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase):
sensitivity, update_op = metrics.streaming_sensitivity_at_specificity(
predictions, labels, specificity=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -2777,7 +2777,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase):
specificity, update_op = metrics.streaming_sensitivity_at_specificity(
predictions, labels, specificity=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1, sess.run(update_op))
self.assertEqual(1, specificity.eval())
@@ -2792,7 +2792,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase):
specificity, update_op = metrics.streaming_sensitivity_at_specificity(
predictions, labels, specificity=0.8)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.8, sess.run(update_op))
self.assertAlmostEqual(0.8, specificity.eval())
@@ -2807,7 +2807,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase):
specificity, update_op = metrics.streaming_sensitivity_at_specificity(
predictions, labels, specificity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.6, sess.run(update_op))
self.assertAlmostEqual(0.6, specificity.eval())
@@ -2824,7 +2824,7 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase):
specificity, update_op = metrics.streaming_sensitivity_at_specificity(
predictions, labels, weights=weights, specificity=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.675, sess.run(update_op))
self.assertAlmostEqual(0.675, specificity.eval())
@@ -2887,7 +2887,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
rec, rec_op = metrics.streaming_recall_at_thresholds(
predictions, labels, thresholds)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -2905,7 +2905,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
def testAllCorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(inputs)
thresholds = [0.5]
@@ -2921,7 +2921,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
self.assertEqual(1, rec.eval())
def testSomeCorrect(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -2940,7 +2940,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
thresholds = [0.5]
@@ -2956,7 +2956,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
self.assertAlmostEqual(0, rec.eval())
def testWeights1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -2982,7 +2982,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, rec_high.eval(), places=5)
def testWeights2d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -3008,7 +3008,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, rec_high.eval(), places=5)
def testExtremeThresholds(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4))
@@ -3032,7 +3032,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, rec_high.eval())
def testZeroLabelsPredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
labels = array_ops.zeros([4])
thresholds = [0.5]
@@ -3082,7 +3082,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
labels = labels.astype(np.float32)
predictions = predictions.astype(np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Reshape the data so its easy to queue up:
predictions_batches = predictions.reshape((batch_size, num_batches))
labels_batches = labels.reshape((batch_size, num_batches))
@@ -3162,7 +3162,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds(
predictions, labels, thresholds)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -3177,7 +3177,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
def testAllCorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(inputs)
thresholds = [0.5]
@@ -3190,7 +3190,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
self.assertEqual(0, fpr.eval())
def testSomeCorrect(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -3206,7 +3206,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
thresholds = [0.5]
@@ -3219,7 +3219,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
self.assertAlmostEqual(1, fpr.eval())
def testWeights1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -3239,7 +3239,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, fpr_high.eval(), places=5)
def testWeights2d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -3259,7 +3259,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, fpr_high.eval(), places=5)
def testExtremeThresholds(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4))
@@ -3277,7 +3277,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
self.assertAlmostEqual(0.0, fpr_high.eval(), places=5)
def testZeroLabelsPredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
labels = array_ops.zeros([4])
thresholds = [0.5]
@@ -3317,7 +3317,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
labels = labels.astype(np.float32)
predictions = predictions.astype(np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Reshape the data so its easy to queue up:
predictions_batches = predictions.reshape((batch_size, num_batches))
labels_batches = labels.reshape((batch_size, num_batches))
@@ -3393,7 +3393,7 @@ class RecallAtPrecisionTest(test.TestCase):
recall, update_op = metrics.recall_at_precision(
labels, predictions, precision=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -3413,7 +3413,7 @@ class RecallAtPrecisionTest(test.TestCase):
recall, update_op = metrics.recall_at_precision(
labels, predictions, precision=1.0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1, sess.run(update_op))
self.assertEqual(1, recall.eval())
@@ -3428,7 +3428,7 @@ class RecallAtPrecisionTest(test.TestCase):
recall, update_op = metrics.recall_at_precision(
labels, predictions, precision=0.8)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.8, sess.run(update_op))
self.assertAlmostEqual(0.8, recall.eval())
@@ -3443,7 +3443,7 @@ class RecallAtPrecisionTest(test.TestCase):
recall, update_op = metrics.recall_at_precision(
labels, predictions, precision=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
target_recall = 2.0 / 3.0
self.assertAlmostEqual(target_recall, sess.run(update_op))
@@ -3461,12 +3461,66 @@ class RecallAtPrecisionTest(test.TestCase):
recall, update_op = metrics.recall_at_precision(
labels, predictions, weights=weights, precision=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
target_recall = 2.0 / 3.0
self.assertAlmostEqual(target_recall, sess.run(update_op))
self.assertAlmostEqual(target_recall, recall.eval())
+ def _test_strict_mode(self, strict_mode, target_precision, expected_recall):
+ num_thresholds = 11
+ predictions_values = [.2, .3, .5, .6, .7, .8, .9, .9, .9, .1]
+ labels_values = [1, 1, 0, 0, 0, 0, 0, 0, 0, 1]
+ # Resulting thresholds and the corresponding precision and recall values at
+ # each threshold:
+ # Thresholds [0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]
+ # precisions: [0.3 0.2 0.1 0 0 0 0 0 0]
+ # recalls: [1.0 0.7 0.3 0 0 0 0 0 0]
+ predictions = constant_op.constant(
+ predictions_values, dtype=dtypes_lib.float32)
+ labels = constant_op.constant(labels_values)
+ recall, update_op = metrics.recall_at_precision(
+ labels,
+ predictions,
+ num_thresholds=num_thresholds,
+ precision=target_precision,
+ strict_mode=strict_mode)
+
+ with self.cached_session() as sess:
+ sess.run(variables.local_variables_initializer())
+ self.assertAlmostEqual(expected_recall, sess.run(update_op))
+ self.assertAlmostEqual(expected_recall, recall.eval())
+
+ def testStrictMode_Off(self):
+ # strict_mode is turned off and return the recall at the threshold where the
+ # precision (0.3) is closest to target precision (0.9). The recall
+ # corresponding to the threshold is 1.0.
+ self._test_strict_mode(
+ strict_mode=False, target_precision=0.9, expected_recall=1.0)
+
+ def testStrictMode_OnAndFail(self):
+ # strict_mode is turned on and we fail to reach the target precision at any
+ # threshold.
+ # Target precision: 0.9
+ # Diff: [-0.6 -0.7 -0.8 -0.9 -0.9 -0.9 -0.9 -0.9 -0.9]
+ # Reciprocal: [-1.6 -1.4 -1.3 -1.1 -1.1 -1.1 -1.1 -1.1 -1.1]
+ # Max index: 3 and corresponding precision is: 0 which is smaller than
+ # target precsion 0.9. As a result, the expected recall is 0.
+ self._test_strict_mode(
+ strict_mode=True, target_precision=0.9, expected_recall=.0)
+
+ def testStrictMode_OnAndSucceed(self):
+ # strict_mode is on and we can reach the target precision at certain
+ # threshold.
+ # Target precision: 0.2
+ # Diff: [0.1 0 -0.1 -0.2 -0.2 -0.2 -0.2 -0.2 -0.2]
+ # Reciprocal: [10 infty -10.0 -5.0 -5.0 -5.0 -5.0 -5.0 -5.0]
+ # Max index: 1 and corresponding precision is: 0.2 which is no smaller than
+ # target precsion 0.2. In this case, we return the recall at index 1, which
+ # is 2.0/3 (0.7).
+ self._test_strict_mode(
+ strict_mode=True, target_precision=0.2, expected_recall=2.0 / 3)
+
class PrecisionAtRecallTest(test.TestCase):
@@ -3511,7 +3565,7 @@ class PrecisionAtRecallTest(test.TestCase):
precision, update_op = metrics.precision_at_recall(
labels, predictions, target_recall=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -3531,7 +3585,7 @@ class PrecisionAtRecallTest(test.TestCase):
precision, update_op = metrics.precision_at_recall(
labels, predictions, target_recall=0.7)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1, sess.run(update_op))
self.assertEqual(1, precision.eval())
@@ -3545,7 +3599,7 @@ class PrecisionAtRecallTest(test.TestCase):
precision, update_op = metrics.precision_at_recall(
labels, predictions, target_recall=0.2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(sess.run(label_prior), sess.run(update_op))
self.assertEqual(sess.run(label_prior), precision.eval())
@@ -3560,7 +3614,7 @@ class PrecisionAtRecallTest(test.TestCase):
precision, update_op = metrics.precision_at_recall(
labels, predictions, target_recall=0.8)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(0.8, sess.run(update_op))
self.assertAlmostEqual(0.8, precision.eval())
@@ -3575,7 +3629,7 @@ class PrecisionAtRecallTest(test.TestCase):
precision, update_op = metrics.precision_at_recall(
labels, predictions, target_recall=0.4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(2.0/3, sess.run(update_op))
self.assertAlmostEqual(2.0/3, precision.eval())
@@ -3594,7 +3648,7 @@ class PrecisionAtRecallTest(test.TestCase):
precision, update_op = metrics.precision_at_recall(
labels, predictions, target_recall=0.8, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(34.0/43, sess.run(update_op))
self.assertAlmostEqual(34.0/43, precision.eval())
@@ -3643,7 +3697,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds(
predictions, labels, thresholds)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -3658,7 +3712,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
def testAllCorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(inputs)
thresholds = [0.5]
@@ -3671,7 +3725,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
self.assertEqual(0, fnr.eval())
def testSomeCorrect(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
@@ -3687,7 +3741,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
def testAllIncorrect(self):
inputs = np.random.randint(0, 2, size=(100, 1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
thresholds = [0.5]
@@ -3700,7 +3754,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
self.assertAlmostEqual(1, fnr.eval())
def testWeights1d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -3720,7 +3774,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
self.assertAlmostEqual(1.0, fnr_high.eval(), places=5)
def testWeights2d(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
@@ -3740,7 +3794,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
self.assertAlmostEqual(1.0, fnr_high.eval(), places=5)
def testExtremeThresholds(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4))
@@ -3758,7 +3812,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
self.assertAlmostEqual(1.0, fnr_high.eval())
def testZeroLabelsPredictions(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
labels = array_ops.zeros([4])
thresholds = [0.5]
@@ -3798,7 +3852,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
labels = labels.astype(np.float32)
predictions = predictions.astype(np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Reshape the data so its easy to queue up:
predictions_batches = predictions.reshape((batch_size, num_batches))
labels_batches = labels.reshape((batch_size, num_batches))
@@ -3886,7 +3940,7 @@ class StreamingRecallAtKTest(test.TestCase):
sp_recall, sp_update_op = metrics.streaming_sparse_recall_at_k(
predictions, array_ops.reshape(labels, (self._batch_size, 1)), k=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0.25, sess.run(update_op))
self.assertEqual(0.25, recall.eval())
@@ -3904,7 +3958,7 @@ class StreamingRecallAtKTest(test.TestCase):
sp_recall, sp_update_op = metrics.streaming_sparse_recall_at_k(
predictions, array_ops.reshape(labels, (self._batch_size, 1)), k=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0.5, sess.run(update_op))
self.assertEqual(0.5, recall.eval())
@@ -3922,7 +3976,7 @@ class StreamingRecallAtKTest(test.TestCase):
sp_recall, sp_update_op = metrics.streaming_sparse_recall_at_k(
predictions, array_ops.reshape(labels, (self._batch_size, 1)), k=3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1.0, sess.run(update_op))
self.assertEqual(1.0, recall.eval())
@@ -3946,7 +4000,7 @@ class StreamingRecallAtKTest(test.TestCase):
k=2,
weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1.0, sess.run(update_op))
self.assertEqual(1.0, recall.eval())
@@ -4068,7 +4122,7 @@ class StreamingSparsePrecisionTest(test.TestCase):
self.assertAlmostEqual(expected, metric.eval())
def test_top_k_rank_invalid(self):
- with self.test_session():
+ with self.cached_session():
# top_k_predictions has rank < 2.
top_k_predictions = [9, 4, 6, 2, 0]
sp_labels = sparse_tensor.SparseTensorValue(
@@ -4615,7 +4669,7 @@ class StreamingSparsePrecisionTest(test.TestCase):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
labels = [[0, 0, 0, 1], [0, 0, 1, 0]]
expected_precision = 0.5
- with self.test_session():
+ with self.cached_session():
_, precision = metrics.streaming_sparse_precision_at_k(
predictions=constant_op.constant(predictions, dtypes_lib.float32),
labels=_binary_2d_label_to_sparse_value(labels),
@@ -5320,7 +5374,7 @@ class StreamingSparseRecallTest(test.TestCase):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
labels = [[0, 0, 1, 0], [0, 0, 0, 1]]
expected_recall = 0.5
- with self.test_session():
+ with self.cached_session():
_, recall = metrics.streaming_sparse_recall_at_k(
predictions=constant_op.constant(predictions, dtypes_lib.float32),
labels=_binary_2d_label_to_sparse_value(labels),
@@ -5364,7 +5418,7 @@ class StreamingMeanAbsoluteErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_absolute_error(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -5386,7 +5440,7 @@ class StreamingMeanAbsoluteErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_absolute_error(
predictions, labels, weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(3, sess.run(update_op))
self.assertEqual(3, error.eval())
@@ -5430,7 +5484,7 @@ class StreamingMeanRelativeErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_relative_error(
predictions, labels, normalizer)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -5455,7 +5509,7 @@ class StreamingMeanRelativeErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_relative_error(
predictions, labels, normalizer=labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(expected_error, sess.run(update_op))
self.assertEqual(expected_error, error.eval())
@@ -5471,7 +5525,7 @@ class StreamingMeanRelativeErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_relative_error(
predictions, labels, normalizer=array_ops.zeros_like(labels))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0.0, sess.run(update_op))
self.assertEqual(0.0, error.eval())
@@ -5509,7 +5563,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
labels = random_ops.random_normal((10, 3), seed=2)
error, update_op = metrics.streaming_mean_squared_error(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -5527,7 +5581,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_squared_error(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, sess.run(update_op))
self.assertEqual(0, error.eval())
@@ -5540,7 +5594,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_squared_error(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(6, sess.run(update_op))
self.assertEqual(6, error.eval())
@@ -5555,13 +5609,13 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
error, update_op = metrics.streaming_mean_squared_error(
predictions, labels, weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(13, sess.run(update_op))
self.assertEqual(13, error.eval())
def testMultipleBatchesOfSizeOne(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
2, dtypes=dtypes_lib.float32, shapes=(1, 3))
@@ -5586,7 +5640,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
self.assertAlmostEqual(208.0 / 6, error.eval(), 5)
def testMetricsComputedConcurrently(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates one set of predictions.
preds_queue0 = data_flow_ops.FIFOQueue(
2, dtypes=dtypes_lib.float32, shapes=(1, 3))
@@ -5629,7 +5683,7 @@ class StreamingMeanSquaredErrorTest(test.TestCase):
self.assertAlmostEqual(79.0 / 6, mse1, 5)
def testMultipleMetricsOnMultipleBatchesOfSizeOne(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
2, dtypes=dtypes_lib.float32, shapes=(1, 3))
@@ -5691,7 +5745,7 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase):
error, update_op = metrics.streaming_root_mean_squared_error(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -5704,7 +5758,7 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase):
self.assertEqual(initial_error, error.eval())
def testSingleUpdateZeroError(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
0.0, shape=(1, 3), dtype=dtypes_lib.float32)
labels = constant_op.constant(0.0, shape=(1, 3), dtype=dtypes_lib.float32)
@@ -5718,7 +5772,7 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase):
self.assertEqual(0, rmse.eval())
def testSingleUpdateWithError(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32)
labels = constant_op.constant(
@@ -5732,7 +5786,7 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase):
self.assertAlmostEqual(math.sqrt(6), rmse.eval(), 5)
def testSingleUpdateWithErrorAndWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant(
@@ -5788,7 +5842,7 @@ class StreamingCovarianceTest(test.TestCase):
predictions = labels * 0.5 + random_ops.random_normal((10, 3), seed=1) * 0.5
cov, update_op = metrics.streaming_covariance(predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -5801,7 +5855,7 @@ class StreamingCovarianceTest(test.TestCase):
self.assertEqual(initial_cov, cov.eval())
def testSingleUpdateIdentical(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = math_ops.to_float(math_ops.range(10))
labels = math_ops.to_float(math_ops.range(10))
@@ -5813,7 +5867,7 @@ class StreamingCovarianceTest(test.TestCase):
self.assertAlmostEqual(expected_cov, cov.eval(), 5)
def testSingleUpdateNonIdentical(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32)
labels = constant_op.constant(
@@ -5827,7 +5881,7 @@ class StreamingCovarianceTest(test.TestCase):
self.assertAlmostEqual(expected_cov, cov.eval())
def testSingleUpdateWithErrorAndWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
labels = constant_op.constant(
@@ -5845,7 +5899,7 @@ class StreamingCovarianceTest(test.TestCase):
self.assertAlmostEqual(expected_cov, cov.eval())
def testMultiUpdateWithErrorNoWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
np.random.seed(123)
n = 100
predictions = np.random.randn(n)
@@ -5879,7 +5933,7 @@ class StreamingCovarianceTest(test.TestCase):
prev_expected_cov = expected_cov
def testMultiUpdateWithErrorAndWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
np.random.seed(123)
n = 100
predictions = np.random.randn(n)
@@ -5969,7 +6023,7 @@ class StreamingPearsonRTest(test.TestCase):
pearson_r, update_op = metrics.streaming_pearson_correlation(
predictions, labels)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -5982,7 +6036,7 @@ class StreamingPearsonRTest(test.TestCase):
self.assertEqual(initial_r, pearson_r.eval())
def testSingleUpdateIdentical(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = math_ops.to_float(math_ops.range(10))
labels = math_ops.to_float(math_ops.range(10))
@@ -5995,7 +6049,7 @@ class StreamingPearsonRTest(test.TestCase):
self.assertAlmostEqual(expected_r, pearson_r.eval(), 5)
def testSingleUpdateNonIdentical(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(
[2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32)
labels = constant_op.constant(
@@ -6010,7 +6064,7 @@ class StreamingPearsonRTest(test.TestCase):
self.assertAlmostEqual(expected_r, pearson_r.eval())
def testSingleUpdateWithErrorAndWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = np.array([2, 4, 6, 8])
labels = np.array([1, 3, 2, 7])
weights = np.array([0, 1, 3, 1])
@@ -6031,7 +6085,7 @@ class StreamingPearsonRTest(test.TestCase):
self.assertAlmostEqual(expected_r, pearson_r.eval())
def testMultiUpdateWithErrorNoWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
np.random.seed(123)
n = 100
predictions = np.random.randn(n)
@@ -6066,7 +6120,7 @@ class StreamingPearsonRTest(test.TestCase):
prev_expected_r = expected_r
def testMultiUpdateWithErrorAndWeights(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
np.random.seed(123)
n = 100
predictions = np.random.randn(n)
@@ -6108,7 +6162,7 @@ class StreamingPearsonRTest(test.TestCase):
prev_expected_r = expected_r
def testMultiUpdateWithErrorAndSingletonBatches(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
np.random.seed(123)
n = 100
predictions = np.random.randn(n)
@@ -6189,7 +6243,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.streaming_mean_cosine_distance(
predictions, labels, dim=1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -6212,7 +6266,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.streaming_mean_cosine_distance(
predictions, labels, dim=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, sess.run(update_op))
self.assertEqual(0, error.eval())
@@ -6229,7 +6283,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.streaming_mean_cosine_distance(
predictions, labels, dim=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(1, sess.run(update_op), 5)
self.assertAlmostEqual(1, error.eval(), 5)
@@ -6251,7 +6305,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.streaming_mean_cosine_distance(
predictions, labels, dim=2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertAlmostEqual(1.0, sess.run(update_op), 5)
self.assertAlmostEqual(1.0, error.eval(), 5)
@@ -6270,7 +6324,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.streaming_mean_cosine_distance(
predictions, labels, dim=2, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(0, sess.run(update_op))
self.assertEqual(0, error.eval())
@@ -6289,7 +6343,7 @@ class StreamingMeanCosineDistanceTest(test.TestCase):
error, update_op = metrics.streaming_mean_cosine_distance(
predictions, labels, dim=2, weights=weights)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1.5, update_op.eval())
self.assertEqual(1.5, error.eval())
@@ -6324,7 +6378,7 @@ class PcntBelowThreshTest(test.TestCase):
self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
def testOneUpdate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = constant_op.constant(
[2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
@@ -6344,7 +6398,7 @@ class PcntBelowThreshTest(test.TestCase):
self.assertAlmostEqual(0.0, pcnt2, 5)
def testSomePresentOneUpdate(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = constant_op.constant(
[2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
weights = constant_op.constant(
@@ -6421,7 +6475,7 @@ class StreamingMeanIOUTest(test.TestCase):
miou, update_op = metrics.streaming_mean_iou(
predictions, labels, num_classes=num_classes)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -6435,7 +6489,7 @@ class StreamingMeanIOUTest(test.TestCase):
def testMultipleUpdates(self):
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
5, dtypes=dtypes_lib.int32, shapes=(1, 1))
@@ -6467,7 +6521,7 @@ class StreamingMeanIOUTest(test.TestCase):
def testMultipleUpdatesWithWeights(self):
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
preds_queue = data_flow_ops.FIFOQueue(
6, dtypes=dtypes_lib.int32, shapes=(1, 1))
@@ -6515,7 +6569,7 @@ class StreamingMeanIOUTest(test.TestCase):
# one class, and thus there is one row and one column with
# zero entries in the confusion matrix.
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the predictions.
# There is no prediction for class 2.
preds_queue = data_flow_ops.FIFOQueue(
@@ -6557,7 +6611,7 @@ class StreamingMeanIOUTest(test.TestCase):
constant_op.constant(1, shape=[7])
], 0)
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(predictions, labels,
num_classes)
sess.run(variables.local_variables_initializer())
@@ -6570,7 +6624,7 @@ class StreamingMeanIOUTest(test.TestCase):
predictions = array_ops.zeros([40])
labels = array_ops.zeros([40])
num_classes = 1
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(predictions, labels,
num_classes)
sess.run(variables.local_variables_initializer())
@@ -6581,7 +6635,7 @@ class StreamingMeanIOUTest(test.TestCase):
predictions = array_ops.zeros([40])
labels = array_ops.ones([40])
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(predictions, labels,
num_classes)
sess.run(variables.local_variables_initializer())
@@ -6603,7 +6657,7 @@ class StreamingMeanIOUTest(test.TestCase):
constant_op.constant(1, shape=[8]),
constant_op.constant(0, shape=[1])
], 0)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(
predictions, labels, num_classes, weights=weights)
sess.run(variables.local_variables_initializer())
@@ -6618,7 +6672,7 @@ class StreamingMeanIOUTest(test.TestCase):
[[[0, 0, 2, 1, 1, 0], [0, 1, 2, 2, 0, 1]], [[0, 0, 2, 1, 1, 1],
[1, 1, 2, 0, 0, 0]]])
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(predictions, labels,
num_classes)
sess.run(variables.local_variables_initializer())
@@ -6630,7 +6684,7 @@ class StreamingMeanIOUTest(test.TestCase):
labels = constant_op.constant([0])
predictions = constant_op.constant([0])
num_classes = 2
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(predictions, labels,
num_classes)
sess.run(variables.local_variables_initializer())
@@ -6644,7 +6698,7 @@ class StreamingMeanIOUTest(test.TestCase):
[[[0, 0, 1, 1, 0, 0], [1, 1, 0, 0, 1, 1]], [[0, 0, 0, 1, 1, 1],
[1, 1, 1, 0, 0, 0]]])
num_classes = 3
- with self.test_session() as sess:
+ with self.cached_session() as sess:
miou, update_op = metrics.streaming_mean_iou(predictions, labels,
num_classes)
sess.run(variables.local_variables_initializer())
@@ -6679,7 +6733,7 @@ class StreamingConcatTest(test.TestCase):
def testNextArraySize(self):
next_array_size = metric_ops._next_array_size # pylint: disable=protected-access
- with self.test_session():
+ with self.cached_session():
self.assertEqual(next_array_size(2, growth_factor=2).eval(), 2)
self.assertEqual(next_array_size(3, growth_factor=2).eval(), 4)
self.assertEqual(next_array_size(4, growth_factor=2).eval(), 4)
@@ -6687,7 +6741,7 @@ class StreamingConcatTest(test.TestCase):
self.assertEqual(next_array_size(6, growth_factor=2).eval(), 8)
def testStreamingConcat(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = array_ops.placeholder(dtypes_lib.int32, [None])
concatenated, update_op = metrics.streaming_concat(values)
sess.run(variables.local_variables_initializer())
@@ -6704,7 +6758,7 @@ class StreamingConcatTest(test.TestCase):
self.assertAllEqual(np.arange(10), concatenated.eval())
def testStreamingConcatStringValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = array_ops.placeholder(dtypes_lib.string, [None])
concatenated, update_op = metrics.streaming_concat(values)
sess.run(variables.local_variables_initializer())
@@ -6723,7 +6777,7 @@ class StreamingConcatTest(test.TestCase):
concatenated.eval())
def testStreamingConcatMaxSize(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = math_ops.range(3)
concatenated, update_op = metrics.streaming_concat(values, max_size=5)
sess.run(variables.local_variables_initializer())
@@ -6740,7 +6794,7 @@ class StreamingConcatTest(test.TestCase):
self.assertAllEqual([0, 1, 2, 0, 1], concatenated.eval())
def testStreamingConcat2D(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = array_ops.reshape(math_ops.range(3), (3, 1))
concatenated, update_op = metrics.streaming_concat(values, axis=-1)
sess.run(variables.local_variables_initializer())
@@ -6763,7 +6817,7 @@ class StreamingConcatTest(test.TestCase):
array_ops.placeholder(dtypes_lib.float32, [None, None]))
def testStreamingConcatReset(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values = array_ops.placeholder(dtypes_lib.int32, [None])
concatenated, update_op = metrics.streaming_concat(values)
sess.run(variables.local_variables_initializer())
@@ -6791,7 +6845,7 @@ class AggregateMetricsTest(test.TestCase):
metrics.streaming_mean(values))
self.assertEqual(len(value_tensors), 1)
self.assertEqual(len(update_ops), 1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(1, update_ops[0].eval())
self.assertEqual(1, value_tensors[0].eval())
@@ -6804,7 +6858,7 @@ class AggregateMetricsTest(test.TestCase):
metrics.streaming_mean_squared_error(predictions, labels))
self.assertEqual(len(value_tensors), 2)
self.assertEqual(len(update_ops), 2)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(2, update_ops[0].eval())
self.assertEqual(4, update_ops[1].eval())
@@ -6825,7 +6879,7 @@ class AggregateMetricMapTest(test.TestCase):
self.assertEqual(2, len(names_to_values))
self.assertEqual(2, len(names_to_updates))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
self.assertEqual(2, names_to_updates['m1'].eval())
self.assertEqual(4, names_to_updates['m2'].eval())
@@ -6860,7 +6914,7 @@ class CountTest(test.TestCase):
self.assertTrue(isinstance(op, ops.Operation) or isinstance(op, ops.Tensor))
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -6877,7 +6931,7 @@ class CountTest(test.TestCase):
self.assertAlmostEqual(8.0, sess.run(result), 5)
def testUpdateOpsReturnsCurrentValue(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
@@ -6898,7 +6952,7 @@ class CountTest(test.TestCase):
self.assertAlmostEqual(8.0, sess.run(result), 5)
def test1dWeightedValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -6925,7 +6979,7 @@ class CountTest(test.TestCase):
self.assertAlmostEqual(3.4, result.eval(), 5)
def test1dWeightedValues_placeholders(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0))
values = array_ops.placeholder(dtype=dtypes_lib.float32)
@@ -6947,7 +7001,7 @@ class CountTest(test.TestCase):
self.assertAlmostEqual(3.4, result.eval(), 5)
def test2dWeightedValues(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
values_queue = data_flow_ops.FIFOQueue(
4, dtypes=dtypes_lib.float32, shapes=(1, 2))
@@ -6974,7 +7028,7 @@ class CountTest(test.TestCase):
self.assertAlmostEqual(4.1, result.eval(), 5)
def test2dWeightedValues_placeholders(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Create the queue that populates the values.
feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0))
values = array_ops.placeholder(dtype=dtypes_lib.float32)
@@ -7047,7 +7101,7 @@ class CohenKappaTest(test.TestCase):
(10, 1), maxval=3, dtype=dtypes_lib.int64, seed=2)
kappa, update_op = metrics.cohen_kappa(labels, predictions, 3)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates.
@@ -7081,7 +7135,7 @@ class CohenKappaTest(test.TestCase):
for dtype in dtypes:
for shape in shapes:
for weight in weights:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions_tensor = constant_op.constant(
np.reshape(predictions, shape), dtype=dtype)
labels_tensor = constant_op.constant(
@@ -7102,7 +7156,7 @@ class CohenKappaTest(test.TestCase):
# Calculated by v0.19: sklearn.metrics.cohen_kappa_score(inputs, inputs)
expect = 1.0
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
labels = constant_op.constant(inputs)
kappa, update_op = metrics.cohen_kappa(labels, predictions, 4)
@@ -7121,7 +7175,7 @@ class CohenKappaTest(test.TestCase):
# Calculated by v0.19: sklearn.metrics.cohen_kappa_score(labels, predictions)
expect = -0.333333333333
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(predictions, dtype=dtypes_lib.float32)
labels = constant_op.constant(labels)
kappa, update_op = metrics.cohen_kappa(labels, predictions, 4)
@@ -7139,7 +7193,7 @@ class CohenKappaTest(test.TestCase):
# labels, predictions, sample_weight=weights)
expect = 0.453466583385
- with self.test_session() as sess:
+ with self.cached_session() as sess:
predictions = constant_op.constant(predictions, dtype=dtypes_lib.float32)
labels = constant_op.constant(labels)
kappa, update_op = metrics.cohen_kappa(
@@ -7164,7 +7218,7 @@ class CohenKappaTest(test.TestCase):
weights_t = array_ops.placeholder(dtypes_lib.float32, shape=(batch_size,))
kappa, update_op = metrics.cohen_kappa(
labels_t, predictions_t, num_classes, weights=weights_t)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
for idx in range(0, num_samples, batch_size):
@@ -7202,7 +7256,7 @@ class CohenKappaTest(test.TestCase):
def testConditionalPackingOptimization(self):
placeholder = array_ops.placeholder(dtypes_lib.float32, [None])
values, update_op = metric_ops.streaming_concat(placeholder)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.local_variables_initializer())
for feed in range(10):
sess.run(update_op, feed_dict={placeholder: [feed]})
diff --git a/tensorflow/contrib/model_pruning/BUILD b/tensorflow/contrib/model_pruning/BUILD
index e662b11be8..3cffd76a25 100644
--- a/tensorflow/contrib/model_pruning/BUILD
+++ b/tensorflow/contrib/model_pruning/BUILD
@@ -113,7 +113,7 @@ py_library(
py_test(
name = "pruning_utils_test",
- size = "small",
+ size = "medium",
srcs = ["python/pruning_utils_test.py"],
srcs_version = "PY2AND3",
deps = [
diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md
index a5267fd904..15d95896d9 100644
--- a/tensorflow/contrib/model_pruning/README.md
+++ b/tensorflow/contrib/model_pruning/README.md
@@ -53,7 +53,7 @@ The pruning library allows for specification of the following hyper parameters:
| weight_sparsity_map | list of strings | [""] | list of weight variable name (or layer name):target sparsity pairs. Eg. [conv1:0.9,conv2/kernel:0.8]. For layers/weights not in this list, sparsity as specified by the target_sparsity hyperparameter is used. |
| threshold_decay | float | 0.9 | The decay factor to use for exponential decay of the thresholds |
| pruning_frequency | integer | 10 | How often should the masks be updated? (in # of global_steps) |
-| nbins | integer | 256 | Number of bins to use for histogram computation |
+| nbins | integer | 256 | Number of bins to use for histogram computation. Note: When running on TPUs, a large (>1024) value for `nbins` may adversely affect the training time. |
| block_height|integer | 1 | Number of rows in a block for block sparse matrices|
| block_width |integer | 1 | Number of cols in a block for block sparse matrices|
| block_pooling_function| string | AVG | The function to use to pool weight values in a block: average (AVG) or max (MAX)|
diff --git a/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py b/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py
index e85ae7b22a..586c6c7bfc 100644
--- a/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py
+++ b/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py
@@ -37,7 +37,7 @@ class RnnCellsTest(test.TestCase):
expected_num_masks = 1
expected_num_rows = 2 * self.dim
expected_num_cols = 4 * self.dim
- with self.test_session():
+ with self.cached_session():
inputs = variables.Variable(
random_ops.random_normal([self.batch_size, self.dim]))
c = variables.Variable(
@@ -61,7 +61,7 @@ class RnnCellsTest(test.TestCase):
expected_num_masks = 1
expected_num_rows = 2 * self.dim
expected_num_cols = 4 * self.dim
- with self.test_session():
+ with self.cached_session():
inputs = variables.Variable(
random_ops.random_normal([self.batch_size, self.dim]))
c = variables.Variable(
diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils.py b/tensorflow/contrib/model_pruning/python/pruning_utils.py
index b50a372e9d..91b0bb7f60 100644
--- a/tensorflow/contrib/model_pruning/python/pruning_utils.py
+++ b/tensorflow/contrib/model_pruning/python/pruning_utils.py
@@ -235,19 +235,18 @@ def compute_cdf_from_histogram(values, value_range, **kwargs):
def compute_cdf(values, value_range, **kwargs):
"""Returns the normalized cumulative distribution of the given values tensor.
- Uses tf.while_loop to directly compute the cdf of the values. Number of bins
- for histogram is fixed at _NBINS=255
+ Uses tf.while_loop to directly compute the cdf of the values.
Args:
values: Numeric `Tensor`.
value_range: Shape [2] `Tensor` of same `dtype` as `values`
- **kwargs: keyword arguments: name
+ **kwargs: keyword arguments: nbins, name
Returns:
A 1-D `Tensor` holding normalized cdf of values.
"""
- nbins = _NBINS
+ nbins = kwargs.get('nbins', _NBINS)
name = kwargs.get('name', None)
with ops.name_scope(name, 'cdf', [values, value_range, nbins]):
values = ops.convert_to_tensor(values, name='values')
@@ -281,7 +280,7 @@ def compute_cdf(values, value_range, **kwargs):
cdf = math_ops.add(
cdf,
array_ops.one_hot(
- loop_count, depth=_NBINS, on_value=temp, off_value=0.0))
+ loop_count, depth=nbins, on_value=temp, off_value=0.0))
return [loop_count + 1, cdf]
_, cdf = control_flow_ops.while_loop(
diff --git a/tensorflow/contrib/nearest_neighbor/python/kernel_tests/hyperplane_lsh_probes_test.py b/tensorflow/contrib/nearest_neighbor/python/kernel_tests/hyperplane_lsh_probes_test.py
index cb69c72970..d0955cbe11 100644
--- a/tensorflow/contrib/nearest_neighbor/python/kernel_tests/hyperplane_lsh_probes_test.py
+++ b/tensorflow/contrib/nearest_neighbor/python/kernel_tests/hyperplane_lsh_probes_test.py
@@ -31,7 +31,7 @@ class HyperplaneLshProbesTest(test.TestCase):
# tests in hyperplane_lsh_probes_test.cc already cover most of the LSH
# functionality.
def simple_batch_test(self):
- with self.test_session():
+ with self.cached_session():
hyperplanes = np.eye(4)
points = np.array([[1.2, 0.5, -0.9, -1.0], [2.0, -3.0, 1.0, -1.5]])
product = np.dot(points, hyperplanes)
diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD
index 5319a8b655..2e4d61d931 100644
--- a/tensorflow/contrib/opt/BUILD
+++ b/tensorflow/contrib/opt/BUILD
@@ -22,6 +22,7 @@ py_library(
"python/training/ggt.py",
"python/training/lars_optimizer.py",
"python/training/lazy_adam_optimizer.py",
+ "python/training/matrix_functions.py",
"python/training/model_average_optimizer.py",
"python/training/moving_average_optimizer.py",
"python/training/multitask_optimizer_wrapper.py",
@@ -158,8 +159,10 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:resource_variable_ops",
"//tensorflow/python:variables",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
@@ -381,3 +384,18 @@ py_test(
"@six_archive//:six",
],
)
+
+py_test(
+ name = "matrix_functions_test",
+ srcs = ["python/training/matrix_functions_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":opt_py",
+ "//tensorflow/python:client",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ ],
+)
diff --git a/tensorflow/contrib/opt/__init__.py b/tensorflow/contrib/opt/__init__.py
index 781621dba0..ad7d7cfa6e 100644
--- a/tensorflow/contrib/opt/__init__.py
+++ b/tensorflow/contrib/opt/__init__.py
@@ -31,6 +31,7 @@ from tensorflow.contrib.opt.python.training.model_average_optimizer import *
from tensorflow.contrib.opt.python.training.moving_average_optimizer import *
from tensorflow.contrib.opt.python.training.multitask_optimizer_wrapper import *
from tensorflow.contrib.opt.python.training.nadam_optimizer import *
+from tensorflow.contrib.opt.python.training.reg_adagrad_optimizer import *
from tensorflow.contrib.opt.python.training.shampoo import *
from tensorflow.contrib.opt.python.training.weight_decay_optimizers import *
from tensorflow.contrib.opt.python.training.powersign import *
@@ -65,6 +66,7 @@ _allowed_symbols = [
'ModelAverageCustomGetter',
'GGTOptimizer',
'ShampooOptimizer',
+ 'RegAdagradOptimizer',
]
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/opt/python/training/adamax_test.py b/tensorflow/contrib/opt/python/training/adamax_test.py
index 5790d8a3f1..61d8b94eca 100644
--- a/tensorflow/contrib/opt/python/training/adamax_test.py
+++ b/tensorflow/contrib/opt/python/training/adamax_test.py
@@ -74,7 +74,7 @@ class AdaMaxOptimizerTest(test.TestCase):
def doTestSparse(self, use_resource=False):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
zero_slots = lambda: np.zeros((3), dtype=dtype.as_numpy_dtype)
m0, v0, m1, v1 = zero_slots(), zero_slots(), zero_slots(), zero_slots()
@@ -142,7 +142,7 @@ class AdaMaxOptimizerTest(test.TestCase):
def testSparseRepeatedIndices(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
repeated_index_update_var = variables.Variable(
[[1.0], [2.0]], dtype=dtype)
aggregated_update_var = variables.Variable(
@@ -233,7 +233,7 @@ class AdaMaxOptimizerTest(test.TestCase):
opt.get_slot(var=var0, name="m").name)
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
self.doTestBasic(use_resource=False)
@test_util.run_in_graph_and_eager_modes(reset_test=True)
@@ -242,7 +242,7 @@ class AdaMaxOptimizerTest(test.TestCase):
def testTensorLearningRate(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -278,7 +278,7 @@ class AdaMaxOptimizerTest(test.TestCase):
def testSharing(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
diff --git a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
index bbafd59aae..6c203e5519 100644
--- a/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/elastic_average_optimizer.py
@@ -128,12 +128,14 @@ class ElasticAverageCustomGetter(object):
= list(global_center_variable)[i]
return local_var
else:
- return getter(
- name,
- trainable=trainable,
- collections=collections,
- *args,
- **kwargs)
+ kwargs['trainable'] = trainable
+ kwargs['collections'] = collections
+ if ops.GraphKeys.LOCAL_VARIABLES in collections:
+ with ops.device(self._worker_device):
+ return getter(name, *args, **kwargs)
+ else:
+ return getter(name, *args, **kwargs)
+
class ElasticAverageOptimizer(optimizer.Optimizer):
diff --git a/tensorflow/contrib/opt/python/training/external_optimizer_test.py b/tensorflow/contrib/opt/python/training/external_optimizer_test.py
index 953586ee70..9997103016 100644
--- a/tensorflow/contrib/opt/python/training/external_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/external_optimizer_test.py
@@ -85,7 +85,7 @@ class ExternalOptimizerInterfaceTest(TestCase):
optimizer = MockOptimizerInterface(loss)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
optimizer.minimize(sess)
@@ -107,7 +107,7 @@ class ExternalOptimizerInterfaceTest(TestCase):
optimizer = MockOptimizerInterface(loss)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
initial_vector_val = sess.run(vector)
@@ -164,7 +164,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
optimizer = external_optimizer.ScipyOptimizerInterface(
self._objective(x), method=method, options=options)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
optimizer.minimize(sess)
@@ -176,7 +176,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
x = variables.Variable(array_ops.zeros(dimension))
optimizer = external_optimizer.ScipyOptimizerInterface(self._objective(x))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
optimizer.minimize(sess)
@@ -242,7 +242,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
optimizer = external_optimizer.ScipyOptimizerInterface(
loss, equalities=equalities, inequalities=inequalities, method='SLSQP')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
optimizer.minimize(sess)
self.assertAllClose(np.ones(2), sess.run(vector))
@@ -260,7 +260,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
optimizer = external_optimizer.ScipyOptimizerInterface(
loss, var_to_bounds=var_to_bounds)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
optimizer.minimize(sess)
self.assertAllClose(np.ones(2), sess.run(vector))
@@ -277,7 +277,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
optimizer = external_optimizer.ScipyOptimizerInterface(
loss, var_to_bounds=var_to_bounds)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
optimizer.minimize(sess)
self.assertAllClose([0., 2.], sess.run(vector))
@@ -293,7 +293,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
optimizer = external_optimizer.ScipyOptimizerInterface(
loss, method='SLSQP')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
optimizer.minimize(sess)
method = optimizer.optimizer_kwargs.get('method')
@@ -312,7 +312,7 @@ class ScipyOptimizerInterfaceTest(TestCase):
optimizer = external_optimizer.ScipyOptimizerInterface(loss, method='SLSQP')
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables.global_variables_initializer())
initial_vector_val = sess.run(vector)
diff --git a/tensorflow/contrib/opt/python/training/ggt_test.py b/tensorflow/contrib/opt/python/training/ggt_test.py
index 1d2a79957b..1775edabb3 100644
--- a/tensorflow/contrib/opt/python/training/ggt_test.py
+++ b/tensorflow/contrib/opt/python/training/ggt_test.py
@@ -171,7 +171,7 @@ class GGTOptimizerTest(test.TestCase):
self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
self.doTestBasic(use_resource=False)
@test_util.run_in_graph_and_eager_modes(reset_test=True)
diff --git a/tensorflow/contrib/opt/python/training/lars_optimizer_test.py b/tensorflow/contrib/opt/python/training/lars_optimizer_test.py
index d94249b994..b76db763da 100644
--- a/tensorflow/contrib/opt/python/training/lars_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/lars_optimizer_test.py
@@ -31,7 +31,7 @@ class LARSOptimizerTest(test.TestCase):
def testLARSGradientOneStep(self):
for _ in range(10):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape = [3, 3]
var_np = np.ones(shape)
grad_np = np.ones(shape)
@@ -77,7 +77,7 @@ class LARSOptimizerTest(test.TestCase):
def testLARSGradientMultiStep(self):
for _ in range(10):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape = [3, 3]
var_np = np.ones(shape)
grad_np = np.ones(shape)
diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
index 72117c1e81..f55209ec49 100644
--- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer.py
@@ -28,6 +28,7 @@ from __future__ import print_function
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.training import adam
@@ -78,3 +79,36 @@ class LazyAdamOptimizer(adam.AdamOptimizer):
lr * m_t_slice / denominator_slice,
use_locking=self._use_locking)
return control_flow_ops.group(var_update, m_t, v_t)
+
+ def _resource_apply_sparse(self, grad, var, indices):
+ beta1_power, beta2_power = self._get_beta_accumulators()
+ beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
+ beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
+ lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
+ beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
+ beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
+ epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
+ lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
+
+ # \\(m := beta1 * m + (1 - beta1) * g_t\\)
+ m = self.get_slot(var, "m")
+ m_t_slice = beta1_t * array_ops.gather(m, indices) + (1 - beta1_t) * grad
+ m_update_op = resource_variable_ops.resource_scatter_update(m.handle,
+ indices,
+ m_t_slice)
+
+ # \\(v := beta2 * v + (1 - beta2) * (g_t * g_t)\\)
+ v = self.get_slot(var, "v")
+ v_t_slice = (beta2_t * array_ops.gather(v, indices) +
+ (1 - beta2_t) * math_ops.square(grad))
+ v_update_op = resource_variable_ops.resource_scatter_update(v.handle,
+ indices,
+ v_t_slice)
+
+ # \\(variable -= learning_rate * m_t / (epsilon_t + sqrt(v_t))\\)
+ var_slice = lr * m_t_slice / (math_ops.sqrt(v_t_slice) + epsilon_t)
+ var_update_op = resource_variable_ops.resource_scatter_sub(var.handle,
+ indices,
+ var_slice)
+
+ return control_flow_ops.group(var_update_op, m_update_op, v_update_op)
diff --git a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
index a16857db7d..f08ffaa36f 100644
--- a/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/lazy_adam_optimizer_test.py
@@ -19,14 +19,18 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.opt.python.training import lazy_adam_optimizer
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -49,11 +53,12 @@ def adam_update_numpy(param,
return param_t, m_t, v_t
-class AdamOptimizerTest(test.TestCase):
+class AdamOptimizerTest(test.TestCase, parameterized.TestCase):
- def testSparse(self):
+ @parameterized.parameters([False, True])
+ def testSparse(self, use_resource):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -61,8 +66,13 @@ class AdamOptimizerTest(test.TestCase):
var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
- var0 = variables.Variable(var0_np)
- var1 = variables.Variable(var1_np)
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ else:
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+
grads0_np_indices = np.array([0, 1], dtype=np.int32)
grads0 = ops.IndexedSlices(
constant_op.constant(grads0_np),
@@ -94,12 +104,17 @@ class AdamOptimizerTest(test.TestCase):
self.assertAllCloseAccordingToType(var0_np, var0.eval())
self.assertAllCloseAccordingToType(var1_np, var1.eval())
- def testSparseDevicePlacement(self):
+ @parameterized.parameters([False, True])
+ def testSparseDevicePlacement(self, use_resource):
for index_dtype in [dtypes.int32, dtypes.int64]:
with self.test_session(force_gpu=test.is_gpu_available()):
# If a GPU is available, tests that all optimizer ops can be placed on
# it (i.e. they have GPU kernels).
- var = variables.Variable([[1.0], [2.0]])
+ if use_resource:
+ var = resource_variable_ops.ResourceVariable([[1.0], [2.0]])
+ else:
+ var = variables.Variable([[1.0], [2.0]])
+
indices = constant_op.constant([0, 1], dtype=index_dtype)
gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices))
optimizer = lazy_adam_optimizer.LazyAdamOptimizer(3.0)
@@ -107,13 +122,21 @@ class AdamOptimizerTest(test.TestCase):
variables.global_variables_initializer().run()
minimize_op.run()
- def testSparseRepeatedIndices(self):
+ @parameterized.parameters([False, True])
+ def testSparseRepeatedIndices(self, use_resource):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
- repeated_index_update_var = variables.Variable(
- [[1.0], [2.0]], dtype=dtype)
- aggregated_update_var = variables.Variable(
- [[1.0], [2.0]], dtype=dtype)
+ with self.cached_session():
+ if use_resource:
+ repeated_index_update_var = resource_variable_ops.ResourceVariable(
+ [[1.0], [2.0]], dtype=dtype)
+ aggregated_update_var = resource_variable_ops.ResourceVariable(
+ [[1.0], [2.0]], dtype=dtype)
+ else:
+ repeated_index_update_var = variables.Variable(
+ [[1.0], [2.0]], dtype=dtype)
+ aggregated_update_var = variables.Variable(
+ [[1.0], [2.0]], dtype=dtype)
+
grad_repeated_index = ops.IndexedSlices(
constant_op.constant(
[0.1, 0.1], shape=[2, 1], dtype=dtype),
@@ -139,6 +162,204 @@ class AdamOptimizerTest(test.TestCase):
self.assertAllClose(aggregated_update_var.eval(),
repeated_index_update_var.eval())
+ def doTestBasic(self, use_resource=False, use_callable_params=False):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ with self.session(graph=ops.Graph()):
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(
+ var0_np, name="var0_%d" % i)
+ var1 = resource_variable_ops.ResourceVariable(
+ var1_np, name="var1_%d" % i)
+ else:
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+
+ learning_rate = lambda: 0.001
+ beta1 = lambda: 0.9
+ beta2 = lambda: 0.999
+ epsilon = lambda: 1e-8
+ if not use_callable_params:
+ learning_rate = learning_rate()
+ beta1 = beta1()
+ beta2 = beta2()
+ epsilon = epsilon()
+
+ opt = lazy_adam_optimizer.LazyAdamOptimizer(learning_rate=learning_rate)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ opt_variables = opt.variables()
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+ self.assertIsNotNone(beta1_power)
+ self.assertIsNotNone(beta2_power is not None)
+ self.assertIn(beta1_power, opt_variables)
+ self.assertIn(beta2_power, opt_variables)
+
+ if not context.executing_eagerly():
+ with ops.Graph().as_default():
+ # Shouldn't return non-slot variables from other graphs.
+ self.assertEqual(0, len(opt.variables()))
+ self.evaluate(variables.global_variables_initializer())
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of Adam
+ for t in range(1, 4):
+ if not context.executing_eagerly():
+ self.evaluate(update)
+ elif t > 1:
+ opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+
+ self.assertAllCloseAccordingToType(0.9**(t + 1),
+ self.evaluate(beta1_power))
+ self.assertAllCloseAccordingToType(0.999**(t + 1),
+ self.evaluate(beta2_power))
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+ self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
+ if use_resource:
+ self.assertEqual("var0_%d/Adam:0" % (i,),
+ opt.get_slot(var=var0, name="m").name)
+
+ def testBasic(self):
+ with self.test_session():
+ self.doTestBasic(use_resource=False)
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testResourceBasic(self):
+ self.doTestBasic(use_resource=True)
+
+ def testBasicCallableParams(self):
+ with context.eager_mode():
+ self.doTestBasic(use_resource=True, use_callable_params=True)
+
+ def testTensorLearningRate(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.test_session():
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ opt = lazy_adam_optimizer.LazyAdamOptimizer(constant_op.constant(0.001))
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of Adam
+ for t in range(1, 4):
+ self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+ self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
+ update.run()
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testSharing(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.test_session():
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ opt = lazy_adam_optimizer.LazyAdamOptimizer()
+ update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 3 steps of intertwined Adam1 and Adam2.
+ for t in range(1, 4):
+ self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+ self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
+ if t % 2 == 0:
+ update1.run()
+ else:
+ update2.run()
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testTwoSessions(self):
+ optimizer = lazy_adam_optimizer.LazyAdamOptimizer()
+
+ with context.eager_mode():
+ var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
+ grads0 = constant_op.constant(np.array([0.1, 0.1]))
+ optimizer.apply_gradients([(grads0, var0)])
+
+ g = ops.Graph()
+ with g.as_default():
+ with self.session(graph=g):
+ var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
+ grads0 = constant_op.constant(np.array([0.1, 0.1]))
+ optimizer.apply_gradients([(grads0, var0)])
+
+ gg = ops.Graph()
+ with gg.as_default():
+ with self.session(graph=gg):
+ var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
+ grads0 = constant_op.constant(np.array([0.1, 0.1]))
+
+ # If the optimizer saves any state not keyed by graph the following line
+ # fails.
+ optimizer.apply_gradients([(grads0, var0)])
+
+ def testSlotsUniqueEager(self):
+ with context.eager_mode():
+ v1 = resource_variable_ops.ResourceVariable(1.)
+ v2 = resource_variable_ops.ResourceVariable(1.)
+ opt = lazy_adam_optimizer.LazyAdamOptimizer(1.)
+ opt.minimize(lambda: v1 + v2)
+ # There should be two non-slot variables, and two unique slot variables
+ # for v1 and v2 respectively.
+ self.assertEqual(6, len(set(opt.variables())))
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/opt/python/training/matrix_functions.py b/tensorflow/contrib/opt/python/training/matrix_functions.py
new file mode 100644
index 0000000000..baab577638
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/matrix_functions.py
@@ -0,0 +1,155 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Matrix functions contains iterative methods for M^p."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+
+
+def matrix_square_root(mat_a, mat_a_size, iter_count=100, ridge_epsilon=1e-4):
+ """Iterative method to get matrix square root.
+
+ Stable iterations for the matrix square root, Nicholas J. Higham
+
+ Page 231, Eq 2.6b
+ http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.6.8799&rep=rep1&type=pdf
+
+ Args:
+ mat_a: the symmetric PSD matrix whose matrix square root be computed
+ mat_a_size: size of mat_a.
+ iter_count: Maximum number of iterations.
+ ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
+
+ Returns:
+ mat_a^0.5
+ """
+
+ def _iter_condition(i, unused_mat_y, unused_old_mat_y, unused_mat_z,
+ unused_old_mat_z, err, old_err):
+ # This method require that we check for divergence every step.
+ return math_ops.logical_and(i < iter_count, err < old_err)
+
+ def _iter_body(i, mat_y, unused_old_mat_y, mat_z, unused_old_mat_z, err,
+ unused_old_err):
+ current_iterate = 0.5 * (3.0 * identity - math_ops.matmul(mat_z, mat_y))
+ current_mat_y = math_ops.matmul(mat_y, current_iterate)
+ current_mat_z = math_ops.matmul(current_iterate, mat_z)
+ # Compute the error in approximation.
+ mat_sqrt_a = current_mat_y * math_ops.sqrt(norm)
+ mat_a_approx = math_ops.matmul(mat_sqrt_a, mat_sqrt_a)
+ residual = mat_a - mat_a_approx
+ current_err = math_ops.sqrt(math_ops.reduce_sum(residual * residual)) / norm
+ return i + 1, current_mat_y, mat_y, current_mat_z, mat_z, current_err, err
+
+ identity = linalg_ops.eye(math_ops.to_int32(mat_a_size))
+ mat_a = mat_a + ridge_epsilon * identity
+ norm = math_ops.sqrt(math_ops.reduce_sum(mat_a * mat_a))
+ mat_init_y = mat_a / norm
+ mat_init_z = identity
+ init_err = norm
+
+ _, _, prev_mat_y, _, _, _, _ = control_flow_ops.while_loop(
+ _iter_condition, _iter_body, [
+ 0, mat_init_y, mat_init_y, mat_init_z, mat_init_z, init_err,
+ init_err + 1.0
+ ])
+ return prev_mat_y * math_ops.sqrt(norm)
+
+
+def matrix_inverse_pth_root(mat_g,
+ mat_g_size,
+ alpha,
+ iter_count=100,
+ epsilon=1e-6,
+ ridge_epsilon=1e-6):
+ """Computes mat_g^alpha, where alpha = -1/p, p a positive integer.
+
+ We use an iterative Schur-Newton method from equation 3.2 on page 9 of:
+
+ A Schur-Newton Method for the Matrix p-th Root and its Inverse
+ by Chun-Hua Guo and Nicholas J. Higham
+ SIAM Journal on Matrix Analysis and Applications,
+ 2006, Vol. 28, No. 3 : pp. 788-804
+ https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf
+
+ Args:
+ mat_g: the symmetric PSD matrix whose power it to be computed
+ mat_g_size: size of mat_g.
+ alpha: exponent, must be -1/p for p a positive integer.
+ iter_count: Maximum number of iterations.
+ epsilon: accuracy indicator, useful for early termination.
+ ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
+
+ Returns:
+ mat_g^alpha
+ """
+
+ identity = linalg_ops.eye(math_ops.to_int32(mat_g_size))
+
+ def mat_power(mat_m, p):
+ """Computes mat_m^p, for p a positive integer.
+
+ Power p is known at graph compile time, so no need for loop and cond.
+ Args:
+ mat_m: a square matrix
+ p: a positive integer
+
+ Returns:
+ mat_m^p
+ """
+ assert p == int(p) and p > 0
+ power = None
+ while p > 0:
+ if p % 2 == 1:
+ power = math_ops.matmul(mat_m, power) if power is not None else mat_m
+ p //= 2
+ mat_m = math_ops.matmul(mat_m, mat_m)
+ return power
+
+ def _iter_condition(i, mat_m, _):
+ return math_ops.logical_and(
+ i < iter_count,
+ math_ops.reduce_max(math_ops.abs(mat_m - identity)) > epsilon)
+
+ def _iter_body(i, mat_m, mat_x):
+ mat_m_i = (1 - alpha) * identity + alpha * mat_m
+ return (i + 1, math_ops.matmul(mat_power(mat_m_i, -1.0 / alpha), mat_m),
+ math_ops.matmul(mat_x, mat_m_i))
+
+ if mat_g_size == 1:
+ mat_h = math_ops.pow(mat_g + ridge_epsilon, alpha)
+ else:
+ damped_mat_g = mat_g + ridge_epsilon * identity
+ z = (1 - 1 / alpha) / (2 * linalg_ops.norm(damped_mat_g))
+ # The best value for z is
+ # (1 - 1/alpha) * (c_max^{-alpha} - c_min^{-alpha}) /
+ # (c_max^{1-alpha} - c_min^{1-alpha})
+ # where c_max and c_min are the largest and smallest singular values of
+ # damped_mat_g.
+ # The above estimate assumes that c_max > c_min * 2^p. (p = -1/alpha)
+ # Can replace above line by the one below, but it is less accurate,
+ # hence needs more iterations to converge.
+ # z = (1 - 1/alpha) / math_ops.trace(damped_mat_g)
+ # If we want the method to always converge, use z = 1 / norm(damped_mat_g)
+ # or z = 1 / math_ops.trace(damped_mat_g), but these can result in many
+ # extra iterations.
+ _, _, mat_h = control_flow_ops.while_loop(
+ _iter_condition, _iter_body,
+ [0, damped_mat_g * z, identity * math_ops.pow(z, -alpha)])
+ return mat_h
diff --git a/tensorflow/contrib/opt/python/training/matrix_functions_test.py b/tensorflow/contrib/opt/python/training/matrix_functions_test.py
new file mode 100644
index 0000000000..518fa38233
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/matrix_functions_test.py
@@ -0,0 +1,63 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for Matrix functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.opt.python.training import matrix_functions
+from tensorflow.python.platform import test
+
+TOLERANCE = 1e-3
+
+
+def np_power(mat_g, alpha):
+ """Computes mat_g^alpha for a square symmetric matrix mat_g."""
+
+ mat_u, diag_d, mat_v = np.linalg.svd(mat_g)
+ diag_d = np.power(diag_d, alpha)
+ return np.dot(np.dot(mat_u, np.diag(diag_d)), mat_v)
+
+
+class MatrixFunctionTests(test.TestCase):
+
+ def testMatrixSquareRootFunction(self):
+ """Tests for matrix square roots."""
+
+ size = 20
+ mat_a = np.random.rand(size, size)
+ mat = np.dot(mat_a, mat_a.T)
+ expected_mat = np_power(mat, 0.5)
+ mat_root = matrix_functions.matrix_square_root(mat, size)
+ self.assertAllCloseAccordingToType(
+ expected_mat, mat_root, atol=TOLERANCE, rtol=TOLERANCE)
+
+ def testMatrixInversePthRootFunction(self):
+ """Tests for matrix inverse pth roots."""
+
+ size = 20
+ mat_a = np.random.rand(size, size)
+ mat = np.dot(mat_a, mat_a.T)
+ expected_mat = np_power(mat, -0.125)
+ mat_root = matrix_functions.matrix_inverse_pth_root(mat, size, -0.125)
+ self.assertAllCloseAccordingToType(
+ expected_mat, mat_root, atol=TOLERANCE, rtol=TOLERANCE)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer.py b/tensorflow/contrib/opt/python/training/model_average_optimizer.py
index b6b10e500b..746df77ba2 100644
--- a/tensorflow/contrib/opt/python/training/model_average_optimizer.py
+++ b/tensorflow/contrib/opt/python/training/model_average_optimizer.py
@@ -89,7 +89,13 @@ class ModelAverageCustomGetter(object):
self._local_2_global[local_var] = global_variable
return local_var
else:
- return getter(name, trainable, collections, *args, **kwargs)
+ kwargs['trainable'] = trainable
+ kwargs['collections'] = collections
+ if ops.GraphKeys.LOCAL_VARIABLES in collections:
+ with ops.device(self._worker_device):
+ return getter(name, *args, **kwargs)
+ else:
+ return getter(name, *args, **kwargs)
class ModelAverageOptimizer(optimizer.Optimizer):
diff --git a/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py
index 3acd940268..b1fc50a21f 100644
--- a/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/model_average_optimizer_test.py
@@ -80,28 +80,28 @@ def _get_workers(num_workers, steps, workers):
var_0 = variable_scope.get_variable(initializer=0.0, name="v0")
var_1 = variable_scope.get_variable(initializer=1.0, name="v1")
- with ops.device("/job:worker/task:" + str(worker_id)):
- if worker_id == 0:
- grads_0 = constant_op.constant(-1.0)
- grads_1 = constant_op.constant(-1.0)
- else:
- grads_0 = constant_op.constant(-2.0)
- grads_1 = constant_op.constant(-2.0)
- sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
- opt = model_average_optimizer.ModelAverageOptimizer(
- opt=sgd_opt,
- num_worker=num_workers,
- ma_custom_getter=ma_coustom,
- is_chief=is_chief,
- interval_steps=steps)
- train_op = [
- opt.apply_gradients([[grads_0, var_0], [grads_1, var_1]],
- global_step)
- ]
- easgd_hook = opt.make_session_run_hook()
+ with ops.device("/job:worker/task:" + str(worker_id)):
+ if worker_id == 0:
+ grads_0 = constant_op.constant(-1.0)
+ grads_1 = constant_op.constant(-1.0)
+ else:
+ grads_0 = constant_op.constant(-2.0)
+ grads_1 = constant_op.constant(-2.0)
+ sgd_opt = gradient_descent.GradientDescentOptimizer(1.0)
+ opt = model_average_optimizer.ModelAverageOptimizer(
+ opt=sgd_opt,
+ num_worker=num_workers,
+ ma_custom_getter=ma_coustom,
+ is_chief=is_chief,
+ interval_steps=steps)
+ train_op = [
+ opt.apply_gradients([[grads_0, var_0], [grads_1, var_1]],
+ global_step)
+ ]
+ ma_hook = opt.make_session_run_hook()
# Creates MonitoredSession
sess = training.MonitoredTrainingSession(
- workers[worker_id].target, hooks=[easgd_hook])
+ workers[worker_id].target, hooks=[ma_hook])
sessions.append(sess)
graphs.append(graph)
diff --git a/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py
index d15716f6f6..f22e724528 100644
--- a/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py
@@ -165,7 +165,7 @@ class MovingAverageOptimizerTest(test.TestCase):
self.assertLess(avg_val1[i], orig_val1[i])
def testFailWhenSaverCreatedBeforeInitialized(self):
- with self.test_session():
+ with self.cached_session():
var = variables.Variable([1.0], name='var', dtype=dtypes.float32)
opt = moving_average_optimizer.MovingAverageOptimizer(
gradient_descent.GradientDescentOptimizer(learning_rate=2.0))
@@ -187,7 +187,7 @@ class MovingAverageOptimizerTest(test.TestCase):
self.apply_gradients_called = True
return super(WrapperOptimizer, self).apply_gradients(*args, **kwargs)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
var = variables.Variable([1.2], name='var', dtype=dtypes.float32)
loss = var ** 2
wrapper_opt = WrapperOptimizer(learning_rate=2.0)
diff --git a/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper_test.py b/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper_test.py
index 618d8eb18d..904aa9ab13 100644
--- a/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper_test.py
+++ b/tensorflow/contrib/opt/python/training/multitask_optimizer_wrapper_test.py
@@ -34,7 +34,7 @@ class MultitaskOptimizerWrapperTest(test.TestCase):
"""
def testWrapper(self):
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
var1 = variables.Variable([3.0, 4.0], dtype=dtypes.float32)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtypes.float32)
@@ -92,7 +92,7 @@ class MultitaskOptimizerWrapperTest(test.TestCase):
self.evaluate(slot1))
def testGradientClipping(self):
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtypes.float32)
var1 = variables.Variable([3.0, 4.0], dtype=dtypes.float32)
var2 = variables.Variable([3.0, 4.0], dtype=dtypes.float32)
diff --git a/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py b/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py
index 825c08a09a..85e05ce71c 100644
--- a/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/nadam_optimizer_test.py
@@ -53,7 +53,7 @@ class NadamOptimizerTest(test.TestCase):
def doTestSparse(self, use_resource=False):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
@@ -106,7 +106,7 @@ class NadamOptimizerTest(test.TestCase):
def doTestBasic(self, use_resource=False):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
# Initialize variables for numpy implementation.
m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
diff --git a/tensorflow/contrib/opt/python/training/reg_adagrad_optimizer_test.py b/tensorflow/contrib/opt/python/training/reg_adagrad_optimizer_test.py
index ea56e1646a..c09e2ac76d 100644
--- a/tensorflow/contrib/opt/python/training/reg_adagrad_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/reg_adagrad_optimizer_test.py
@@ -36,7 +36,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def doTestBasic(self, use_locking=False, use_resource=False):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
if use_resource:
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
@@ -73,7 +73,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = resource_variable_ops.ResourceVariable(
[[1.0, 2.0], [3.0, 4.0]], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
@@ -92,7 +92,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def testTensorLearningRate(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -116,7 +116,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def testSparseBasic(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
grads0 = ops.IndexedSlices(
@@ -144,7 +144,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def testSparseRepeatedIndices(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
repeated_index_update_var = variables.Variable(
[[1.0], [2.0]], dtype=dtype)
aggregated_update_var = variables.Variable([[1.0], [2.0]], dtype=dtype)
@@ -170,7 +170,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def testSparseRepeatedIndicesResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var_repeated = resource_variable_ops.ResourceVariable(
[1.0, 2.0], dtype=dtype)
loss_repeated = math_ops.reduce_sum(
@@ -194,7 +194,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def testSparseStability(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
shape = [1, 6]
var0 = variables.Variable(
[[
@@ -230,7 +230,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def testSharing(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -263,7 +263,7 @@ class RegAdagradOptimizerTest(test.TestCase):
np.array([2.715679168701172, 3.715679168701172]), var1.eval())
def testDynamicShapeVariable_Ok(self):
- with self.test_session():
+ with self.cached_session():
v = variable_scope.get_variable(
"v", initializer=constant_op.constant(1.), validate_shape=False)
self.assertFalse(v.shape.is_fully_defined())
@@ -274,7 +274,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def testSkipUpdatingSlots(self):
iav = 0.130005 # A value that works with float16
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
@@ -306,7 +306,7 @@ class RegAdagradOptimizerTest(test.TestCase):
def testSparseSkipUpdatingSlots(self):
iav = 0.130005 # A value that works with float16
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
grads0 = ops.IndexedSlices(
diff --git a/tensorflow/contrib/opt/python/training/shampoo.py b/tensorflow/contrib/opt/python/training/shampoo.py
index 294627f42a..f161521b97 100644
--- a/tensorflow/contrib/opt/python/training/shampoo.py
+++ b/tensorflow/contrib/opt/python/training/shampoo.py
@@ -23,6 +23,7 @@ from __future__ import division
from __future__ import print_function
import numpy as np
+from tensorflow.contrib.opt.python.training import matrix_functions
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -76,7 +77,7 @@ class ShampooOptimizer(optimizer.Optimizer):
learning_rate=1.0,
svd_interval=1,
precond_update_interval=1,
- epsilon=0.1,
+ epsilon=1e-4,
alpha=0.5,
use_iterative_root=False,
use_locking=False,
@@ -255,81 +256,18 @@ class ShampooOptimizer(optimizer.Optimizer):
def _compute_power_iter(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name,
iter_count=100, epsilon=1e-6):
- """Computes mat_g^alpha, where alpha = -1/p, p a positive integer.
+ """Computes mat_g^alpha, where alpha = -1/p, p a positive integer."""
+
+ mat_g_sqrt = matrix_functions.matrix_square_root(mat_g, mat_g_size,
+ iter_count, self._epsilon)
+ mat_h = matrix_functions.matrix_inverse_pth_root(
+ mat_g_sqrt,
+ mat_g_size,
+ 2 * alpha,
+ iter_count,
+ epsilon,
+ ridge_epsilon=0.0)
- We use an iterative Schur-Newton method from equation 3.2 on page 9 of:
-
- A Schur-Newton Method for the Matrix p-th Root and its Inverse
- by Chun-Hua Guo and Nicholas J. Higham
- SIAM Journal on Matrix Analysis and Applications,
- 2006, Vol. 28, No. 3 : pp. 788-804
- https://pdfs.semanticscholar.org/0abe/7f77433cf5908bfe2b79aa91af881da83858.pdf
-
- Args:
- var: the variable we are updating.
- mat_g: the symmetric PSD matrix whose power it to be computed
- mat_g_size: size of mat_g.
- alpha: exponent, must be -1/p for p a positive integer.
- mat_h_slot_name: name of slot to store the power, if needed.
- iter_count: Maximum number of iterations.
- epsilon: accuracy indicator, useful for early termination.
-
- Returns:
- mat_g^alpha
- """
-
- identity = linalg_ops.eye(math_ops.to_int32(mat_g_size))
-
- def MatPower(mat_m, p):
- """Computes mat_m^p, for p a positive integer.
-
- Power p is known at graph compile time, so no need for loop and cond.
- Args:
- mat_m: a square matrix
- p: a positive integer
-
- Returns:
- mat_m^p
- """
- assert p == int(p) and p > 0
- power = None
- while p > 0:
- if p % 2 == 1:
- power = math_ops.matmul(mat_m, power) if power is not None else mat_m
- p //= 2
- mat_m = math_ops.matmul(mat_m, mat_m)
- return power
-
- def IterCondition(i, mat_m, _):
- return math_ops.logical_and(
- i < iter_count,
- math_ops.reduce_max(math_ops.abs(mat_m - identity)) > epsilon)
-
- def IterBody(i, mat_m, mat_x):
- mat_m_i = (1 - alpha) * identity + alpha * mat_m
- return (i + 1, math_ops.matmul(MatPower(mat_m_i, -1.0/alpha), mat_m),
- math_ops.matmul(mat_x, mat_m_i))
-
- if mat_g_size == 1:
- mat_h = math_ops.pow(mat_g + self._epsilon, alpha)
- else:
- damped_mat_g = mat_g + self._epsilon * identity
- z = (1 - 1 / alpha) / (2 * linalg_ops.norm(damped_mat_g))
- # The best value for z is
- # (1 - 1/alpha) * (c_max^{-alpha} - c_min^{-alpha}) /
- # (c_max^{1-alpha} - c_min^{1-alpha})
- # where c_max and c_min are the largest and smallest singular values of
- # damped_mat_g.
- # The above estimate assumes that c_max > c_min * 2^p. (p = -1/alpha)
- # Can replace above line by the one below, but it is less accurate,
- # hence needs more iterations to converge.
- # z = (1 - 1/alpha) / math_ops.trace(damped_mat_g)
- # If we want the method to always converge, use z = 1 / norm(damped_mat_g)
- # or z = 1 / math_ops.trace(damped_mat_g), but these can result in many
- # extra iterations.
- _, _, mat_h = control_flow_ops.while_loop(
- IterCondition, IterBody,
- [0, damped_mat_g * z, identity * math_ops.pow(z, -alpha)])
if mat_h_slot_name is not None:
return state_ops.assign(self.get_slot(var, mat_h_slot_name), mat_h)
return mat_h
@@ -422,6 +360,8 @@ class ShampooOptimizer(optimizer.Optimizer):
mat_gbar_weight_t * precond_update_interval, i),
lambda: mat_g)
+ mat_g_updated = mat_g_updated / float(shape[i].value)
+
if self._svd_interval == 1:
mat_h = self._compute_power(var, mat_g_updated, shape[i], neg_alpha)
else:
@@ -443,7 +383,13 @@ class ShampooOptimizer(optimizer.Optimizer):
name="precond_" + str(i))
else:
# Tensor size is too large -- perform diagonal Shampoo update
- grad_outer = math_ops.reduce_sum(grad * grad, axis=axes)
+ # Only normalize non-vector cases.
+ if axes:
+ normalizer = 1.0 if indices is not None else float(shape[i].value)
+ grad_outer = math_ops.reduce_sum(grad * grad, axis=axes) / normalizer
+ else:
+ grad_outer = grad * grad
+
if i == 0 and indices is not None:
assert self._mat_gbar_decay == 1.0
mat_g_updated = state_ops.scatter_add(mat_g, indices,
diff --git a/tensorflow/contrib/opt/python/training/shampoo_test.py b/tensorflow/contrib/opt/python/training/shampoo_test.py
index 2e0a202ae2..05bcf2cfa3 100644
--- a/tensorflow/contrib/opt/python/training/shampoo_test.py
+++ b/tensorflow/contrib/opt/python/training/shampoo_test.py
@@ -31,6 +31,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
TOLERANCE = 1e-3
+RIDGE_EPSILON = 1e-4
def np_power(mat_g, alpha):
@@ -52,7 +53,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np = np.random.rand(size)
grad_np_2 = np.random.rand(size)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(
0, dtype=dtypes.int64, use_resource=use_resource_var)
var = variables.Variable(
@@ -77,8 +78,8 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
# let up compute this in numpy
# Update rule is var = var - lr * mat_g^{-0.5} * grad
# lr = 1
- mat_g = np.outer(grad_np, grad_np)
- mat_h = np_power(mat_g + 0.1 * np.eye(size), -0.5)
+ mat_g = np.outer(grad_np, grad_np) / grad_np.shape[0]
+ mat_h = np_power(mat_g + RIDGE_EPSILON * np.eye(size), -0.5)
new_val_np = init_var_np - np.dot(mat_h, grad_np)
self.assertAllCloseAccordingToType(new_val_np, new_val,
@@ -88,8 +89,8 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
update_2.run()
new_val = sess.run(var)
- mat_g += np.outer(grad_np_2, grad_np_2)
- mat_h = np_power(mat_g + 0.1 * np.eye(size), -0.5)
+ mat_g += np.outer(grad_np_2, grad_np_2) / grad_np.shape[0]
+ mat_h = np_power(mat_g + RIDGE_EPSILON * np.eye(size), -0.5)
new_val_np -= np.dot(mat_h, grad_np_2)
self.assertAllCloseAccordingToType(new_val_np, new_val,
@@ -103,7 +104,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np = np.random.rand(size[0], size[1])
grad_np_2 = np.random.rand(size[0], size[1])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(
0, dtype=dtypes.int64, use_resource=use_resource_var)
var = variables.Variable(
@@ -128,10 +129,10 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
# let up compute this in numpy
# Update rule is var = var - lr * mat_g1^{-0.25} * grad * mat_g2^{-0.25}
# lr = 1
- mat_g1 = np.dot(grad_np, grad_np.transpose())
- mat_left = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.25)
- mat_g2 = np.dot(grad_np.transpose(), grad_np)
- mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25)
+ mat_g1 = np.dot(grad_np, grad_np.transpose()) / grad_np.shape[0]
+ mat_left = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.25)
+ mat_g2 = np.dot(grad_np.transpose(), grad_np) / grad_np.shape[1]
+ mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
new_val_np = init_var_np - np.dot(np.dot(mat_left, grad_np), mat_right)
self.assertAllCloseAccordingToType(new_val_np, new_val,
@@ -141,10 +142,10 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
update_2.run()
new_val = sess.run(var)
- mat_g1 += np.dot(grad_np_2, grad_np_2.transpose())
- mat_left = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.25)
- mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2)
- mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25)
+ mat_g1 += np.dot(grad_np_2, grad_np_2.transpose()) / grad_np_2.shape[0]
+ mat_left = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.25)
+ mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) / grad_np_2.shape[1]
+ mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
new_val_np -= np.dot(np.dot(mat_left, grad_np_2), mat_right)
self.assertAllCloseAccordingToType(new_val_np, new_val,
@@ -162,7 +163,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np = np.random.rand(size[0], size[1], size[2])
grad_np_2 = np.random.rand(size[0], size[1], size[2])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(
0, dtype=dtypes.int64, use_resource=use_resource_var)
var = variables.Variable(
@@ -188,12 +189,18 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
# let up compute this in numpy
# Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad
# lr = 1
- mat_g1 = np.tensordot(grad_np, grad_np, axes=([1, 2], [1, 2]))
- mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
- mat_g2 = np.tensordot(grad_np, grad_np, axes=([0, 2], [0, 2]))
- mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
- mat_g3 = np.tensordot(grad_np, grad_np, axes=([0, 1], [0, 1]))
- mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+ mat_g1 = (
+ np.tensordot(grad_np, grad_np, axes=([1, 2], [1, 2])) /
+ grad_np.shape[0])
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0)
+ mat_g2 = (
+ np.tensordot(grad_np, grad_np, axes=([0, 2], [0, 2])) /
+ grad_np.shape[1])
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0)
+ mat_g3 = (
+ np.tensordot(grad_np, grad_np, axes=([0, 1], [0, 1])) /
+ grad_np.shape[2])
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0)
precond_grad = np.tensordot(grad_np, mat_g1_a, axes=([0], [0]))
precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
@@ -207,12 +214,18 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
update_2.run()
new_val = sess.run(var)
- mat_g1 += np.tensordot(grad_np_2, grad_np_2, axes=([1, 2], [1, 2]))
- mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
- mat_g2 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 2], [0, 2]))
- mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
- mat_g3 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 1], [0, 1]))
- mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+ mat_g1 += (
+ np.tensordot(grad_np_2, grad_np_2, axes=([1, 2], [1, 2])) /
+ grad_np_2.shape[0])
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0)
+ mat_g2 += (
+ np.tensordot(grad_np_2, grad_np_2, axes=([0, 2], [0, 2])) /
+ grad_np_2.shape[1])
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0)
+ mat_g3 += (
+ np.tensordot(grad_np_2, grad_np_2, axes=([0, 1], [0, 1])) /
+ grad_np_2.shape[2])
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0)
precond_grad = np.tensordot(grad_np_2, mat_g1_a, axes=([0], [0]))
precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
@@ -240,7 +253,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np = np.random.rand(size)
grad_np_2 = np.random.rand(size)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(
0, dtype=dtypes.int64, use_resource=use_resource_var)
var = variables.Variable(
@@ -265,19 +278,21 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
# let up compute this in numpy
# Update rule is var = var - lr * gg^{-0.5} * grad
# lr = 1
- mat_g = grad_np * grad_np + 0.1
- new_val_np = init_var_np - np.power(mat_g, -0.5) * grad_np
-
- self.assertAllCloseAccordingToType(new_val_np, new_val)
+ mat_g = (grad_np * grad_np)
+ new_val_np = init_var_np - np.power(mat_g + RIDGE_EPSILON, -0.5) * grad_np
+ self.assertAllCloseAccordingToType(
+ new_val_np, new_val, atol=TOLERANCE, rtol=TOLERANCE)
# Run another step of Shampoo
update_2.run()
new_val = sess.run(var)
- mat_g += grad_np_2 * grad_np_2
- new_val_np -= np.power(mat_g, -0.5) * grad_np_2
+ mat_g += (grad_np_2 * grad_np_2)
+ new_val_np -= np.power(mat_g + RIDGE_EPSILON, -0.5) * grad_np_2
+
+ self.assertAllCloseAccordingToType(
+ new_val_np, new_val, atol=TOLERANCE, rtol=TOLERANCE)
- self.assertAllCloseAccordingToType(new_val_np, new_val)
@parameterized.named_parameters(('Var', False), ('ResourceVar', True))
def testLargeMatrix(self, use_resource_var):
@@ -294,7 +309,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np = np.random.rand(size[0], size[1])
grad_np_2 = np.random.rand(size[0], size[1])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(
0, dtype=dtypes.int64, use_resource=use_resource_var)
var = variables.Variable(
@@ -322,10 +337,11 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
# with broadcasting
# lr = 1
- mat_g1 = np.sum(grad_np * grad_np, axis=1, keepdims=True)
- mat_left = np.power(mat_g1 + 0.1, -0.25)
- mat_g2 = np.dot(grad_np.transpose(), grad_np)
- mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25)
+ mat_g1 = np.sum(
+ grad_np * grad_np, axis=1, keepdims=True) / grad_np.shape[0]
+ mat_left = np.power(mat_g1 + RIDGE_EPSILON, -0.25)
+ mat_g2 = np.dot(grad_np.transpose(), grad_np) / grad_np.shape[1]
+ mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
new_val_np = init_var_np - np.dot(grad_np * mat_left, mat_right)
self.assertAllCloseAccordingToType(new_val_np, new_val,
@@ -335,10 +351,11 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
update_2.run()
new_val = sess.run(var)
- mat_g1 += np.sum(grad_np_2 * grad_np_2, axis=1, keepdims=True)
- mat_left = np.power(mat_g1 + 0.1, -0.25)
- mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2)
- mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25)
+ mat_g1 += np.sum(
+ grad_np_2 * grad_np_2, axis=1, keepdims=True) / grad_np_2.shape[0]
+ mat_left = np.power(mat_g1 + RIDGE_EPSILON, -0.25)
+ mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) / grad_np_2.shape[1]
+ mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
new_val_np -= np.dot(grad_np_2 * mat_left, mat_right)
self.assertAllCloseAccordingToType(new_val_np, new_val,
@@ -365,7 +382,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
replace=False))
grad_np_2 = np.random.rand(sample_size_2, size[1])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(
0, dtype=dtypes.int64, use_resource=use_resource_var)
var = variables.Variable(
@@ -405,9 +422,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
mat_g1 = np.sum(grad_np * grad_np, axis=1, keepdims=True)
mat_g1_acc = np.zeros((size[0], 1))
mat_g1_acc[grad_indices] += mat_g1
- mat_left = np.power(mat_g1 + 0.1, -0.25)
- mat_g2 = np.dot(grad_np.transpose(), grad_np)
- mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25)
+ mat_left = np.power(mat_g1 + RIDGE_EPSILON, -0.25)
+ mat_g2 = np.dot(grad_np.transpose(), grad_np) / grad_np.shape[1]
+ mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
new_val_np = init_var_np
new_val_np[grad_indices, :] -= np.dot(grad_np * mat_left, mat_right)
@@ -420,9 +437,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
mat_g1 = np.sum(grad_np_2 * grad_np_2, axis=1, keepdims=True)
mat_g1_acc[grad_indices_2] += mat_g1
- mat_left = np.power(mat_g1_acc[grad_indices_2] + 0.1, -0.25)
- mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2)
- mat_right = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.25)
+ mat_left = np.power(mat_g1_acc[grad_indices_2] + RIDGE_EPSILON, -0.25)
+ mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) / grad_np_2.shape[1]
+ mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
new_val_np[grad_indices_2, :] -= np.dot(grad_np_2 * mat_left, mat_right)
self.assertAllCloseAccordingToType(new_val_np, new_val,
@@ -445,7 +462,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
replace=False))
grad_np = np.random.rand(sample_size, size[1], size[2])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(
0, dtype=dtypes.int64, use_resource=use_resource_var)
var = variables.Variable(
@@ -474,12 +491,15 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_dense = np.zeros_like(init_var_np)
grad_dense[grad_indices] = grad_np
- mat_g1 = np.tensordot(grad_dense, grad_dense, axes=([1, 2], [1, 2]))
- mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
- mat_g2 = np.tensordot(grad_dense, grad_dense, axes=([0, 2], [0, 2]))
- mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
- mat_g3 = np.tensordot(grad_dense, grad_dense, axes=([0, 1], [0, 1]))
- mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+ mat_g1 = np.tensordot(
+ grad_dense, grad_dense, axes=([1, 2], [1, 2])) / grad_dense.shape[0]
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0)
+ mat_g2 = np.tensordot(
+ grad_dense, grad_dense, axes=([0, 2], [0, 2])) / grad_dense.shape[1]
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0)
+ mat_g3 = np.tensordot(
+ grad_dense, grad_dense, axes=([0, 1], [0, 1])) / grad_dense.shape[2]
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0)
precond_grad = np.tensordot(grad_dense, mat_g1_a, axes=([0], [0]))
precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
@@ -512,7 +532,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
gbar_decay = 0.9
gbar_weight = 0.1
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(
0, dtype=dtypes.int64, use_resource=use_resource_var)
var = variables.Variable(
@@ -536,12 +556,15 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
# let up compute this in numpy
# Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad
# lr = 1
- mat_g1 = np.tensordot(grad_np, grad_np, axes=([1, 2], [1, 2]))
- mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
- mat_g2 = np.tensordot(grad_np, grad_np, axes=([0, 2], [0, 2]))
- mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
- mat_g3 = np.tensordot(grad_np, grad_np, axes=([0, 1], [0, 1]))
- mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+ mat_g1 = np.tensordot(
+ grad_np, grad_np, axes=([1, 2], [1, 2])) / grad_np.shape[0]
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0)
+ mat_g2 = np.tensordot(
+ grad_np, grad_np, axes=([0, 2], [0, 2])) / grad_np.shape[1]
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0)
+ mat_g3 = np.tensordot(
+ grad_np, grad_np, axes=([0, 1], [0, 1])) / grad_np.shape[2]
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0)
gbar_np = gbar_weight * grad_np
precond_grad = np.tensordot(gbar_np, mat_g1_a, axes=([0], [0]))
@@ -556,12 +579,15 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
update_2.run()
new_val = sess.run(var)
- mat_g1 += np.tensordot(grad_np_2, grad_np_2, axes=([1, 2], [1, 2]))
- mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
- mat_g2 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 2], [0, 2]))
- mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
- mat_g3 += np.tensordot(grad_np_2, grad_np_2, axes=([0, 1], [0, 1]))
- mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+ mat_g1 += np.tensordot(
+ grad_np_2, grad_np_2, axes=([1, 2], [1, 2])) / grad_np_2.shape[0]
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]), -0.5 / 3.0)
+ mat_g2 += np.tensordot(
+ grad_np_2, grad_np_2, axes=([0, 2], [0, 2])) / grad_np_2.shape[1]
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.5 / 3.0)
+ mat_g3 += np.tensordot(
+ grad_np_2, grad_np_2, axes=([0, 1], [0, 1])) / grad_np_2.shape[2]
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]), -0.5 / 3.0)
gbar_np_2 = gbar_decay * gbar_np + gbar_weight * grad_np_2
precond_grad = np.tensordot(gbar_np_2, mat_g1_a, axes=([0], [0]))
@@ -601,7 +627,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
mat_g3_a = np.eye(size[2])
mat_g3 = np.zeros_like(mat_g3_a)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(
0, dtype=dtypes.int64, use_resource=use_resource_var)
var = variables.Variable(
@@ -626,13 +652,19 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
# let up compute this in numpy
# Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad
# lr = 1
- mat_g1 += np.tensordot(grad_np[i], grad_np[i], axes=([1, 2], [1, 2]))
- mat_g2 += np.tensordot(grad_np[i], grad_np[i], axes=([0, 2], [0, 2]))
- mat_g3 += np.tensordot(grad_np[i], grad_np[i], axes=([0, 1], [0, 1]))
+ mat_g1 += np.tensordot(
+ grad_np[i], grad_np[i], axes=([1, 2], [1, 2])) / grad_np[i].shape[0]
+ mat_g2 += np.tensordot(
+ grad_np[i], grad_np[i], axes=([0, 2], [0, 2])) / grad_np[i].shape[1]
+ mat_g3 += np.tensordot(
+ grad_np[i], grad_np[i], axes=([0, 1], [0, 1])) / grad_np[i].shape[2]
if (i + 1) % svd_interval == 0:
- mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
- mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
- mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]),
+ -0.5 / 3.0)
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]),
+ -0.5 / 3.0)
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]),
+ -0.5 / 3.0)
precond_grad = np.tensordot(grad_np[i], mat_g1_a, axes=([0], [0]))
precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
@@ -672,7 +704,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
mat_g3_a = np.eye(size[2])
mat_g3 = np.zeros_like(mat_g3_a)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
global_step = variables.Variable(
0, dtype=dtypes.int64, use_resource=use_resource_var)
var = variables.Variable(
@@ -700,17 +732,23 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
# Update rule is var = var - lr * Prod_i mat_g_i^{-0.5/3} grad
# lr = 1
if (i + 1) % precond_update_interval == 0:
- mat_g1 += (np.tensordot(grad_np[i], grad_np[i], axes=([1, 2], [1, 2]))
- * precond_update_interval)
- mat_g2 += (np.tensordot(grad_np[i], grad_np[i], axes=([0, 2], [0, 2]))
- * precond_update_interval)
- mat_g3 += (np.tensordot(grad_np[i], grad_np[i], axes=([0, 1], [0, 1]))
- * precond_update_interval)
+ mat_g1 += (
+ np.tensordot(grad_np[i], grad_np[i], axes=([1, 2], [1, 2])) /
+ grad_np[i].shape[0] * precond_update_interval)
+ mat_g2 += (
+ np.tensordot(grad_np[i], grad_np[i], axes=([0, 2], [0, 2])) /
+ grad_np[i].shape[1] * precond_update_interval)
+ mat_g3 += (
+ np.tensordot(grad_np[i], grad_np[i], axes=([0, 1], [0, 1])) /
+ grad_np[i].shape[2] * precond_update_interval)
if (i + 1) % svd_interval == 0:
- mat_g1_a = np_power(mat_g1 + 0.1 * np.eye(size[0]), -0.5/3.0)
- mat_g2_a = np_power(mat_g2 + 0.1 * np.eye(size[1]), -0.5/3.0)
- mat_g3_a = np_power(mat_g3 + 0.1 * np.eye(size[2]), -0.5/3.0)
+ mat_g1_a = np_power(mat_g1 + RIDGE_EPSILON * np.eye(size[0]),
+ -0.5 / 3.0)
+ mat_g2_a = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]),
+ -0.5 / 3.0)
+ mat_g3_a = np_power(mat_g3 + RIDGE_EPSILON * np.eye(size[2]),
+ -0.5 / 3.0)
precond_grad = np.tensordot(grad_np[i], mat_g1_a, axes=([0], [0]))
precond_grad = np.tensordot(precond_grad, mat_g2_a, axes=([0], [0]))
diff --git a/tensorflow/contrib/opt/python/training/sign_decay_test.py b/tensorflow/contrib/opt/python/training/sign_decay_test.py
index c31cb924ea..3a84789afd 100644
--- a/tensorflow/contrib/opt/python/training/sign_decay_test.py
+++ b/tensorflow/contrib/opt/python/training/sign_decay_test.py
@@ -66,7 +66,7 @@ class SignDecaysTest(test.TestCase):
linear_decay_fn = sign_decay.get_linear_decay_fn(num_training_steps)
for step in range(0, 1000, 100):
- with self.test_session():
+ with self.cached_session():
tf_decayed = linear_decay_fn(step).eval()
py_decayed = py_linear_decay_fn(num_training_steps)(step)
self.assertAlmostEqual(tf_decayed, py_decayed, places=4)
@@ -78,7 +78,7 @@ class SignDecaysTest(test.TestCase):
num_training_steps, num_periods=5, zero_after=2)
for step in range(0, 1000, 100):
- with self.test_session():
+ with self.cached_session():
tf_decayed = cosine_decay_fn(step).eval()
py_decayed = py_cosine_decay_fn(num_training_steps)(step)
self.assertAlmostEqual(tf_decayed, py_decayed, places=4)
@@ -95,7 +95,7 @@ class SignDecaysTest(test.TestCase):
num_training_steps, num_periods=5, zero_after=2)
for step in range(0, 1000, 100):
- with self.test_session():
+ with self.cached_session():
tf_decayed = restart_decay_fn(step).eval()
py_decayed = py_restart_decay_fn(num_training_steps)(step)
self.assertAlmostEqual(tf_decayed, py_decayed, places=4)
diff --git a/tensorflow/contrib/opt/python/training/variable_clipping_optimizer_test.py b/tensorflow/contrib/opt/python/training/variable_clipping_optimizer_test.py
index fdda86b0b5..ff0ea8d766 100644
--- a/tensorflow/contrib/opt/python/training/variable_clipping_optimizer_test.py
+++ b/tensorflow/contrib/opt/python/training/variable_clipping_optimizer_test.py
@@ -158,7 +158,7 @@ class VariableClippingOptimizerTest(test.TestCase):
def testDenseLocal(self):
for dtype in [dtypes.float32, dtypes.float64, dtypes.half]:
- with self.test_session():
+ with self.cached_session():
var0, var1, update_op = self._setupDense(False, dtype)
self._assertDenseCorrect(var0, var1, update_op)
@@ -171,7 +171,7 @@ class VariableClippingOptimizerTest(test.TestCase):
def testSparseLocal(self):
for dtype in [dtypes.float64, dtypes.float32, dtypes.half]:
- with self.test_session():
+ with self.cached_session():
var0, var1, update_op = self._setupSparse(False, dtype)
self._assertSparseCorrect(var0, var1, update_op)
diff --git a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py
index b9cf40eb7b..200b0d2008 100644
--- a/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py
+++ b/tensorflow/contrib/opt/python/training/weight_decay_optimizers.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.opt.python.training import shampoo
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import resource_variable_ops
@@ -26,6 +27,7 @@ from tensorflow.python.training import adam
from tensorflow.python.training import momentum as momentum_opt
from tensorflow.python.training import optimizer
from tensorflow.python.util.tf_export import tf_export
+from tensorflow.python.ops import array_ops
class DecoupledWeightDecayExtension(object):
@@ -159,8 +161,8 @@ class DecoupledWeightDecayExtension(object):
def _decay_weights_sparse_op(self, var, indices, scatter_add):
if not self._decay_var_list or var in self._decay_var_list:
- return scatter_add(var, indices, -self._weight_decay * var,
- self._use_locking)
+ update = -self._weight_decay * array_ops.gather(var, indices)
+ return scatter_add(var, indices, update, self._use_locking)
return control_flow_ops.no_op()
# Here, we overwrite the apply functions that the base optimizer calls.
@@ -360,3 +362,74 @@ class AdamWOptimizer(DecoupledWeightDecayExtension, adam.AdamOptimizer):
super(AdamWOptimizer, self).__init__(
weight_decay, learning_rate=learning_rate, beta1=beta1, beta2=beta2,
epsilon=epsilon, use_locking=use_locking, name=name)
+
+
+@tf_export("contrib.opt.ShampooWOptimizer")
+class ShampooWOptimizer(DecoupledWeightDecayExtension,
+ shampoo.ShampooOptimizer):
+ """Optimizer that implements the Shampoo algorithm with weight decay.
+
+ For further information see the documentation of the Shampoo Optimizer.
+ """
+
+ def __init__(self,
+ weight_decay,
+ global_step,
+ max_matrix_size=768,
+ gbar_decay=0.0,
+ gbar_weight=1.0,
+ mat_gbar_decay=1.0,
+ mat_gbar_weight=1.0,
+ learning_rate=1.0,
+ svd_interval=1,
+ precond_update_interval=1,
+ epsilon=1e-4,
+ alpha=0.5,
+ use_iterative_root=False,
+ use_locking=False,
+ name="ShampooW"):
+ """Construct a new ShampooW optimizer.
+
+ For further information see the documentation of the Shampoo Optimizer.
+
+ Args:
+ weight_decay: A `Tensor` or a floating point value. The weight decay.
+ global_step: tensorflow variable indicating the step.
+ max_matrix_size: We do not perform SVD for matrices larger than this.
+ gbar_decay:
+ gbar_weight: Used to update gbar: gbar[t] = gbar_decay[t] * gbar[t-1] +
+ gbar_weight[t] * g[t]
+ mat_gbar_decay:
+ mat_gbar_weight: Used to update mat_gbar: mat_gbar_j[t] =
+ mat_gbar_decay[t] * mat_gbar_j[t-1] + mat_gbar_weight[t] * gg_j[t]
+ learning_rate: Similar to SGD
+ svd_interval: We should do SVD after this many steps. Default = 1, i.e.
+ every step. Usually 20 leads to no loss of accuracy, and 50 or 100 is
+ also OK. May also want more often early,
+ and less often later - set in caller as for example:
+ "svd_interval = lambda(T): tf.cond(
+ T < 2000, lambda: 20.0, lambda: 1000.0)"
+ precond_update_interval: We should update the preconditioners after this
+ many steps. Default = 1. Usually less than svd_interval.
+ epsilon: epsilon * I_n is added to each mat_gbar_j for stability
+ alpha: total power of the preconditioners.
+ use_iterative_root: should the optimizer use SVD (faster) or the iterative
+ root method (for TPU) for finding the roots of PSD matrices.
+ use_locking: If `True` use locks for update operations.
+ name: name of optimizer.
+ """
+ super(ShampooWOptimizer, self).__init__(
+ weight_decay,
+ global_step=global_step,
+ max_matrix_size=max_matrix_size,
+ gbar_decay=gbar_decay,
+ gbar_weight=gbar_weight,
+ mat_gbar_decay=mat_gbar_weight,
+ learning_rate=learning_rate,
+ svd_interval=svd_interval,
+ precond_update_interval=precond_update_interval,
+ epsilon=epsilon,
+ alpha=alpha,
+ use_iterative_root=use_iterative_root,
+ use_locking=use_locking,
+ name=name)
diff --git a/tensorflow/contrib/optimizer_v2/adam.py b/tensorflow/contrib/optimizer_v2/adam.py
index 631d4f44df..04b1552b61 100644
--- a/tensorflow/contrib/optimizer_v2/adam.py
+++ b/tensorflow/contrib/optimizer_v2/adam.py
@@ -40,15 +40,14 @@ class AdamOptimizer(optimizer_v2.OptimizerV2):
Initialization:
- $$m_0 := 0 (Initialize initial 1st moment vector)$$
- $$v_0 := 0 (Initialize initial 2nd moment vector)$$
- $$t := 0 (Initialize timestep)$$
-
+ $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
+ $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
+ $$t := 0 \text{(Initialize timestep)}$$
The update rule for `variable` with gradient `g` uses an optimization
described at the end of section2 of the paper:
$$t := t + 1$$
- $$lr_t := \text{learning_rate} * \sqrt{(1 - beta_2^t) / (1 - beta_1^t)}$$
+ $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
$$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
$$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index f6ecaba834..6af59dcfbf 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -214,7 +214,8 @@ class _OptimizerV2State(object):
# with that Tensor cast to that dtype.
with ops.init_scope():
self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)}
- for name, (dynamic, value) in hyper.items() if not dynamic}
+ for name, (dynamic, value) in sorted(hyper.items())
+ if not dynamic}
self._slots = {}
self._non_slot_dict = {}
# Extra state to help Optimizers implement Checkpointable. Holds information
@@ -231,7 +232,8 @@ class _OptimizerV2State(object):
ret._deferred_dependencies = self._deferred_dependencies
ret._deferred_slot_restorations = self._deferred_slot_restorations
ret._hyper = {name: {None: _resolve(value, name)}
- for name, (dynamic, value) in hyper.items() if dynamic}
+ for name, (dynamic, value) in sorted(hyper.items())
+ if dynamic}
ret._hyper.update(self._hyper)
ret._non_slot_devices = non_slot_devices
ret._distribution = distribution
diff --git a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
index 31a6fe1d94..9a19502276 100644
--- a/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
+++ b/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py
@@ -38,7 +38,7 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
desired_shape = numpy.array([6, None])
output_tensor = input_tensor.reshape((6, 2))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
result = periodic_resample(input_tensor, desired_shape).eval()
self.assertAllEqual(result, output_tensor)
@@ -49,7 +49,7 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
desired_shape = numpy.array([5, None])
output_tensor = input_tensor.reshape((6, 2))[:-1]
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
result = periodic_resample(input_tensor, desired_shape).eval()
self.assertAllEqual(result, output_tensor)
@@ -63,7 +63,7 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
[15]]])
# NOTE: output_tensor != input_tensor.reshape((4, 4, -1))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
result = periodic_resample(input_tensor, desired_shape).eval()
# input_tensor[0, 0, 0] == result[0, 0, 0]
@@ -88,14 +88,14 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
[[49], [53], [57], [61]], [[51], [55], [59], [63]]]])
# NOTE: output_tensor != input_tensor.reshape((4, 4, 4, -1))
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
result = periodic_resample(input_tensor, desired_shape).eval()
self.assertAllEqual(result, output_tensor)
def testPeriodicResampleErrors(self):
input_tensor = numpy.zeros(shape=[1, 2, 2, 4])
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError,
'Dimension 3 input tensor has size 4, desired shape has size 1'):
@@ -109,7 +109,7 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
desired_shape = numpy.array([4, 4, None])
result_shape = (4, 4, 1)
input_shape = (2, 2, 4)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32, shape=input_shape)
output = periodic_resample(x, desired_shape)
error = gradient_checker.compute_gradient_error(
@@ -117,7 +117,7 @@ class PeriodicResampleTest(test_util.TensorFlowTestCase):
self.assertLess(error, 1e-4)
def testPeriodicResampleShapeInference(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Case 1: output shape can be fully inferreed.
x = array_ops.placeholder(dtypes.float32, shape=(2, 2, 4))
output = periodic_resample(x, [4, 4, None])
diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD
index 499fec4ffa..c59f667f6a 100644
--- a/tensorflow/contrib/quantize/BUILD
+++ b/tensorflow/contrib/quantize/BUILD
@@ -22,6 +22,7 @@ py_test(
":common",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:session",
"//tensorflow/python:variable_scope",
@@ -89,7 +90,6 @@ py_library(
":common",
":graph_matcher",
":input_to_ops",
- "//tensorflow/contrib/graph_editor:graph_editor_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
@@ -171,7 +171,6 @@ py_library(
":graph_matcher",
":input_to_ops",
":quant_ops",
- "//tensorflow/contrib/graph_editor:graph_editor_py",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
diff --git a/tensorflow/contrib/quantize/python/common.py b/tensorflow/contrib/quantize/python/common.py
index bf648e158e..b27117dd48 100644
--- a/tensorflow/contrib/quantize/python/common.py
+++ b/tensorflow/contrib/quantize/python/common.py
@@ -131,3 +131,29 @@ def DropStringPrefix(s, prefix):
return s[len(prefix):]
else:
return s
+
+
+def RerouteTensor(t0, t1, can_modify=None):
+ """Reroute the end of the tensor t0 to the ends of the tensor t1.
+
+ Args:
+ t0: a tf.Tensor.
+ t1: a tf.Tensor.
+ can_modify: iterable of operations which can be modified. Any operation
+ outside within_ops will be left untouched by this function.
+
+ Returns:
+ The number of individual modifications made by the function.
+ """
+ nb_update_inputs = 0
+ consumers = t1.consumers()
+ if can_modify is not None:
+ consumers = [c for c in consumers if c in can_modify]
+ consumers_indices = {}
+ for c in consumers:
+ consumers_indices[c] = [i for i, t in enumerate(c.inputs) if t is t1]
+ for c in consumers:
+ for i in consumers_indices[c]:
+ c._update_input(i, t0) # pylint: disable=protected-access
+ nb_update_inputs += 1
+ return nb_update_inputs
diff --git a/tensorflow/contrib/quantize/python/common_test.py b/tensorflow/contrib/quantize/python/common_test.py
index 06c62f2d26..2b26302f8a 100644
--- a/tensorflow/contrib/quantize/python/common_test.py
+++ b/tensorflow/contrib/quantize/python/common_test.py
@@ -20,8 +20,10 @@ from __future__ import print_function
from tensorflow.contrib.quantize.python import common
from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
@@ -62,6 +64,29 @@ class CommonTest(test_util.TensorFlowTestCase):
_, step_val = sess.run([b, quantization_step_tensor])
self.assertEqual(step_val, 2)
+ def testRerouteTensor(self):
+ a = constant_op.constant(1, name='a')
+ b = constant_op.constant(2, name='b')
+ c = constant_op.constant(3, name='c')
+ d = constant_op.constant(4, name='d')
+
+ add_ac = math_ops.add(a, c)
+ add_ad = math_ops.add(a, d)
+
+ # Ensure that before rerouting the inputs are what we think.
+ self._CheckOpHasInputs(add_ac.op, [a, c])
+ self._CheckOpHasInputs(add_ad.op, [a, d])
+
+ # references to tensor a should be replaced with b for all ops in
+ # can_modify. This means add_ac will be changed but add_ad will not.
+ common.RerouteTensor(b, a, can_modify=[add_ac.op])
+ self._CheckOpHasInputs(add_ac.op, [b, c])
+ self._CheckOpHasInputs(add_ad.op, [a, d])
+
+ def _CheckOpHasInputs(self, op, inputs):
+ for i in inputs:
+ self.assertIn(i, op.inputs)
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py
index d9f179bee4..2971b28f45 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import re
-from tensorflow.contrib import graph_editor
from tensorflow.contrib.quantize.python import common
from tensorflow.contrib.quantize.python import graph_matcher
from tensorflow.contrib.quantize.python import input_to_ops
@@ -134,8 +133,8 @@ def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
bias_add_tensor = math_ops.add(
new_layer_tensor, bias_tensor, name='add_fold')
- nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor,
- match.output_tensor)
+ nodes_modified_count = common.RerouteTensor(bias_add_tensor,
+ match.output_tensor)
if nodes_modified_count == 0:
raise ValueError('Folding batch norms failed, %s had no outputs.' %
match.output_tensor.name)
@@ -370,8 +369,9 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
lambda: match.bn_decay_mean_tensor,
name='freeze_moving_mean')
- graph_editor.reroute_ts(
- [bn_decay_mean_out], [match.bn_decay_mean_tensor],
+ common.RerouteTensor(
+ bn_decay_mean_out,
+ match.bn_decay_mean_tensor,
can_modify=bn_decay_mean_consumers)
bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers())
@@ -380,8 +380,9 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
lambda: bn_decay_zero,
lambda: match.bn_decay_var_tensor,
name='freeze_moving_var')
- graph_editor.reroute_ts(
- [bn_decay_var_out], [match.bn_decay_var_tensor],
+ common.RerouteTensor(
+ bn_decay_var_out,
+ match.bn_decay_var_tensor,
can_modify=bn_decay_var_consumers)
correction_recip = utils.smart_cond(
@@ -486,9 +487,8 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
activation = common.GetEndpointActivationOp(graph, bn)
if activation:
- nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
- [original_op.outputs[0]],
- can_modify=[activation])
+ nodes_modified_count = common.RerouteTensor(
+ folded_op.outputs[0], original_op.outputs[0], can_modify=[activation])
if nodes_modified_count != 1:
raise ValueError('Unexpected inputs to op: %s' % activation.name)
continue
@@ -497,9 +497,8 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
# operations instead of Relu* above.
add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add')
- nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
- [original_op.outputs[0]],
- can_modify=[add_bypass])
+ nodes_modified_count = common.RerouteTensor(
+ folded_op.outputs[0], original_op.outputs[0], can_modify=[add_bypass])
if nodes_modified_count != 1:
raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index 2ddbd73ea6..e88db0acd5 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import re
-from tensorflow.contrib import graph_editor
from tensorflow.contrib.quantize.python import common
from tensorflow.contrib.quantize.python import graph_matcher
from tensorflow.contrib.quantize.python import input_to_ops
@@ -592,8 +591,8 @@ def _InsertQuantOp(context,
name=name_prefix + '/delayed_quant')
if consumers:
- tensors_modified_count = graph_editor.reroute_ts(
- [quant], [inputs], can_modify=consumers)
+ tensors_modified_count = common.RerouteTensor(
+ quant, inputs, can_modify=consumers)
# Some operations can have multiple output tensors going to the same
# consumer. Since consumers is a set, we need to ensure that
# tensors_modified_count is greater than or equal to the length of the set
diff --git a/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py b/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py
index 00fbd4fbb8..aea80a5256 100644
--- a/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py
+++ b/tensorflow/contrib/recurrent/python/kernel_tests/recurrent_test.py
@@ -56,7 +56,7 @@ class RecurrentTest(test_util.TensorFlowTestCase):
x_power=state.x_power * theta.x)
return next_state, []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
theta = _PolyTheta(x=array_ops.constant(2.0))
state = _PolyState(
value=array_ops.constant(0.0),
@@ -142,7 +142,7 @@ class RecurrentTest(test_util.TensorFlowTestCase):
def _ParameterizedTestElman(self, seqlen, use_grad):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
random_seed.set_random_seed(342462)
batch = 3
diff --git a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
index 67a8f59c3c..c3db71359c 100644
--- a/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
+++ b/tensorflow/contrib/recurrent/python/ops/functional_rnn.py
@@ -178,7 +178,8 @@ def _ApplyLengthsToBatch(sequence_lengths, tf_output):
# TODO(drpng): just use Update so that we don't carry over the gradients?
"""Sets the output to be zero at the end of the sequence."""
# output is batch major.
- batch_size, max_time, vector_size = tf_output.shape
+ shape = array_ops.shape(tf_output)
+ batch_size, max_time, vector_size = shape[0], shape[1], shape[2]
output_time = array_ops.tile(math_ops.range(0, max_time), [batch_size])
output_time = array_ops.reshape(output_time, [batch_size, max_time])
lengths = array_ops.tile(
@@ -278,11 +279,16 @@ def functional_rnn(cell, inputs, sequence_length=None,
if initial_state is None:
initial_state = cell.zero_state(batch_size, dtype)
func_cell = _FunctionalRnnCell(cell, inputs, initial_state)
+ if sequence_length is not None:
+ max_length = math_ops.reduce_max(sequence_length)
+ else:
+ max_length = None
extended_acc_state, extended_final_state = recurrent.Recurrent(
theta=func_cell.theta,
state0=func_cell.extended_initial_state,
inputs=inputs,
cell_fn=func_cell.cell_step,
+ max_input_length=max_length,
use_tpu=use_tpu)
tf_output, tf_state = _PostProcessOutput(
extended_acc_state, extended_final_state, func_cell,
diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD
index 5874245d58..4e67d80558 100644
--- a/tensorflow/contrib/rnn/BUILD
+++ b/tensorflow/contrib/rnn/BUILD
@@ -212,6 +212,7 @@ cuda_py_tests(
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
],
+ tags = ["noasan"],
)
tf_custom_op_library(
@@ -279,7 +280,10 @@ cuda_py_tests(
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
],
- tags = ["no_oss"],
+ tags = [
+ "no_oss",
+ "noasan",
+ ],
)
tf_cc_test(
@@ -287,6 +291,7 @@ tf_cc_test(
size = "small",
srcs = ["ops/gru_ops_test.cc"],
data = [":python/ops/_gru_ops.so"],
+ tags = ["noasan"],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
@@ -306,6 +311,7 @@ tf_cc_test(
size = "small",
srcs = ["ops/lstm_ops_test.cc"],
data = [":python/ops/_lstm_ops.so"],
+ tags = ["noasan"],
# We must ensure that the dependencies can be dynamically linked since
# the shared library must be able to use core:framework.
# linkstatic = tf_kernel_tests_linkstatic(),
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index 15ce9d1ce7..be0306cb07 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -48,7 +48,7 @@ Linear = core_rnn_cell._Linear # pylint: disable=invalid-name
class RNNCellTest(test.TestCase):
def testLinear(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(1.0)):
x = array_ops.zeros([1, 2])
@@ -69,7 +69,7 @@ class RNNCellTest(test.TestCase):
self.assertEqual(len(variables_lib.trainable_variables()), 2)
def testBasicRNNCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -89,7 +89,7 @@ class RNNCellTest(test.TestCase):
self.assertEqual(res[0].shape, (1, 2))
def testBasicRNNCellNotTrainable(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
def not_trainable_getter(getter, *args, **kwargs):
kwargs["trainable"] = False
@@ -116,7 +116,7 @@ class RNNCellTest(test.TestCase):
self.assertEqual(res[0].shape, (1, 2))
def testIndRNNCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -137,7 +137,7 @@ class RNNCellTest(test.TestCase):
self.assertEqual(res[0].shape, (1, 2))
def testGRUCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -165,7 +165,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[0], [[0.156736, 0.156736]])
def testIndyGRUCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -193,7 +193,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[0], [[0.155127, 0.157328]])
def testSRUCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -208,7 +208,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[0], [[0.509682, 0.509682]])
def testSRUCellWithDiffSize(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
@@ -288,7 +288,7 @@ class RNNCellTest(test.TestCase):
def testBasicLSTMCellDimension0Error(self):
"""Tests that dimension 0 in both(x and m) shape must be equal."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
num_units = 2
@@ -309,7 +309,7 @@ class RNNCellTest(test.TestCase):
def testBasicLSTMCellStateSizeError(self):
"""Tests that state_size must be num_units * 2."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
num_units = 2
@@ -329,7 +329,7 @@ class RNNCellTest(test.TestCase):
})
def testBasicLSTMCellStateTupleType(self):
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -360,7 +360,7 @@ class RNNCellTest(test.TestCase):
self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple))
def testBasicLSTMCellWithStateTuple(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -459,7 +459,7 @@ class RNNCellTest(test.TestCase):
self.assertEqual(len(res), 2)
def testLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 8
num_proj = 6
state_size = num_units + num_proj
@@ -494,7 +494,7 @@ class RNNCellTest(test.TestCase):
float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6)
def testLSTMCellVariables(self):
- with self.test_session():
+ with self.cached_session():
num_units = 8
num_proj = 6
state_size = num_units + num_proj
@@ -517,7 +517,7 @@ class RNNCellTest(test.TestCase):
"root/lstm_cell/projection/kernel")
def testLSTMCellLayerNorm(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 2
num_proj = 3
batch_size = 1
@@ -562,22 +562,21 @@ class RNNCellTest(test.TestCase):
rnn_cell_impl.DropoutWrapper,
rnn_cell_impl.ResidualWrapper,
lambda cell: rnn_cell_impl.MultiRNNCell([cell])]:
- with self.test_session():
- cell = rnn_cell_impl.BasicRNNCell(1)
- wrapper = wrapper_type(cell)
- wrapper(array_ops.ones([1, 1]),
- state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32))
- self.evaluate([v.initializer for v in cell.variables])
- checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper)
- prefix = os.path.join(self.get_temp_dir(), "ckpt")
- self.evaluate(cell._bias.assign([40.]))
- save_path = checkpoint.save(prefix)
- self.evaluate(cell._bias.assign([0.]))
- checkpoint.restore(save_path).assert_consumed().run_restore_ops()
- self.assertAllEqual([40.], self.evaluate(cell._bias))
+ cell = rnn_cell_impl.BasicRNNCell(1)
+ wrapper = wrapper_type(cell)
+ wrapper(array_ops.ones([1, 1]),
+ state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32))
+ self.evaluate([v.initializer for v in cell.variables])
+ checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper)
+ prefix = os.path.join(self.get_temp_dir(), "ckpt")
+ self.evaluate(cell._bias.assign([40.]))
+ save_path = checkpoint.save(prefix)
+ self.evaluate(cell._bias.assign([0.]))
+ checkpoint.restore(save_path).assert_consumed().run_restore_ops()
+ self.assertAllEqual([40.], self.evaluate(cell._bias))
def testOutputProjectionWrapper(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
@@ -594,7 +593,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[0], [[0.231907, 0.231907]])
def testInputProjectionWrapper(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -612,7 +611,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]])
def testResidualWrapper(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
@@ -638,7 +637,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[2], res[3])
def testResidualWrapperWithSlice(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 5])
@@ -716,7 +715,7 @@ class RNNCellTest(test.TestCase):
self.assertTrue([s for s in gpu_stats if "gru_cell" in s.node_name])
def testEmbeddingWrapper(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 1], dtype=dtypes.int32)
@@ -735,7 +734,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[0], [[0.17139, 0.17139]])
def testEmbeddingWrapperWithDynamicRnn(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope("root"):
inputs = ops.convert_to_tensor([[[0], [0]]], dtype=dtypes.int64)
input_lengths = ops.convert_to_tensor([2], dtype=dtypes.int64)
@@ -753,7 +752,7 @@ class RNNCellTest(test.TestCase):
sess.run(outputs)
def testMultiRNNCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -770,7 +769,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res, [[0.175991, 0.175991, 0.13248, 0.13248]])
def testMultiRNNCellWithStateTuple(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -809,7 +808,7 @@ class DropoutWrapperTest(test.TestCase):
time_steps=None,
parallel_iterations=None,
**kwargs):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
if batch_size is None and time_steps is None:
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
index aa4562be7c..bf699db3ed 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
@@ -1906,7 +1906,7 @@ class StateSaverRNNTest(test.TestCase):
state_saver = TestStateSaverWithCounters(batch_size, 2 * num_units)
out, state, state_saver = self._factory(scope=None, state_saver=state_saver)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(variables_lib.global_variables_initializer())
sess.run(variables_lib.local_variables_initializer())
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py
index f2a032e41e..8d34b9e852 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/fused_rnn_cell_test.py
@@ -38,7 +38,7 @@ class FusedRnnCellTest(test.TestCase):
def testBasicRNNFusedWrapper(self):
"""This test checks that using a wrapper for BasicRNN works as expected."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=19890212)
cell = rnn_cell.BasicRNNCell(10)
@@ -106,7 +106,7 @@ class FusedRnnCellTest(test.TestCase):
self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2)
def testTimeReversedFusedRNN(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=19890213)
fw_cell = rnn_cell.BasicRNNCell(10)
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index 2df8f0ec05..6689664fb9 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -47,7 +47,7 @@ from tensorflow.python.util import nest
class RNNCellTest(test.TestCase):
def testCoupledInputForgetGateLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 2
state_size = num_units * 2
batch_size = 3
@@ -81,7 +81,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1], expected_state)
def testTimeFreqLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 8
state_size = num_units * 2
batch_size = 3
@@ -120,7 +120,7 @@ class RNNCellTest(test.TestCase):
float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6)
def testGridLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 8
batch_size = 3
input_size = 4
@@ -166,7 +166,7 @@ class RNNCellTest(test.TestCase):
.state_f00_b00_c[i, :]))) > 1e-6)
def testGridLSTMCellWithFrequencyBlocks(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 8
batch_size = 3
feature_size = 2
@@ -248,7 +248,7 @@ class RNNCellTest(test.TestCase):
]],
dtype=np.float32)
for state_is_tuple in [False, True]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"state_is_tuple" + str(state_is_tuple),
initializer=init_ops.constant_initializer(0.5)):
@@ -294,7 +294,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(np.concatenate(res[1], axis=1), expected_state)
def testBidirectionGridLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 2
batch_size = 3
input_size = 4
@@ -374,7 +374,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(np.concatenate(res[1], axis=1), expected_state)
def testBidirectionGridLSTMCellWithSliceOffset(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 2
batch_size = 3
input_size = 4
@@ -487,7 +487,7 @@ class RNNCellTest(test.TestCase):
input_size = 4
for state_is_tuple in [False, True]:
with ops.Graph().as_default():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"state_is_tuple_" + str(state_is_tuple)):
lstm_cell = rnn_cell.BasicLSTMCell(
@@ -538,7 +538,7 @@ class RNNCellTest(test.TestCase):
batch_size = 3
for state_is_tuple in [False, True]:
with ops.Graph().as_default():
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"state_is_tuple_" + str(state_is_tuple)):
lstm_cell = rnn_cell.BasicLSTMCell(
@@ -677,7 +677,7 @@ class RNNCellTest(test.TestCase):
0.79457647, 0.79457647, 0.79457647, 0.79457647, 0.79457653, 0.79457653,
0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348, 0.62739348
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"nas_test", initializer=init_ops.constant_initializer(0.5)):
cell = contrib_rnn_cell.NASCell(num_units=num_units)
@@ -725,7 +725,7 @@ class RNNCellTest(test.TestCase):
0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997, 0.78973997,
1.87398517, 1.87398517, 1.87398517, 1.87398517, 1.87398517
]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"nas_proj_test", initializer=init_ops.constant_initializer(0.5)):
cell = contrib_rnn_cell.NASCell(num_units=num_units, num_proj=num_proj)
@@ -765,7 +765,7 @@ class RNNCellTest(test.TestCase):
[[0.13752282, 0.13752282], [0.10545051, 0.10545051],
[0.10074195, 0.10074195]],
dtype=np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"ugrnn_cell_test", initializer=init_ops.constant_initializer(0.5)):
cell = contrib_rnn_cell.UGRNNCell(num_units=num_units)
@@ -796,7 +796,7 @@ class RNNCellTest(test.TestCase):
[[2.00431061, 2.00431061], [4.00060606, 4.00060606],
[6.00008249, 6.00008249]],
dtype=np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"intersection_rnn_cell_test",
initializer=init_ops.constant_initializer(0.5)):
@@ -837,7 +837,7 @@ class RNNCellTest(test.TestCase):
cell(inputs, init_state)
def testPhasedLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_units = 2
batch_size = 3
input_size = 4
@@ -874,7 +874,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1].h, expected_state_h)
def testConv1DLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape = [2, 1]
filter_size = [3]
num_features = 1
@@ -907,7 +907,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1].h, expected_state_h)
def testConv2DLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape = [2, 2, 1]
filter_size = [3, 3]
num_features = 1
@@ -948,7 +948,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1].h, expected_state_h)
def testConv3DLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape = [2, 2, 2, 1]
filter_size = [3, 3, 3]
num_features = 1
@@ -999,7 +999,7 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1].h, expected_state_h)
def testHighwayWrapper(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"base_cell", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
@@ -1030,7 +1030,7 @@ class RNNCellTest(test.TestCase):
# Try with input dimension equal to num_units or not.
for num_inputs in [num_units, num_units + number_of_groups]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root1_%d" % num_inputs,
initializer=init_ops.constant_initializer(0.5)):
@@ -1059,7 +1059,7 @@ class RNNCellTest(test.TestCase):
# Try with num_inputs equal to or not equal to num_units.
for num_inputs in [num_units, num_units + number_of_groups]:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root2_%d" % num_inputs,
initializer=init_ops.constant_initializer(0.5)):
@@ -1092,7 +1092,7 @@ class RNNCellTest(test.TestCase):
batch_size = 2
num_units = 4
number_of_groups = 2
- with self.test_session():
+ with self.cached_session():
with variable_scope.variable_scope(
"glstm_failure", initializer=init_ops.constant_initializer(0.5)):
gcell = contrib_rnn_cell.GLSTMCell(
@@ -1121,7 +1121,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
# NOTE: all the values in the current test case have been calculated.
def testBasicLSTMCell(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -1189,7 +1189,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
def testBasicLSTMCellWithoutNorm(self):
"""Tests that BasicLSTMCell with layer_norm=False."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -1256,7 +1256,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
self.assertAllClose(res[1].h, expected_h, 1e-5)
def testBasicLSTMCellWithStateTuple(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -1294,7 +1294,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
def testBasicLSTMCellWithStateTupleLayerNorm(self):
"""The results of LSTMCell and LayerNormBasicLSTMCell should be the same."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -1353,7 +1353,7 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
num_units = 5
allowed_low = [1, 2, 3]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"other", initializer=init_ops.constant_initializer(1)):
x = array_ops.zeros([1, 5])
@@ -1479,7 +1479,7 @@ class CompiledWrapperTest(test.TestCase):
self.assertAllClose(xla_g, non_xla_g, atol=atol)
def testMultiRNNCellWithStateTuple(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
@@ -1583,7 +1583,7 @@ class WeightNormLSTMCellTest(test.TestCase):
def _cell_output(self, cell):
"""Calculates cell output."""
- with self.test_session() as sess:
+ with self.cached_session() as sess:
init = init_ops.constant_initializer(0.5)
with variable_scope.variable_scope("root",
initializer=init):
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index f74c95f962..06c481672c 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -97,10 +97,10 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
The default non-peephole implementation is based on:
- http://www.bioinf.jku.at/publications/older/2604.pdf
+ https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
- S. Hochreiter and J. Schmidhuber.
- "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
+ Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
+ "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
The peephole implementation is based on:
@@ -2448,10 +2448,10 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
The default non-peephole implementation is based on:
- http://www.bioinf.jku.at/publications/older/2604.pdf
+ https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
- S. Hochreiter and J. Schmidhuber.
- "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
+ Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
+ "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
The peephole implementation is based on:
@@ -2802,9 +2802,11 @@ class WeightNormLSTMCell(rnn_cell_impl.RNNCell):
Training of Deep Neural Networks
The default LSTM implementation based on:
- http://www.bioinf.jku.at/publications/older/2604.pdf
- S. Hochreiter and J. Schmidhuber.
- "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
+
+ https://pdfs.semanticscholar.org/1154/0131eae85b2e11d53df7f1360eeb6476e7f4.pdf
+
+ Felix Gers, Jurgen Schmidhuber, and Fred Cummins.
+ "Learning to forget: Continual prediction with LSTM." IET, 850-855, 1999.
The class uses optional peephole connections, optional cell clipping
and an optional projection layer.
diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD
index e7eb4ac563..b897224c6d 100644
--- a/tensorflow/contrib/saved_model/BUILD
+++ b/tensorflow/contrib/saved_model/BUILD
@@ -36,6 +36,7 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
+ ":keras_saved_model",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:framework_ops",
"//tensorflow/python:lib",
@@ -101,23 +102,33 @@ py_library(
tags = ["no_windows"],
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_ops",
"//tensorflow/python:lib",
+ "//tensorflow/python:metrics",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:saver",
"//tensorflow/python:util",
+ "//tensorflow/python/estimator",
+ "//tensorflow/python/estimator:export",
+ "//tensorflow/python/estimator:keras",
+ "//tensorflow/python/estimator:model_fn",
"//tensorflow/python/keras:engine",
- "//tensorflow/python/saved_model:constants",
+ "//tensorflow/python/saved_model",
],
)
py_test(
name = "keras_saved_model_test",
- size = "small",
+ size = "medium",
srcs = ["python/saved_model/keras_saved_model_test.py"],
srcs_version = "PY2AND3",
deps = [
- ":saved_model_py",
+ ":keras_saved_model",
"//tensorflow/python:client_testlib",
"//tensorflow/python:training",
"//tensorflow/python/keras",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/contrib/saved_model/__init__.py b/tensorflow/contrib/saved_model/__init__.py
index 95e1a8967b..074dc655ac 100644
--- a/tensorflow/contrib/saved_model/__init__.py
+++ b/tensorflow/contrib/saved_model/__init__.py
@@ -26,10 +26,13 @@ from __future__ import print_function
# pylint: disable=unused-import,wildcard-import,line-too-long
from tensorflow.contrib.saved_model.python.saved_model.keras_saved_model import *
from tensorflow.contrib.saved_model.python.saved_model.signature_def_utils import *
-# pylint: enable=unused-import,widcard-import,line-too-long
+# pylint: enable=unused-import,wildcard-import,line-too-long
from tensorflow.python.util.all_util import remove_undocumented
-_allowed_symbols = ["get_signature_def_by_key", "load_model", "save_model"]
+_allowed_symbols = [
+ "get_signature_def_by_key",
+ "load_keras_model",
+ "save_keras_model"]
remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/BUILD b/tensorflow/contrib/saved_model/cc/saved_model/BUILD
index 3c616c555b..ea4d41d43b 100644
--- a/tensorflow/contrib/saved_model/cc/saved_model/BUILD
+++ b/tensorflow/contrib/saved_model/cc/saved_model/BUILD
@@ -30,6 +30,7 @@ cc_library(
hdrs = ["signature_def_utils.h"],
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/cc/saved_model:signature_constants",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_proto_parsing",
@@ -42,6 +43,7 @@ tf_cc_test(
srcs = ["signature_def_utils_test.cc"],
deps = [
":signature_def_utils",
+ "//tensorflow/cc/saved_model:signature_constants",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_proto_parsing",
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc
index a45908d272..e87e497e5f 100644
--- a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc
+++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h"
+#include "tensorflow/cc/saved_model/signature_constants.h"
+#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -33,6 +35,79 @@ Status FindInProtobufMap(StringPiece description,
*value = &it->second;
return Status::OK();
}
+
+// Looks up the TensorInfo for the given key in the given map and verifies that
+// its datatype matches the given correct datatype.
+bool VerifyTensorInfoForKeyInMap(const protobuf::Map<string, TensorInfo>& map,
+ const string& key, DataType correct_dtype) {
+ const TensorInfo* tensor_info;
+ const Status& status = FindInProtobufMap("", map, key, &tensor_info);
+ if (!status.ok()) {
+ return false;
+ }
+ if (tensor_info->dtype() != correct_dtype) {
+ return false;
+ }
+ return true;
+}
+
+bool IsValidPredictSignature(const SignatureDef& signature_def) {
+ if (signature_def.method_name() != kPredictMethodName) {
+ return false;
+ }
+ if (signature_def.inputs().empty()) {
+ return false;
+ }
+ if (signature_def.outputs().empty()) {
+ return false;
+ }
+ return true;
+}
+
+bool IsValidRegressionSignature(const SignatureDef& signature_def) {
+ if (signature_def.method_name() != kRegressMethodName) {
+ return false;
+ }
+ if (!VerifyTensorInfoForKeyInMap(signature_def.inputs(), kRegressInputs,
+ DT_STRING)) {
+ return false;
+ }
+ if (!VerifyTensorInfoForKeyInMap(signature_def.outputs(), kRegressOutputs,
+ DT_FLOAT)) {
+ return false;
+ }
+ return true;
+}
+
+bool IsValidClassificationSignature(const SignatureDef& signature_def) {
+ if (signature_def.method_name() != kClassifyMethodName) {
+ return false;
+ }
+ if (!VerifyTensorInfoForKeyInMap(signature_def.inputs(), kClassifyInputs,
+ DT_STRING)) {
+ return false;
+ }
+ if (signature_def.outputs().empty()) {
+ return false;
+ }
+ for (auto const& output : signature_def.outputs()) {
+ const string& key = output.first;
+ const TensorInfo& tensor_info = output.second;
+ if (key == kClassifyOutputClasses) {
+ if (tensor_info.dtype() != DT_STRING) {
+ return false;
+ }
+ } else if (key == kClassifyOutputScores) {
+ if (tensor_info.dtype() != DT_FLOAT) {
+ return false;
+ }
+ } else {
+ return false;
+ }
+ }
+ return true;
+}
+
} // namespace
Status FindSignatureDefByKey(const MetaGraphDef& meta_graph_def,
@@ -74,4 +149,10 @@ Status FindOutputTensorNameByKey(const SignatureDef& signature_def,
return Status::OK();
}
+bool IsValidSignature(const SignatureDef& signature_def) {
+ return IsValidClassificationSignature(signature_def) ||
+ IsValidRegressionSignature(signature_def) ||
+ IsValidPredictSignature(signature_def);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h
index b732cdd41e..bb24faa989 100644
--- a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h
+++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h
@@ -64,6 +64,9 @@ Status FindInputTensorNameByKey(const SignatureDef& signature_def,
Status FindOutputTensorNameByKey(const SignatureDef& signature_def,
const string& tensor_info_key, string* name);
+// Determine whether a SignatureDef can be served by TensorFlow Serving.
+bool IsValidSignature(const SignatureDef& signature_def);
+
} // namespace tensorflow
#endif // TENSORFLOW_CONTRIB_SAVED_MODEL_CC_SAVED_MODEL_SIGNATURE_DEF_UTILS_H_
diff --git a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc
index a063e95696..c743112ce0 100644
--- a/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc
+++ b/tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/contrib/saved_model/cc/saved_model/signature_def_utils.h"
+#include "tensorflow/cc/saved_model/signature_constants.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -22,7 +23,7 @@ limitations under the License.
namespace tensorflow {
-class SignatureDefUtilsTest : public ::testing::Test {
+class FindByKeyTest : public ::testing::Test {
protected:
MetaGraphDef MakeSampleMetaGraphDef() {
MetaGraphDef result;
@@ -32,13 +33,23 @@ class SignatureDefUtilsTest : public ::testing::Test {
return result;
}
+ void SetInputNameForKey(const string& key, const string& name,
+ SignatureDef* signature_def) {
+ (*signature_def->mutable_inputs())[key].set_name(name);
+ }
+
+ void SetOutputNameForKey(const string& key, const string& name,
+ SignatureDef* signature_def) {
+ (*signature_def->mutable_outputs())[key].set_name(name);
+ }
+
SignatureDef MakeSampleSignatureDef() {
SignatureDef result;
result.set_method_name(kMethodName);
- (*result.mutable_inputs())[kInput1Key].set_name(kInput1Name);
- (*result.mutable_inputs())[kInput2Key].set_name(kInput2Name);
- (*result.mutable_outputs())[kOutput1Key].set_name(kOutput1Name);
- (*result.mutable_outputs())[kOutput2Key].set_name(kOutput2Name);
+ SetInputNameForKey(kInput1Key, kInput1Name, &result);
+ SetInputNameForKey(kInput2Key, kInput2Name, &result);
+ SetOutputNameForKey(kOutput1Key, kOutput1Name, &result);
+ SetOutputNameForKey(kOutput2Key, kOutput2Name, &result);
return result;
}
@@ -54,7 +65,7 @@ class SignatureDefUtilsTest : public ::testing::Test {
const string kOutput2Name = "output_two";
};
-TEST_F(SignatureDefUtilsTest, FindSignatureDefByKey) {
+TEST_F(FindByKeyTest, FindSignatureDefByKey) {
const MetaGraphDef meta_graph_def = MakeSampleMetaGraphDef();
const SignatureDef* signature_def;
// Succeeds for an existing signature.
@@ -67,7 +78,7 @@ TEST_F(SignatureDefUtilsTest, FindSignatureDefByKey) {
.ok());
}
-TEST_F(SignatureDefUtilsTest, FindInputTensorNameByKey) {
+TEST_F(FindByKeyTest, FindInputTensorNameByKey) {
const SignatureDef signature_def = MakeSampleSignatureDef();
string name;
// Succeeds for an existing input.
@@ -78,7 +89,7 @@ TEST_F(SignatureDefUtilsTest, FindInputTensorNameByKey) {
FindInputTensorNameByKey(signature_def, "nonexistent", &name).ok());
}
-TEST_F(SignatureDefUtilsTest, FindOutputTensorNameByKey) {
+TEST_F(FindByKeyTest, FindOutputTensorNameByKey) {
const SignatureDef signature_def = MakeSampleSignatureDef();
string name;
// Succeeds for an existing output.
@@ -89,4 +100,100 @@ TEST_F(SignatureDefUtilsTest, FindOutputTensorNameByKey) {
FindOutputTensorNameByKey(signature_def, "nonexistent", &name).ok());
}
+class IsValidSignatureTest : public ::testing::Test {
+ protected:
+ void SetInputDataTypeForKey(const string& key, DataType dtype) {
+ (*signature_def_.mutable_inputs())[key].set_dtype(dtype);
+ }
+
+ void SetOutputDataTypeForKey(const string& key, DataType dtype) {
+ (*signature_def_.mutable_outputs())[key].set_dtype(dtype);
+ }
+
+ void EraseOutputKey(const string& key) {
+ (*signature_def_.mutable_outputs()).erase(key);
+ }
+
+ void ExpectInvalidSignature() {
+ EXPECT_FALSE(IsValidSignature(signature_def_));
+ }
+
+ void ExpectValidSignature() { EXPECT_TRUE(IsValidSignature(signature_def_)); }
+
+ SignatureDef signature_def_;
+};
+
+TEST_F(IsValidSignatureTest, IsValidPredictSignature) {
+ signature_def_.set_method_name("not_kPredictMethodName");
+ // Incorrect method name
+ ExpectInvalidSignature();
+
+ signature_def_.set_method_name(kPredictMethodName);
+ // No inputs
+ ExpectInvalidSignature();
+
+ SetInputDataTypeForKey(kPredictInputs, DT_STRING);
+ // No outputs
+ ExpectInvalidSignature();
+
+ SetOutputDataTypeForKey(kPredictOutputs, DT_STRING);
+ ExpectValidSignature();
+}
+
+TEST_F(IsValidSignatureTest, IsValidRegressionSignature) {
+ signature_def_.set_method_name("not_kRegressMethodName");
+ // Incorrect method name
+ ExpectInvalidSignature();
+
+ signature_def_.set_method_name(kRegressMethodName);
+ // No inputs
+ ExpectInvalidSignature();
+
+ SetInputDataTypeForKey(kRegressInputs, DT_STRING);
+ // No outputs
+ ExpectInvalidSignature();
+
+ SetOutputDataTypeForKey(kRegressOutputs, DT_STRING);
+ // Incorrect data type
+ ExpectInvalidSignature();
+
+ SetOutputDataTypeForKey(kRegressOutputs, DT_FLOAT);
+ ExpectValidSignature();
+}
+
+TEST_F(IsValidSignatureTest, IsValidClassificationSignature) {
+ signature_def_.set_method_name("not_kClassifyMethodName");
+ // Incorrect method name
+ ExpectInvalidSignature();
+
+ signature_def_.set_method_name(kClassifyMethodName);
+ // No inputs
+ ExpectInvalidSignature();
+
+ SetInputDataTypeForKey(kClassifyInputs, DT_STRING);
+ // No outputs
+ ExpectInvalidSignature();
+
+ SetOutputDataTypeForKey("invalidKey", DT_FLOAT);
+ // Invalid key
+ ExpectInvalidSignature();
+
+ EraseOutputKey("invalidKey");
+ SetOutputDataTypeForKey(kClassifyOutputClasses, DT_FLOAT);
+ // Invalid dtype for classes
+ ExpectInvalidSignature();
+
+ SetOutputDataTypeForKey(kClassifyOutputClasses, DT_STRING);
+ // Valid without scores
+ ExpectValidSignature();
+
+ SetOutputDataTypeForKey(kClassifyOutputScores, DT_STRING);
+ // Invalid dtype for scores
+ ExpectInvalidSignature();
+
+ SetOutputDataTypeForKey(kClassifyOutputScores, DT_FLOAT);
+ // Valid with both classes and scores
+ ExpectValidSignature();
+}
+
} // namespace tensorflow
diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py
index e2a969f053..2c5c8c4afd 100644
--- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py
+++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model.py
@@ -20,28 +20,69 @@ from __future__ import print_function
import os
+from tensorflow.python.client import session
+from tensorflow.python.estimator import keras as estimator_keras_util
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.estimator.export import export as export_helpers
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.keras import backend as K
+from tensorflow.python.keras import models as models_lib
+from tensorflow.python.keras import optimizers
from tensorflow.python.keras.models import model_from_json
from tensorflow.python.lib.io import file_io
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import constants
+from tensorflow.python.saved_model import utils_impl as saved_model_utils
+from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.training.checkpointable import util as checkpointable_utils
from tensorflow.python.util import compat
-def save_model(model, saved_model_path):
+def save_keras_model(
+ model, saved_model_path, custom_objects=None, as_text=None):
"""Save a `tf.keras.Model` into Tensorflow SavedModel format.
- `save_model` generates such files/folders under the `saved_model_path` folder:
+ `save_model` generates new files/folders under the `saved_model_path` folder:
1) an asset folder containing the json string of the model's
- configuration(topology).
+ configuration (topology).
2) a checkpoint containing the model weights.
+ 3) a saved_model.pb file containing the model's MetaGraphs. The prediction
+ graph is always exported. The evaluaton and training graphs are exported
+ if the following conditions are met:
+ - Evaluation: model loss is defined.
+ - Training: model is compiled with an optimizer defined under `tf.train`.
+ This is because `tf.keras.optimizers.Optimizer` instances cannot be
+ saved to checkpoints.
- Note that subclassed models can not be saved via this function, unless you
- provide an implementation for get_config() and from_config().
- Also note that `tf.keras.optimizers.Optimizer` instances can not currently be
- saved to checkpoints. Use optimizers from `tf.train`.
+ Model Requirements:
+ - Model must be a sequential model or functional model. Subclassed models can
+ not be saved via this function, unless you provide an implementation for
+ get_config() and from_config().
+ - All variables must be saveable by the model. In general, this condition is
+ met through the use of layers defined in the keras library. However,
+ there is currently a bug with variables created in Lambda layer functions
+ not being saved correctly (see
+ https://github.com/keras-team/keras/issues/9740).
+
+ Note that each mode is exported in separate graphs, so different modes do not
+ share variables. To use the train graph with evaluation or prediction graphs,
+ create a new checkpoint if variable values have been updated.
Args:
model: A `tf.keras.Model` to be saved.
saved_model_path: a string specifying the path to the SavedModel directory.
+ The SavedModel will be saved to a timestamped folder created within this
+ directory.
+ custom_objects: Optional dictionary mapping string names to custom classes
+ or functions (e.g. custom loss functions).
+ as_text: whether to write the `SavedModel` proto in text format.
+
+ Returns:
+ String path to the SavedModel folder, a subdirectory of `saved_model_path`.
Raises:
NotImplementedError: If the passed in model is a subclassed model.
@@ -49,35 +90,200 @@ def save_model(model, saved_model_path):
if not model._is_graph_network:
raise NotImplementedError
- # save model configuration as a json string under assets folder.
- model_json = model.to_json()
- assets_destination_dir = os.path.join(
- compat.as_bytes(saved_model_path),
- compat.as_bytes(constants.ASSETS_DIRECTORY))
+ export_dir = export_helpers.get_timestamped_export_dir(saved_model_path)
+ temp_export_dir = export_helpers.get_temp_export_dir(export_dir)
+
+ builder = saved_model_builder.SavedModelBuilder(temp_export_dir)
+
+ # Manually save variables to export them in an object-based checkpoint. This
+ # skips the `builder.add_meta_graph_and_variables()` step, which saves a
+ # named-based checkpoint.
+ # TODO(b/113134168): Add fn to Builder to save with object-based saver.
+ # TODO(b/113178242): This should only export the model json structure. Only
+ # one save is needed once the weights can be copied from the model to clone.
+ checkpoint_path = _export_model_json_and_variables(model, temp_export_dir)
+
+ # Export each mode. Use ModeKeys enums defined for `Estimator` to ensure that
+ # Keras models and `Estimator`s are exported with the same format.
+ # Every time a mode is exported, the code checks to see if new variables have
+ # been created (e.g. optimizer slot variables). If that is the case, the
+ # checkpoint is re-saved to include the new variables.
+ export_args = {'builder': builder,
+ 'model': model,
+ 'custom_objects': custom_objects,
+ 'checkpoint_path': checkpoint_path}
+
+ has_saved_vars = False
+ if model.optimizer:
+ if isinstance(model.optimizer, optimizers.TFOptimizer):
+ _export_mode(model_fn_lib.ModeKeys.TRAIN, has_saved_vars, **export_args)
+ has_saved_vars = True
+ _export_mode(model_fn_lib.ModeKeys.EVAL, has_saved_vars, **export_args)
+ else:
+ logging.warning(
+ 'Model was compiled with an optimizer, but the optimizer is not from '
+ '`tf.train` (e.g. `tf.train.AdagradOptimizer`). Only the serving '
+ 'graph was exported. The train and evaluate graphs were not added to '
+ 'the SavedModel.')
+ _export_mode(model_fn_lib.ModeKeys.PREDICT, has_saved_vars, **export_args)
+
+ builder.save(as_text)
+
+ gfile.Rename(temp_export_dir, export_dir)
+ return export_dir
- if not file_io.file_exists(assets_destination_dir):
- file_io.recursive_create_dir(assets_destination_dir)
+def _export_model_json_and_variables(model, saved_model_path):
+ """Save model variables and json structure into SavedModel subdirectories."""
+ # Save model configuration as a json string under assets folder.
+ model_json = model.to_json()
model_json_filepath = os.path.join(
- compat.as_bytes(assets_destination_dir),
- compat.as_bytes(constants.SAVED_MODEL_FILENAME_JSON))
+ saved_model_utils.get_or_create_assets_dir(saved_model_path),
+ compat.as_text(constants.SAVED_MODEL_FILENAME_JSON))
file_io.write_string_to_file(model_json_filepath, model_json)
- # save model weights in checkpoint format.
- checkpoint_destination_dir = os.path.join(
- compat.as_bytes(saved_model_path),
- compat.as_bytes(constants.VARIABLES_DIRECTORY))
+ # Save model weights in checkpoint format under variables folder.
+ saved_model_utils.get_or_create_variables_dir(saved_model_path)
+ checkpoint_prefix = saved_model_utils.get_variables_path(saved_model_path)
+ model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True)
+ return checkpoint_prefix
- if not file_io.file_exists(checkpoint_destination_dir):
- file_io.recursive_create_dir(checkpoint_destination_dir)
- checkpoint_prefix = os.path.join(
- compat.as_text(checkpoint_destination_dir),
- compat.as_text(constants.VARIABLES_FILENAME))
- model.save_weights(checkpoint_prefix, save_format='tf', overwrite=True)
+def _get_var_list(model):
+ """Return list of all checkpointed saveable objects in the model."""
+ return checkpointable_utils.named_saveables(model)
+
+
+def _export_mode(
+ mode, has_saved_vars, builder, model, custom_objects, checkpoint_path):
+ """Export a model, and optionally save new vars from the clone model.
+
+ Args:
+ mode: A `tf.estimator.ModeKeys` string.
+ has_saved_vars: A `boolean` indicating whether the SavedModel has already
+ exported variables.
+ builder: A `SavedModelBuilder` object.
+ model: A `tf.keras.Model` object.
+ custom_objects: A dictionary mapping string names to custom classes
+ or functions.
+ checkpoint_path: String path to checkpoint.
+
+ Raises:
+ ValueError: If the train/eval mode is being exported, but the model does
+ not have an optimizer.
+ """
+ compile_clone = (mode != model_fn_lib.ModeKeys.PREDICT)
+ if compile_clone and not model.optimizer:
+ raise ValueError(
+ 'Model does not have an optimizer. Cannot export mode %s' % mode)
+
+ model_graph = ops.get_default_graph()
+ with ops.Graph().as_default() as g:
+
+ K.set_learning_phase(mode == model_fn_lib.ModeKeys.TRAIN)
+
+ # Clone the model into blank graph. This will create placeholders for inputs
+ # and targets.
+ clone = models_lib.clone_and_build_model(
+ model, custom_objects=custom_objects, compile_clone=compile_clone)
+
+ # Make sure that iterations variable is added to the global step collection,
+ # to ensure that, when the SavedModel graph is loaded, the iterations
+ # variable is returned by `tf.train.get_global_step()`. This is required for
+ # compatibility with the SavedModelEstimator.
+ if compile_clone:
+ g.add_to_collection(ops.GraphKeys.GLOBAL_STEP, clone.optimizer.iterations)
+
+ # Extract update and train ops from train/test/predict functions.
+ if mode == model_fn_lib.ModeKeys.TRAIN:
+ clone._make_train_function()
+ builder._add_train_op(clone.train_function.updates_op)
+ elif mode == model_fn_lib.ModeKeys.EVAL:
+ clone._make_test_function()
+ else:
+ clone._make_predict_function()
+ g.get_collection_ref(ops.GraphKeys.UPDATE_OPS).extend(clone.state_updates)
+
+ clone_var_list = checkpointable_utils.named_saveables(clone)
+
+ with session.Session().as_default():
+ if has_saved_vars:
+ # Confirm all variables in the clone have an entry in the checkpoint.
+ status = clone.load_weights(checkpoint_path)
+ status.assert_existing_objects_matched()
+ else:
+ # Confirm that variables between the clone and model match up exactly,
+ # not counting optimizer objects. Optimizer objects are ignored because
+ # if the model has not trained, the slot variables will not have been
+ # created yet.
+ # TODO(b/113179535): Replace with checkpointable equivalence.
+ _assert_same_non_optimizer_objects(model, model_graph, clone, g)
+
+ # TODO(b/113178242): Use value transfer for checkpointable objects.
+ clone.load_weights(checkpoint_path)
+
+ # Add graph and variables to SavedModel.
+ # TODO(b/113134168): Switch to add_meta_graph_and_variables.
+ clone.save_weights(checkpoint_path, save_format='tf', overwrite=True)
+ builder._has_saved_variables = True
+
+ # Add graph to the SavedModel builder.
+ builder.add_meta_graph(
+ model_fn_lib.EXPORT_TAG_MAP[mode],
+ signature_def_map=_create_signature_def_map(clone, mode),
+ saver=saver_lib.Saver(clone_var_list),
+ main_op=variables.local_variables_initializer())
+ return None
+
+
+def _create_signature_def_map(model, mode):
+ """Create a SignatureDef map from a Keras model."""
+ inputs_dict = {name: x for name, x in zip(model.input_names, model.inputs)}
+ if model.optimizer:
+ targets_dict = {x.name.split(':')[0]: x
+ for x in model.targets if x is not None}
+ inputs_dict.update(targets_dict)
+ outputs_dict = {name: x
+ for name, x in zip(model.output_names, model.outputs)}
+ export_outputs = model_fn_lib.export_outputs_for_mode(
+ mode,
+ predictions=outputs_dict,
+ loss=model.total_loss if model.optimizer else None,
+ metrics=estimator_keras_util._convert_keras_metrics_to_estimator(model))
+ return export_helpers.build_all_signature_defs(
+ inputs_dict,
+ export_outputs=export_outputs,
+ serving_only=(mode == model_fn_lib.ModeKeys.PREDICT))
+
+
+def _assert_same_non_optimizer_objects(model, model_graph, clone, clone_graph):
+ """Assert model and clone contain the same checkpointable objects."""
+
+ def get_non_optimizer_objects(m, g):
+ """Gather set of model and optimizer checkpointable objects."""
+ # Set default graph because optimizer.variables() returns optimizer
+ # variables defined in the default graph.
+ with g.as_default():
+ all_objects = set(checkpointable_utils.list_objects(m))
+ optimizer_and_variables = set()
+ for obj in all_objects:
+ if isinstance(obj, optimizers.TFOptimizer):
+ optimizer_and_variables.update(checkpointable_utils.list_objects(obj))
+ optimizer_and_variables.update(set(obj.optimizer.variables()))
+ return all_objects - optimizer_and_variables
+
+ model_objects = get_non_optimizer_objects(model, model_graph)
+ clone_objects = get_non_optimizer_objects(clone, clone_graph)
+
+ if len(model_objects) != len(clone_objects):
+ raise errors.InternalError(
+ None, None,
+ 'Model and clone must use the same variables.'
+ '\n\tModel variables: %s\n\t Clone variables: %s'
+ % (model_objects, clone_objects))
-def load_model(saved_model_path):
+def load_keras_model(saved_model_path):
"""Load a keras.Model from SavedModel.
load_model reinstantiates model state by:
diff --git a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
index 107ae1b07b..12dd72a95b 100644
--- a/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
+++ b/tensorflow/contrib/saved_model/python/saved_model/keras_saved_model_test.py
@@ -20,20 +20,37 @@ from __future__ import print_function
import os
import shutil
+
+from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.saved_model.python.saved_model import keras_saved_model
from tensorflow.python import keras
+from tensorflow.python.client import session
+from tensorflow.python.eager import context
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import training
+from tensorflow.python.keras.utils import tf_utils
+from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
+from tensorflow.python.saved_model import constants
+from tensorflow.python.saved_model import loader_impl
+from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import training as training_module
class TestModelSavingandLoading(test.TestCase):
+ def _save_model_dir(self, dirname='saved_model'):
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
+ return os.path.join(temp_dir, dirname)
+
def test_saving_sequential_model(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,)))
model.add(keras.layers.RepeatVector(3))
@@ -48,19 +65,17 @@ class TestModelSavingandLoading(test.TestCase):
model.train_on_batch(x, y)
ref_y = model.predict(x)
- temp_dir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, temp_dir)
- temp_saved_model = os.path.join(temp_dir, 'saved_model')
- keras_saved_model.save_model(model, temp_saved_model)
+ temp_saved_model = self._save_model_dir()
+ output_path = keras_saved_model.save_keras_model(model, temp_saved_model)
- loaded_model = keras_saved_model.load_model(temp_saved_model)
+ loaded_model = keras_saved_model.load_keras_model(output_path)
y = loaded_model.predict(x)
self.assertAllClose(ref_y, y, atol=1e-05)
@test_util.run_in_graph_and_eager_modes
def test_saving_sequential_model_without_compile(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,)))
model.add(keras.layers.RepeatVector(3))
@@ -69,18 +84,15 @@ class TestModelSavingandLoading(test.TestCase):
x = np.random.random((1, 3))
ref_y = model.predict(x)
- temp_dir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, temp_dir)
-
- temp_saved_model = os.path.join(temp_dir, 'saved_model')
- keras_saved_model.save_model(model, temp_saved_model)
- loaded_model = keras_saved_model.load_model(temp_saved_model)
+ temp_saved_model = self._save_model_dir()
+ output_path = keras_saved_model.save_keras_model(model, temp_saved_model)
+ loaded_model = keras_saved_model.load_keras_model(output_path)
y = loaded_model.predict(x)
self.assertAllClose(ref_y, y, atol=1e-05)
def test_saving_functional_model(self):
- with self.test_session():
+ with self.cached_session():
inputs = keras.layers.Input(shape=(3,))
x = keras.layers.Dense(2)(inputs)
output = keras.layers.Dense(3)(x)
@@ -95,19 +107,17 @@ class TestModelSavingandLoading(test.TestCase):
model.train_on_batch(x, y)
ref_y = model.predict(x)
- temp_dir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, temp_dir)
- temp_saved_model = os.path.join(temp_dir, 'saved_model')
- keras_saved_model.save_model(model, temp_saved_model)
- loaded_model = keras_saved_model.load_model(temp_saved_model)
+ temp_saved_model = self._save_model_dir()
+ output_path = keras_saved_model.save_keras_model(model, temp_saved_model)
+ loaded_model = keras_saved_model.load_keras_model(output_path)
y = loaded_model.predict(x)
self.assertAllClose(ref_y, y, atol=1e-05)
@test_util.run_in_graph_and_eager_modes
def test_saving_functional_model_without_compile(self):
- with self.test_session():
+ with self.cached_session():
inputs = keras.layers.Input(shape=(3,))
x = keras.layers.Dense(2)(inputs)
output = keras.layers.Dense(3)(x)
@@ -118,19 +128,17 @@ class TestModelSavingandLoading(test.TestCase):
y = np.random.random((1, 3))
ref_y = model.predict(x)
- temp_dir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, temp_dir)
- temp_saved_model = os.path.join(temp_dir, 'saved_model')
- keras_saved_model.save_model(model, temp_saved_model)
- loaded_model = keras_saved_model.load_model(temp_saved_model)
+ temp_saved_model = self._save_model_dir()
+ output_path = keras_saved_model.save_keras_model(model, temp_saved_model)
+ loaded_model = keras_saved_model.load_keras_model(output_path)
y = loaded_model.predict(x)
self.assertAllClose(ref_y, y, atol=1e-05)
@test_util.run_in_graph_and_eager_modes
def test_saving_with_tf_optimizer(self):
- with self.test_session():
+ with self.cached_session():
model = keras.models.Sequential()
model.add(keras.layers.Dense(2, input_shape=(3,)))
model.add(keras.layers.Dense(3))
@@ -142,14 +150,13 @@ class TestModelSavingandLoading(test.TestCase):
x = np.random.random((1, 3))
y = np.random.random((1, 3))
model.train_on_batch(x, y)
+ model.train_on_batch(x, y)
ref_y = model.predict(x)
- temp_dir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, temp_dir)
- temp_saved_model = os.path.join(temp_dir, 'saved_model')
- keras_saved_model.save_model(model, temp_saved_model)
- loaded_model = keras_saved_model.load_model(temp_saved_model)
+ temp_saved_model = self._save_model_dir()
+ output_path = keras_saved_model.save_keras_model(model, temp_saved_model)
+ loaded_model = keras_saved_model.load_keras_model(output_path)
loaded_model.compile(
loss='mse',
optimizer=training_module.RMSPropOptimizer(0.1),
@@ -170,8 +177,10 @@ class TestModelSavingandLoading(test.TestCase):
self.assertAllClose(ref_y, y, atol=1e-05)
# test saving/loading again
- keras_saved_model.save_model(loaded_model, temp_saved_model)
- loaded_model = keras_saved_model.load_model(temp_saved_model)
+ temp_saved_model2 = self._save_model_dir('saved_model_2')
+ output_path2 = keras_saved_model.save_keras_model(
+ loaded_model, temp_saved_model2)
+ loaded_model = keras_saved_model.load_keras_model(output_path2)
y = loaded_model.predict(x)
self.assertAllClose(ref_y, y, atol=1e-05)
@@ -190,11 +199,231 @@ class TestModelSavingandLoading(test.TestCase):
return self.layer2(self.layer1(inp))
model = SubclassedModel()
- temp_dir = self.get_temp_dir()
- self.addCleanup(shutil.rmtree, temp_dir)
- temp_saved_model = os.path.join(temp_dir, 'saved_model')
+
+ temp_saved_model = self._save_model_dir()
with self.assertRaises(NotImplementedError):
- keras_saved_model.save_model(model, temp_saved_model)
+ keras_saved_model.save_keras_model(model, temp_saved_model)
+
+
+class LayerWithLearningPhase(keras.engine.base_layer.Layer):
+
+ def call(self, x):
+ phase = keras.backend.learning_phase()
+ output = tf_utils.smart_cond(
+ phase, lambda: x * 0, lambda: array_ops.identity(x))
+ if not context.executing_eagerly():
+ output._uses_learning_phase = True # pylint: disable=protected-access
+ return output
+
+ def compute_output_shape(self, input_shape):
+ return input_shape
+
+
+def functional_model(uses_learning_phase):
+ inputs = keras.layers.Input(shape=(3,))
+ x = keras.layers.Dense(2)(inputs)
+ x = keras.layers.Dense(3)(x)
+ if uses_learning_phase:
+ x = LayerWithLearningPhase()(x)
+ return keras.models.Model(inputs, x)
+
+
+def sequential_model(uses_learning_phase):
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(2, input_shape=(3,)))
+ model.add(keras.layers.Dense(3))
+ if uses_learning_phase:
+ model.add(LayerWithLearningPhase())
+ return model
+
+
+def load_model(sess, path, mode):
+ tags = model_fn_lib.EXPORT_TAG_MAP[mode]
+ sig_def_key = (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+ if mode == model_fn_lib.ModeKeys.PREDICT else mode)
+ meta_graph_def = loader_impl.load(sess, tags, path)
+ inputs = {
+ k: sess.graph.get_tensor_by_name(v.name)
+ for k, v in meta_graph_def.signature_def[sig_def_key].inputs.items()}
+ outputs = {
+ k: sess.graph.get_tensor_by_name(v.name)
+ for k, v in meta_graph_def.signature_def[sig_def_key].outputs.items()}
+ return inputs, outputs
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
+
+ def _save_model_dir(self, dirname='saved_model'):
+ temp_dir = self.get_temp_dir()
+ self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
+ return os.path.join(temp_dir, dirname)
+
+ @parameterized.parameters(
+ (functional_model, True, training_module.AdadeltaOptimizer(), True),
+ (functional_model, True, training_module.AdadeltaOptimizer(), False),
+ (functional_model, False, None, False),
+ (sequential_model, True, training_module.AdadeltaOptimizer(), True),
+ (sequential_model, True, training_module.AdadeltaOptimizer(), False),
+ (sequential_model, False, None, False))
+ def testSaveAndLoadSavedModelExport(
+ self, model_builder, uses_learning_phase, optimizer, train_before_export):
+ saved_model_path = self._save_model_dir()
+ with self.test_session(graph=ops.Graph()):
+ input_arr = np.random.random((1, 3))
+ target_arr = np.random.random((1, 3))
+
+ model = model_builder(uses_learning_phase)
+ if optimizer is not None:
+ model.compile(
+ loss='mse',
+ optimizer=optimizer,
+ metrics=['mae'])
+ if train_before_export:
+ model.train_on_batch(input_arr, target_arr)
+
+ ref_loss, ref_mae = model.evaluate(input_arr, target_arr)
+
+ ref_predict = model.predict(input_arr)
+
+ # Export SavedModel
+ output_path = keras_saved_model.save_keras_model(model, saved_model_path)
+
+ input_name = model.input_names[0]
+ output_name = model.output_names[0]
+ target_name = output_name + '_target'
+
+ # Load predict graph, and test predictions
+ with session.Session(graph=ops.Graph()) as sess:
+ inputs, outputs = load_model(sess, output_path,
+ model_fn_lib.ModeKeys.PREDICT)
+
+ predictions = sess.run(outputs[output_name],
+ {inputs[input_name]: input_arr})
+ self.assertAllClose(ref_predict, predictions, atol=1e-05)
+
+ if optimizer:
+ # Load eval graph, and test predictions, loss and metric values
+ with session.Session(graph=ops.Graph()) as sess:
+ inputs, outputs = load_model(sess, output_path,
+ model_fn_lib.ModeKeys.EVAL)
+
+ eval_results = sess.run(outputs, {inputs[input_name]: input_arr,
+ inputs[target_name]: target_arr})
+
+ self.assertEqual(int(train_before_export),
+ sess.run(training_module.get_global_step()))
+ self.assertAllClose(ref_loss, eval_results['loss'], atol=1e-05)
+ self.assertAllClose(
+ ref_mae, eval_results['metrics/mae/update_op'], atol=1e-05)
+ self.assertAllClose(
+ ref_predict, eval_results['predictions/' + output_name], atol=1e-05)
+
+ # Load train graph, and check for the train op, and prediction values
+ with session.Session(graph=ops.Graph()) as sess:
+ inputs, outputs = load_model(sess, output_path,
+ model_fn_lib.ModeKeys.TRAIN)
+ self.assertEqual(int(train_before_export),
+ sess.run(training_module.get_global_step()))
+ self.assertIn('loss', outputs)
+ self.assertIn('metrics/mae/update_op', outputs)
+ self.assertIn('metrics/mae/value', outputs)
+ self.assertIn('predictions/' + output_name, outputs)
+
+ # Train for a step
+ train_op = ops.get_collection(constants.TRAIN_OP_KEY)
+ train_outputs, _ = sess.run(
+ [outputs, train_op], {inputs[input_name]: input_arr,
+ inputs[target_name]: target_arr})
+ self.assertEqual(int(train_before_export) + 1,
+ sess.run(training_module.get_global_step()))
+
+ if uses_learning_phase:
+ self.assertAllClose(
+ [[0, 0, 0]], train_outputs['predictions/' + output_name],
+ atol=1e-05)
+ else:
+ self.assertNotAllClose(
+ [[0, 0, 0]], train_outputs['predictions/' + output_name],
+ atol=1e-05)
+
+ def testSaveAndLoadSavedModelWithCustomObject(self):
+ saved_model_path = self._save_model_dir()
+ with session.Session(graph=ops.Graph()) as sess:
+ def relu6(x):
+ return keras.backend.relu(x, max_value=6)
+ inputs = keras.layers.Input(shape=(1,))
+ outputs = keras.layers.Activation(relu6)(inputs)
+ model = keras.models.Model(inputs, outputs)
+ output_path = keras_saved_model.save_keras_model(
+ model, saved_model_path, custom_objects={'relu6': relu6})
+ with session.Session(graph=ops.Graph()) as sess:
+ inputs, outputs = load_model(sess, output_path,
+ model_fn_lib.ModeKeys.PREDICT)
+ input_name = model.input_names[0]
+ output_name = model.output_names[0]
+ predictions = sess.run(
+ outputs[output_name], {inputs[input_name]: [[7], [-3], [4]]})
+ self.assertAllEqual([[6], [0], [4]], predictions)
+
+ def testAssertModelCloneSameObjectsIgnoreOptimizer(self):
+ input_arr = np.random.random((1, 3))
+ target_arr = np.random.random((1, 3))
+
+ model_graph = ops.Graph()
+ clone_graph = ops.Graph()
+
+ # Create two models with the same layers but different optimizers.
+ with session.Session(graph=model_graph):
+ inputs = keras.layers.Input(shape=(3,))
+ x = keras.layers.Dense(2)(inputs)
+ x = keras.layers.Dense(3)(x)
+ model = keras.models.Model(inputs, x)
+
+ model.compile(loss='mse', optimizer=training_module.AdadeltaOptimizer())
+ model.train_on_batch(input_arr, target_arr)
+
+ with session.Session(graph=clone_graph):
+ inputs = keras.layers.Input(shape=(3,))
+ x = keras.layers.Dense(2)(inputs)
+ x = keras.layers.Dense(3)(x)
+ clone = keras.models.Model(inputs, x)
+ clone.compile(loss='mse', optimizer=keras.optimizers.RMSprop(lr=0.0001))
+ clone.train_on_batch(input_arr, target_arr)
+
+ keras_saved_model._assert_same_non_optimizer_objects(
+ model, model_graph, clone, clone_graph)
+
+ def testAssertModelCloneSameObjectsThrowError(self):
+ input_arr = np.random.random((1, 3))
+ target_arr = np.random.random((1, 3))
+
+ model_graph = ops.Graph()
+ clone_graph = ops.Graph()
+
+ # Create two models with the same layers but different optimizers.
+ with session.Session(graph=model_graph):
+ inputs = keras.layers.Input(shape=(3,))
+ x = keras.layers.Dense(2)(inputs)
+ x = keras.layers.Dense(3)(x)
+ model = keras.models.Model(inputs, x)
+
+ model.compile(loss='mse', optimizer=training_module.AdadeltaOptimizer())
+ model.train_on_batch(input_arr, target_arr)
+
+ with session.Session(graph=clone_graph):
+ inputs = keras.layers.Input(shape=(3,))
+ x = keras.layers.Dense(2)(inputs)
+ x = keras.layers.Dense(4)(x)
+ x = keras.layers.Dense(3)(x)
+ clone = keras.models.Model(inputs, x)
+ clone.compile(loss='mse', optimizer=keras.optimizers.RMSprop(lr=0.0001))
+ clone.train_on_batch(input_arr, target_arr)
+
+ with self.assertRaisesRegexp(
+ errors.InternalError, 'Model and clone must use the same variables.'):
+ keras_saved_model._assert_same_non_optimizer_objects(
+ model, model_graph, clone, clone_graph)
if __name__ == '__main__':
diff --git a/tensorflow/contrib/session_bundle/session_bundle.cc b/tensorflow/contrib/session_bundle/session_bundle.cc
index cf26e3cae7..a690d9b129 100644
--- a/tensorflow/contrib/session_bundle/session_bundle.cc
+++ b/tensorflow/contrib/session_bundle/session_bundle.cc
@@ -138,10 +138,10 @@ Status RunRestoreOp(const RunOptions& run_options, const StringPiece export_dir,
Tensor variables_tensor =
CreateStringTensor(GetVariablesFilename(export_dir));
std::vector<std::pair<string, Tensor>> inputs = {
- {variables_filename_const_op_name.ToString(), variables_tensor}};
+ {string(variables_filename_const_op_name), variables_tensor}};
AddAssetsTensorsToInputs(export_dir, asset_files, &inputs);
RunMetadata run_metadata;
- return session->Run(run_options, inputs, {}, {restore_op_name.ToString()},
+ return session->Run(run_options, inputs, {}, {string(restore_op_name)},
nullptr /* outputs */, &run_metadata);
}
@@ -152,7 +152,7 @@ Status RunInitOp(const RunOptions& run_options, const StringPiece export_dir,
std::vector<std::pair<string, Tensor>> inputs;
AddAssetsTensorsToInputs(export_dir, asset_files, &inputs);
RunMetadata run_metadata;
- return session->Run(run_options, inputs, {}, {init_op_name.ToString()},
+ return session->Run(run_options, inputs, {}, {string(init_op_name)},
nullptr /* outputs */, &run_metadata);
}
@@ -251,15 +251,14 @@ Status LoadSessionBundleFromPathUsingRunOptions(const SessionOptions& options,
auto log_and_count = [&](const string& status_str) {
LOG(INFO) << "Loading SessionBundle: " << status_str << ". Took "
<< load_latency_microsecs << " microseconds.";
- load_attempt_count->GetCell(export_dir.ToString(), status_str)
- ->IncrementBy(1);
+ load_attempt_count->GetCell(string(export_dir), status_str)->IncrementBy(1);
};
if (status.ok()) {
log_and_count(kLoadAttemptSuccess);
} else {
log_and_count(kLoadAttemptFail);
}
- load_latency->GetCell(export_dir.ToString())
+ load_latency->GetCell(string(export_dir))
->IncrementBy(load_latency_microsecs);
return status;
}
diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
index d877831fce..a6ce45c203 100644
--- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
+++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
@@ -416,12 +416,17 @@ class Image(ItemHandler):
def decode_image():
"""Decodes a image based on the headers."""
- return image_ops.decode_image(image_buffer, channels=self._channels)
+ return math_ops.cast(
+ image_ops.decode_image(image_buffer, channels=self._channels),
+ self._dtype)
def decode_jpeg():
"""Decodes a jpeg image with specified '_dct_method'."""
- return image_ops.decode_jpeg(
- image_buffer, channels=self._channels, dct_method=self._dct_method)
+ return math_ops.cast(
+ image_ops.decode_jpeg(
+ image_buffer,
+ channels=self._channels,
+ dct_method=self._dct_method), self._dtype)
def check_jpeg():
"""Checks if an image is jpeg."""
diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
index d783d4fef4..826242c9d7 100644
--- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
+++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py
@@ -37,12 +37,12 @@ from tensorflow.python.platform import test
class TFExampleDecoderTest(test.TestCase):
def _EncodedFloatFeature(self, ndarray):
- return feature_pb2.Feature(float_list=feature_pb2.FloatList(
- value=ndarray.flatten().tolist()))
+ return feature_pb2.Feature(
+ float_list=feature_pb2.FloatList(value=ndarray.flatten().tolist()))
def _EncodedInt64Feature(self, ndarray):
- return feature_pb2.Feature(int64_list=feature_pb2.Int64List(
- value=ndarray.flatten().tolist()))
+ return feature_pb2.Feature(
+ int64_list=feature_pb2.Int64List(value=ndarray.flatten().tolist()))
def _EncodedBytesFeature(self, tf_encoded):
with self.test_session():
@@ -74,12 +74,14 @@ class TFExampleDecoderTest(test.TestCase):
if image_format in ['raw', 'RAW']:
return constant_op.constant(image.tostring(), dtype=dtypes.string)
- def GenerateImage(self, image_format, image_shape):
+ def GenerateImage(self, image_format, image_shape, image_dtype=np.uint8):
"""Generates an image and an example containing the encoded image.
Args:
image_format: the encoding format of the image.
image_shape: the shape of the image to generate.
+ image_dtype: the dtype of values in the image. Only 'raw' image can have
+ type different than uint8.
Returns:
image: the generated image.
@@ -87,14 +89,18 @@ class TFExampleDecoderTest(test.TestCase):
serialized image and a feature key 'image/format' set to the image
encoding format ['jpeg', 'JPEG', 'png', 'PNG', 'raw'].
"""
+ assert image_format in ['raw', 'RAW'] or image_dtype == np.uint8
num_pixels = image_shape[0] * image_shape[1] * image_shape[2]
image = np.linspace(
- 0, num_pixels - 1, num=num_pixels).reshape(image_shape).astype(np.uint8)
+ 0, num_pixels - 1,
+ num=num_pixels).reshape(image_shape).astype(image_dtype)
tf_encoded = self._Encoder(image, image_format)
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/encoded': self._EncodedBytesFeature(tf_encoded),
- 'image/format': self._StringFeature(image_format)
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image/encoded': self._EncodedBytesFeature(tf_encoded),
+ 'image/format': self._StringFeature(image_format)
+ }))
return image, example.SerializeToString()
@@ -168,8 +174,7 @@ class TFExampleDecoderTest(test.TestCase):
tf_decoded_image = self.DecodeExample(
serialized_example,
- tfexample_decoder.Image(
- shape=None, channels=channels),
+ tfexample_decoder.Image(shape=None, channels=channels),
image_format='jpeg')
self.assertEqual(tf_decoded_image.get_shape().ndims, 3)
@@ -225,27 +230,38 @@ class TFExampleDecoderTest(test.TestCase):
self.assertAllClose(image, decoded_image, atol=0)
- def testDecodeExampleWithJpegEncodingAt16BitCausesError(self):
+ def testDecodeExampleWithRawEncodingFloatDtype(self):
image_shape = (2, 3, 3)
- unused_image, serialized_example = self.GenerateImage(
+ image, serialized_example = self.GenerateImage(
+ image_format='raw', image_shape=image_shape, image_dtype=np.float32)
+
+ decoded_image = self.RunDecodeExample(
+ serialized_example,
+ tfexample_decoder.Image(shape=image_shape, dtype=dtypes.float32),
+ image_format='raw')
+
+ self.assertAllClose(image, decoded_image, atol=0)
+
+ def testDecodeExampleWithJpegEncodingAt16BitDoesNotCauseError(self):
+ image_shape = (2, 3, 3)
+ # Image has type uint8 but decoding at uint16 should not cause problems.
+ image, serialized_example = self.GenerateImage(
image_format='jpeg', image_shape=image_shape)
- # decode_raw support uint16 now so ValueError will be thrown instead.
- with self.assertRaisesRegexp(
- ValueError,
- 'true_fn and false_fn must have the same type: uint16, uint8'):
- unused_decoded_image = self.RunDecodeExample(
- serialized_example,
- tfexample_decoder.Image(dtype=dtypes.uint16),
- image_format='jpeg')
+ decoded_image = self.RunDecodeExample(
+ serialized_example,
+ tfexample_decoder.Image(dtype=dtypes.uint16),
+ image_format='jpeg')
+ self.assertAllClose(image, decoded_image, atol=1.001)
def testDecodeExampleWithStringTensor(self):
tensor_shape = (2, 3, 1)
np_array = np.array([[['ab'], ['cd'], ['ef']],
[['ghi'], ['jkl'], ['mnop']]])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'labels': self._BytesFeature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'labels': self._BytesFeature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -259,7 +275,9 @@ class TFExampleDecoderTest(test.TestCase):
default_value=constant_op.constant(
'', shape=tensor_shape, dtype=dtypes.string))
}
- items_to_handlers = {'labels': tfexample_decoder.Tensor('labels'),}
+ items_to_handlers = {
+ 'labels': tfexample_decoder.Tensor('labels'),
+ }
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
[tf_labels] = decoder.decode(serialized_example, ['labels'])
@@ -271,9 +289,10 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithFloatTensor(self):
np_array = np.random.rand(2, 3, 1).astype('f')
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'array': self._EncodedFloatFeature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'array': self._EncodedFloatFeature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -282,7 +301,9 @@ class TFExampleDecoderTest(test.TestCase):
keys_to_features = {
'array': parsing_ops.FixedLenFeature(np_array.shape, dtypes.float32)
}
- items_to_handlers = {'array': tfexample_decoder.Tensor('array'),}
+ items_to_handlers = {
+ 'array': tfexample_decoder.Tensor('array'),
+ }
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
[tf_array] = decoder.decode(serialized_example, ['array'])
@@ -291,9 +312,10 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithInt64Tensor(self):
np_array = np.random.randint(1, 10, size=(2, 3, 1))
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'array': self._EncodedInt64Feature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'array': self._EncodedInt64Feature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -302,7 +324,9 @@ class TFExampleDecoderTest(test.TestCase):
keys_to_features = {
'array': parsing_ops.FixedLenFeature(np_array.shape, dtypes.int64)
}
- items_to_handlers = {'array': tfexample_decoder.Tensor('array'),}
+ items_to_handlers = {
+ 'array': tfexample_decoder.Tensor('array'),
+ }
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
[tf_array] = decoder.decode(serialized_example, ['array'])
@@ -311,9 +335,10 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithVarLenTensor(self):
np_array = np.array([[[1], [2], [3]], [[4], [5], [6]]])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'labels': self._EncodedInt64Feature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'labels': self._EncodedInt64Feature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -322,7 +347,9 @@ class TFExampleDecoderTest(test.TestCase):
keys_to_features = {
'labels': parsing_ops.VarLenFeature(dtype=dtypes.int64),
}
- items_to_handlers = {'labels': tfexample_decoder.Tensor('labels'),}
+ items_to_handlers = {
+ 'labels': tfexample_decoder.Tensor('labels'),
+ }
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
[tf_labels] = decoder.decode(serialized_example, ['labels'])
@@ -332,9 +359,10 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithFixLenTensorWithShape(self):
np_array = np.array([[1, 2, 3], [4, 5, 6]])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'labels': self._EncodedInt64Feature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'labels': self._EncodedInt64Feature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -342,12 +370,10 @@ class TFExampleDecoderTest(test.TestCase):
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
'labels':
- parsing_ops.FixedLenFeature(
- np_array.shape, dtype=dtypes.int64),
+ parsing_ops.FixedLenFeature(np_array.shape, dtype=dtypes.int64),
}
items_to_handlers = {
- 'labels': tfexample_decoder.Tensor(
- 'labels', shape=np_array.shape),
+ 'labels': tfexample_decoder.Tensor('labels', shape=np_array.shape),
}
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
@@ -357,9 +383,10 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithVarLenTensorToDense(self):
np_array = np.array([[1, 2, 3], [4, 5, 6]])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'labels': self._EncodedInt64Feature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'labels': self._EncodedInt64Feature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -369,8 +396,7 @@ class TFExampleDecoderTest(test.TestCase):
'labels': parsing_ops.VarLenFeature(dtype=dtypes.int64),
}
items_to_handlers = {
- 'labels': tfexample_decoder.Tensor(
- 'labels', shape=np_array.shape),
+ 'labels': tfexample_decoder.Tensor('labels', shape=np_array.shape),
}
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
@@ -382,12 +408,18 @@ class TFExampleDecoderTest(test.TestCase):
np_image = np.random.rand(2, 3, 1).astype('f')
np_labels = np.array([[[1], [2], [3]], [[4], [5], [6]]])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image': self._EncodedFloatFeature(np_image),
- 'image/shape': self._EncodedInt64Feature(np.array(np_image.shape)),
- 'labels': self._EncodedInt64Feature(np_labels),
- 'labels/shape': self._EncodedInt64Feature(np.array(np_labels.shape)),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image':
+ self._EncodedFloatFeature(np_image),
+ 'image/shape':
+ self._EncodedInt64Feature(np.array(np_image.shape)),
+ 'labels':
+ self._EncodedInt64Feature(np_labels),
+ 'labels/shape':
+ self._EncodedInt64Feature(np.array(np_labels.shape)),
+ }))
serialized_example = example.SerializeToString()
@@ -401,11 +433,9 @@ class TFExampleDecoderTest(test.TestCase):
}
items_to_handlers = {
'image':
- tfexample_decoder.Tensor(
- 'image', shape_keys='image/shape'),
+ tfexample_decoder.Tensor('image', shape_keys='image/shape'),
'labels':
- tfexample_decoder.Tensor(
- 'labels', shape_keys='labels/shape'),
+ tfexample_decoder.Tensor('labels', shape_keys='labels/shape'),
}
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
@@ -419,14 +449,22 @@ class TFExampleDecoderTest(test.TestCase):
np_labels = np.array([[[1], [2], [3]], [[4], [5], [6]]])
height, width, depth = np_labels.shape
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image': self._EncodedFloatFeature(np_image),
- 'image/shape': self._EncodedInt64Feature(np.array(np_image.shape)),
- 'labels': self._EncodedInt64Feature(np_labels),
- 'labels/height': self._EncodedInt64Feature(np.array([height])),
- 'labels/width': self._EncodedInt64Feature(np.array([width])),
- 'labels/depth': self._EncodedInt64Feature(np.array([depth])),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image':
+ self._EncodedFloatFeature(np_image),
+ 'image/shape':
+ self._EncodedInt64Feature(np.array(np_image.shape)),
+ 'labels':
+ self._EncodedInt64Feature(np_labels),
+ 'labels/height':
+ self._EncodedInt64Feature(np.array([height])),
+ 'labels/width':
+ self._EncodedInt64Feature(np.array([width])),
+ 'labels/depth':
+ self._EncodedInt64Feature(np.array([depth])),
+ }))
serialized_example = example.SerializeToString()
@@ -442,8 +480,7 @@ class TFExampleDecoderTest(test.TestCase):
}
items_to_handlers = {
'image':
- tfexample_decoder.Tensor(
- 'image', shape_keys='image/shape'),
+ tfexample_decoder.Tensor('image', shape_keys='image/shape'),
'labels':
tfexample_decoder.Tensor(
'labels',
@@ -459,10 +496,12 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithSparseTensor(self):
np_indices = np.array([[1], [2], [5]])
np_values = np.array([0.1, 0.2, 0.6]).astype('f')
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'indices': self._EncodedInt64Feature(np_indices),
- 'values': self._EncodedFloatFeature(np_values),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'indices': self._EncodedInt64Feature(np_indices),
+ 'values': self._EncodedFloatFeature(np_values),
+ }))
serialized_example = example.SerializeToString()
@@ -472,7 +511,9 @@ class TFExampleDecoderTest(test.TestCase):
'indices': parsing_ops.VarLenFeature(dtype=dtypes.int64),
'values': parsing_ops.VarLenFeature(dtype=dtypes.float32),
}
- items_to_handlers = {'labels': tfexample_decoder.SparseTensor(),}
+ items_to_handlers = {
+ 'labels': tfexample_decoder.SparseTensor(),
+ }
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
[tf_labels] = decoder.decode(serialized_example, ['labels'])
@@ -485,11 +526,13 @@ class TFExampleDecoderTest(test.TestCase):
np_indices = np.array([[1], [2], [5]])
np_values = np.array([0.1, 0.2, 0.6]).astype('f')
np_shape = np.array([6])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'indices': self._EncodedInt64Feature(np_indices),
- 'values': self._EncodedFloatFeature(np_values),
- 'shape': self._EncodedInt64Feature(np_shape),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'indices': self._EncodedInt64Feature(np_indices),
+ 'values': self._EncodedFloatFeature(np_values),
+ 'shape': self._EncodedInt64Feature(np_shape),
+ }))
serialized_example = example.SerializeToString()
@@ -515,10 +558,12 @@ class TFExampleDecoderTest(test.TestCase):
np_indices = np.array([[1], [2], [5]])
np_values = np.array([0.1, 0.2, 0.6]).astype('f')
np_shape = np.array([6])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'indices': self._EncodedInt64Feature(np_indices),
- 'values': self._EncodedFloatFeature(np_values),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'indices': self._EncodedInt64Feature(np_indices),
+ 'values': self._EncodedFloatFeature(np_values),
+ }))
serialized_example = example.SerializeToString()
@@ -544,10 +589,12 @@ class TFExampleDecoderTest(test.TestCase):
np_values = np.array([0.1, 0.2, 0.6]).astype('f')
np_shape = np.array([6])
np_dense = np.array([0.0, 0.1, 0.2, 0.0, 0.0, 0.6]).astype('f')
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'indices': self._EncodedInt64Feature(np_indices),
- 'values': self._EncodedFloatFeature(np_values),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'indices': self._EncodedInt64Feature(np_indices),
+ 'values': self._EncodedFloatFeature(np_values),
+ }))
serialized_example = example.SerializeToString()
@@ -559,8 +606,7 @@ class TFExampleDecoderTest(test.TestCase):
}
items_to_handlers = {
'labels':
- tfexample_decoder.SparseTensor(
- shape=np_shape, densify=True),
+ tfexample_decoder.SparseTensor(shape=np_shape, densify=True),
}
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
@@ -572,9 +618,10 @@ class TFExampleDecoderTest(test.TestCase):
tensor_shape = (2, 3, 1)
np_array = np.random.rand(2, 3, 1)
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/depth_map': self._EncodedFloatFeature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'image/depth_map': self._EncodedFloatFeature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -603,9 +650,10 @@ class TFExampleDecoderTest(test.TestCase):
tensor_shape = (2, 3, 1)
np_array = np.random.rand(2, 3, 1)
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/depth_map': self._EncodedFloatFeature(np_array),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(feature={
+ 'image/depth_map': self._EncodedFloatFeature(np_array),
+ }))
serialized_example = example.SerializeToString()
@@ -701,12 +749,14 @@ class TFExampleDecoderTest(test.TestCase):
np_xmax = np.random.rand(num_bboxes, 1)
np_bboxes = np.hstack([np_ymin, np_xmin, np_ymax, np_xmax])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/object/bbox/ymin': self._EncodedFloatFeature(np_ymin),
- 'image/object/bbox/xmin': self._EncodedFloatFeature(np_xmin),
- 'image/object/bbox/ymax': self._EncodedFloatFeature(np_ymax),
- 'image/object/bbox/xmax': self._EncodedFloatFeature(np_xmax),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image/object/bbox/ymin': self._EncodedFloatFeature(np_ymin),
+ 'image/object/bbox/xmin': self._EncodedFloatFeature(np_xmin),
+ 'image/object/bbox/ymax': self._EncodedFloatFeature(np_ymax),
+ 'image/object/bbox/xmax': self._EncodedFloatFeature(np_xmax),
+ }))
serialized_example = example.SerializeToString()
with self.test_session():
@@ -740,26 +790,32 @@ class TFExampleDecoderTest(test.TestCase):
np_xmax = np.random.rand(num_bboxes, 1)
np_bboxes = np.hstack([np_ymin, np_xmin, np_ymax, np_xmax])
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/object/bbox/ymin': self._EncodedFloatFeature(np_ymin),
- 'image/object/bbox/xmin': self._EncodedFloatFeature(np_xmin),
- 'image/object/bbox/ymax': self._EncodedFloatFeature(np_ymax),
- 'image/object/bbox/xmax': self._EncodedFloatFeature(np_xmax),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image/object/bbox/ymin': self._EncodedFloatFeature(np_ymin),
+ 'image/object/bbox/xmin': self._EncodedFloatFeature(np_xmin),
+ 'image/object/bbox/ymax': self._EncodedFloatFeature(np_ymax),
+ 'image/object/bbox/xmax': self._EncodedFloatFeature(np_xmax),
+ }))
serialized_example = example.SerializeToString()
with self.test_session():
serialized_example = array_ops.reshape(serialized_example, shape=[])
keys_to_features = {
- 'image/object/bbox/ymin': parsing_ops.FixedLenSequenceFeature(
- [], dtypes.float32, allow_missing=True),
- 'image/object/bbox/xmin': parsing_ops.FixedLenSequenceFeature(
- [], dtypes.float32, allow_missing=True),
- 'image/object/bbox/ymax': parsing_ops.FixedLenSequenceFeature(
- [], dtypes.float32, allow_missing=True),
- 'image/object/bbox/xmax': parsing_ops.FixedLenSequenceFeature(
- [], dtypes.float32, allow_missing=True),
+ 'image/object/bbox/ymin':
+ parsing_ops.FixedLenSequenceFeature(
+ [], dtypes.float32, allow_missing=True),
+ 'image/object/bbox/xmin':
+ parsing_ops.FixedLenSequenceFeature(
+ [], dtypes.float32, allow_missing=True),
+ 'image/object/bbox/ymax':
+ parsing_ops.FixedLenSequenceFeature(
+ [], dtypes.float32, allow_missing=True),
+ 'image/object/bbox/xmax':
+ parsing_ops.FixedLenSequenceFeature(
+ [], dtypes.float32, allow_missing=True),
}
items_to_handlers = {
@@ -784,11 +840,16 @@ class TFExampleDecoderTest(test.TestCase):
with self.test_session():
tf_string = tf_encoded.eval()
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/encoded': feature_pb2.Feature(bytes_list=feature_pb2.BytesList(
- value=[tf_string, tf_string])),
- 'image/format': self._StringFeature(image_format),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image/encoded':
+ feature_pb2.Feature(
+ bytes_list=feature_pb2.BytesList(
+ value=[tf_string, tf_string])),
+ 'image/format':
+ self._StringFeature(image_format),
+ }))
serialized_example = example.SerializeToString()
with self.test_session():
@@ -797,8 +858,7 @@ class TFExampleDecoderTest(test.TestCase):
decoder = tfexample_decoder.TFExampleDecoder(
keys_to_features={
'image/encoded':
- parsing_ops.FixedLenFeature(
- (2,), dtypes.string),
+ parsing_ops.FixedLenFeature((2,), dtypes.string),
'image/format':
parsing_ops.FixedLenFeature(
(), dtypes.string, default_value=image_format),
@@ -814,10 +874,12 @@ class TFExampleDecoderTest(test.TestCase):
def testDecodeExampleWithLookup(self):
- example = example_pb2.Example(features=feature_pb2.Features(feature={
- 'image/object/class/text': self._BytesFeature(
- np.array(['cat', 'dog', 'guinea pig'])),
- }))
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ 'image/object/class/text':
+ self._BytesFeature(np.array(['cat', 'dog', 'guinea pig'])),
+ }))
serialized_example = example.SerializeToString()
# 'dog' -> 0, 'guinea pig' -> 1, 'cat' -> 2
table = lookup_ops.index_table_from_tensor(
diff --git a/tensorflow/contrib/specs/python/specs_test.py b/tensorflow/contrib/specs/python/specs_test.py
index 9a4ad36793..b7ce6aa20a 100644
--- a/tensorflow/contrib/specs/python/specs_test.py
+++ b/tensorflow/contrib/specs/python/specs_test.py
@@ -38,7 +38,7 @@ def _rand(*size):
class SpecsTest(test.TestCase):
def testSimpleConv(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(1, 18, 19, 5))
spec = "net = Cr(64, [5, 5])"
outputs = specs.create_net(spec, inputs)
@@ -53,7 +53,7 @@ class SpecsTest(test.TestCase):
def testUnary(self):
# This is just a quick and dirty check that these ops exist
# and work as unary ops.
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(17, 55))
spec = "net = Do(0.5) | Bn | Unit(1) | Relu | Sig | Tanh | Smax"
outputs = specs.create_net(spec, inputs)
@@ -63,7 +63,7 @@ class SpecsTest(test.TestCase):
self.assertEqual(tuple(result.shape), (17, 55))
def testAdd(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(17, 55))
spec = "net = Fs(10) + Fr(10)"
outputs = specs.create_net(spec, inputs)
@@ -77,7 +77,7 @@ class SpecsTest(test.TestCase):
"<> variablev2 dot variablev2 biasadd relu add")
def testMpPower(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(1, 64, 64, 5))
spec = "M2 = Mp([2, 2]); net = M2**3"
outputs = specs.create_net(spec, inputs)
@@ -90,7 +90,7 @@ class SpecsTest(test.TestCase):
"_ maxpool maxpool maxpool")
def testAbbrevPower(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(1, 64, 64, 5))
spec = "C3 = Cr([3, 3]); M2 = Mp([2, 2]); net = (C3(5) | M2)**3"
outputs = specs.create_net(spec, inputs)
@@ -106,7 +106,7 @@ class SpecsTest(test.TestCase):
" biasadd relu maxpool")
def testAbbrevPower2(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(1, 64, 64, 5))
spec = "C3 = Cr(_1=[3, 3]); M2 = Mp([2, 2]);"
spec += "net = (C3(_0=5) | M2)**3"
@@ -123,7 +123,7 @@ class SpecsTest(test.TestCase):
" maxpool")
def testConc(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(10, 20))
spec = "net = Conc(1, Fs(20), Fs(10))"
outputs = specs.create_net(spec, inputs)
@@ -137,7 +137,7 @@ class SpecsTest(test.TestCase):
"<> variablev2 dot variablev2 biasadd sig _ concatv2")
def testImport(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(10, 20))
spec = ("S = Import('from tensorflow.python.ops" +
" import math_ops; f = math_ops.sigmoid')")
@@ -150,7 +150,7 @@ class SpecsTest(test.TestCase):
self.assertEqual(summaries.tf_spec_structure(spec, inputs), "_ sig sig")
def testKeywordRestriction(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(10, 20))
spec = "import re; net = Conc(1, Fs(20), Fs(10))"
self.assertRaises(ValueError, lambda: specs.create_net(spec, inputs))
@@ -179,7 +179,7 @@ class SpecsTest(test.TestCase):
# XXX: the cleverness of this code is over 9000
# TODO: original author please fix
def DISABLED_testVar(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with specs.ops:
# pylint: disable=undefined-variable
v = Var("test_var",
@@ -196,7 +196,7 @@ class SpecsTest(test.TestCase):
# XXX: the cleverness of this code is over 9000
# TODO: original author please fix
def DISABLED_testShared(self):
- with self.test_session():
+ with self.cached_session():
with specs.ops:
# pylint: disable=undefined-variable
f = Shared(Fr(100))
diff --git a/tensorflow/contrib/specs/python/summaries_test.py b/tensorflow/contrib/specs/python/summaries_test.py
index 34ff4bc8ca..b82ba06d3f 100644
--- a/tensorflow/contrib/specs/python/summaries_test.py
+++ b/tensorflow/contrib/specs/python/summaries_test.py
@@ -34,7 +34,7 @@ def _rand(*size):
class SummariesTest(test.TestCase):
def testStructure(self):
- with self.test_session():
+ with self.cached_session():
inputs_shape = (1, 18, 19, 5)
inputs = constant_op.constant(_rand(*inputs_shape))
spec = "net = Cr(64, [5, 5])"
@@ -48,7 +48,7 @@ class SummariesTest(test.TestCase):
"_ variablev2 conv variablev2 biasadd relu")
def testStructureFromTensor(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(1, 18, 19, 5))
spec = "net = Cr(64, [5, 5])"
outputs = specs.create_net(spec, inputs)
@@ -60,7 +60,7 @@ class SummariesTest(test.TestCase):
"_ variablev2 conv variablev2 biasadd relu")
def testPrint(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(1, 18, 19, 5))
spec = "net = Cr(64, [5, 5])"
outputs = specs.create_net(spec, inputs)
@@ -70,7 +70,7 @@ class SummariesTest(test.TestCase):
summaries.tf_spec_print(spec, inputs)
def testSummary(self):
- with self.test_session():
+ with self.cached_session():
inputs = constant_op.constant(_rand(1, 18, 19, 5))
spec = "net = Cr(64, [5, 5])"
outputs = specs.create_net(spec, inputs)
diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD
index 22d6e499d2..00c855daa3 100644
--- a/tensorflow/contrib/tensor_forest/BUILD
+++ b/tensorflow/contrib/tensor_forest/BUILD
@@ -462,7 +462,10 @@ py_test(
size = "small",
srcs = ["python/kernel_tests/scatter_add_ndim_op_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip_gpu"],
+ tags = [
+ "no_gpu",
+ "no_pip_gpu",
+ ],
deps = [
":tensor_forest_ops_py",
"//tensorflow/python:framework_test_lib",
@@ -534,10 +537,11 @@ py_library(
py_test(
name = "random_forest_test",
- size = "medium",
+ size = "large",
srcs = ["client/random_forest_test.py"],
srcs_version = "PY2AND3",
tags = [
+ "noasan",
"nomac", # b/63258195
"notsan",
],
diff --git a/tensorflow/contrib/tensor_forest/client/random_forest.py b/tensorflow/contrib/tensor_forest/client/random_forest.py
index db970deff5..0042d37acd 100644
--- a/tensorflow/contrib/tensor_forest/client/random_forest.py
+++ b/tensorflow/contrib/tensor_forest/client/random_forest.py
@@ -134,19 +134,19 @@ def _get_default_head(params, weights_name, output_type, name=None):
weight_column=weights_name,
label_dimension=params.num_outputs,
name=name,
- loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
else:
if params.num_classes == 2:
return core_head_lib.binary_classification_head(
weight_column=weights_name,
name=name,
- loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
else:
return core_head_lib.multi_class_head(
n_classes=params.num_classes,
weight_column=weights_name,
name=name,
- loss_reduction=losses.Reduction.SUM_OVER_NONZERO_WEIGHTS)
+ loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
def get_model_fn(params,
graph_builder_class,
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc
index d43884481a..99c5800391 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc
@@ -130,7 +130,11 @@ void TensorDataSet::RandomSample(int example,
num_total_features += num_sparse;
}
}
- int rand_feature = rng_->Uniform(num_total_features);
+ int rand_feature = 0;
+ {
+ mutex_lock lock(mu_);
+ rand_feature = rng_->Uniform(num_total_features);
+ }
if (rand_feature < available_features_.size()) { // it's dense.
*feature_id = available_features_[rand_feature];
*type = input_spec_.GetDenseFeatureType(rand_feature);
diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h
index 95f75b4d7e..4945b53007 100644
--- a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h
+++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h
@@ -25,6 +25,7 @@
#include "tensorflow/core/lib/random/philox_random.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
namespace tensorforest {
@@ -120,6 +121,8 @@ class TensorDataSet {
int32 split_sampling_random_seed_;
std::unique_ptr<random::PhiloxRandom> single_rand_;
std::unique_ptr<random::SimplePhilox> rng_;
+ // Mutex for using random number generator.
+ mutable mutex mu_;
};
} // namespace tensorforest
} // namespace tensorflow
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index a0fc3e43a9..122a67a407 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -279,6 +279,7 @@ tf_cuda_library(
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
+ "//tensorflow/core:framework",
"//tensorflow/core:framework_lite",
"//tensorflow/core:gpu_runtime",
"//tensorflow/core:graph",
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 0f5abe6898..c98b07ad8b 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/resources/trt_resources.h"
#include "tensorflow/core/framework/node_def.pb.h" // NOLINT
#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor_shape.pb.h" // NOLINT
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/graph/algorithm.h"
diff --git a/tensorflow/contrib/timeseries/examples/BUILD b/tensorflow/contrib/timeseries/examples/BUILD
index 71b0d48798..21c0c30c19 100644
--- a/tensorflow/contrib/timeseries/examples/BUILD
+++ b/tensorflow/contrib/timeseries/examples/BUILD
@@ -32,6 +32,7 @@ py_test(
name = "predict_test",
timeout = "long", # Moderate but for asan
srcs = ["predict_test.py"],
+ data = ["data/period_trend.csv"],
srcs_version = "PY2AND3",
tags = [
"no_windows", # TODO: needs investigation on Windows
diff --git a/tensorflow/contrib/timeseries/examples/known_anomaly.py b/tensorflow/contrib/timeseries/examples/known_anomaly.py
index 71621abc71..1226433625 100644
--- a/tensorflow/contrib/timeseries/examples/known_anomaly.py
+++ b/tensorflow/contrib/timeseries/examples/known_anomaly.py
@@ -41,7 +41,7 @@ _MODULE_PATH = path.dirname(__file__)
_DATA_FILE = path.join(_MODULE_PATH, "data/changepoints.csv")
-def state_space_esitmator(exogenous_feature_columns):
+def state_space_estimator(exogenous_feature_columns):
"""Constructs a StructuralEnsembleRegressor."""
def _exogenous_update_condition(times, features):
@@ -68,7 +68,7 @@ def state_space_esitmator(exogenous_feature_columns):
4, 64)
-def autoregressive_esitmator(exogenous_feature_columns):
+def autoregressive_estimator(exogenous_feature_columns):
input_window_size = 8
output_window_size = 2
return (
@@ -169,10 +169,10 @@ def main(unused_argv):
"Please install matplotlib to generate a plot from this example.")
make_plot("Ignoring a known anomaly (state space)",
*train_and_evaluate_exogenous(
- estimator_fn=state_space_esitmator))
+ estimator_fn=state_space_estimator))
make_plot("Ignoring a known anomaly (autoregressive)",
*train_and_evaluate_exogenous(
- estimator_fn=autoregressive_esitmator, train_steps=3000))
+ estimator_fn=autoregressive_estimator, train_steps=3000))
pyplot.show()
diff --git a/tensorflow/contrib/timeseries/examples/known_anomaly_test.py b/tensorflow/contrib/timeseries/examples/known_anomaly_test.py
index 8c64f2e186..57ccf8f260 100644
--- a/tensorflow/contrib/timeseries/examples/known_anomaly_test.py
+++ b/tensorflow/contrib/timeseries/examples/known_anomaly_test.py
@@ -28,7 +28,7 @@ class KnownAnomalyExampleTest(test.TestCase):
def test_shapes_and_variance_structural_ar(self):
(times, observed, all_times, mean, upper_limit, lower_limit,
anomaly_locations) = known_anomaly.train_and_evaluate_exogenous(
- train_steps=1, estimator_fn=known_anomaly.autoregressive_esitmator)
+ train_steps=1, estimator_fn=known_anomaly.autoregressive_estimator)
self.assertAllEqual(
anomaly_locations,
[25, 50, 75, 100, 125, 150, 175, 249])
@@ -40,7 +40,7 @@ class KnownAnomalyExampleTest(test.TestCase):
def test_shapes_and_variance_structural_ssm(self):
(times, observed, all_times, mean, upper_limit, lower_limit,
anomaly_locations) = known_anomaly.train_and_evaluate_exogenous(
- train_steps=50, estimator_fn=known_anomaly.state_space_esitmator)
+ train_steps=50, estimator_fn=known_anomaly.state_space_estimator)
self.assertAllEqual(
anomaly_locations,
[25, 50, 75, 100, 125, 150, 175, 249])
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
index e65e7b74d4..647455ae42 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
@@ -122,7 +122,7 @@ class EvaluationMetricsTests(test.TestCase):
metric[1] for metric in outputs.eval_metric_ops.values()]
loss_mean, loss_update = metrics.mean(outputs.loss)
metric_update_ops.append(loss_update)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(sess, coord=coordinator)
variables.local_variables_initializer().run()
diff --git a/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py b/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py
index 703537abf0..f92148b788 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py
@@ -88,7 +88,7 @@ class RandomWindowInputFnTests(test.TestCase):
window_size=window_size, batch_size=batch_size)
result, _ = input_fn()
init_op = variables.local_variables_initializer()
- with self.test_session() as session:
+ with self.cached_session() as session:
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
session.run(init_op)
@@ -261,7 +261,7 @@ class WholeDatasetInputFnTests(test.TestCase):
def _whole_dataset_input_fn_test_template(
self, time_series_reader, num_features, num_samples):
result, _ = input_pipeline.WholeDatasetInputFn(time_series_reader)()
- with self.test_session() as session:
+ with self.cached_session() as session:
session.run(variables.local_variables_initializer())
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
@@ -340,7 +340,7 @@ class AllWindowInputFnTests(test.TestCase):
window_size=window_size)
features, _ = input_fn()
init_op = variables.local_variables_initializer()
- with self.test_session() as session:
+ with self.cached_session() as session:
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
session.run(init_op)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py
index 02d2524b66..c0de42b15b 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils_test.py
@@ -55,7 +55,7 @@ class MathUtilsTest(test.TestCase):
running_sum = running_sum + current_contribution
# pylint: enable=g-no-augmented-assignment
transition_power = numpy.dot(transition, transition_power)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(result,
math_utils.power_sums_tensor(
array_size, transition, addition).eval())
@@ -66,7 +66,7 @@ class MathUtilsTest(test.TestCase):
result = []
for i in range(powers.shape[0]):
result.append(numpy.linalg.matrix_power(matrix, powers[i]))
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(result,
math_utils.matrix_to_powers(matrix, powers).eval(),
rtol=1e-5,
@@ -78,7 +78,7 @@ class MathUtilsTest(test.TestCase):
result = []
for i in range(batch.shape[0]):
result.append(numpy.linalg.matrix_power(batch[i], powers[i]))
- with self.test_session():
+ with self.cached_session():
# TODO(allenl): Numerical errors seem to be creeping in. Maybe it can be
# made slightly more stable?
self.assertAllClose(result,
@@ -91,7 +91,7 @@ class MathUtilsTest(test.TestCase):
left_transpose = numpy.transpose(left, [0, 2, 1])
right = numpy.random.normal(size=[2, 3]).astype(numpy.float32)
expected_result = numpy.dot(left, right)
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(expected_result,
math_utils.batch_times_matrix(
left, right).eval())
@@ -114,7 +114,7 @@ class MathUtilsTest(test.TestCase):
right_transpose = numpy.transpose(right, [0, 2, 1])
expected_result = numpy.transpose(numpy.dot(right_transpose, left.T),
[0, 2, 1])
- with self.test_session():
+ with self.cached_session():
self.assertAllClose(expected_result,
math_utils.matrix_times_batch(
left, right).eval())
@@ -132,7 +132,7 @@ class MathUtilsTest(test.TestCase):
adj_x=True, adj_y=True).eval())
def test_make_diagonal_undefined_shapes(self):
- with self.test_session():
+ with self.cached_session():
completely_undefined = array_ops.placeholder(dtype=dtypes.float32)
partly_undefined = array_ops.placeholder(
shape=[None, None], dtype=dtypes.float32)
@@ -152,7 +152,7 @@ class MathUtilsTest(test.TestCase):
[5., 6.]]}))
def test_make_diagonal_mostly_defined_shapes(self):
- with self.test_session():
+ with self.cached_session():
mostly_defined = array_ops.placeholder(
shape=[None, 2], dtype=dtypes.float32)
blocked = math_utils.block_diagonal([[[2.]],
@@ -192,7 +192,7 @@ class TestMakeToeplitzMatrix(test.TestCase):
def _test_make_toeplitz_matrix(self, inputs, output_expected):
output_tf = math_utils.make_toeplitz_matrix(inputs)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
output_tf_np = sess.run(output_tf)
self.assertAllClose(output_tf_np, output_expected)
@@ -201,13 +201,13 @@ class TestMakeCovarianceMatrix(test.TestCase):
def test_zero_size_matrix(self):
raw = numpy.zeros([0, 0])
- with self.test_session():
+ with self.cached_session():
constructed = math_utils.sign_magnitude_positive_definite(raw=raw).eval()
self.assertEqual((0, 0), constructed.shape)
def test_sign_magnitude_positive_definite(self):
for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session():
+ with self.cached_session():
matrix_tensor = math_utils.sign_magnitude_positive_definite(
raw=constant_op.constant([[-1., -2.], [3., 4.]], dtype=dtype),
off_diagonal_scale=constant_op.constant(-1., dtype=dtype),
@@ -230,7 +230,8 @@ class TestLookupTable(test.TestCase):
name="test_lookup")
def stack_tensor(base_tensor):
return array_ops.stack([base_tensor + 1, base_tensor + 2])
- with self.test_session() as session:
+
+ with self.cached_session() as session:
((float_output, double_output), int_output) = session.run(
hash_table.lookup([2, 1, 0]))
def expected_output_before_insert(base_tensor):
diff --git a/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py b/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py
index cfd31cc70d..a049dbe773 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py
@@ -29,7 +29,7 @@ class ModelUtilsTest(test.TestCase):
def test_parameter_switching(self):
parameter = array_ops.constant(5)
overridden_parameter = array_ops.constant(3)
- with self.test_session():
+ with self.cached_session():
getter = model_utils.parameter_switch({overridden_parameter: 4})
self.assertEqual(5, getter(parameter))
self.assertEqual(4, getter(overridden_parameter))
diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_management_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_management_test.py
index 5f7e3da2db..42ba6e1c25 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/state_management_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/state_management_test.py
@@ -127,7 +127,7 @@ class ChainingStateManagerTest(test.TestCase):
chainer.initialize_graph(model=stub_model)
model_outputs = chainer.define_loss(
model=stub_model, features=features, mode=estimator_lib.ModeKeys.TRAIN)
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
@@ -178,7 +178,7 @@ class ChainingStateManagerTest(test.TestCase):
result_model_outputs = chainer.define_loss(
model=stub_model, features=result_input_fn()[0],
mode=estimator_lib.ModeKeys.TRAIN)
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
@@ -221,7 +221,7 @@ class ChainingStateManagerTest(test.TestCase):
chainer.initialize_graph(model=stub_model)
model_outputs = chainer.define_loss(
model=stub_model, features=features, mode=estimator_lib.ModeKeys.TRAIN)
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
coordinator = coordinator_lib.Coordinator()
queue_runner_impl.start_queue_runners(session, coord=coordinator)
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index 56e451e2e3..298ffc1ded 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -16,6 +16,7 @@ package(
"//cloud/vmm/testing/tests/tpu:__subpackages__",
"//learning/brain:__subpackages__",
"//learning/deepmind:__subpackages__",
+ "//medical/pathology:__subpackages__",
"//tensorflow:__subpackages__",
],
)
@@ -166,6 +167,7 @@ py_library(
name = "keras_support",
srcs = [
"python/tpu/keras_support.py",
+ "python/tpu/keras_tpu_variables.py",
],
srcs_version = "PY2AND3",
visibility = [
diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py
index 537d94b797..3c0456dc2f 100644
--- a/tensorflow/contrib/tpu/__init__.py
+++ b/tensorflow/contrib/tpu/__init__.py
@@ -33,6 +33,7 @@
@@shard
@@batch_parallel
@@rewrite
+@@outside_compilation
@@CrossShardOptimizer
diff --git a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
index 06553929dc..ea8e0e00ed 100644
--- a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
+++ b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
@@ -18,28 +18,111 @@ limitations under the License.
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
+using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+
+REGISTER_OP("AllToAll")
+ .Input("input: T")
+ .Input("group_assignment: int32")
+ .Output("output: T")
+ .Attr("T: {bfloat16, float}")
+ .Attr("concat_dimension: int")
+ .Attr("split_dimension: int")
+ .Attr("split_count: int")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle input = c->input(0);
+ int64 rank;
+ if (c->RankKnown(input)) {
+ rank = c->Rank(input);
+ } else {
+ return errors::InvalidArgument("input's rank is unknown.");
+ }
+ int concat_dimension;
+ int split_dimension;
+
+ TF_RETURN_IF_ERROR(c->GetAttr("concat_dimension", &concat_dimension));
+
+ if (concat_dimension < 0 || concat_dimension >= rank) {
+ return errors::InvalidArgument("concat_dimension ", concat_dimension,
+ " is out of range of input rank ", rank);
+ }
+
+ TF_RETURN_IF_ERROR(c->GetAttr("split_dimension", &split_dimension));
+ if (split_dimension < 0 || split_dimension >= rank) {
+ return errors::InvalidArgument("split_dimension ", split_dimension,
+ " is out of range of input rank ", rank);
+ }
+
+ std::vector<DimensionHandle> dims;
+ dims.resize(rank);
+
+ for (int32 i = 0; i < rank; ++i) {
+ int64 in_idx = i;
+ if (i == concat_dimension) {
+ in_idx = split_dimension;
+ } else if (i == split_dimension) {
+ in_idx = concat_dimension;
+ }
+
+ dims[i] = c->Dim(input, in_idx);
+ }
+
+ c->set_output(0, c->MakeShape(dims));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+An Op to exchange data across TPU replicas. On each replica, the input is
+split into `split_count` blocks along `split_dimension` and send to the other
+replicas given group_assignment. After receiving `split_count` - 1 blocks from
+other replicas, we concatenate the blocks along `concat_dimension` as the
+output.
+
+For example, suppose there are 2 TPU replicas:
+replica 0 receives input: `[[A, B]]`
+replica 1 receives input: `[[C, D]]`
+
+group_assignment=`[[0, 1]]`
+concat_dimension=0
+split_dimension=1
+split_count=2
+
+replica 0's output: `[[A], [C]]`
+replica 1's output: `[[B], [D]]`
+
+input: The local input to the sum.
+group_assignment: An int32 tensor with shape
+ [num_groups, num_replicas_per_group]. `group_assignment[i]` represents the
+ replica ids in the ith subgroup.
+concat_dimension: The dimension number to concatenate.
+split_dimension: The dimension number to split.
+split_count: The number of splits, this number must equal to the sub-group
+ size(group_assignment.get_shape()[1])
+output: The exchanged result.
+T: The type of elements to be exchanged.
+)doc");
REGISTER_OP("CrossReplicaSum")
.Input("input: T")
+ .Input("group_assignment: int32")
.Output("output: T")
.Attr("T: {bfloat16, float}")
- .Attr("group_assignment: list(int) = []")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
-An Op to sum inputs across replicated TPU instances. Each
-instance supplies its own input. If group_assignment is empty, the output of
-each is the sum of all the inputs, otherwise the output of each is the sum of
-the inputs belonging to the same group.
+An Op to sum inputs across replicated TPU instances. Each instance supplies its
+own input.
-For example, suppose there are 4 TPU instances: `[A, B, C, D]`. Passing
-group_assignment=`[0,1,0,1]` sets `A, C` as group 0, and `B, D` as group 1.
-Thus we get the outputs: `[A+C, B+D, A+C, B+D]`.
+For example, suppose there are 8 TPU instances: `[A, B, C, D, E, F, G, H]`.
+Passing group_assignment=`[[0,2,4,6],[1,3,5,7]]` sets `A, C, E, G` as group 0,
+and `B, D, F, H` as group 1. Thus we get the outputs:
+`[A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H]`.
input: The local input to the sum.
+group_assignment: An int32 tensor with shape
+ [num_groups, num_replicas_per_group]. `group_assignment[i]` represents the
+ replica ids in the ith subgroup.
output: The sum of all the distributed inputs.
T: The type of elements to be summed.
-group_assignment: The list of group ids. `group_assignment[i]` represents the
- group id of replica i.
)doc");
} // namespace tensorflow
diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
index 8e6e9aa0cd..b498599962 100644
--- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
+++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
@@ -156,7 +156,8 @@ bool NewSession(const string& service_addr,
channel_args));
NewProfileSessionResponse new_session_response;
TF_QCHECK_OK(FromGrpcStatus(
- stub->NewSession(&context, new_session_request, &new_session_response)));
+ stub->NewSession(&context, new_session_request, &new_session_response)))
+ << new_session_response.error_message();
std::cout << "Profile session succeed for host(s):"
<< str_util::Join(hostnames, ",") << std::endl;
diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
index 2b13343efa..f88dc51636 100644
--- a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
+++ b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
@@ -79,12 +79,15 @@ message StepInfoResult {
// The step duration in picoseconds.
optional uint64 duration_ps = 2;
// The infeed duration in picoseconds.
- // Can turn into a map if we want a variable number of ops.
optional uint64 infeed_duration_ps = 3;
+ // The outfeed duration in picoseconds.
+ optional uint64 host_outfeed_ps = 8;
// The start time of this step in picoseconds.
optional uint64 begin_ps = 4;
// The waiting time within this step in picoseconds.
optional uint64 wait_duration_ps = 5;
+ // The unit b outfeed duration in picoseconds.
+ optional uint64 unit_b_outfeed_ps = 9;
// The time spent on cross-replica-sum in picoseconds.
optional uint64 crs_duration_ps = 6;
// Percentage of unit b time spent on infeed.
diff --git a/tensorflow/contrib/tpu/proto/optimization_parameters.proto b/tensorflow/contrib/tpu/proto/optimization_parameters.proto
index 2cc17d6d92..fc1320501b 100644
--- a/tensorflow/contrib/tpu/proto/optimization_parameters.proto
+++ b/tensorflow/contrib/tpu/proto/optimization_parameters.proto
@@ -9,8 +9,8 @@ message ClippingLimits {
google.protobuf.FloatValue upper = 2; // +inf if not set
}
-// Get the learning rate from a <yet to be determined> source that can change
-// dynamically.
+// Get the learning rate from the parameters of the SendTPUEmbeddingGradients
+// op.
message DynamicLearningRate {
}
@@ -119,7 +119,9 @@ message OptimizationParameters {
// Whether to use gradient accumulation (do two passes over the input
// gradients: one to accumulate them into a temporary array and another to
- // apply them using the actual optimization algorithm).
+ // apply them using the actual optimization algorithm). This feature is
+ // experimental -- it has not been fully verified and may cause training
+ // crashes and/or failures.
bool use_gradient_accumulation = 15;
// Optimization algorithm parameters; which field is selected determines which
diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
index bf442d9116..d92a0652bb 100644
--- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py
+++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
@@ -21,8 +21,10 @@ from __future__ import print_function
import platform
+from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.platform import tf_logging as logging
if platform.system() != "Windows":
# pylint: disable=wildcard-import,unused-import,g-import-not-at-top
@@ -36,10 +38,85 @@ if platform.system() != "Windows":
_tpu_ops = loader.load_op_library(
resource_loader.get_path_to_datafile("_tpu_ops.so"))
+ def _create_default_group_assignment():
+ num_shards = tpu_function.get_tpu_context().number_of_shards
+ if num_shards is None:
+ logging.warning(
+ "cross_replica_sum should be used within a tpu_shard_context, but "
+ "got unset number_of_shards. Assuming 1.")
+ num_shards = 1
+ group_assignment = [list(range(num_shards))]
+ return group_assignment
+
+ def all_to_all(x,
+ concat_dimension,
+ split_dimension,
+ split_count,
+ group_assignment=None,
+ name=None):
+ """Exchange data across TPU replicas.
+
+ Args:
+ x: The local tensor.
+ concat_dimension: The dimension number to concatenate.
+ split_dimension: The dimension number to split.
+ split_count: The number of splits, this number must equal to the sub-group
+ size(group_assignment.get_shape()[1])
+ group_assignment: Optional 2d int32 lists with shape [num_groups,
+ num_replicas_per_group]. `group_assignment[i]` represents the replica
+ ids in the ith subgroup.
+ name: Optional op name.
+
+ Returns:
+ A `Tensor` which is concatenated by data from different replicas.
+ """
+ if group_assignment is None:
+ group_assignment = _create_default_group_assignment()
+ return gen_tpu_ops.all_to_all(
+ x,
+ group_assignment,
+ concat_dimension=concat_dimension,
+ split_dimension=split_dimension,
+ split_count=split_count,
+ name=name)
+
+ @ops.RegisterGradient("AllToAll")
+ def _all_to_all_grad(op, grad):
+ # The gradient of a all-to-all is also a all-to-all but the
+ # split_dimension and concat_dimension is swapped.
+ # The graident with respect to group_assignment is None.
+ return [
+ gen_tpu_ops.all_to_all(
+ grad,
+ op.inputs[1],
+ concat_dimension=op.get_attr("split_dimension"),
+ split_dimension=op.get_attr("concat_dimension"),
+ split_count=op.get_attr("split_count")), None
+ ]
+
+ def cross_replica_sum(x, group_assignment=None, name=None):
+ """Sum the input tensor accorss replicas according to group_assignment.
+
+ Args:
+ x: The local tensor to the sum.
+ group_assignment: Optional 2d int32 lists with shape [num_groups,
+ num_replicas_per_group]. `group_assignment[i]` represents the replica
+ ids in the ith subgroup.
+ name: Optional op name.
+
+ Returns:
+ A `Tensor` which is summed across replicas.
+ """
+ if group_assignment is None:
+ group_assignment = _create_default_group_assignment()
+
+ return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
+
@ops.RegisterGradient("CrossReplicaSum")
def _cross_replica_sum_grad(op, grad):
# The gradient of a cross replica sum is also a cross-replica sum.
- return gen_tpu_ops.cross_replica_sum(grad, op.get_attr("group_assignment"))
+ # The graident with respect to group_assignment is None.
+ return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None]
# This extra type checking exists to give a more helpful error message in
# the common case that uint8 and int64 values are infed. Remove when both
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index a5e8277ba5..d8c3872363 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -58,29 +58,38 @@ from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_reso
from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result
from tensorflow.contrib.tpu.python.ops import tpu_ops
+from tensorflow.contrib.tpu.python.tpu import keras_tpu_variables
from tensorflow.contrib.tpu.python.tpu import tpu
from tensorflow.contrib.tpu.python.tpu import tpu_function
from tensorflow.contrib.tpu.python.tpu import tpu_optimizer
+from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session as tf_session
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.eager import context
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend as K
+from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras import models
from tensorflow.python.keras import optimizers as keras_optimizers
from tensorflow.python.keras.engine import base_layer
+from tensorflow.python.keras.engine import training_arrays
+from tensorflow.python.keras.engine import training_utils
from tensorflow.python.keras.layers import embeddings
+from tensorflow.python.keras.utils.generic_utils import make_batches
+from tensorflow.python.keras.utils.generic_utils import slice_arrays
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.util import tf_inspect
_SESSIONS = {}
@@ -96,9 +105,9 @@ def tpu_session(cluster_resolver):
if cluster_spec:
config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
+ logging.info('Connecting to: %s', master)
graph = ops.Graph()
session = tf_session.Session(graph=graph, target=master, config=config)
-
with graph.as_default():
session.run(tpu.initialize_system())
@@ -109,25 +118,94 @@ def tpu_session(cluster_resolver):
def reset_tpu_sessions():
_SESSIONS.clear()
+try:
+ from scipy.sparse import issparse # pylint: disable=g-import-not-at-top
+except ImportError:
+ issparse = None
-# Work-around dependency cycle between DistributionStrategy and TPU lib.
-def TPUDistributionStrategy(tpu_cluster_resolver=None): # pylint: disable=invalid-name
- """Construct a TPUDistributionStrategy."""
- from tensorflow.contrib.distribute.python import tpu_strategy # pylint: disable=g-import-not-at-top
- # TODO -- remove this when TPUStrategy API is consistent (b/112705069)
- if tpu_cluster_resolver is None:
- tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('')
- args, _, _, _ = tf_inspect.getargspec(tpu_strategy.TPUStrategy.__init__)
- if len(args) == 3:
- logging.info('Detected new TPUStrategy API.')
- return tpu_strategy.TPUStrategy(tpu_cluster_resolver, steps_per_run=1)
- else:
- logging.info('Detected old TPUStrategy API.')
- strategy = tpu_strategy.TPUStrategy(num_cores_per_host=8)
- strategy._tpu_cluster_resolver = tpu_cluster_resolver
+def get_tpu_system_metadata(tpu_cluster_resolver):
+ """Retrieves TPU system metadata given a TPUClusterResolver."""
+ master = tpu_cluster_resolver.master()
+
+ # pylint: disable=protected-access
+ cluster_spec = tpu_cluster_resolver.cluster_spec()
+ cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None
+ tpu_system_metadata = (
+ tpu_system_metadata_lib._query_tpu_system_metadata(
+ master,
+ cluster_def=cluster_def,
+ query_topology=False))
+
+ return tpu_system_metadata
+
+
+class TPUDistributionStrategy(object):
+ """The strategy to run Keras model on TPU."""
- return strategy
+ def __init__(self, tpu_cluster_resolver=None, using_single_core=False):
+ """Construct a TPUDistributionStrategy.
+
+ Args:
+ tpu_cluster_resolver: Any instance of `TPUClusterResolver`. If None, will
+ create one with '' as master address.
+ using_single_core: Bool. This is the debugging option, which might be
+ removed in future once the model replication functionality is mature
+ enough. If `False` (default behavior), the system automatically finds
+ the best configuration, in terms of number of TPU cores, for the model
+ replication, typically using all avaiable TPU cores. If overwrites as
+ `True`, force the model replication using single core, i.e., no
+ replication.
+ """
+
+ if tpu_cluster_resolver is None:
+ tpu_cluster_resolver = tpu_cluster_resolver_lib.TPUClusterResolver('')
+
+ metadata = get_tpu_system_metadata(tpu_cluster_resolver)
+ self._tpu_metadata = metadata
+ self._tpu_cluster_resolver = tpu_cluster_resolver
+ self._num_cores = 1 if using_single_core else metadata.num_cores
+
+ # Walk device list to identify TPU worker for enqueue/dequeue operations.
+ worker_re = re.compile('/job:([^/]+)')
+ for device in metadata.devices:
+ if 'TPU:0' in device.name:
+ self._worker_name = worker_re.search(device.name).group(1)
+ break
+
+ def _make_assignment_for_model(self, cpu_model):
+ """Makes a `TPUAssignment` for the passed in `cpu_model`."""
+ num_cores = self._num_cores
+ if num_cores > 1 and cpu_model.stateful:
+ logging.warning(
+ 'Model replication does not currently support stateful models. '
+ 'Degrading to a single core.')
+ num_cores = 1
+
+ return TPUAssignment(
+ worker_name=self._worker_name, num_cores=num_cores)
+
+
+class TPUAssignment(object):
+ """This is object holding TPU resources assignment for the concrete model.
+
+ `TPUDistributionStrategy` is responsible to create the instance of
+ `TPUAssignment`, so, it can dynamically adjust the `num_cores` to use based on
+ model and input batch sizes.
+ """
+
+ def __init__(self, worker_name, num_cores):
+ self._worker_name = worker_name
+ self._num_cores = num_cores
+
+ @property
+ def worker_name(self):
+ return self._worker_name
+
+ @property
+ def num_towers(self):
+ # TODO(xiejw): Support automatically assign num_cores based on inputs.
+ return self._num_cores
class TPUEmbedding(embeddings.Embedding):
@@ -180,6 +258,8 @@ class KerasCrossShardOptimizer(keras_optimizers.Optimizer):
return [tpu_ops.cross_replica_sum(grad) / num_shards for grad in grads]
def set_weights(self, weights):
+ # TODO(power): Figure out whether we really need this given there is no
+ # caller for this API yet.
self._opt.set_weights()
def get_weights(self):
@@ -204,9 +284,9 @@ def _valid_name(tensor_name):
def _replicated_optimizer(opt):
"""Wrap the optimizer `opt` with CrossShardOptimizer if applicable."""
- if tpu_function.get_tpu_context().number_of_shards == 1:
- return opt
-
+ # Always wrap `opt` with CrossShardOptimizer, even if we are running on a
+ # single core. This ensures Keras properly tracks and initializes optimizer
+ # variables.
if isinstance(opt, keras_optimizers.TFOptimizer):
return tpu_optimizer.CrossShardOptimizer(opt.optimizer)
else:
@@ -447,8 +527,8 @@ class TPUNumpyInfeedManager(TPUInfeedManager):
infeed_dict[tensor] = value
return infeed_dict
- def __init__(self, distribution_strategy):
- self._strategy = distribution_strategy
+ def __init__(self, tpu_assignment):
+ self._tpu_assignment = tpu_assignment
def _split_tensors(self, inputs):
"""Split input data across shards.
@@ -461,16 +541,16 @@ class TPUNumpyInfeedManager(TPUInfeedManager):
Returns:
List of lists containing the input to feed to each TPU shard.
"""
- if self._strategy.num_towers == 1:
+ if self._tpu_assignment.num_towers == 1:
return [inputs]
batch_size = inputs[0].shape[0]
- assert batch_size % self._strategy.num_towers == 0, (
- 'batch_size must be divisible by strategy.num_towers (%s vs %s)' %
- (batch_size, self._strategy.num_towers))
- shard_size = batch_size // self._strategy.num_towers
+ assert batch_size % self._tpu_assignment.num_towers == 0, (
+ 'batch_size must be divisible by the number of TPU cores in use (%s '
+ 'vs %s)' % (batch_size, self._tpu_assignment.num_towers))
+ shard_size = batch_size // self._tpu_assignment.num_towers
input_list = []
- for index in range(self._strategy.num_towers):
+ for index in range(self._tpu_assignment.num_towers):
shard_inputs = [
x[index * shard_size:(index + 1) * shard_size] for x in inputs
]
@@ -485,8 +565,9 @@ class TPUNumpyInfeedManager(TPUInfeedManager):
infeed_op = []
shard_infeed_tensors = []
- for shard_id in range(self._strategy.num_towers):
- with ops.device('/device:CPU:0'):
+ for shard_id in range(self._tpu_assignment.num_towers):
+ with ops.device(
+ '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
infeed_tensors = []
with ops.device('/device:TPU:%d' % shard_id):
for spec in input_specs:
@@ -525,30 +606,31 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
# TODO(saeta): Verify tpu_model_op is as expected!
return {}
- def __init__(self, dataset, distribution_strategy, tpu_session):
+ # pylint: disable=redefined-outer-name
+ def __init__(self, dataset, tpu_assignment, tpu_session):
"""Constructs a TPUDatasetInfeedManager.
Must be called within a `KerasTPUModel.tpu_session` context!
Args:
dataset: A `tf.data.Dataset` to infeed.
- distribution_strategy: The `TPUDistributionStrategy` used to configure the
+ tpu_assignment: The `TPUAssignment` used to configure the
Keras TPU model.
tpu_session: The `tf.Session` object used for running the TPU model.
"""
self._verify_dataset_shape(dataset)
self._dataset = dataset
- self._strategy = distribution_strategy
+ self._tpu_assignment = tpu_assignment
dummy_x_shape = dataset.output_shapes[0].as_list()
- dummy_x_shape[0] *= distribution_strategy.num_towers
+ dummy_x_shape[0] *= tpu_assignment.num_towers
dummy_y_shape = dataset.output_shapes[1].as_list()
- dummy_y_shape[0] *= distribution_strategy.num_towers
+ dummy_y_shape[0] *= tpu_assignment.num_towers
self._iterator = dataset.make_initializable_iterator()
tpu_session.run(self._iterator.initializer)
self._get_next_ops = []
ctrl_deps = []
- for i in range(distribution_strategy.num_towers):
+ for i in range(tpu_assignment.num_towers):
with ops.control_dependencies(ctrl_deps): # Ensure deterministic
# TODO(saeta): Ensure correct placement!
get_next_op = self._iterator.get_next()
@@ -612,7 +694,7 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
'currently requires static shapes. The provided '
'dataset only has a partially defined shape. '
'(Dimension %d of output tensor %d is not statically known '
- 'for output shapes: %s.%s)' % (i, j, dataset.output_shapes, hint))
+ 'for output shapes: %s.%s)' % (j, i, dataset.output_shapes, hint))
@property
def dummy_x(self):
@@ -628,10 +710,11 @@ class TPUDatasetInfeedManager(TPUInfeedManager):
def build_infeed_from_input_specs(self, input_specs, execution_mode):
shard_infeed_tensors = self._get_next_ops
- assert len(shard_infeed_tensors) == self._strategy.num_towers
+ assert len(shard_infeed_tensors) == self._tpu_assignment.num_towers
infeed_ops = []
- for shard_id in range(self._strategy.num_towers):
- with ops.device('/device:CPU:0'):
+ for shard_id in range(self._tpu_assignment.num_towers):
+ with ops.device(
+ '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
infeed_ops.append(
tpu_ops.infeed_enqueue_tuple(
shard_infeed_tensors[shard_id],
@@ -654,10 +737,10 @@ class TPUFunction(object):
instead of being injected as `feed_dict` items or fetches.
"""
- def __init__(self, model, execution_mode, strategy):
+ def __init__(self, model, execution_mode, tpu_assignment):
self.model = model
self.execution_mode = execution_mode
- self._strategy = strategy
+ self._tpu_assignment = tpu_assignment
self._compilation_cache = {}
self._cloned_model = None
@@ -709,8 +792,8 @@ class TPUFunction(object):
# Clone our CPU model, running within the TPU device context.
with TPURewriteContext(tpu_input_map):
with variable_scope.variable_scope('tpu_model_%s' % id(self.model)):
- # TODO(power): Replicate variables.
- with ops.device('/device:TPU:0'):
+ with keras_tpu_variables.replicated_scope(
+ self._tpu_assignment.num_towers):
self._cloned_model = models.clone_model(self.model)
# Create a copy of the optimizer for this graph.
@@ -780,7 +863,7 @@ class TPUFunction(object):
# `execute op` replicates `_model_fn` `num_replicas` times, with each shard
# running on a different logical core.
compile_op, execute_op = tpu.split_compile_and_replicate(
- _model_fn, inputs=[[]] * self._strategy.num_towers)
+ _model_fn, inputs=[[]] * self._tpu_assignment.num_towers)
# Generate CPU side operations to enqueue features/labels and dequeue
# outputs from the model call.
@@ -788,8 +871,9 @@ class TPUFunction(object):
input_specs, self.execution_mode)
# Build output ops.
outfeed_op = []
- for shard_id in range(self._strategy.num_towers):
- with ops.device('/device:CPU:0'):
+ for shard_id in range(self._tpu_assignment.num_towers):
+ with ops.device(
+ '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
outfeed_op.extend(
tpu_ops.outfeed_dequeue_tuple(
dtypes=[spec.dtype for spec in self._outfeed_spec],
@@ -807,7 +891,7 @@ class TPUFunction(object):
def _test_model_compiles(self, tpu_model_ops):
"""Verifies that the given TPUModelOp can be compiled via XLA."""
logging.info('Started compiling')
- start_time = time.clock()
+ start_time = time.time()
result = K.get_session().run(tpu_model_ops.compile_op)
proto = tpu_compilation_result.CompilationResultProto()
@@ -816,38 +900,52 @@ class TPUFunction(object):
raise RuntimeError('Compilation failed: {}'.format(
proto.status_error_message))
- end_time = time.clock()
+ end_time = time.time()
logging.info('Finished compiling. Time elapsed: %s secs',
end_time - start_time)
- def __call__(self, inputs):
- assert isinstance(inputs, list)
+ def _lookup_infeed_manager(self, inputs):
+ """Return an existing manager, or construct a new InfeedManager for inputs.
+
+ _lookup_infeed_manager will return an existing InfeedManager if one has been
+ previously assigned for this model and input. If not, it will construct a
+ new TPUNumpyInfeedManager.
+
+ Args:
+ inputs: A NumPy input to the model.
+
+ Returns:
+ A `TPUInfeedManager` object to manage infeeds for this input.
+ """
+ if inputs is None:
+ return None
- infeed_manager = None
for x, mgr in self.model._numpy_to_infeed_manager_list:
if inputs[0] is x:
- infeed_manager = mgr
- break
- if infeed_manager is None:
- infeed_manager = TPUNumpyInfeedManager(self.model._strategy)
+ return mgr
+ return TPUNumpyInfeedManager(self.model._tpu_assignment)
- # Strip sample weight from inputs
- if (self.execution_mode == model_fn_lib.ModeKeys.TRAIN or
- self.execution_mode == model_fn_lib.ModeKeys.EVAL):
- input_tensors = self.model._feed_inputs + self.model._feed_targets
- inputs = inputs[:len(input_tensors)]
- else:
- input_tensors = self.model._feed_inputs
+ def _tpu_model_ops_for_input_specs(self, input_specs, infeed_manager):
+ """Looks up the corresponding `TPUModelOp` for a given `input_specs`.
- infeed_instance = infeed_manager.make_infeed_instance(inputs)
- del inputs # To avoid accident usage.
- input_specs = infeed_instance.make_input_specs(input_tensors)
+ It instantiates a new copy of the model for each unique input shape.
+
+ Args:
+ input_specs: The specification of the inputs to train on.
+ infeed_manager: The infeed manager responsible for feeding in data.
+
+ Returns:
+ A `TPUModelOp` instance that can be used to execute a step of the model.
+ """
+ if input_specs is None or infeed_manager is None:
+ # Note: this condition is possible during the prologue or epilogue of the
+ # pipelined loop.
+ return None
# XLA requires every operation in the graph has a fixed shape. To
# handle varying batch sizes we recompile a new sub-graph for each
# unique input shape.
shape_key = tuple([tuple(spec.shape.as_list()) for spec in input_specs])
-
if shape_key not in self._compilation_cache:
with self.model.tpu_session():
logging.info('New input shapes; (re-)compiling: mode=%s, %s',
@@ -857,24 +955,47 @@ class TPUFunction(object):
self._compilation_cache[shape_key] = new_tpu_model_ops
self._test_model_compiles(new_tpu_model_ops)
- # Initialize our TPU weights on the first compile.
- self.model._initialize_weights(self._cloned_model)
- tpu_model_ops = self._compilation_cache[shape_key]
+ return self._compilation_cache[shape_key]
- infeed_dict = infeed_instance.make_feed_dict(tpu_model_ops)
+ def _construct_input_tensors_and_inputs(self, inputs):
+ """Returns input tensors and numpy array inputs corresponding to `inputs`.
- with self.model.tpu_session() as session:
- _, _, outfeed_outputs = session.run([
- tpu_model_ops.infeed_op, tpu_model_ops.execute_op,
- tpu_model_ops.outfeed_op
- ], infeed_dict)
+ Args:
+ inputs: NumPy inputs.
+
+ Returns:
+ A tuple of `input_tensors`, and `inputs`.
+ """
+ if inputs is None:
+ # Note: this condition is possible during the prologue or epilogue of the
+ # pipelined loop.
+ return None, None
+ # Strip sample weight from inputs
+ if (self.execution_mode == model_fn_lib.ModeKeys.TRAIN or
+ self.execution_mode == model_fn_lib.ModeKeys.EVAL):
+ input_tensors = self.model._feed_inputs + self.model._feed_targets
+ inputs = inputs[:len(input_tensors)]
+ return input_tensors, inputs
+ else:
+ input_tensors = self.model._feed_inputs
+ return input_tensors, inputs
+
+ def _process_outputs(self, outfeed_outputs):
+ """Processes the outputs of a model function execution.
- # TODO(xiejw): Decide how to reduce outputs, or just discard all but first.
+ Args:
+ outfeed_outputs: The sharded outputs of the TPU computation.
+
+ Returns:
+ The aggregated outputs of the TPU computation to be used in the rest of
+ the model execution.
+ """
+ # TODO(xiejw): Decide how to reduce outputs, or discard all but first.
if self.execution_mode == model_fn_lib.ModeKeys.PREDICT:
outputs = [[]] * len(self._outfeed_spec)
outputs_per_replica = len(self._outfeed_spec)
- for i in range(self._strategy.num_towers):
+ for i in range(self._tpu_assignment.num_towers):
output_group = outfeed_outputs[i * outputs_per_replica:(i + 1) *
outputs_per_replica]
for j in range(outputs_per_replica):
@@ -882,7 +1003,139 @@ class TPUFunction(object):
return [np.concatenate(group) for group in outputs]
else:
- return outfeed_outputs[:len(outfeed_outputs) // self._strategy.num_towers]
+ return outfeed_outputs[:len(outfeed_outputs) //
+ self._tpu_assignment.num_towers]
+
+ def __call__(self, inputs):
+ """__call__ executes the function on the computational hardware.
+
+ It handles executing infeed, and preprocessing in addition to executing the
+ model on the TPU hardware.
+
+ Note: `__call__` has a sibling method `pipeline_run` which performs the same
+ operations, but with software pipelining.
+
+ Args:
+ inputs: The inputs to use to train.
+
+ Returns:
+ The output of the computation for the given mode it is executed in.
+
+ Raises:
+ RuntimeError: If there is an inappropriate use of the function.
+ """
+ assert isinstance(inputs, list)
+
+ infeed_manager = self._lookup_infeed_manager(inputs)
+ input_tensors, inputs = self._construct_input_tensors_and_inputs(inputs)
+ infeed_instance = infeed_manager.make_infeed_instance(inputs)
+ del inputs # To avoid accident usage.
+ input_specs = infeed_instance.make_input_specs(input_tensors)
+ tpu_model_ops = self._tpu_model_ops_for_input_specs(input_specs,
+ infeed_manager)
+ infeed_dict = infeed_instance.make_feed_dict(tpu_model_ops)
+
+ # Initialize our TPU weights on the first compile.
+ self.model._initialize_weights(self._cloned_model)
+
+ with self.model.tpu_session() as session:
+ _, _, outfeed_outputs = session.run([
+ tpu_model_ops.infeed_op, tpu_model_ops.execute_op,
+ tpu_model_ops.outfeed_op
+ ], infeed_dict)
+ return self._process_outputs(outfeed_outputs)
+
+ def pipeline_run(self, cur_step_inputs, next_step_inputs):
+ """pipeline_run executes the function on the computational hardware.
+
+ pipeline_run performs the same computation as __call__, however it runs the
+ infeed in a software pipelined fashion compared to the on-device execution.
+
+ Note: it is the responsibility of the caller to call `pipeline_run` in the
+ following sequence:
+ - Once with `cur_step_inputs=None` and `next_step_inputs=list(...)`
+ - `n` times with `cur_step_inputs` and `next_step_inputs` as `list`s
+ - Once with `cur_step_inputs=list(...)` and `next_step_inputs=None`
+ Additionally, it is the responsibility of the caller to pass
+ `next_step_inputs` as `cur_step_inputs` on the next invocation of
+ `pipeline_run`.
+
+ Args:
+ cur_step_inputs: The current step's inputs.
+ next_step_inputs: The next step's inputs.
+
+ Returns:
+ The output of the computation for the given mode it is executed in.
+
+ Raises:
+ RuntimeError: If there is an inappropriate use of the function.
+ """
+ # Software pipelined case.
+ next_step_infeed_manager = self._lookup_infeed_manager(next_step_inputs)
+ cur_step_infeed_manager = self._lookup_infeed_manager(cur_step_inputs)
+
+ if (next_step_infeed_manager is not None
+ and cur_step_infeed_manager is not None):
+ assert type(next_step_infeed_manager) is type(cur_step_infeed_manager)
+
+ next_input_tensors, next_step_inputs = (
+ self._construct_input_tensors_and_inputs(next_step_inputs))
+ cur_input_tensors, cur_step_inputs = (
+ self._construct_input_tensors_and_inputs(cur_step_inputs))
+
+ cur_infeed_instance = None
+ if cur_step_infeed_manager:
+ cur_infeed_instance = cur_step_infeed_manager.make_infeed_instance(
+ cur_step_inputs)
+ next_infeed_instance = None
+ if next_step_infeed_manager:
+ next_infeed_instance = next_step_infeed_manager.make_infeed_instance(
+ next_step_inputs)
+
+ del cur_step_inputs # Avoid accidental re-use.
+ del next_step_inputs # Avoid accidental re-use.
+
+ cur_tpu_model_ops = None
+ next_tpu_model_ops = None
+ infeed_dict = None
+
+ if cur_infeed_instance and cur_input_tensors and cur_step_infeed_manager:
+ cur_input_specs = cur_infeed_instance.make_input_specs(
+ cur_input_tensors)
+ cur_tpu_model_ops = self._tpu_model_ops_for_input_specs(
+ cur_input_specs, cur_step_infeed_manager)
+
+ if (next_infeed_instance
+ and next_input_tensors
+ and next_step_infeed_manager):
+ next_input_specs = next_infeed_instance.make_input_specs(
+ next_input_tensors)
+ next_tpu_model_ops = self._tpu_model_ops_for_input_specs(
+ next_input_specs, next_step_infeed_manager)
+ infeed_dict = next_infeed_instance.make_feed_dict(next_tpu_model_ops)
+
+ # Initialize our TPU weights on the first compile.
+ self.model._initialize_weights(self._cloned_model)
+
+ if next_tpu_model_ops and cur_tpu_model_ops:
+ with self.model.tpu_session() as session:
+ _, _, outfeed_outputs = session.run([
+ next_tpu_model_ops.infeed_op, cur_tpu_model_ops.execute_op,
+ cur_tpu_model_ops.outfeed_op
+ ], infeed_dict)
+ return self._process_outputs(outfeed_outputs)
+ if cur_tpu_model_ops:
+ with self.model.tpu_session() as session:
+ _, outfeed_outputs = session.run([
+ cur_tpu_model_ops.execute_op, cur_tpu_model_ops.outfeed_op])
+ return self._process_outputs(outfeed_outputs)
+ if next_tpu_model_ops:
+ with self.model.tpu_session() as session:
+ session.run(next_tpu_model_ops.infeed_op, infeed_dict)
+ return None
+ raise RuntimeError('Internal error: both current & next tpu_model_ops '
+ 'were None')
+
class KerasTPUModel(models.Model):
@@ -903,16 +1156,15 @@ class KerasTPUModel(models.Model):
self.predict_function = None
self.test_function = None
self.train_function = None
- self._strategy = strategy
- cluster_resolver = self._strategy._tpu_cluster_resolver
+ cluster_resolver = strategy._tpu_cluster_resolver
self._tpu_name_or_address = cluster_resolver.get_master()
self._cpu_model = cpu_model
+ self._tpu_assignment = strategy._make_assignment_for_model(cpu_model)
self._tpu_model = None
self._tpu_weights_initialized = False
self._session = tpu_session(cluster_resolver)
- self._graph = self._session.graph
# If the input CPU model has already been compiled, compile our TPU model
# immediately.
@@ -931,7 +1183,7 @@ class KerasTPUModel(models.Model):
return {
'cpu_model': self._cpu_model,
'tpu_name_or_address': self._tpu_name_or_address,
- 'strategy': self._strategy,
+ 'tpu_assignment': self._tpu_assignment,
}
def compile(self,
@@ -975,6 +1227,10 @@ class KerasTPUModel(models.Model):
steps_per_epoch=None,
validation_steps=None,
**kwargs):
+ if context.executing_eagerly():
+ raise EnvironmentError('KerasTPUModel currently does not support eager '
+ 'mode.')
+
assert not self._numpy_to_infeed_manager_list # Ensure empty.
infeed_managers = [] # Managers to clean up at the end of the fit call.
@@ -987,7 +1243,8 @@ class KerasTPUModel(models.Model):
'https://github.com/tensorflow/tpu/tree/master/models/experimental'
'/keras')
if callable(x):
- with self.tpu_session() as sess:
+ with self.tpu_session() as sess,\
+ ops.device('/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
dataset = x()
if steps_per_epoch is None:
raise ValueError('When using tf.data as input to a model, you '
@@ -995,7 +1252,8 @@ class KerasTPUModel(models.Model):
if y is not None:
raise ValueError('When using tf.data as input to a model, y must be '
'None')
- infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess)
+ infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
+ sess)
# Use dummy numpy inputs for the rest of Keras' shape checking. We
# intercept them when building the model.
x = infeed_manager.dummy_x
@@ -1016,7 +1274,8 @@ class KerasTPUModel(models.Model):
if validation_steps is None:
raise ValueError('When using tf.data as validation for a model, you '
'should specify the validation_steps argument.')
- infeed_manager = TPUDatasetInfeedManager(dataset, self._strategy, sess)
+ infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
+ sess)
# Use dummy numpy inputs for the rest of Keras' shape checking. We
# intercept them when building the model.
val_x = infeed_manager.dummy_x
@@ -1026,7 +1285,28 @@ class KerasTPUModel(models.Model):
self._numpy_to_infeed_manager_list = infeed_managers
try:
- return super(KerasTPUModel, self).fit(
+ if not kwargs.get('_pipeline', True):
+ logging.info(
+ 'Running non-pipelined training loop (`_pipeline=%s`).',
+ kwargs['_pipeline'])
+ kwargs.pop('_pipeline')
+ return super(KerasTPUModel, self).fit(
+ x,
+ y,
+ batch_size,
+ epochs,
+ verbose,
+ callbacks,
+ validation_split,
+ validation_data,
+ shuffle,
+ class_weight,
+ sample_weight,
+ initial_epoch,
+ steps_per_epoch,
+ validation_steps,
+ **kwargs)
+ return self._pipeline_fit(
x,
y,
batch_size,
@@ -1045,23 +1325,479 @@ class KerasTPUModel(models.Model):
finally:
self._numpy_to_infeed_manager_list = []
+ def evaluate(self,
+ x=None,
+ y=None,
+ batch_size=None,
+ verbose=1,
+ sample_weight=None,
+ steps=None):
+ assert not self._numpy_to_infeed_manager_list # Ensure empty.
+
+ infeed_managers = [] # Managers to clean up at the end of the fit call.
+ if isinstance(x, dataset_ops.Dataset):
+ # TODO(b/111413240): Support taking a tf.data.Dataset directly.
+ raise ValueError(
+ 'Taking a Dataset directly is not yet supported. Please '
+ 'wrap your dataset construction code in a function and '
+ 'pass that to fit instead. For examples, see: '
+ 'https://github.com/tensorflow/tpu/tree/master/models/experimental'
+ '/keras')
+ if callable(x):
+ with self.tpu_session() as sess:
+ dataset = x()
+ if steps is None:
+ raise ValueError('When using tf.data as input to a model, you '
+ 'should specify the steps argument.')
+ if y is not None:
+ raise ValueError('When using tf.data as input to a model, y must be '
+ 'None')
+ infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
+ sess)
+ # Use dummy numpy inputs for the rest of Keras' shape checking. We
+ # intercept them when building the model.
+ x = infeed_manager.dummy_x
+ y = infeed_manager.dummy_y
+ infeed_managers.append((x, infeed_manager))
+
+ self._numpy_to_infeed_manager_list = infeed_managers
+ try:
+ return super(KerasTPUModel, self).evaluate(
+ x,
+ y,
+ batch_size,
+ verbose,
+ sample_weight,
+ steps)
+ finally:
+ self._numpy_to_infeed_manager_list = []
+
+ def _pipeline_fit(self,
+ x,
+ y,
+ batch_size,
+ epochs,
+ verbose,
+ callbacks,
+ validation_split,
+ validation_data,
+ shuffle,
+ class_weight,
+ sample_weight,
+ initial_epoch,
+ steps_per_epoch,
+ validation_steps,
+ **kwargs):
+ # Similar to super.fit(...), but modified to support software pipelining.
+
+ # Backwards compatibility
+ if batch_size is None and steps_per_epoch is None:
+ batch_size = 32
+ # Legacy support
+ if 'nb_epoch' in kwargs:
+ logging.warning('The `nb_epoch` argument in `fit` has been renamed '
+ '`epochs`.')
+ epochs = kwargs.pop('nb_epoch')
+ if kwargs:
+ raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
+
+ # Validate and standardize user data
+ x, y, sample_weights = self._standardize_user_data(
+ x,
+ y,
+ sample_weight=sample_weight,
+ class_weight=class_weight,
+ batch_size=batch_size,
+ check_steps=True,
+ steps_name='steps_per_epoch',
+ steps=steps_per_epoch,
+ validation_split=validation_split)
+
+ # Prepare validation data
+ val_x, val_y, val_sample_weights = self._prepare_validation_data(
+ validation_data,
+ validation_split,
+ validation_steps,
+ x,
+ y,
+ sample_weights,
+ batch_size)
+ return self._pipeline_fit_loop(
+ x,
+ y,
+ sample_weights=sample_weights,
+ batch_size=batch_size,
+ epochs=epochs,
+ verbose=verbose,
+ callbacks=callbacks,
+ val_inputs=val_x,
+ val_targets=val_y,
+ val_sample_weights=val_sample_weights,
+ shuffle=shuffle,
+ initial_epoch=initial_epoch,
+ steps_per_epoch=steps_per_epoch,
+ validation_steps=validation_steps)
+
+ def _pipeline_fit_loop(self,
+ inputs,
+ targets,
+ sample_weights,
+ batch_size,
+ epochs,
+ verbose,
+ callbacks,
+ val_inputs,
+ val_targets,
+ val_sample_weights,
+ shuffle,
+ initial_epoch,
+ steps_per_epoch,
+ validation_steps):
+ self._make_train_function()
+ sample_weights = sample_weights or []
+ val_sample_weights = val_sample_weights or []
+ if self.uses_learning_phase and not isinstance(K.learning_phase(), int):
+ ins = inputs + targets + sample_weights + [1]
+ else:
+ ins = inputs + targets + sample_weights
+
+ do_validation = False
+ if val_inputs:
+ do_validation = True
+ if (steps_per_epoch is None and verbose and inputs and
+ hasattr(inputs[0], 'shape') and hasattr(val_inputs[0], 'shape')):
+ print('Train on %d samples, validate on %d samples' %
+ (inputs[0].shape[0], val_inputs[0].shape[0]))
+
+ if validation_steps:
+ do_validation = True
+ if steps_per_epoch is None:
+ raise ValueError('Can only use `validation_steps` when doing step-wise '
+ 'training, i.e. `steps_per_epoch` must be set.')
+
+ num_training_samples = training_utils.check_num_samples(
+ ins, batch_size, steps_per_epoch, 'steps_per_epoch')
+ count_mode = 'steps' if steps_per_epoch else 'samples'
+ callbacks = cbks.configure_callbacks(
+ callbacks,
+ self,
+ do_validation=do_validation,
+ val_inputs=val_inputs,
+ val_targets=val_targets,
+ val_sample_weights=val_sample_weights,
+ batch_size=batch_size,
+ epochs=epochs,
+ steps_per_epoch=steps_per_epoch,
+ samples=num_training_samples,
+ validation_steps=validation_steps,
+ verbose=verbose,
+ count_mode=count_mode)
+
+ if num_training_samples is not None:
+ index_array = np.arange(num_training_samples)
+
+ # To prevent a slowdown, we find beforehand the arrays that need conversion.
+ feed = self._feed_inputs + self._feed_targets + self._feed_sample_weights
+ indices_for_conversion_to_dense = []
+ for i in range(len(feed)):
+ if issparse is not None and issparse(ins[i]) and not K.is_sparse(feed[i]):
+ indices_for_conversion_to_dense.append(i)
+
+ callbacks.on_train_begin()
+ for epoch in range(initial_epoch, epochs):
+ # Reset stateful metrics
+ for m in self.stateful_metric_functions:
+ m.reset_states()
+ # Update callbacks
+ callbacks.on_epoch_begin(epoch)
+ epoch_logs = {}
+ if steps_per_epoch is not None:
+ # Step-wise fit loop.
+ self._pipeline_fit_loop_step_wise(
+ ins=ins,
+ callbacks=callbacks,
+ steps_per_epoch=steps_per_epoch,
+ epochs=epochs,
+ do_validation=do_validation,
+ val_inputs=val_inputs,
+ val_targets=val_targets,
+ val_sample_weights=val_sample_weights,
+ validation_steps=validation_steps,
+ epoch_logs=epoch_logs)
+ else:
+ # Sample-wise fit loop.
+ self._pipeline_fit_loop_sample_wise(
+ ins=ins,
+ callbacks=callbacks,
+ index_array=index_array,
+ shuffle=shuffle,
+ batch_size=batch_size,
+ num_training_samples=num_training_samples,
+ indices_for_conversion_to_dense=indices_for_conversion_to_dense,
+ do_validation=do_validation,
+ val_inputs=val_inputs,
+ val_targets=val_targets,
+ val_sample_weights=val_sample_weights,
+ validation_steps=validation_steps,
+ epoch_logs=epoch_logs)
+
+ callbacks.on_epoch_end(epoch, epoch_logs)
+ if callbacks.model.stop_training:
+ break
+ callbacks.on_train_end()
+ return self.history
+
+ def _pipeline_fit_loop_sample_wise(self,
+ ins,
+ callbacks,
+ index_array,
+ shuffle,
+ batch_size,
+ num_training_samples,
+ indices_for_conversion_to_dense,
+ do_validation,
+ val_inputs,
+ val_targets,
+ val_sample_weights,
+ validation_steps,
+ epoch_logs):
+ f = self.train_function
+ if shuffle == 'batch':
+ index_array = training_utils.batch_shuffle(index_array, batch_size)
+ elif shuffle:
+ np.random.shuffle(index_array)
+ batches = make_batches(num_training_samples, batch_size)
+
+ ins_last_batch = None
+ last_batch_logs = None
+ batch_index = 0
+
+ for batch_index, (batch_start, batch_end) in enumerate(batches):
+ batch_ids = index_array[batch_start:batch_end]
+ try:
+ if isinstance(ins[-1], int):
+ # Do not slice the training phase flag.
+ ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]]
+ else:
+ ins_batch = slice_arrays(ins, batch_ids)
+ except TypeError:
+ raise TypeError('TypeError while preparing batch. If using HDF5 '
+ 'input data, pass shuffle="batch".')
+
+ # Pipeline batch logs
+ next_batch_logs = {}
+ next_batch_logs['batch'] = batch_index
+ next_batch_logs['size'] = len(batch_ids)
+ if batch_index > 0:
+ # Callbacks operate one step behind in software pipeline.
+ callbacks.on_batch_begin(batch_index - 1, last_batch_logs)
+ for i in indices_for_conversion_to_dense:
+ ins_batch[i] = ins_batch[i].toarray()
+
+ outs = f.pipeline_run(cur_step_inputs=ins_last_batch,
+ next_step_inputs=ins_batch)
+ ins_last_batch = ins_batch
+
+ if batch_index == 0:
+ assert outs is None
+ else:
+ if not isinstance(outs, list):
+ outs = [outs]
+ for l, o in zip(self.metrics_names, outs):
+ last_batch_logs[l] = o # pylint: disable=unsupported-assignment-operation
+ callbacks.on_batch_end(batch_index - 1, last_batch_logs)
+ if callbacks.model.stop_training:
+ return
+ last_batch_logs = next_batch_logs
+
+ # Final batch
+ callbacks.on_batch_begin(batch_index, last_batch_logs)
+ outs = f.pipeline_run(cur_step_inputs=ins_last_batch, next_step_inputs=None)
+ if not isinstance(outs, list):
+ outs = [outs]
+ for l, o in zip(self.metrics_names, outs):
+ last_batch_logs[l] = o
+ callbacks.on_batch_end(batch_index, last_batch_logs)
+ if callbacks.model.stop_training:
+ return
+
+ if do_validation:
+ val_outs = training_arrays.test_loop(
+ self,
+ val_inputs,
+ val_targets,
+ sample_weights=val_sample_weights,
+ batch_size=batch_size,
+ steps=validation_steps,
+ verbose=0)
+ if not isinstance(val_outs, list):
+ val_outs = [val_outs]
+ # Same labels assumed.
+ for l, o in zip(self.metrics_names, val_outs):
+ epoch_logs['val_' + l] = o
+
+ def _pipeline_fit_loop_step_wise(self,
+ ins,
+ callbacks,
+ steps_per_epoch,
+ epochs,
+ do_validation,
+ val_inputs,
+ val_targets,
+ val_sample_weights,
+ validation_steps,
+ epoch_logs):
+ f = self.train_function
+
+ # Loop prologue
+ try:
+ outs = f.pipeline_run(cur_step_inputs=None, next_step_inputs=ins)
+ assert outs is None # Function shouldn't return anything!
+ except errors.OutOfRangeError:
+ logging.warning('Your dataset iterator ran out of data on the first step '
+ 'of the epoch, preventing further training. Check to '
+ 'make sure your paths are correct and you have '
+ 'permissions to read the files. Skipping validation')
+
+ for step_index in range(steps_per_epoch):
+ batch_logs = {'batch': step_index, 'size': 1}
+ callbacks.on_batch_begin(step_index, batch_logs)
+ try:
+ if step_index < steps_per_epoch - 1:
+ next_step_inputs = ins
+ else:
+ next_step_inputs = None
+ outs = f.pipeline_run(cur_step_inputs=ins,
+ next_step_inputs=next_step_inputs)
+ except errors.OutOfRangeError:
+ logging.warning('Your dataset iterator ran out of data; '
+ 'interrupting training. Make sure that your '
+ 'dataset can generate at least `steps_per_batch * '
+ 'epochs` batches (in this case, %d batches). You '
+ 'may need to use the repeat() function when '
+ 'building your dataset.' % steps_per_epoch * epochs)
+ break
+
+ if not isinstance(outs, list):
+ outs = [outs]
+ for l, o in zip(self.metrics_names, outs):
+ batch_logs[l] = o
+
+ callbacks.on_batch_end(step_index, batch_logs)
+ if callbacks.model.stop_training:
+ break
+
+ if do_validation:
+ val_outs = training_arrays.test_loop(self,
+ val_inputs,
+ val_targets,
+ sample_weights=val_sample_weights,
+ steps=validation_steps,
+ verbose=0)
+ if not isinstance(val_outs, list):
+ val_outs = [val_outs]
+ # Same labels assumed.
+ for l, o in zip(self.metrics_names, val_outs):
+ epoch_logs['val_' + l] = o
+
+ def _prepare_validation_data(self,
+ validation_data,
+ validation_split,
+ validation_steps,
+ x,
+ y,
+ sample_weights,
+ batch_size):
+ """Prepares the validation dataset.
+
+ Args:
+ validation_data: The validation data (if provided)
+ validation_split: The validation split (if provided)
+ validation_steps: The validation steps (if provided)
+ x: The main training data x (if provided)
+ y: The main training data y (if provided)
+ sample_weights: The sample weights (if provided)
+ batch_size: The training batch size (if provided)
+
+ Returns:
+ A 3-tuple of (val_x, val_y, val_sample_weights).
+
+ Raises:
+ ValueError: If the provided arguments are not compatible with
+ `KerasTPUModel`.
+ """
+ # Note: this is similar to a section of $tf/python/keras/engine/training.py
+ # It differns in that tf.data objects are not allowed to be passed directly.
+ # Additionally, it handles validating shapes & types appropriately for use
+ # in TPUs.
+ if validation_data:
+ if (isinstance(validation_data, iterator_ops.Iterator) or
+ isinstance(validation_data, iterator_ops.EagerIterator) or
+ isinstance(validation_data, dataset_ops.Dataset)):
+ raise ValueError('KerasTPUModel cannot handle a Dataset or Iterator '
+ 'for validation_data. Please instead pass a function '
+ 'that returns a `tf.data.Dataset`.')
+ if len(validation_data) == 2:
+ val_x, val_y = validation_data # pylint: disable=unpacking-non-sequence
+ val_sample_weight = None
+ elif len(validation_data) == 3:
+ val_x, val_y, val_sample_weight = validation_data # pylint: disable=unpacking-non-sequence
+ else:
+ raise ValueError('When passing a `validation_data` argument, it must '
+ 'contain either 2 items (x_val, y_val), or 3 items '
+ '(x_val, y_val, val_sample_weights). However we '
+ 'received `validation_data=%s`' % validation_data)
+ val_x, val_y, val_sample_weights = self._standardize_user_data(
+ val_x,
+ val_y,
+ sample_weight=val_sample_weight,
+ batch_size=batch_size,
+ steps=validation_steps)
+ elif validation_split and 0. < validation_split < 1.:
+ if training_utils.has_symbolic_tensors(x):
+ raise ValueError('If your data is in the form of symbolic tensors, you '
+ 'cannot use `validation_split`.')
+ if hasattr(x[0], 'shape'):
+ split_at = int(x[0].shape[0] * (1. - validation_split))
+ else:
+ split_at = int(len(x[0]) * (1. - validation_split))
+
+ x, val_x = (slice_arrays(x, 0, split_at), slice_arrays(x, split_at))
+ y, val_y = (slice_arrays(y, 0, split_at), slice_arrays(y, split_at))
+ sample_weights, val_sample_weights = (slice_arrays(
+ sample_weights, 0, split_at), slice_arrays(sample_weights, split_at))
+ elif validation_steps:
+ val_x = []
+ val_y = []
+ val_sample_weights = []
+ else:
+ val_x = None
+ val_y = None
+ val_sample_weights = None
+
+ return val_x, val_y, val_sample_weights
+
def _make_train_function(self):
if not self.train_function:
self.train_function = TPUFunction(
- self, model_fn_lib.ModeKeys.TRAIN, strategy=self._strategy)
+ self,
+ model_fn_lib.ModeKeys.TRAIN,
+ tpu_assignment=self._tpu_assignment)
return self.train_function
def _make_test_function(self):
if not self.test_function:
self.test_function = TPUFunction(
- self, model_fn_lib.ModeKeys.EVAL, strategy=self._strategy)
+ self, model_fn_lib.ModeKeys.EVAL, tpu_assignment=self._tpu_assignment)
return self.test_function
def _make_predict_function(self):
if not self.predict_function:
self.predict_function = TPUFunction(
- self, model_fn_lib.ModeKeys.PREDICT, strategy=self._strategy)
+ self,
+ model_fn_lib.ModeKeys.PREDICT,
+ tpu_assignment=self._tpu_assignment)
return self.predict_function
def _initialize_weights(self, cloned_model):
@@ -1115,7 +1851,7 @@ class KerasTPUModel(models.Model):
@contextlib.contextmanager
def tpu_session(self):
"""Yields a TPU session and sets it as the default Keras session."""
- with self._graph.as_default():
+ with self._session.graph.as_default():
default_session = K.get_session()
# N.B. We have to call `K.set_session()` AND set our session as the
# TF default. `K.get_session()` surprisingly does not return the value
@@ -1133,6 +1869,7 @@ class KerasTPUModel(models.Model):
self._session.close()
+# pylint: disable=bad-continuation
def _validate_shapes(model):
"""Validate that all layers in `model` have constant shape."""
for layer in model.layers:
@@ -1160,10 +1897,13 @@ Layer: %(layer)s
Input shape: %(input_shape)s
Output shape: %(output_shape)s
""" % {
- 'layer': layer,
- 'input_shape': layer.input_shape,
- 'output_shape': layer.output_shape
- })
+ 'layer': layer,
+ 'input_shape': layer.input_shape,
+ 'output_shape': layer.output_shape
+ })
+
+
+# pylint: enable=bad-continuation
@experimental
@@ -1205,5 +1945,10 @@ def tpu_model(model, strategy=None):
if strategy is None:
strategy = TPUDistributionStrategy()
+ else:
+ if not isinstance(strategy, TPUDistributionStrategy):
+ raise TypeError(
+ '`strategy` must have type `tf.contrib.tpu.TPUDistributionStrategy`. '
+ 'Got: {}'.format(type(strategy)))
return KerasTPUModel(cpu_model=model, strategy=strategy)
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
new file mode 100644
index 0000000000..170977d8ab
--- /dev/null
+++ b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
@@ -0,0 +1,287 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Distributed variable implementation for TPUs.
+
+N.B. This is an experimental feature that should only be used for Keras support.
+
+It is unsupported and will be removed in favor of Distribution Strategy soon.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+
+from tensorflow.python.client import session as session_lib
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_resource_variable_ops
+from tensorflow.python.ops import variable_scope
+
+
+@contextlib.contextmanager
+def _handle_graph(handle):
+ with handle.graph.as_default():
+ yield
+
+
+def _enclosing_tpu_context():
+ # pylint: disable=protected-access
+ context = ops.get_default_graph()._get_control_flow_context()
+ # pylint: enable=protected-access
+ while context is not None and not isinstance(
+ context, control_flow_ops.XLAControlFlowContext):
+ context = context.outer_context
+ return context
+
+
+class ReplicatedVariable(object):
+ """A replicated variable for use on TPUs.
+
+ When accessed inside a tpu.replicate() context, this variable acts as if it
+ is a single variable whose handle is a replicated input to the computation.
+
+ Outside a tpu.replicate() context currently this object has pretty murky
+ semantics, especially with respect to things such as
+ * initialization
+ * colocation.
+ """
+
+ def __init__(self, name, variables):
+ self._name = name
+ self._primary_var = variables[0]
+ self._vars = variables
+ self._cached_value = None
+ self._dtype = variables[0].dtype
+
+ @property
+ def handle(self):
+ tpu_context = _enclosing_tpu_context()
+ if tpu_context is None:
+ return self._primary_var.handle
+
+ return tpu_context.get_replicated_var_handle(self)
+
+ @contextlib.contextmanager
+ def _assign_dependencies(self):
+ """Makes assignments depend on the cached value, if any.
+
+ This prevents undefined behavior with reads not ordered wrt writes.
+
+ Yields:
+ None.
+ """
+ if self._cached_value is not None:
+ with ops.control_dependencies([self._cached_value]):
+ yield
+ else:
+ yield
+
+ @property
+ def initializer(self):
+ return control_flow_ops.group([v.initializer for v in self._vars])
+
+ @property
+ def graph(self):
+ return self._primary_var.graph
+
+ @property
+ def _shared_name(self):
+ return self._common_name
+
+ @property
+ def _unique_id(self):
+ return self._primary_var._unique_id # pylint: disable=protected-access
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def dtype(self):
+ return self._primary_var.dtype
+
+ @property
+ def shape(self):
+ return self._primary_var.shape
+
+ def get_shape(self):
+ return self._primary_var.get_shape()
+
+ def to_proto(self, export_scope=None):
+ return self._primary_var.to_proto(export_scope=export_scope)
+
+ @property
+ def constraint(self):
+ return None
+
+ @property
+ def op(self):
+ return self.get().op
+
+ @property
+ def is_tensor_like(self):
+ return True
+
+ def _read_variable_op(self):
+ if _enclosing_tpu_context() is None:
+ return self._primary_var.read_value()
+ v = gen_resource_variable_ops.read_variable_op(self.handle, self._dtype)
+ return v
+
+ def read_value(self):
+ return self._read_variable_op()
+
+ def is_initialized(self, name=None):
+ return self._vars[0].is_initialized(name=name)
+
+ def __getitem__(self, *args):
+ return self.read_value().__getitem__(*args)
+
+ def assign(self, value, use_locking=None, name=None, read_value=False):
+ """Assign `value` to all replicas.
+
+ Outside of the tpu.rewrite context, assign explicitly to all replicas.
+ Inside of the tpu.rewrite context, assigns to the local replica.
+
+ Arguments:
+ value: Tensor to assign
+ use_locking: ignored
+ name: ignored
+ read_value: return the value from the assignment
+ Returns:
+ Assignment operation, or new value of the variable if `read_value` is True
+ """
+ del use_locking
+ if _enclosing_tpu_context() is None:
+ assign_ops = []
+ with self._assign_dependencies():
+ for var in self._vars:
+ assign_ops.append(var.assign(value, use_locking=None, name=name))
+
+ if read_value:
+ with ops.control_dependencies(assign_ops):
+ return self.read_value()
+ else:
+ return control_flow_ops.group(assign_ops)
+
+ with _handle_graph(self.handle), self._assign_dependencies():
+ value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
+ assign_op = gen_resource_variable_ops.assign_variable_op(
+ self.handle, value_tensor, name=name)
+ if read_value:
+ return self._read_variable_op()
+ return assign_op
+
+ def assign_add(self, delta, use_locking=None, name=None, read_value=True):
+ del use_locking
+ with _handle_graph(self.handle), self._assign_dependencies():
+ assign_add_op = gen_resource_variable_ops.assign_add_variable_op(
+ self.handle,
+ ops.convert_to_tensor(delta, dtype=self.dtype),
+ name=name)
+ if read_value:
+ return self._read_variable_op()
+ return assign_add_op
+
+ def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
+ del use_locking
+ with _handle_graph(self.handle), self._assign_dependencies():
+ assign_sub_op = gen_resource_variable_ops.assign_sub_variable_op(
+ self.handle,
+ ops.convert_to_tensor(delta, dtype=self.dtype),
+ name=name)
+ if read_value:
+ return self._read_variable_op()
+ return assign_sub_op
+
+ def get(self):
+ return self._primary_var
+
+ def _should_act_as_resource_variable(self):
+ """Pass resource_variable_ops.is_resource_variable check."""
+ pass
+
+ def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
+ """Converts a variable to a tensor."""
+ # pylint: disable=protected-access
+ if _enclosing_tpu_context() is None:
+ return self._primary_var._dense_var_to_tensor(dtype, name, as_ref)
+ # pylint: enable=protected-access
+ if dtype is not None and dtype != self.dtype:
+ return NotImplemented
+ if as_ref:
+ return self.handle
+ else:
+ return self.read_value()
+
+
+# Register a conversion function which reads the value of the variable,
+# allowing instances of the class to be used as tensors.
+def _tensor_conversion(var, dtype=None, name=None, as_ref=False):
+ return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
+
+
+def replicated_fetch_function(var):
+ # pylint: disable=protected-access
+ return ([var._dense_var_to_tensor()], lambda v: v[0])
+ # pylint: enable=protected-access
+
+
+ops.register_tensor_conversion_function(ReplicatedVariable, _tensor_conversion)
+ops.register_dense_tensor_like_type(ReplicatedVariable)
+session_lib.register_session_run_conversion_functions(
+ ReplicatedVariable, replicated_fetch_function)
+
+
+def replicated_scope(num_replicas):
+ """Variable scope for constructing replicated variables."""
+
+ def _replicated_variable_getter(getter, name, *args, **kwargs):
+ """Getter that constructs replicated variables."""
+ collections = kwargs.pop("collections", None)
+ if collections is None:
+ collections = [ops.GraphKeys.GLOBAL_VARIABLES]
+ kwargs["collections"] = []
+
+ variables = []
+ index = {}
+ for i in range(num_replicas):
+ replica_name = "{}/{}".format(name, i)
+ with ops.device("device:TPU:{}".format(i)):
+ v = getter(*args, name=replica_name, **kwargs)
+ variables.append(v)
+ index[i] = v
+ result = ReplicatedVariable(name, variables)
+
+ g = ops.get_default_graph()
+ # If "trainable" is True, next_creator() will add the member variables
+ # to the TRAINABLE_VARIABLES collection, so we manually remove
+ # them and replace with the MirroredVariable. We can't set
+ # "trainable" to False for next_creator() since that causes functions
+ # like implicit_gradients to skip those variables.
+ if kwargs.get("trainable", True):
+ collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
+ l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
+ for v in index.values():
+ if v in l:
+ l.remove(v)
+ g.add_to_collections(collections, result)
+
+ return result
+
+ return variable_scope.variable_scope(
+ "", custom_getter=_replicated_variable_getter)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 7fa06d6d56..0f9f7cd91b 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -42,9 +42,9 @@ _BLACKLISTED_OPS = set([
"Placeholder",
])
-# These operations will currently fail to compile, but we should be able to
-# support them eventually via CPU offload or extending our operation set.
-_NOT_IMPLEMENTED_OPS = set([
+# XLA doesn't currently support reading of intermediate tensors, thus some ops
+# are not supported.
+_UNSUPPORTED_OPS = set([
"AudioSummary",
"AudioSummaryV2",
"HistogramSummary",
@@ -78,10 +78,10 @@ def initialize_system(embedding_config=None, job=None):
embedding_config: If not None, an `EmbeddingLayerConfiguration` proto
describing the desired configuration of the hardware embedding lookup
tables. If embedding_config is None, no hardware embeddings can be used.
- job: The job (the XXX in TensorFlow device specification /job:XXX)
- that contains the TPU devices that will be initialized. If job=None
- it is assumed there is only one job in the TensorFlow flock, and an
- error will be returned if this assumption does not hold.
+ job: The job (the XXX in TensorFlow device specification /job:XXX) that
+ contains the TPU devices that will be initialized. If job=None it is
+ assumed there is only one job in the TensorFlow flock, and an error will
+ be returned if this assumption does not hold.
Returns:
A serialized `TopologyProto` that describes the TPU system. Note:
the topology must be evaluated using `Session.run` before it can be used.
@@ -118,9 +118,9 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
tpu.replicate() computation with the attribute "_tpu_replicate=XYZ", where XYZ
is a unique name.
- We use a `ControlFlowContext` to perform the annotation since it
- integrates with Tensorflow constructs like ResourceVariables. For example,
- if a `ResourceVariable` is constructed inside a tpu.replicate() block, the
+ We use a `ControlFlowContext` to perform the annotation since it integrates
+ with Tensorflow constructs like ResourceVariables. For example, if a
+ `ResourceVariable` is constructed inside a tpu.replicate() block, the
`ResourceVariable` implementation can use
`with ops.control_dependencies(None)` to build the variable's definition
outside the replicated computation.
@@ -149,6 +149,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
self._gradient_colocation_stack = []
self._host_compute_core = []
self._name = name
+ self._name_as_bytes = compat.as_bytes(name)
self._unsupported_ops = []
self._pivot = pivot
self._replicated_vars = {}
@@ -156,8 +157,8 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
def get_replicated_var_handle(self, var):
"""Returns a variable handle for replicated TPU variable 'var'.
- This is an method used by an experimental replicated variable
- implementation and is not intended as a public API.
+ This is a method used by an experimental replicated variable implementation
+ and is not intended as a public API.
Args:
var: The replicated TPU variable.
@@ -210,28 +211,24 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
if gradient_uid == "__unsupported__":
raise NotImplementedError(
"No gradient_uid calling gradient within outside_compilation")
- # When we take the gradient of an op X in an
- # outside_compilation cluster C in a forward computation we
- # would like to put the ops corresponding to the gradient of
- # X into a new outside_compilation cluster C'. However, if
- # we take the gradient of X twice, the second one should get
- # yet another new outside_compilation cluster C''.
+ # When we take the gradient of an op X in an outside_compilation
+ # cluster C in a forward computation we would like to put the ops
+ # corresponding to the gradient of X into a new outside_compilation
+ # cluster C'. However, if we take the gradient of X twice, the second
+ # one should get yet another new outside_compilation cluster C''.
#
- # The mechanism we adopt is to use a 'root_cluster' which is
- # the cluster that X was in before we took gradients, and a
- # 'gradient_uid' which is different for every invocation of
- # gradients, and put the gradient of X in cluster
- # 'root_cluster.gradient_uid'.
+ # The mechanism we adopt is to use a 'root_cluster' which is the
+ # cluster that X was in before we took gradients, and a 'gradient_uid'
+ # which is different for every invocation of gradients, and put the
+ # gradient of X in cluster 'root_cluster.gradient_uid'.
#
- # When taking a gradient of a gradient, some ops will be
- # colocated with Op in the forward pass (e.g., cluster
- # root_cluster) and some in the backward pass (e.g., cluster
- # root_cluster.initial_gradient_uid). We need all of the
- # grad-of-grad ops to be in the same cluster to avoid cyclic
- # dependencies between clusters. We adopt a heuristic that
- # puts any op clustered with root_cluster.<xxx> in
- # root_cluster.gradient_uid, even if xxx was
- # initial_gradient_uid.
+ # When taking a gradient of a gradient, some ops will be colocated
+ # with Op in the forward pass (e.g., cluster root_cluster) and some in
+ # the backward pass (e.g., cluster root_cluster.initial_gradient_uid).
+ # We need all of the grad-of-grad ops to be in the same cluster to
+ # avoid cyclic dependencies between clusters. We adopt a heuristic
+ # that puts any op clustered with root_cluster.<xxx> in
+ # root_cluster.gradient_uid, even if xxx was initial_gradient_uid.
self._in_gradient_colocation = op
parts = outside_attr.split(".")
cluster = parts[0] + "." + gradient_uid
@@ -323,16 +320,13 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
return self._host_compute_core
def AddOp(self, op):
- self._AddOpInternal(op)
-
- def _AddOpInternal(self, op):
# pylint: disable=protected-access
if op.type in _BLACKLISTED_OPS:
logging.error("Operation of type %s (%s) is not supported on the TPU. "
"Execution will fail if this op is used in the graph. " %
(op.type, op.name))
- if op.type in _NOT_IMPLEMENTED_OPS:
+ if op.type in _UNSUPPORTED_OPS:
self._unsupported_ops.append(op)
if any(x.dtype._is_ref_dtype for x in op.inputs):
@@ -342,7 +336,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
if _TPU_REPLICATE_ATTR in op.node_def.attr:
raise ValueError("TPU computations cannot be nested")
op._set_attr(_TPU_REPLICATE_ATTR,
- attr_value_pb2.AttrValue(s=compat.as_bytes(self._name)))
+ attr_value_pb2.AttrValue(s=self._name_as_bytes))
if self._outside_compilation_cluster:
op._set_attr(
_OUTSIDE_COMPILATION_ATTR,
@@ -356,11 +350,12 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
# Remove any control edges from outer control flow contexts. These may cause
# mismatched frame errors.
- control_inputs, external_inputs = self._RemoveExternalControlEdges(op)
+ (internal_control_inputs,
+ external_control_inputs) = self._RemoveExternalControlEdges(op)
if not op.inputs:
# Add a control edge from the control pivot to this op.
- if not control_inputs:
+ if not internal_control_inputs:
# pylint: disable=protected-access
op._add_control_input(self.GetControlPivot())
# pylint: enable=protected-access
@@ -371,19 +366,19 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
if real_x != x:
op._update_input(index, real_x) # pylint: disable=protected-access
- if external_inputs:
+ if external_control_inputs:
# Use an identity to pull control inputs as data inputs. Note that we
# ignore ops which don't have outputs. TODO(phawkins): fix that.
with ops.control_dependencies(None):
self.Enter()
- external_inputs = [
+ external_control_inputs = [
array_ops.identity(x.outputs[0]).op
- for x in external_inputs
+ for x in external_control_inputs
if x.outputs
]
self.Exit()
# pylint: disable=protected-access
- op._add_control_inputs(external_inputs)
+ op._add_control_inputs(external_control_inputs)
# pylint: enable=protected-access
# Mark op's outputs as seen by this context and any outer contexts.
@@ -399,6 +394,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
self._outer_context.AddInnerOp(op)
def AddValue(self, val):
+ """Add `val` to the current context and its outer context recursively."""
if val.name in self._values:
# Use the real value if it comes from outer context.
result = self._external_values.get(val.name)
@@ -415,7 +411,7 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
return result
def AddInnerOp(self, op):
- self._AddOpInternal(op)
+ self.AddOp(op)
if self._outer_context:
self._outer_context.AddInnerOp(op)
@@ -656,13 +652,31 @@ def split_compile_and_replicate(computation,
# TODO(phawkins): consider removing this code. It will
# be less confusing to clients if they knowingly choose to use resource
# variables.
+ # Partitioned variables is not supported (b/112311320).
+ def custom_getter(getter, name, *args, **kwargs):
+ """Variables on TPU have a few restrictions."""
+ partitioner = kwargs["partitioner"]
+ if partitioner is not None:
+ kwargs["partitioner"] = None
+ logging.warning(
+ "Partitioned variables are not supported on TPU. Got "
+ "`partitioner` that is {} for variable {}. "
+ "Setting `partitioner` to `None`."
+ .format(partitioner, name))
+ return getter(name, *args, **kwargs)
+
vscope = variable_scope.get_variable_scope()
+
saved_use_resource = vscope.use_resource
+ saved_custom_getter = vscope.custom_getter
+
vscope.set_use_resource(True)
+ vscope.set_custom_getter(custom_getter)
outputs = computation(*computation_inputs)
vscope.set_use_resource(saved_use_resource)
+ vscope.set_custom_getter(saved_custom_getter)
# If the computation returns `None`, make it an empty tuple.
if outputs is None:
@@ -765,11 +779,10 @@ def shard(computation,
name=None):
"""Shards `computation` for parallel execution.
- `inputs` must be a list of Tensors or None (equivalent to an empty
- list), each of which has a corresponding split axis (from
- `input_shard_axes`). Each input is split into `num_shards` pieces
- along the corresponding axis, and computation is applied to each
- shard in parallel.
+ `inputs` must be a list of Tensors or None (equivalent to an empty list), each
+ of which has a corresponding split axis (from `input_shard_axes`). Each input
+ is split into `num_shards` pieces along the corresponding axis, and
+ computation is applied to each shard in parallel.
Tensors are broadcast to all shards if they are lexically captured by
`computation`. e.g.,
@@ -791,10 +804,9 @@ def shard(computation,
Args:
computation: A Python function that builds a computation to apply to each
shard of the input.
- inputs: A list of input tensors or None (equivalent to an empty
- list). Each input tensor has a corresponding shard axes, given
- by `input_shard_axes`, which must have size divisible by
- `num_shards`.
+ inputs: A list of input tensors or None (equivalent to an empty list). Each
+ input tensor has a corresponding shard axes, given by `input_shard_axes`,
+ which must have size divisible by `num_shards`.
num_shards: The number of shards.
input_shard_axes: A list of dimensions along which to shard `inputs`, or
`None`. `None` means "shard all inputs along dimension 0". If not `None`,
@@ -913,9 +925,9 @@ def batch_parallel(computation,
Convenience wrapper around shard().
- `inputs` must be a list of Tensors or None (equivalent to an empty
- list). Each input is split into `num_shards` pieces along the 0-th
- dimension, and computation is applied to each shard in parallel.
+ `inputs` must be a list of Tensors or None (equivalent to an empty list).
+ Each input is split into `num_shards` pieces along the 0-th dimension, and
+ computation is applied to each shard in parallel.
Tensors are broadcast to all shards if they are lexically captured by
`computation`. e.g.,
@@ -933,9 +945,8 @@ def batch_parallel(computation,
Args:
computation: A Python function that builds a computation to apply to each
shard of the input.
- inputs: A list of input tensors or None (equivalent to an empty
- list). The 0-th dimension of each Tensor must have size
- divisible by `num_shards`.
+ inputs: A list of input tensors or None (equivalent to an empty list). The
+ 0-th dimension of each Tensor must have size divisible by `num_shards`.
num_shards: The number of shards.
infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
of arguments as inputs to `computation`.
@@ -968,14 +979,14 @@ def rewrite(computation,
"""Rewrites `computation` for execution on a TPU system.
Args:
- computation: A Python function that builds a computation to apply
- to the input. If the function takes n inputs, 'inputs' should be
- a list of n tensors.
+ computation: A Python function that builds a computation to apply to the
+ input. If the function takes n inputs, 'inputs' should be a list of n
+ tensors.
- `computation` may return a list of operations and tensors. Tensors must
+ `computation` may return a list of operations and tensors. Tensors must
come before operations in the returned list. The return value of
`rewrite` is a list of tensors corresponding to the tensors from the
- from `computation`.
+ output of `computation`.
All `Operation`s returned from `computation` will be executed when
evaluating any of the returned output tensors.
@@ -1070,12 +1081,12 @@ class _TPUInferenceContext(control_flow_ops.XLAControlFlowContext):
def validate_inference_rewrite_for_variables(graph):
"""Validates whether rewrite_for_inference() 'worked' for variables.
- The rewrite_for_inference() method is supposed to append
- GuaranteeConstOps after ReadVariableOps, but this mechanism works only
- if you are using tf.get_variable() to create and access variables in your
- tpu computation. This validation method can be called immediately after
- calling tpu.rewrite_for_inference() to check whether GuaranteeConstOps
- where added to the graph.
+ The rewrite_for_inference() method is supposed to append GuaranteeConstOps
+ after ReadVariableOps, but this mechanism works only if you are using
+ tf.get_variable() to create and access variables in your tpu computation.
+ This validation method can be called immediately after calling
+ tpu.rewrite_for_inference() to check whether GuaranteeConstOps where added
+ to the graph.
Typical usages:
tpu.validate_inference_rewrite_for_variables(tf.get_default_graph())
@@ -1089,10 +1100,9 @@ def validate_inference_rewrite_for_variables(graph):
"""
if not any([x.type == "GuaranteeConst" for x in graph.get_operations()]):
raise RuntimeError(
- "No GuaranteeConst ops found in the graph after "
- "running tpu.rewrite_for_inference(...). Please "
- "check that you are using tf.get_variable() to "
- "create and access variables in your tpu "
+ "No GuaranteeConst ops found in the graph after running "
+ "tpu.rewrite_for_inference(...). Please check that you are using "
+ "tf.get_variable() to create and access variables in your tpu "
"computation.")
@@ -1108,16 +1118,16 @@ def rewrite_for_inference(computation,
in your computation, it moves the ReadVariableOps outside the TPU
computation, and adds GuaranteeConst ops just after the ReadVariableOps.
This mechanism works only if you are using tf.get_variable() to create and
- access variables in your tpu computation. You can validate whether
- this worked, by calling validate_inference_rewrite_for_variables() method
+ access variables in your tpu computation. You can validate whether this
+ worked, by calling validate_inference_rewrite_for_variables() method
immediately after this method to check whether GuaranteeConstOps where
added to the graph.
Args:
- computation: A Python function that builds a computation to apply
- to the input. If the function takes n inputs, 'inputs' should be
- a list of n tensors. If the function returns m outputs, rewrite
- will return a list of m tensors.
+ computation: A Python function that builds a computation to apply to the
+ input. If the function takes n inputs, 'inputs' should be a list of n
+ tensors. If the function returns m outputs, rewrite will return a list of
+ m tensors.
inputs: A list of input tensors or `None` (equivalent to an empty list).
infeed_queue: If not `None`, the `InfeedQueue` from which to append a tuple
of arguments as inputs to `computation`.
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 2e4050bd99..1ff04f5c26 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -804,11 +804,14 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
per_host_sharded_inputs.append(flattened_inputs)
if inputs_structure_recorder.flattened_input_dims:
+ input_partition_dims = inputs_structure_recorder.flattened_input_dims
+ if signals:
+ input_partition_dims += [None] * len(signals)
# pylint: disable=protected-access
infeed_queue = tpu_feed._PartitionedInfeedQueue(
number_of_tuple_elements=len(per_host_sharded_inputs[0]),
host_id=host_id,
- input_partition_dims=inputs_structure_recorder.flattened_input_dims,
+ input_partition_dims=input_partition_dims,
device_assignment=ctx.device_assignment)
per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(
per_host_sharded_inputs)
@@ -2821,8 +2824,6 @@ def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
"""Executes `model_fn_wrapper` multiple times on all TPU shards."""
- num_cores = ctx.num_cores
-
(single_tpu_predict_step, host_calls, captured_scaffold_fn,
captured_predict_hooks
) = model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn)
@@ -2841,7 +2842,7 @@ def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):
(dummy_predict_op,) = tpu.shard(
multi_tpu_predict_steps_on_single_shard,
inputs=[],
- num_shards=num_cores,
+ num_shards=ctx.num_replicas,
outputs_from_all_shards=False,
device_assignment=ctx.device_assignment)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py
index 74a675b645..1e11de6421 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py
@@ -19,7 +19,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import collections
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu_function
@@ -44,8 +43,9 @@ class CrossShardOptimizer(optimizer.Optimizer):
reduction: The reduction to apply to the shard losses.
name: Optional name prefix for the operations created when applying
gradients. Defaults to "CrossShardOptimizer".
- group_assignment: Optional list of group ids for applying the optimizer
- to subgroups.
+ group_assignment: Optional 2d int32 lists with shape
+ [num_groups, num_replicas_per_group] which describles how to apply
+ optimizer to subgroups.
Raises:
ValueError: If reduction is not a valid cross-shard reduction.
@@ -74,11 +74,22 @@ class CrossShardOptimizer(optimizer.Optimizer):
"""
if not group_assignment:
return None
- if len(group_assignment) != num_shards:
- raise ValueError("The size of group_assignment does not equal to "
- "num_shard({0}). Got group_assignment={1}".format(
- num_shards, self._group_assignment))
- subgroup_size_list = dict(collections.Counter(group_assignment)).values()
+ if not (isinstance(group_assignment, list) and
+ all(isinstance(i, list) for i in group_assignment)):
+ raise ValueError("group_assignment must be a list of list. Got {}".format(
+ group_assignment))
+
+ replica_ids = set()
+ for g in group_assignment:
+ for i in g:
+ replica_ids.add(i)
+
+ if set(range(num_shards)) != replica_ids:
+ raise ValueError("group_assignment must be a permutation of range({0})."
+ " Got group_assignment={1}".format(
+ num_shards, group_assignment))
+
+ subgroup_size_list = [len(group) for group in group_assignment]
if all(subgroup_size_list[0] == size for size in subgroup_size_list):
return subgroup_size_list[0]
else:
diff --git a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py
index 81278ea82c..afeef978f3 100644
--- a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py
+++ b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py
@@ -108,7 +108,7 @@ class BatchSequencesWithStatesTest(test.TestCase):
expected_seq4_batch1, expected_seq4_batch2,
key=None, make_keys_unique=False):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
next_batch = sqss.batch_sequences_with_states(
input_key=key if key is not None else self.key,
input_sequences=self.sequences,
@@ -332,7 +332,7 @@ class BatchSequencesWithStatesTest(test.TestCase):
"seq4": self.sequences["seq4"],
}
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
".*should be a multiple of: 3, but saw "
"value: 4. Consider setting pad=True."):
diff --git a/tensorflow/contrib/training/python/training/bucket_ops_test.py b/tensorflow/contrib/training/python/training/bucket_ops_test.py
index 504f1fcd41..b259e0ee83 100644
--- a/tensorflow/contrib/training/python/training/bucket_ops_test.py
+++ b/tensorflow/contrib/training/python/training/bucket_ops_test.py
@@ -112,7 +112,7 @@ class BucketTest(test.TestCase):
self.assertAllEqual(
[[32], [32, None], [32, 3], [None, None]],
[out.get_shape().as_list() for out in bucketed_dynamic[1]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for v in range(32):
self.enqueue_inputs(sess, {
self.scalar_int_feed: v,
@@ -162,7 +162,7 @@ class BucketTest(test.TestCase):
self.assertAllEqual(
[[None], [None, None], [None, 3], [None, None]],
[out.get_shape().as_list() for out in bucketed_dynamic[1]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for v in range(15):
self.enqueue_inputs(sess, {
self.scalar_int_feed: v,
@@ -204,7 +204,7 @@ class BucketTest(test.TestCase):
self.assertAllEqual(
[[32], [32, None], [32, 3], [None, None]],
[out.get_shape().as_list() for out in bucketed_dynamic[1]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for v in range(64):
self.enqueue_inputs(sess, {
self.scalar_int_feed: v,
@@ -286,7 +286,7 @@ class BucketTest(test.TestCase):
self.assertAllEqual(
[[32], [32, None], [32, 3]],
[out.get_shape().as_list() for out in bucketed_dynamic[1]])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for v in range(128):
self.enqueue_inputs(sess, {
self.scalar_int_feed: v,
@@ -405,7 +405,7 @@ class BucketBySequenceLengthTest(test.TestCase):
num_pairs_to_enqueue - (batch_size - 1) * num_buckets,
num_pairs_dequeued)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
# Feed the inputs, then close the input thread.
diff --git a/tensorflow/contrib/training/python/training/evaluation.py b/tensorflow/contrib/training/python/training/evaluation.py
index 01bac891da..16a647bf66 100644
--- a/tensorflow/contrib/training/python/training/evaluation.py
+++ b/tensorflow/contrib/training/python/training/evaluation.py
@@ -296,6 +296,7 @@ class SummaryAtEndHook(session_run_hook.SessionRunHook):
def begin(self):
if self._replace_summary_op:
+ # This can still remain None if there are no summaries.
self._summary_op = summary.merge_all()
self._global_step = training_util.get_or_create_global_step()
@@ -304,10 +305,12 @@ class SummaryAtEndHook(session_run_hook.SessionRunHook):
self._summary_writer = summary.FileWriterCache.get(self._log_dir)
def end(self, session):
- global_step = training_util.global_step(session, self._global_step)
- summary_str = session.run(self._summary_op, self._feed_dict)
+ if self._summary_op is not None:
+ global_step = training_util.global_step(session, self._global_step)
+ summary_str = session.run(self._summary_op, self._feed_dict)
+ if self._summary_writer:
+ self._summary_writer.add_summary(summary_str, global_step)
if self._summary_writer:
- self._summary_writer.add_summary(summary_str, global_step)
self._summary_writer.flush()
diff --git a/tensorflow/contrib/training/python/training/evaluation_test.py b/tensorflow/contrib/training/python/training/evaluation_test.py
index c36d00e842..ddd135f047 100644
--- a/tensorflow/contrib/training/python/training/evaluation_test.py
+++ b/tensorflow/contrib/training/python/training/evaluation_test.py
@@ -67,7 +67,7 @@ class CheckpointIteratorTest(test.TestCase):
global_step = variables.get_or_create_global_step()
saver = saver_lib.Saver() # Saves the global step.
- with self.test_session() as session:
+ with self.cached_session() as session:
session.run(variables_lib.global_variables_initializer())
save_path = os.path.join(checkpoint_dir, 'model.ckpt')
saver.save(session, save_path, global_step=global_step)
@@ -427,9 +427,11 @@ class EvaluateRepeatedlyTest(test.TestCase):
names_to_updates = {'Accuracy': update_op0, 'Another_accuracy': update_op1}
return names_to_values, names_to_updates
- def _verify_summaries(self, output_dir, names_to_values):
+ def _verify_events(self, output_dir, names_to_values):
"""Verifies that the given `names_to_values` are found in the summaries.
+ Also checks that a GraphDef was written out to the events file.
+
Args:
output_dir: An existing directory where summaries are found.
names_to_values: A dictionary of strings to values.
@@ -440,7 +442,13 @@ class EvaluateRepeatedlyTest(test.TestCase):
self.assertEqual(len(output_filepath), 1)
events = summary_iterator.summary_iterator(output_filepath[0])
- summaries = [e.summary for e in events if e.summary.value]
+ summaries = []
+ graph_def = None
+ for event in events:
+ if event.summary.value:
+ summaries.append(event.summary)
+ elif event.graph_def:
+ graph_def = event.graph_def
values = []
for summary in summaries:
for value in summary.value:
@@ -448,6 +456,7 @@ class EvaluateRepeatedlyTest(test.TestCase):
saved_results = {v.tag: v.simple_value for v in values}
for name in names_to_values:
self.assertAlmostEqual(names_to_values[name], saved_results[name], 5)
+ self.assertIsNotNone(graph_def)
def testSummariesAreFlushedToDisk(self):
checkpoint_dir = os.path.join(self.get_temp_dir(), 'summaries_are_flushed')
@@ -475,7 +484,23 @@ class EvaluateRepeatedlyTest(test.TestCase):
],
max_number_of_evaluations=1)
- self._verify_summaries(logdir, names_to_values)
+ self._verify_events(logdir, names_to_values)
+
+ def testSummaryAtEndHookWithoutSummaries(self):
+ logdir = os.path.join(self.get_temp_dir(),
+ 'summary_at_end_hook_without_summaires')
+ if gfile.Exists(logdir):
+ gfile.DeleteRecursively(logdir)
+
+ with ops.Graph().as_default():
+ # Purposefully don't add any summaries. The hook will just dump the
+ # GraphDef event.
+ hook = evaluation.SummaryAtEndHook(log_dir=logdir)
+ hook.begin()
+ with self.cached_session() as session:
+ hook.after_create_session(session, None)
+ hook.end(session)
+ self._verify_events(logdir, {})
if __name__ == '__main__':
diff --git a/tensorflow/contrib/training/python/training/resample_test.py b/tensorflow/contrib/training/python/training/resample_test.py
index 774241a816..8665a24883 100644
--- a/tensorflow/contrib/training/python/training/resample_test.py
+++ b/tensorflow/contrib/training/python/training/resample_test.py
@@ -44,7 +44,7 @@ class ResampleTest(test.TestCase):
([3], [0, 0, 0]),
([0, 1, 2, 3], [1, 2, 2, 3, 3, 3]),
]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for inputs, expected in cases:
array_inputs = numpy.array(inputs, dtype=numpy.int32)
actual = sess.run(resample._repeat_range(array_inputs))
@@ -65,7 +65,7 @@ class ResampleTest(test.TestCase):
init = control_flow_ops.group(variables.local_variables_initializer(),
variables.global_variables_initializer())
- with self.test_session() as s:
+ with self.cached_session() as s:
s.run(init) # initialize
# outputs
@@ -112,7 +112,7 @@ class ResampleTest(test.TestCase):
init = control_flow_ops.group(variables.local_variables_initializer(),
variables.global_variables_initializer())
expected_sum_op = math_ops.reduce_sum(vals)
- with self.test_session() as s:
+ with self.cached_session() as s:
s.run(init)
expected_sum = n * s.run(expected_sum_op)
@@ -147,7 +147,7 @@ class ResampleTest(test.TestCase):
resampled = resample.resample_at_rate([vals], rates)
- with self.test_session() as s:
+ with self.cached_session() as s:
rs, = s.run(resampled, {
vals: list(range(count)),
rates: numpy.zeros(
diff --git a/tensorflow/contrib/training/python/training/sampling_ops_test.py b/tensorflow/contrib/training/python/training/sampling_ops_test.py
index bf7fb4fd48..1aeff7dc80 100644
--- a/tensorflow/contrib/training/python/training/sampling_ops_test.py
+++ b/tensorflow/contrib/training/python/training/sampling_ops_test.py
@@ -146,7 +146,7 @@ class StratifiedSampleTest(test.TestCase):
for illegal_label in illegal_labels:
# Run session that should fail.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors_impl.InvalidArgumentError):
sess.run([val_tf, lbl_tf],
feed_dict={label_ph: illegal_label,
@@ -154,7 +154,7 @@ class StratifiedSampleTest(test.TestCase):
for illegal_prob in illegal_probs:
# Run session that should fail.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaises(errors_impl.InvalidArgumentError):
sess.run([prob_tf],
feed_dict={label_ph: valid_labels,
@@ -172,7 +172,7 @@ class StratifiedSampleTest(test.TestCase):
summary_op = logging_ops.merge_summary(
ops.get_collection(ops.GraphKeys.SUMMARIES))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(coord=coord)
@@ -197,7 +197,7 @@ class StratifiedSampleTest(test.TestCase):
batch_size,
init_probs=[0, .3, 0, .7, 0],
enqueue_many=True)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(coord=coord)
@@ -228,7 +228,7 @@ class StratifiedSampleTest(test.TestCase):
# Run graph to make sure there are no shape-related runtime errors.
for vals, labels in legal_input_pairs:
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([val_tf, labels_tf],
feed_dict={vals_ph: vals,
labels_ph: labels})
@@ -253,7 +253,7 @@ class StratifiedSampleTest(test.TestCase):
self.assertEqual(len(val_list), len(val_input_batch))
self.assertTrue(isinstance(lbls, ops.Tensor))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(coord=coord)
@@ -283,7 +283,7 @@ class StratifiedSampleTest(test.TestCase):
# Run session and keep track of how frequently the labels and values appear.
data_l = []
label_l = []
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Need to initialize variables that keep running total of classes seen.
variables.global_variables_initializer().run()
@@ -374,7 +374,7 @@ class RejectionSampleTest(test.TestCase):
'rejection_sample/prob_with_checks:0')
# Run session that should fail.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
for illegal_prob in [-0.1, 1.1]:
with self.assertRaises(errors_impl.InvalidArgumentError):
sess.run(prob_tensor, feed_dict={prob_ph: illegal_prob})
@@ -393,7 +393,7 @@ class RejectionSampleTest(test.TestCase):
sample = sampling_ops.rejection_sample(tensor_list, accept_prob_fn,
batch_size)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(coord=coord)
diff --git a/tensorflow/contrib/training/python/training/sampling_ops_threading_test.py b/tensorflow/contrib/training/python/training/sampling_ops_threading_test.py
index ca78c0029e..73ad859ab3 100644
--- a/tensorflow/contrib/training/python/training/sampling_ops_threading_test.py
+++ b/tensorflow/contrib/training/python/training/sampling_ops_threading_test.py
@@ -59,7 +59,7 @@ class SamplingOpsThreadingTest(test.TestCase):
out_tensor = queue.dequeue()
# Run the multi-threaded session.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Need to initialize variables that keep running total of classes seen.
variables.global_variables_initializer().run()
diff --git a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver_test.py b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver_test.py
index 7aebd9d9fe..8932b905c9 100644
--- a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver_test.py
+++ b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver_test.py
@@ -36,7 +36,7 @@ from tensorflow.python.platform import test
class SequenceQueueingStateSaverTest(test.TestCase):
def testSequenceInputWrapper(self):
- with self.test_session():
+ with self.cached_session():
length = 3
key = "key"
padded_length = 4
@@ -54,7 +54,7 @@ class SequenceQueueingStateSaverTest(test.TestCase):
self.assertTrue(isinstance(input_wrapper.context["context1"], ops.Tensor))
def testStateSaverWithTwoSimpleSteps(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size_value = 2
batch_size = constant_op.constant(batch_size_value)
num_unroll = 2
@@ -159,7 +159,7 @@ class SequenceQueueingStateSaverTest(test.TestCase):
self.assertEqual(0, state_saver.barrier.ready_size().eval())
def testStateSaverFailsIfPaddedLengthIsNotMultipleOfNumUnroll(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = constant_op.constant(32)
num_unroll = 17
bad_padded_length = 3
@@ -194,7 +194,7 @@ class SequenceQueueingStateSaverTest(test.TestCase):
})
def _testStateSaverFailsIfCapacityTooSmall(self, batch_size):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
num_unroll = 2
length = array_ops.placeholder(dtypes.int32)
key = array_ops.placeholder(dtypes.string)
@@ -243,7 +243,7 @@ class SequenceQueueingStateSaverTest(test.TestCase):
self._testStateSaverFailsIfCapacityTooSmall(batch_size)
def testStateSaverFailsIfInconsistentPaddedLength(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = constant_op.constant(32)
num_unroll = 17
length = array_ops.placeholder(dtypes.int32)
@@ -282,7 +282,7 @@ class SequenceQueueingStateSaverTest(test.TestCase):
def testStateSaverFailsIfInconsistentWriteState(self):
# TODO(b/26910386): Identify why this infrequently causes timeouts.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = constant_op.constant(1)
num_unroll = 17
length = array_ops.placeholder(dtypes.int32)
@@ -326,7 +326,7 @@ class SequenceQueueingStateSaverTest(test.TestCase):
def testStateSaverWithManyInputsReadWriteThread(self):
batch_size_value = 32
num_proc_threads = 100
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = constant_op.constant(batch_size_value)
num_unroll = 17
length = array_ops.placeholder(dtypes.int32)
@@ -490,7 +490,7 @@ class SequenceQueueingStateSaverTest(test.TestCase):
self.assertGreater(processed_count[0], 2 * 20 * batch_size_value)
def testStateSaverProcessesExamplesInOrder(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size_value = 32
batch_size = constant_op.constant(batch_size_value)
num_unroll = 17
@@ -563,7 +563,7 @@ class SequenceQueueingStateSaverTest(test.TestCase):
self.assertEqual(get_ready_size.eval(), 0)
def testStateSaverCanHandleVariableBatchsize(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
batch_size = array_ops.placeholder(dtypes.int32)
num_unroll = 17
length = array_ops.placeholder(dtypes.int32)
diff --git a/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py
index 4a46e9a49e..3269d5fef2 100644
--- a/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py
+++ b/tensorflow/contrib/training/python/training/sgdr_learning_rate_decay_test.py
@@ -62,7 +62,7 @@ class SGDRDecayTest(test_util.TensorFlowTestCase):
def get_sgdr_values(self, lr, initial_period_steps, t_mul, iters):
"""Get an array with learning rate values from the consecutive steps
using current tensorflow implementation."""
- with self.test_session():
+ with self.cached_session():
step = placeholder(dtypes.int32)
decay = sgdr_decay(lr, step, initial_period_steps, t_mul)
@@ -76,7 +76,7 @@ class SGDRDecayTest(test_util.TensorFlowTestCase):
"""Compare values generated by tensorflow implementation to the values
generated by the original implementation
(https://github.com/loshchil/SGDR/blob/master/SGDR_WRNs.py)."""
- with self.test_session():
+ with self.cached_session():
lr = 10.0
init_steps = 2
t_mul = 3
@@ -92,7 +92,7 @@ class SGDRDecayTest(test_util.TensorFlowTestCase):
def testMDecay(self):
"""Test m_mul argument. Check values for learning rate at the beginning
of the first, second, third and fourth period. """
- with self.test_session():
+ with self.cached_session():
step = placeholder(dtypes.int32)
lr = 0.1
@@ -121,7 +121,7 @@ class SGDRDecayTest(test_util.TensorFlowTestCase):
def testCos(self):
"""Check learning rate values at the beginning, in the middle
and at the end of the period."""
- with self.test_session():
+ with self.cached_session():
step = placeholder(dtypes.int32)
lr = 0.2
t_e = 1000
diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py
index df0a186f4f..d9b0511a98 100644
--- a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py
+++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py
@@ -79,7 +79,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
queue_handle, value = iterator.get_next()
enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([[0, 0, 0]], sess.run(value))
value_1, _ = sess.run([value, enqueue_negative])
self.assertAllEqual([[1, 0, 0]], value_1)
@@ -101,7 +101,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
queue_handle, value = iterator.get_next()
enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual([0], sess.run(value))
value_1, _ = sess.run([value, enqueue_negative])
self.assertEqual([1], value_1)
@@ -126,7 +126,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
enqueue_zeroth = tqd.enqueue_in_queue_dataset([queue_handle[0]],
array_ops.expand_dims(
value[0], axis=0))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
value_0, _ = sess.run([value, enqueue_negative])
self.assertAllEqual([0, 1], value_0)
value_1, _ = sess.run([value, enqueue_zeroth])
@@ -147,7 +147,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
tqd.enqueue_in_queue_dataset(queue_handle, value + 100 + i)
for i in range(1000)
]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
value_0, _ = sess.run((value, enqueue_many_more))
self.assertEqual([0], value_0)
rest = []
@@ -174,7 +174,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
queue_handle, value = iterator.get_next()
enqueue = tqd.enqueue_in_queue_dataset(queue_handle, value + 1)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
i = 0
while i < 4:
received, _ = sess.run((value, enqueue))
@@ -199,7 +199,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
batch_size=1, padded_shapes=[2]))
iterator = dataset.make_one_shot_iterator()
_, value = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesOpError(
r"Incompatible input shapes at component 0 between "
r"input dataset this dataset: \[3\] vs. \[2\]"):
@@ -224,7 +224,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
np.array(
[[1]], dtype=np.int32))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with self.assertRaisesOpError(
"mismatched number of tensors. Queue expects 1 tensors but "
"tried to insert 2"):
@@ -274,7 +274,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
with ops.control_dependencies([enqueue_rest_op]):
calc = array_ops.identity(value_head)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([[0, 0], [2, 2], [4, 4]], sess.run(calc))
self.assertAllEqual([[4, 4], [6, 6]], sess.run(calc))
self.assertAllEqual([[6, 6]], sess.run(calc))
@@ -304,7 +304,7 @@ class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
iterator = dataset.make_one_shot_iterator()
_, (unused_count, padded_value) = iterator.get_next()
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertAllEqual([[-1, -1, -1, -1], [2, 2, -1, -1], [4, 4, 4, 4]],
sess.run(padded_value))
self.assertAllEqual([[6] * 6], sess.run(padded_value))
diff --git a/tensorflow/contrib/training/python/training/training_test.py b/tensorflow/contrib/training/python/training/training_test.py
index 94cf7788b2..3b524ac8c7 100644
--- a/tensorflow/contrib/training/python/training/training_test.py
+++ b/tensorflow/contrib/training/python/training/training_test.py
@@ -62,7 +62,7 @@ class ClipGradsTest(test.TestCase):
clipped_gradients_to_variables = training.clip_gradient_norms(
gradients_to_variables, 3.0)
- with self.test_session() as session:
+ with self.cached_session() as session:
session.run(variables_lib2.global_variables_initializer())
self.assertAlmostEqual(4.0, gradients_to_variables[0][0].eval())
self.assertAlmostEqual(3.0, clipped_gradients_to_variables[0][0].eval())
@@ -75,7 +75,7 @@ class ClipGradsTest(test.TestCase):
clipped_gradients_to_variables = training.clip_gradient_norms_fn(3.0)(
gradients_to_variables)
- with self.test_session() as session:
+ with self.cached_session() as session:
session.run(variables_lib2.global_variables_initializer())
self.assertAlmostEqual(4.0, gradients_to_variables[0][0].eval())
self.assertAlmostEqual(3.0, clipped_gradients_to_variables[0][0].eval())
@@ -122,7 +122,7 @@ class CreateTrainOpTest(test.TestCase):
moving_variance = variables_lib.get_variables_by_name('moving_variance')[
0]
- with self.test_session() as session:
+ with self.cached_session() as session:
# Initialize all variables
session.run(variables_lib2.global_variables_initializer())
mean, variance = session.run([moving_mean, moving_variance])
@@ -155,7 +155,7 @@ class CreateTrainOpTest(test.TestCase):
moving_variance = variables_lib.get_variables_by_name('moving_variance')[
0]
- with self.test_session() as session:
+ with self.cached_session() as session:
# Initialize all variables
session.run(variables_lib2.global_variables_initializer())
mean, variance = session.run([moving_mean, moving_variance])
@@ -186,7 +186,7 @@ class CreateTrainOpTest(test.TestCase):
global_step = variables_lib.get_or_create_global_step()
- with self.test_session() as session:
+ with self.cached_session() as session:
# Initialize all variables
session.run(variables_lib2.global_variables_initializer())
@@ -209,7 +209,7 @@ class CreateTrainOpTest(test.TestCase):
global_step = variables_lib.get_or_create_global_step()
- with self.test_session() as session:
+ with self.cached_session() as session:
# Initialize all variables
session.run(variables_lib2.global_variables_initializer())
@@ -535,7 +535,7 @@ class TrainTest(test.TestCase):
train_biases = training.create_train_op(
total_loss, optimizer, variables_to_train=[biases])
- with self.test_session() as session:
+ with self.cached_session() as session:
# Initialize the variables.
session.run(variables_lib2.global_variables_initializer())
diff --git a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
index ad3dce1784..d4951b156c 100644
--- a/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
+++ b/tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
@@ -63,7 +63,7 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync(
}
CHECK(dst_name.compare(rdma_mgr_->local_worker()) == 0);
RdmaChannel* rc = rdma_mgr_->FindChannel(src_name);
- string key(std::move(parsed.FullKey().ToString()));
+ string key(parsed.FullKey());
string key_with_step_id = VerbsUtil::AppendStepidToKey(key, step_id_);
Device* dst_dev;