aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/BUILD11
-rw-r--r--tensorflow/contrib/batching/BUILD58
-rw-r--r--tensorflow/contrib/batching/basic_batch_scheduler.h21
-rw-r--r--tensorflow/contrib/batching/batch_scheduler.h21
-rw-r--r--tensorflow/contrib/batching/serial_device_batch_scheduler.h21
-rw-r--r--tensorflow/contrib/batching/shared_batch_scheduler.h21
-rw-r--r--tensorflow/contrib/batching/test_util/BUILD19
-rw-r--r--tensorflow/contrib/batching/test_util/fake_clock_env.h21
-rw-r--r--tensorflow/contrib/batching/util/BUILD28
-rw-r--r--tensorflow/contrib/batching/util/periodic_function.h20
-rw-r--r--tensorflow/contrib/bigtable/README.md4
-rw-r--r--tensorflow/contrib/bigtable/python/ops/bigtable_api.py4
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py1
-rw-r--r--tensorflow/contrib/cmake/CMakeLists.txt15
-rw-r--r--tensorflow/contrib/cmake/external/jemalloc.cmake50
-rw-r--r--tensorflow/contrib/cmake/external/protobuf.cmake2
-rw-r--r--tensorflow/contrib/cmake/make.bat38
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt3
-rw-r--r--tensorflow/contrib/cmake/tf_core_framework.cmake23
-rw-r--r--tensorflow/contrib/data/README.md18
-rw-r--r--tensorflow/contrib/data/__init__.py11
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD560
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py226
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py987
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/bucketing_test.py824
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py632
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py71
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py148
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py76
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py79
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py811
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py125
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py359
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py281
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/BUILD164
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py65
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py103
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py58
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py225
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py85
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py223
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py183
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py58
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py109
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py851
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py948
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py78
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py1083
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py353
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py)40
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/resample_test.py182
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py172
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/BUILD555
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py83
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py253
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py49
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py73
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py95
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py692
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py71
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py45
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py122
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py61
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py57
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py46
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py83
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py88
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py140
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py66
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py101
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py139
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py50
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py39
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py118
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py46
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py40
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py129
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py85
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py39
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py148
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py53
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py106
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py53
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py99
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py51
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py40
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py54
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py115
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py590
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py95
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py253
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py71
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py91
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py83
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py527
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py118
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD170
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py549
-rw-r--r--tensorflow/contrib/data/python/ops/counter.py13
-rw-r--r--tensorflow/contrib/data/python/ops/enumerate_ops.py15
-rw-r--r--tensorflow/contrib/data/python/ops/error_ops.py37
-rw-r--r--tensorflow/contrib/data/python/ops/get_single_element.py29
-rw-r--r--tensorflow/contrib/data/python/ops/grouping.py441
-rw-r--r--tensorflow/contrib/data/python/ops/indexed_dataset_ops.py177
-rw-r--r--tensorflow/contrib/data/python/ops/interleave_ops.py149
-rw-r--r--tensorflow/contrib/data/python/ops/iterator_ops.py167
-rw-r--r--tensorflow/contrib/data/python/ops/map_defun.py56
-rw-r--r--tensorflow/contrib/data/python/ops/optimization.py171
-rw-r--r--tensorflow/contrib/data/python/ops/parsing_ops.py107
-rw-r--r--tensorflow/contrib/data/python/ops/prefetching_ops.py486
-rw-r--r--tensorflow/contrib/data/python/ops/random_ops.py34
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py674
-rw-r--r--tensorflow/contrib/data/python/ops/resampling.py260
-rw-r--r--tensorflow/contrib/data/python/ops/scan_ops.py137
-rw-r--r--tensorflow/contrib/data/python/ops/shuffle_ops.py56
-rw-r--r--tensorflow/contrib/data/python/ops/stats_ops.py201
-rw-r--r--tensorflow/contrib/data/python/ops/threadpool.py88
-rw-r--r--tensorflow/contrib/data/python/ops/unique.py43
-rw-r--r--tensorflow/contrib/data/python/ops/writers.py40
-rw-r--r--tensorflow/contrib/distribute/python/BUILD36
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py3
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py20
-rw-r--r--tensorflow/contrib/distribute/python/keras_test.py336
-rw-r--r--tensorflow/contrib/distribute/python/metrics_v1_test.py121
-rw-r--r--tensorflow/contrib/distribute/python/minimize_loss_test.py5
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py17
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py2
-rw-r--r--tensorflow/contrib/distribute/python/moving_averages_test.py141
-rw-r--r--tensorflow/contrib/distribute/python/one_device_strategy.py17
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy.py22
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy_test.py3
-rw-r--r--tensorflow/contrib/distribute/python/prefetching_ops_v2.py2
-rw-r--r--tensorflow/contrib/distribute/python/strategy_test_lib.py6
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py40
-rw-r--r--tensorflow/contrib/distribute/python/values.py40
-rw-r--r--tensorflow/contrib/distribute/python/values_test.py8
-rw-r--r--tensorflow/contrib/distributions/BUILD2
-rw-r--r--tensorflow/contrib/eager/python/datasets.py4
-rw-r--r--tensorflow/contrib/eager/python/datasets_test.py6
-rw-r--r--tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb2
-rw-r--r--tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py10
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py12
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/revnet_test.py3
-rw-r--r--tensorflow/contrib/eager/python/remote_test.py5
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py2
-rw-r--r--tensorflow/contrib/estimator/python/estimator/multi_head.py67
-rw-r--r--tensorflow/contrib/estimator/python/estimator/multi_head_test.py75
-rw-r--r--tensorflow/contrib/estimator/python/estimator/rnn_test.py2
-rw-r--r--tensorflow/contrib/factorization/BUILD9
-rw-r--r--tensorflow/contrib/factorization/python/ops/gmm_ops.py14
-rw-r--r--tensorflow/contrib/factorization/python/ops/wals_test.py16
-rw-r--r--tensorflow/contrib/fused_conv/BUILD2
-rw-r--r--tensorflow/contrib/ignite/BUILD139
-rw-r--r--tensorflow/contrib/ignite/README.md167
-rw-r--r--tensorflow/contrib/ignite/__init__.py42
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc334
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h81
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h126
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_client.h84
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_dataset.cc81
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_dataset.h63
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc422
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h99
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc198
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_plain_client.h43
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc123
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc142
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc151
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h51
-rw-r--r--tensorflow/contrib/ignite/ops/dataset_ops.cc56
-rw-r--r--tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py772
-rw-r--r--tensorflow/contrib/ignite/python/ops/ignite_op_loader.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py)25
-rwxr-xr-xtensorflow/contrib/ignite/python/tests/bin/start-plain.sh24
-rw-r--r--tensorflow/contrib/ignite/python/tests/config/ignite-config-plain.xml39
-rw-r--r--tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py82
-rw-r--r--tensorflow/contrib/ignite/python/tests/sql/init.sql20
-rwxr-xr-xtensorflow/contrib/ignite/python/tests/start_ignite.sh22
-rwxr-xr-xtensorflow/contrib/ignite/python/tests/stop_ignite.sh19
-rw-r--r--tensorflow/contrib/image/kernels/image_ops.cc2
-rw-r--r--tensorflow/contrib/image/kernels/image_ops.h7
-rw-r--r--tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc1
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/image_ops_test.py9
-rw-r--r--tensorflow/contrib/lite/BUILD26
-rw-r--r--tensorflow/contrib/lite/c/builtin_op_data.h16
-rw-r--r--tensorflow/contrib/lite/c/builtin_op_data_test.cc2
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc34
-rw-r--r--tensorflow/contrib/lite/delegates/flex/BUILD10
-rw-r--r--tensorflow/contrib/lite/delegates/flex/delegate.cc9
-rw-r--r--tensorflow/contrib/lite/experimental/micro/BUILD76
-rw-r--r--tensorflow/contrib/lite/experimental/micro/README.md114
-rw-r--r--tensorflow/contrib/lite/experimental/micro/compatibility.h32
-rw-r--r--tensorflow/contrib/lite/experimental/micro/examples/micro_speech/BUILD31
-rw-r--r--tensorflow/contrib/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc55
-rw-r--r--tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc1672
-rw-r--r--tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h27
-rw-r--r--tensorflow/contrib/lite/experimental/micro/kernels/BUILD107
-rw-r--r--tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.cc43
-rw-r--r--tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h34
-rw-r--r--tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv.cc208
-rw-r--r--tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test.cc406
-rw-r--r--tensorflow/contrib/lite/experimental/micro/kernels/fully_connected.cc184
-rw-r--r--tensorflow/contrib/lite/experimental/micro/kernels/fully_connected_test.cc643
-rw-r--r--tensorflow/contrib/lite/experimental/micro/kernels/softmax.cc213
-rw-r--r--tensorflow/contrib/lite/experimental/micro/kernels/softmax_test.cc220
-rw-r--r--tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h170
-rw-r--r--tensorflow/contrib/lite/experimental/micro/micro_error_reporter.cc78
-rw-r--r--tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h34
-rw-r--r--tensorflow/contrib/lite/experimental/micro/micro_error_reporter_test.cc (renamed from tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h)16
-rw-r--r--tensorflow/contrib/lite/experimental/micro/micro_interpreter.cc310
-rw-r--r--tensorflow/contrib/lite/experimental/micro/micro_interpreter.h71
-rw-r--r--tensorflow/contrib/lite/experimental/micro/micro_interpreter_test.cc197
-rw-r--r--tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.cc80
-rw-r--r--tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h46
-rw-r--r--tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver_test.cc83
-rw-r--r--tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.cc149
-rw-r--r--tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h51
-rw-r--r--tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator_test.cc144
-rw-r--r--tensorflow/contrib/lite/experimental/micro/testing/BUILD17
-rw-r--r--tensorflow/contrib/lite/experimental/micro/testing/Dockerfile.bluepill21
-rw-r--r--tensorflow/contrib/lite/experimental/micro/testing/bluepill.resc36
-rw-r--r--tensorflow/contrib/lite/experimental/micro/testing/micro_test.bzl67
-rw-r--r--tensorflow/contrib/lite/experimental/micro/testing/micro_test.h138
-rwxr-xr-xtensorflow/contrib/lite/experimental/micro/testing/test_bluepill_binary.sh54
-rwxr-xr-xtensorflow/contrib/lite/experimental/micro/testing/test_linux_binary.sh39
-rw-r--r--tensorflow/contrib/lite/experimental/micro/tools/make/Makefile166
-rwxr-xr-xtensorflow/contrib/lite/experimental/micro/tools/make/download_dependencies.sh73
-rw-r--r--tensorflow/contrib/lite/experimental/micro/tools/make/targets/bluepill_makefile.inc65
-rw-r--r--tensorflow/contrib/lite/g3doc/performance.md10
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/android_build.md18
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/index.md18
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md18
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md18
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md18
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md18
-rw-r--r--tensorflow/contrib/lite/interpreter.h15
-rw-r--r--tensorflow/contrib/lite/interpreter_test.cc6
-rw-r--r--tensorflow/contrib/lite/java/BUILD95
-rw-r--r--tensorflow/contrib/lite/java/aar_with_jni.bzl5
-rw-r--r--tensorflow/contrib/lite/java/ovic/BUILD61
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/BUILD5
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java77
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml27
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/res/values/strings.xml3
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/BoundingBox.java68
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java152
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassificationResult.java (renamed from tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java)12
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java10
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifierBenchmarker.java142
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetectionResult.java91
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetector.java184
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetectorBenchmarker.java160
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicValidator.java2
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java6
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicDetectorTest.java149
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/testdata/BUILD5
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/testdata/coco_labels.txt91
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java26
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java27
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java20
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc22
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h24
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc50
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/tensor_jni.h17
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java46
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java14
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java13
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD32
-rw-r--r--tensorflow/contrib/lite/kernels/activations.cc7
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc435
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc186
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc86
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc56
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons.cc16
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons_test.cc11
-rw-r--r--tensorflow/contrib/lite/kernels/internal/compatibility.h23
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.cc598
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.h184
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h6
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc300
-rw-r--r--tensorflow/contrib/lite/kernels/lstm_eval.cc912
-rw-r--r--tensorflow/contrib/lite/kernels/lstm_eval.h79
-rw-r--r--tensorflow/contrib/lite/kernels/sparse_output_fully_connected.cc235
-rw-r--r--tensorflow/contrib/lite/kernels/sparse_output_fully_connected_test.cc158
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc310
-rw-r--r--tensorflow/contrib/lite/model.cc35
-rw-r--r--tensorflow/contrib/lite/model_flex_test.cc45
-rw-r--r--tensorflow/contrib/lite/model_test.cc22
-rw-r--r--tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h8
-rw-r--r--tensorflow/contrib/lite/profiling/profile_summarizer_test.cc6
-rw-r--r--tensorflow/contrib/lite/profiling/profiler_test.cc4
-rw-r--r--tensorflow/contrib/lite/python/convert.py8
-rw-r--r--tensorflow/contrib/lite/python/interpreter.py17
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc19
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h1
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs12
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h243
-rw-r--r--tensorflow/contrib/lite/testdata/multi_add_flex.binbin0 -> 1052 bytes
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py14
-rw-r--r--tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py81
-rw-r--r--tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py38
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize.cc30
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc30
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc14
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc37
-rw-r--r--tensorflow/contrib/lite/toco/model.h9
-rw-r--r--tensorflow/contrib/lite/toco/model_cmdline_flags.cc18
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.cc241
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.h27
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export_test.cc125
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc32
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.h6
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc32
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/BUILD24
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc12
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h6
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py2
-rw-r--r--tensorflow/contrib/model_pruning/README.md1
-rw-r--r--tensorflow/contrib/opt/BUILD5
-rw-r--r--tensorflow/contrib/opt/python/training/shampoo.py16
-rw-r--r--tensorflow/contrib/opt/python/training/shampoo_test.py8
-rw-r--r--tensorflow/contrib/optimizer_v2/BUILD11
-rw-r--r--tensorflow/contrib/optimizer_v2/adadelta.py75
-rw-r--r--tensorflow/contrib/optimizer_v2/adagrad.py79
-rw-r--r--tensorflow/contrib/optimizer_v2/adagrad_test.py3
-rw-r--r--tensorflow/contrib/optimizer_v2/adam.py129
-rw-r--r--tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py68
-rw-r--r--tensorflow/contrib/optimizer_v2/gradient_descent.py40
-rw-r--r--tensorflow/contrib/optimizer_v2/momentum.py69
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2.py1209
-rw-r--r--tensorflow/contrib/optimizer_v2/rmsprop.py154
-rw-r--r--tensorflow/contrib/quantize/BUILD1
-rw-r--r--tensorflow/contrib/quantize/python/quant_ops.py28
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py115
-rw-r--r--tensorflow/contrib/quantize/python/quantize_graph_test.py37
-rw-r--r--tensorflow/contrib/rnn/BUILD2
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py137
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py245
-rw-r--r--tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py40
-rw-r--r--tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py48
-rw-r--r--tensorflow/contrib/sparsemax/python/ops/sparsemax.py27
-rw-r--r--tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py28
-rw-r--r--tensorflow/contrib/stateless/BUILD13
-rw-r--r--tensorflow/contrib/stateless/__init__.py8
-rw-r--r--tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py154
-rw-r--r--tensorflow/contrib/stateless/python/stateless_ops.py214
-rw-r--r--tensorflow/contrib/tensorrt/BUILD20
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/BUILD7
-rw-r--r--tensorflow/contrib/tpu/BUILD3
-rw-r--r--tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc96
-rw-r--r--tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc3
-rw-r--r--tensorflow/contrib/tpu/profiler/pip_package/setup.py2
-rw-r--r--tensorflow/contrib/tpu/profiler/tf_op_stats.proto6
-rw-r--r--tensorflow/contrib/tpu/profiler/version.h2
-rw-r--r--tensorflow/contrib/tpu/proto/optimization_parameters.proto6
-rw-r--r--tensorflow/contrib/tpu/python/ops/tpu_ops.py148
-rw-r--r--tensorflow/contrib/tpu/python/tpu/datasets.py4
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py285
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_context.py14
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py37
-rw-r--r--tensorflow/contrib/tpu/tpu_estimator.md2
-rw-r--r--tensorflow/contrib/training/BUILD2
-rw-r--r--tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py2
364 files changed, 16677 insertions, 24831 deletions
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 98dff965a9..fa06d351d4 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -29,6 +29,7 @@ py_library(
"//tensorflow/contrib/cluster_resolver:cluster_resolver_py",
"//tensorflow/contrib/coder:coder_py",
"//tensorflow/contrib/compiler:compiler_py",
+ "//tensorflow/contrib/compiler:xla",
"//tensorflow/contrib/autograph",
"//tensorflow/contrib/constrained_optimization",
"//tensorflow/contrib/copy_graph:copy_graph_py",
@@ -123,6 +124,11 @@ py_library(
"//tensorflow/contrib/tensorrt:init_py",
"//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
],
+ }) + select({
+ "//tensorflow:with_ignite_support": [
+ "//tensorflow/contrib/ignite",
+ ],
+ "//conditions:default": [],
}),
)
@@ -184,5 +190,10 @@ cc_library(
"//tensorflow/contrib/kinesis:dataset_ops_op_lib",
"//tensorflow/contrib/tensorrt:trt_engine_op_op_lib",
],
+ }) + select({
+ "//tensorflow:with_ignite_support": [
+ "//tensorflow/contrib/ignite:dataset_ops_op_lib",
+ ],
+ "//conditions:default": [],
}),
)
diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD
index b27a19b16c..648f3ebb05 100644
--- a/tensorflow/contrib/batching/BUILD
+++ b/tensorflow/contrib/batching/BUILD
@@ -7,64 +7,6 @@ package(
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
-
-cc_library(
- name = "batch_scheduler_hdrs",
- hdrs = ["batch_scheduler.h"],
- deps = [
- "//tensorflow/core/kernels/batching_util:batch_scheduler_hdrs",
- ],
-)
-
-cc_library(
- name = "batch_scheduler",
- hdrs = ["batch_scheduler.h"],
- deps = [
- "//tensorflow/core/kernels/batching_util:batch_scheduler",
- ],
-)
-
-cc_library(
- name = "shared_batch_scheduler_hdrs",
- hdrs = ["shared_batch_scheduler.h"],
- deps = [
- "//tensorflow/core/kernels/batching_util:shared_batch_scheduler_hdrs",
- ],
-)
-
-cc_library(
- name = "shared_batch_scheduler",
- hdrs = ["shared_batch_scheduler.h"],
- deps = [
- "//tensorflow/core/kernels/batching_util:shared_batch_scheduler",
- ],
- alwayslink = 1,
-)
-
-cc_library(
- name = "adaptive_shared_batch_scheduler",
- hdrs = ["adaptive_shared_batch_scheduler.h"],
- deps = [
- "//tensorflow/core/kernels/batching_util:adaptive_shared_batch_scheduler",
- ],
-)
-
-cc_library(
- name = "serial_device_batch_scheduler",
- hdrs = ["serial_device_batch_scheduler.h"],
- deps = [
- "//tensorflow/core/kernels/batching_util:serial_device_batch_scheduler",
- ],
-)
-
-cc_library(
- name = "basic_batch_scheduler",
- hdrs = ["basic_batch_scheduler.h"],
- deps = [
- "//tensorflow/core/kernels/batching_util:basic_batch_scheduler",
- ],
-)
-
load(
"//tensorflow:tensorflow.bzl",
"py_test",
diff --git a/tensorflow/contrib/batching/basic_batch_scheduler.h b/tensorflow/contrib/batching/basic_batch_scheduler.h
deleted file mode 100644
index d9b37da693..0000000000
--- a/tensorflow/contrib/batching/basic_batch_scheduler.h
+++ /dev/null
@@ -1,21 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_
-#define TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_
-
-#include "tensorflow/core/kernels/batching_util/basic_batch_scheduler.h"
-
-#endif // TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_
diff --git a/tensorflow/contrib/batching/batch_scheduler.h b/tensorflow/contrib/batching/batch_scheduler.h
deleted file mode 100644
index 8e94e1fd8b..0000000000
--- a/tensorflow/contrib/batching/batch_scheduler.h
+++ /dev/null
@@ -1,21 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_
-#define TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_
-
-#include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
-
-#endif // TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_
diff --git a/tensorflow/contrib/batching/serial_device_batch_scheduler.h b/tensorflow/contrib/batching/serial_device_batch_scheduler.h
deleted file mode 100644
index bf6b708361..0000000000
--- a/tensorflow/contrib/batching/serial_device_batch_scheduler.h
+++ /dev/null
@@ -1,21 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_
-#define TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_
-
-#include "tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h"
-
-#endif // TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_
diff --git a/tensorflow/contrib/batching/shared_batch_scheduler.h b/tensorflow/contrib/batching/shared_batch_scheduler.h
deleted file mode 100644
index 83a59695d7..0000000000
--- a/tensorflow/contrib/batching/shared_batch_scheduler.h
+++ /dev/null
@@ -1,21 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_
-#define TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_
-
-#include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
-
-#endif // TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_
diff --git a/tensorflow/contrib/batching/test_util/BUILD b/tensorflow/contrib/batching/test_util/BUILD
deleted file mode 100644
index 7cb2d8079b..0000000000
--- a/tensorflow/contrib/batching/test_util/BUILD
+++ /dev/null
@@ -1,19 +0,0 @@
-# Description: Utilities to aid testing.
-
-package(
- default_visibility = ["//tensorflow:internal"],
-)
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-cc_library(
- name = "fake_clock_env",
- testonly = 1,
- hdrs = ["fake_clock_env.h"],
- visibility = ["//visibility:public"],
- deps = [
- "//tensorflow/core/kernels/batching_util:fake_clock_env",
- ],
-)
diff --git a/tensorflow/contrib/batching/test_util/fake_clock_env.h b/tensorflow/contrib/batching/test_util/fake_clock_env.h
deleted file mode 100644
index 40a39a5569..0000000000
--- a/tensorflow/contrib/batching/test_util/fake_clock_env.h
+++ /dev/null
@@ -1,21 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_
-#define TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_
-
-#include "tensorflow/core/kernels/batching_util/fake_clock_env.h"
-
-#endif // TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_
diff --git a/tensorflow/contrib/batching/util/BUILD b/tensorflow/contrib/batching/util/BUILD
deleted file mode 100644
index 8f81b6702f..0000000000
--- a/tensorflow/contrib/batching/util/BUILD
+++ /dev/null
@@ -1,28 +0,0 @@
-# Description: Utilities.
-
-package(
- default_visibility = ["//tensorflow:internal"],
-)
-
-licenses(["notice"]) # Apache 2.0
-
-load("//tensorflow:tensorflow.bzl", "tf_cc_test")
-
-cc_library(
- name = "periodic_function_dynamic",
- hdrs = ["periodic_function.h"],
- visibility = ["//visibility:public"],
- deps = [
- "//tensorflow/core/kernels/batching_util:periodic_function_dynamic",
- "//third_party/eigen3",
- ],
-)
-
-cc_library(
- name = "periodic_function",
- visibility = ["//visibility:public"],
- deps = [
- ":periodic_function_dynamic",
- "//tensorflow/core/kernels/batching_util:periodic_function",
- ],
-)
diff --git a/tensorflow/contrib/batching/util/periodic_function.h b/tensorflow/contrib/batching/util/periodic_function.h
deleted file mode 100644
index aa2ed0a385..0000000000
--- a/tensorflow/contrib/batching/util/periodic_function.h
+++ /dev/null
@@ -1,20 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_
-#define TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_
-
-#include "tensorflow/core/kernels/batching_util/periodic_function.h"
-
-#endif // TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_
diff --git a/tensorflow/contrib/bigtable/README.md b/tensorflow/contrib/bigtable/README.md
index f33eaf7e3d..2c44abed5e 100644
--- a/tensorflow/contrib/bigtable/README.md
+++ b/tensorflow/contrib/bigtable/README.md
@@ -203,7 +203,7 @@ def interleave_fn(index):
start = tf.string_join(['training_data_', start_idx_str])
end = tf.string_join(['training_data_', end_idx_str])
return table.scan_range(start_idx, end_idx, columns=columns)
-ds = ds.apply(tf.contrib.data.parallel_interleave(
+ds = ds.apply(tf.data.experimental.parallel_interleave(
interleave_fn, cycle_length=NUM_PARALLEL_READS, prefetch_input_elements=1))
```
@@ -249,7 +249,7 @@ def make_row_key_dataset():
- ...
- fake-data-23498103
"""
- counter_dataset = tf.contrib.data.Counter()
+ counter_dataset = tf.data.experimental.Counter()
width = 8
row_key_prefix = 'fake-data-'
ds = counter_dataset.map(lambda index: tf.as_string(index,
diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
index cf56822ff4..7c87b0daeb 100644
--- a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
+++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
@@ -31,8 +31,8 @@ from six import iteritems
from six import string_types
from tensorflow.contrib.bigtable.ops import gen_bigtable_ops
-from tensorflow.contrib.data.python.ops import interleave_ops
from tensorflow.contrib.util import loader
+from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
@@ -228,7 +228,7 @@ class BigtableTable(object):
"""Retrieves a sampling of row keys from the Bigtable table.
This dataset is most often used in conjunction with
- `tf.contrib.data.parallel_interleave` to construct a set of ranges for
+ `tf.data.experimental.parallel_interleave` to construct a set of ranges for
scanning in parallel.
Returns:
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
index 1056894f18..f4a8e16c99 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
@@ -60,6 +60,7 @@ class TPUClusterResolver(ClusterResolver):
if (self._tpu == compat.as_bytes('') or
self._tpu == compat.as_bytes('local') or
self._tpu.startswith(compat.as_bytes('/bns')) or
+ self._tpu.startswith(compat.as_bytes('localhost:')) or
self._tpu.startswith(compat.as_bytes('grpc://'))):
return False
return True
diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt
index c6d6f04168..60f53b8b75 100644
--- a/tensorflow/contrib/cmake/CMakeLists.txt
+++ b/tensorflow/contrib/cmake/CMakeLists.txt
@@ -30,7 +30,6 @@ endif()
option(tensorflow_ENABLE_GRPC_SUPPORT "Enable gRPC support" ON)
option(tensorflow_ENABLE_HDFS_SUPPORT "Enable HDFS support" OFF)
-option(tensorflow_ENABLE_JEMALLOC_SUPPORT "Enable jemalloc support" OFF)
option(tensorflow_BUILD_CC_EXAMPLE "Build the C++ tutorial example" ON)
option(tensorflow_BUILD_PYTHON_BINDINGS "Build the Python bindings" ON)
option(tensorflow_BUILD_ALL_KERNELS "Build all OpKernels" ON)
@@ -218,10 +217,6 @@ if (tensorflow_WIN_CPU_SIMD_OPTIONS)
endif()
endif()
-if (tensorflow_ENABLE_JEMALLOC_SUPPORT)
- add_definitions(-DTENSORFLOW_USE_JEMALLOC -DJEMALLOC_EXPORT=)
-endif()
-
# External dependencies
include(zlib)
include(gif)
@@ -329,12 +324,6 @@ if(tensorflow_ENABLE_GRPC_SUPPORT)
list(APPEND tensorflow_EXTERNAL_DEPENDENCIES boringssl)
endif()
endif()
-if(tensorflow_ENABLE_JEMALLOC_SUPPORT)
- include(jemalloc)
- list(APPEND tensorflow_EXTERNAL_LIBRARIES ${jemalloc_STATIC_LIBRARIES})
- list(APPEND tensorflow_EXTERNAL_DEPENDENCIES jemalloc)
- include_directories(${jemalloc_INCLUDE_DIRS})
-endif()
if(tensorflow_ENABLE_SNAPPY_SUPPORT)
include(snappy)
list(APPEND tensorflow_EXTERNAL_LIBRARIES ${snappy_STATIC_LIBRARIES})
@@ -363,9 +352,7 @@ if (tensorflow_ENABLE_MKL_SUPPORT)
list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkldnn_STATIC_LIBRARIES})
list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkldnn_copy_shared_to_destination)
include_directories(${mkldnn_INCLUDE_DIRS})
- else (tensorflow_ENABLE_MKLDNN_SUPPORT)
- add_definitions(-DINTEL_MKL_ML_ONLY)
- endif()
+ endif(tensorflow_ENABLE_MKLDNN_SUPPORT)
endif (tensorflow_ENABLE_MKL_SUPPORT)
if (tensorflow_ENABLE_GPU)
diff --git a/tensorflow/contrib/cmake/external/jemalloc.cmake b/tensorflow/contrib/cmake/external/jemalloc.cmake
deleted file mode 100644
index afadcc007d..0000000000
--- a/tensorflow/contrib/cmake/external/jemalloc.cmake
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-include (ExternalProject)
-
-set(jemalloc_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/jemalloc/src/jemalloc/include)
-set(jemalloc_URL https://mirror.bazel.build/github.com/jemalloc/jemalloc-cmake/archive/jemalloc-cmake.4.3.1.tar.gz)
-set(jemalloc_HASH SHA256=f9be9a05fe906deb5c1c8ca818071a7d2e27d66fd87f5ba9a7bf3750bcedeaf0)
-set(jemalloc_BUILD ${CMAKE_CURRENT_BINARY_DIR}/jemalloc/src/jemalloc)
-
-if (WIN32)
- set(jemalloc_INCLUDE_DIRS
- ${jemalloc_INCLUDE_DIRS}
- ${CMAKE_CURRENT_BINARY_DIR}/jemalloc/src/jemalloc/include/msvc_compat
- )
- if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*")
- set(jemalloc_STATIC_LIBRARIES ${jemalloc_BUILD}/Release/jemalloc.lib)
- else()
- set(jemalloc_STATIC_LIBRARIES ${jemalloc_BUILD}/jemalloc.lib)
- endif()
-else()
- set(jemalloc_STATIC_LIBRARIES ${jemalloc_BUILD}/Release/jemalloc.a)
-endif()
-
-ExternalProject_Add(jemalloc
- PREFIX jemalloc
- URL ${jemalloc_URL}
- URL_HASH ${jemalloc_HASH}
- DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
- BUILD_IN_SOURCE 1
- BUILD_BYPRODUCTS ${jemalloc_STATIC_LIBRARIES}
- BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target jemalloc
- INSTALL_COMMAND ${CMAKE_COMMAND} -E echo "Skipping install step."
- CMAKE_CACHE_ARGS
- -DCMAKE_BUILD_TYPE:STRING=Release
- -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF
- -Dwith-jemalloc-prefix:STRING=jemalloc_
- -Dwithout-export:BOOL=ON
-)
diff --git a/tensorflow/contrib/cmake/external/protobuf.cmake b/tensorflow/contrib/cmake/external/protobuf.cmake
index f56fb35a0f..56a57a2340 100644
--- a/tensorflow/contrib/cmake/external/protobuf.cmake
+++ b/tensorflow/contrib/cmake/external/protobuf.cmake
@@ -16,7 +16,7 @@ include (ExternalProject)
set(PROTOBUF_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src)
set(PROTOBUF_URL https://github.com/google/protobuf.git)
-set(PROTOBUF_TAG v3.6.0)
+set(PROTOBUF_TAG v3.6.1)
if(WIN32)
if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*")
diff --git a/tensorflow/contrib/cmake/make.bat b/tensorflow/contrib/cmake/make.bat
new file mode 100644
index 0000000000..d52b24e01d
--- /dev/null
+++ b/tensorflow/contrib/cmake/make.bat
@@ -0,0 +1,38 @@
+%echo off
+
+cd /d %~dp0
+
+if exist _build rd /s /q _build
+
+mkdir _build
+chdir _build
+
+
+rem cmake ../ -G "Visual Studio 15 Win64" -DCMAKE_GENERATOR_TOOLSET=v141,host=x64 -DCMAKE_INSTALL_PREFIX:PATH=.\install
+
+CALL :NORMALIZEPATH "..\..\..\.."
+SET SOURCE_DIR=%RETVAL%
+
+echo %SOURCE_DIR%
+
+SET SOURCE_DIR=F:\frameworks\tensorflow\
+
+CALL :NORMALIZEPATH "../../../tools/git/gen_git_source.py"
+SET SOURCE_PYTHON_SCRIPT=%RETVAL%
+
+CALL :NORMALIZEPATH "../../../core/util/version_info.cc"
+SET SOURCE_VERSION_CC=%RETVAL%
+
+python %SOURCE_PYTHON_SCRIPT% --raw_generate %SOURCE_VERSION_CC% --source_dir %SOURCE_DIR% --git_tag_override=
+
+cmake ../ -G "Visual Studio 15 Win64" -DCMAKE_GENERATOR_TOOLSET=v141,host=x64 -DCMAKE_INSTALL_PREFIX:PATH=.\install
+
+EXIT /B
+
+:NORMALIZEPATH
+ SET RETVAL=%~dpfn1
+ EXIT /B
+
+
+
+ \ No newline at end of file
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index 2975b167ec..6e72670142 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -134,7 +134,6 @@ tensorflow/contrib/cudnn_rnn/python/ops
tensorflow/contrib/data
tensorflow/contrib/data/python
tensorflow/contrib/data/python/kernel_tests
-tensorflow/contrib/data/python/kernel_tests/serialization
tensorflow/contrib/data/python/ops
tensorflow/contrib/decision_trees
tensorflow/contrib/decision_trees/proto
@@ -206,6 +205,8 @@ tensorflow/contrib/integrate/python
tensorflow/contrib/integrate/python/ops
tensorflow/contrib/kafka/python
tensorflow/contrib/kafka/python/ops
+tensorflow/contrib/ignite/python
+tensorflow/contrib/ignite/python/ops
tensorflow/contrib/keras
tensorflow/contrib/keras/api
tensorflow/contrib/keras/api/keras
diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake
index 067c299a71..7e806685b8 100644
--- a/tensorflow/contrib/cmake/tf_core_framework.cmake
+++ b/tensorflow/contrib/cmake/tf_core_framework.cmake
@@ -258,14 +258,21 @@ add_dependencies(tf_core_lib ${tensorflow_EXTERNAL_DEPENDENCIES} tf_protos_cc)
# force_rebuild always runs forcing ${VERSION_INFO_CC} target to run
# ${VERSION_INFO_CC} would cache, but it depends on a phony never produced
# target.
-set(VERSION_INFO_CC ${tensorflow_source_dir}/tensorflow/core/util/version_info.cc)
-add_custom_target(force_rebuild_target ALL DEPENDS ${VERSION_INFO_CC})
-add_custom_command(OUTPUT __force_rebuild COMMAND ${CMAKE_COMMAND} -E echo)
-add_custom_command(OUTPUT
- ${VERSION_INFO_CC}
- COMMAND ${PYTHON_EXECUTABLE} ${tensorflow_source_dir}/tensorflow/tools/git/gen_git_source.py
- ARGS --raw_generate ${VERSION_INFO_CC} --source_dir ${tensorflow_source_dir} --git_tag_override=${GIT_TAG_OVERRIDE}
- DEPENDS __force_rebuild)
+# This code forces rebuild every time, not needed as version from git is fetched only once
+# move to make.bat which mimicks make.sh
+
+if (NOT WIN32)
+
+ set(VERSION_INFO_CC ${tensorflow_source_dir}/tensorflow/core/util/version_info.cc)
+ add_custom_target(force_rebuild_target ALL DEPENDS ${VERSION_INFO_CC})
+ add_custom_command(OUTPUT __force_rebuild COMMAND ${CMAKE_COMMAND} -E echo)
+ add_custom_command(OUTPUT
+ ${VERSION_INFO_CC}
+ COMMAND ${PYTHON_EXECUTABLE} ${tensorflow_source_dir}/tensorflow/tools/git/gen_git_source.py
+ ARGS --raw_generate ${VERSION_INFO_CC} --source_dir ${tensorflow_source_dir} --git_tag_override=${GIT_TAG_OVERRIDE}
+ DEPENDS __force_rebuild)
+endif()
+
set(tf_version_srcs ${tensorflow_source_dir}/tensorflow/core/util/version_info.cc)
########################################################
diff --git a/tensorflow/contrib/data/README.md b/tensorflow/contrib/data/README.md
index 848782e8d8..90be7a66ca 100644
--- a/tensorflow/contrib/data/README.md
+++ b/tensorflow/contrib/data/README.md
@@ -1,10 +1,12 @@
`tf.contrib.data` API
=====================
-NOTE: The `tf.contrib.data` module has been deprecated. Use `tf.data` instead.
-We are continuing to support existing code using the `tf.contrib.data` APIs in
-the current version of TensorFlow, but will eventually remove support. The
-`tf.data` APIs are subject to backwards compatibility guarantees.
+NOTE: The `tf.contrib.data` module has been deprecated. Use `tf.data` instead,
+or `tf.data.experimental` for the experimental transformations previously hosted
+in this module. We are continuing to support existing code using the
+`tf.contrib.data` APIs in the current version of TensorFlow, but will eventually
+remove support. The non-experimental `tf.data` APIs are subject to backwards
+compatibility guarantees.
Porting your code to `tf.data`
------------------------------
@@ -25,13 +27,13 @@ instead apply them using `Dataset.apply()` transformation. The full list of
changes is as follows:
* `dataset.dense_to_sparse_batch(...)` is now
- `dataset.apply(tf.contrib.data.dense_to_sparse_batch(...)`.
+ `dataset.apply(tf.data.experimental.dense_to_sparse_batch(...)`.
* `dataset.enumerate(...)` is now
- `dataset.apply(tf.contrib.data.enumerate_dataset(...))`.
+ `dataset.apply(tf.data.experimental.enumerate_dataset(...))`.
* `dataset.group_by_window(...)` is now
- `dataset.apply(tf.contrib.data.group_by_window(...))`.
+ `dataset.apply(tf.data.experimental.group_by_window(...))`.
* `dataset.ignore_errors()` is now
- `dataset.apply(tf.contrib.data.ignore_errors())`.
+ `dataset.apply(tf.data.experimental.ignore_errors())`.
* `dataset.unbatch()` is now `dataset.apply(tf.contrib.data.unbatch())`.
The `Dataset.make_dataset_resource()` and `Iterator.dispose_op()` methods have
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index 3cb51279c3..c3d3e981fa 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -96,10 +96,6 @@ from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datase
from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
-
-# Optimization constant that can be used to enable auto-tuning.
-from tensorflow.contrib.data.python.ops.optimization import AUTOTUNE
-
from tensorflow.contrib.data.python.ops.parsing_ops import parse_example_dataset
from tensorflow.contrib.data.python.ops.prefetching_ops import copy_to_device
from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device
@@ -114,11 +110,12 @@ from tensorflow.contrib.data.python.ops.resampling import rejection_resample
from tensorflow.contrib.data.python.ops.scan_ops import scan
from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat
from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch
-from tensorflow.contrib.data.python.ops.stats_ops import latency_stats
-from tensorflow.contrib.data.python.ops.stats_ops import set_stats_aggregator
-from tensorflow.contrib.data.python.ops.stats_ops import StatsAggregator
from tensorflow.contrib.data.python.ops.unique import unique
from tensorflow.contrib.data.python.ops.writers import TFRecordWriter
+
+# Optimization constant that can be used to enable auto-tuning.
+from tensorflow.python.data.experimental.ops.optimization import AUTOTUNE
+
from tensorflow.python.data.ops.iterator_ops import get_next_as_optional
from tensorflow.python.data.ops.optional_ops import Optional
# pylint: enable=unused-import
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 33784afa3f..42f538b4ba 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -8,51 +8,17 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "py_test")
py_test(
- name = "batch_dataset_op_test",
- size = "medium",
- srcs = ["batch_dataset_op_test.py"],
+ name = "assert_element_shape_test",
+ srcs = ["assert_element_shape_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss", # (b/79552534)
- "no_pip",
- ],
deps = [
"//tensorflow/contrib/data/python/ops:batching",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
"//tensorflow/python:script_ops",
- "//tensorflow/python:session",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:string_ops",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python:util",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "bucketing_test",
- size = "medium",
- srcs = ["bucketing_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:grouping",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python/data/kernel_tests:test_base",
@@ -62,147 +28,6 @@ py_test(
)
py_test(
- name = "csv_dataset_op_test",
- size = "medium",
- srcs = ["csv_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- "//tensorflow/contrib/data/python/ops:error_ops",
- "//tensorflow/contrib/data/python/ops:readers",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:platform_test",
- "//tensorflow/python:session",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:readers",
- "//tensorflow/python/eager:context",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "dataset_constructor_op_test",
- size = "medium",
- srcs = ["dataset_constructor_op_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "manual",
- "nomac", # b/62040583
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- ],
-)
-
-py_test(
- name = "directed_interleave_dataset_test",
- size = "medium",
- srcs = ["directed_interleave_dataset_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:interleave_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:random_seed",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "get_single_element_test",
- size = "small",
- srcs = ["get_single_element_test.py"],
- deps = [
- "//tensorflow/contrib/data/python/ops:get_single_element",
- "//tensorflow/contrib/data/python/ops:grouping",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "indexed_dataset_ops_test",
- srcs = ["indexed_dataset_ops_test.py"],
- deps = [
- "//tensorflow/contrib/data/python/ops:indexed_dataset_ops",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:experimental_dataset_ops_gen",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "interleave_dataset_op_test",
- size = "medium",
- srcs = ["interleave_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "notap",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:interleave_ops",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:script_ops",
- "//tensorflow/python:sparse_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "@six_archive//:six",
- ],
-)
-
-py_test(
- name = "iterator_ops_test",
- size = "small",
- srcs = ["iterator_ops_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- "//tensorflow/contrib/data/python/ops:iterator_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variables",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/estimator:estimator_py",
- ],
-)
-
-py_test(
name = "lmdb_dataset_op_test",
size = "medium",
srcs = ["lmdb_dataset_op_test.py"],
@@ -229,252 +54,18 @@ py_test(
)
py_test(
- name = "map_dataset_op_test",
- size = "medium",
- srcs = ["map_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_pip",
- "noasan", # times out
- "optonly",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/contrib/data/python/ops:error_ops",
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:io_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:util",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "filter_dataset_op_test",
- size = "medium",
- srcs = ["filter_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:io_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:util",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "map_defun_op_test",
+ name = "reduce_dataset_test",
size = "small",
- srcs = ["map_defun_op_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ srcs = ["reduce_dataset_test.py"],
deps = [
- "//tensorflow/contrib/data/python/ops:map_defun",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:data_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:function",
- "//tensorflow/python:functional_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:session",
- "//tensorflow/python/data/kernel_tests:test_base",
- ],
-)
-
-py_test(
- name = "parsing_ops_test",
- size = "small",
- srcs = ["parsing_ops_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:parsing_ops",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//third_party/py/numpy",
- ],
-)
-
-cuda_py_test(
- name = "prefetching_ops_test",
- size = "small",
- srcs = ["prefetching_ops_test.py"],
- additional_deps = [
- "//tensorflow/contrib/data/python/ops:prefetching_ops",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:function",
- "//tensorflow/python:resource_variable_ops",
- "//tensorflow/python/compat:compat",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- ],
- tags = ["no_windows_gpu"],
-)
-
-py_test(
- name = "range_dataset_op_test",
- size = "small",
- srcs = ["range_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:counter",
- "//tensorflow/contrib/data/python/ops:enumerate_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_library(
- name = "reader_dataset_ops_test_base",
- testonly = 1,
- srcs = [
- "reader_dataset_ops_test_base.py",
- ],
- srcs_version = "PY2AND3",
- visibility = [
- "//tensorflow/contrib/data/python/kernel_tests:__pkg__",
- "//tensorflow/contrib/data/python/kernel_tests/serialization:__pkg__",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:readers",
- "//tensorflow/core:protos_all_py",
+ "//tensorflow/contrib/data/python/ops:get_single_element",
+ "//tensorflow/contrib/data/python/ops:grouping",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
- "//tensorflow/python:lib",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:util",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python/data/ops:readers",
- ],
-)
-
-py_test(
- name = "reader_dataset_ops_test",
- size = "medium",
- srcs = ["reader_dataset_ops_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":reader_dataset_ops_test_base",
- "//tensorflow/contrib/data/python/ops:readers",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:string_ops",
- "//tensorflow/python/data/ops:readers",
- "//tensorflow/python/data/util:nest",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "resample_test",
- size = "medium",
- srcs = ["resample_test.py"],
- shard_count = 2,
- srcs_version = "PY2AND3",
- tags = [
- "noasan",
- "optonly",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:resampling",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:string_ops",
- "//tensorflow/python:util",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
- "@six_archive//:six",
- ],
-)
-
-py_test(
- name = "scan_dataset_op_test",
- size = "small",
- srcs = ["scan_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- "//tensorflow/contrib/data/python/ops:scan_ops",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/eager:context",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "shuffle_dataset_op_test",
- size = "medium",
- srcs = ["shuffle_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_pip",
- "optonly",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:shuffle_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
],
)
@@ -496,142 +87,3 @@ py_test(
"@absl_py//absl/testing:parameterized",
],
)
-
-py_library(
- name = "sql_dataset_op_test_base",
- srcs = ["sql_dataset_op_test_base.py"],
- srcs_version = "PY2AND3",
- visibility = [
- "//tensorflow/contrib/data/python/kernel_tests:__pkg__",
- "//tensorflow/contrib/data/python/kernel_tests/serialization:__pkg__",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:readers",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python/data/kernel_tests:test_base",
- "@org_sqlite//:python",
- ],
-)
-
-py_test(
- name = "sql_dataset_op_test",
- size = "small",
- srcs = ["sql_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":sql_dataset_op_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- ],
-)
-
-py_test(
- name = "stats_dataset_ops_test",
- size = "medium",
- srcs = ["stats_dataset_ops_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":reader_dataset_ops_test_base",
- ":stats_dataset_test_base",
- "//tensorflow/contrib/data/python/ops:stats_ops",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_library(
- name = "stats_dataset_test_base",
- srcs = ["stats_dataset_test_base.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/kernel_tests:test_base",
- ],
-)
-
-py_test(
- name = "threadpool_dataset_ops_test",
- size = "small",
- srcs = ["threadpool_dataset_ops_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- "//tensorflow/contrib/data/python/ops:threadpool",
- "//tensorflow/contrib/data/python/ops:unique",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:script_ops",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "unique_dataset_op_test",
- size = "small",
- srcs = ["unique_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- "//tensorflow/contrib/data/python/ops:unique",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:util",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "window_dataset_op_test",
- size = "medium",
- srcs = ["window_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_pip",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/contrib/data/python/ops:grouping",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "writer_ops_test",
- size = "small",
- srcs = ["writer_ops_test.py"],
- deps = [
- "//tensorflow/contrib/data/python/ops:writers",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:lib",
- "//tensorflow/python:util",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:readers",
- ],
-)
diff --git a/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py
new file mode 100644
index 0000000000..0456463a19
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py
@@ -0,0 +1,226 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.data.python.ops import batching
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import script_ops
+from tensorflow.python.platform import test
+
+
+class AssertElementShapeTest(test_base.DatasetTestBase):
+
+ def test_assert_element_shape(self):
+
+ def create_dataset(_):
+ return (array_ops.ones(2, dtype=dtypes.float32),
+ array_ops.zeros((3, 4), dtype=dtypes.int32))
+
+ dataset = dataset_ops.Dataset.range(5).map(create_dataset)
+ expected_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 4)))
+ self.assertEqual(expected_shapes, dataset.output_shapes)
+
+ result = dataset.apply(batching.assert_element_shape(expected_shapes))
+ self.assertEqual(expected_shapes, result.output_shapes)
+
+ iterator = result.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run(get_next)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def test_assert_wrong_element_shape(self):
+
+ def create_dataset(_):
+ return (array_ops.ones(2, dtype=dtypes.float32),
+ array_ops.zeros((3, 4), dtype=dtypes.int32))
+
+ dataset = dataset_ops.Dataset.range(3).map(create_dataset)
+ wrong_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 10)))
+ with self.assertRaises(ValueError):
+ dataset.apply(batching.assert_element_shape(wrong_shapes))
+
+ def test_assert_element_shape_on_unknown_shape_dataset(self):
+
+ def create_unknown_shape_dataset(x):
+ return script_ops.py_func(
+ lambda _: ( # pylint: disable=g-long-lambda
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
+ [x],
+ [dtypes.float32, dtypes.int32])
+
+ dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset)
+ unknown_shapes = (tensor_shape.TensorShape(None),
+ tensor_shape.TensorShape(None))
+ self.assertEqual(unknown_shapes, dataset.output_shapes)
+
+ expected_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 4)))
+ result = dataset.apply(batching.assert_element_shape(expected_shapes))
+ self.assertEqual(expected_shapes, result.output_shapes)
+
+ iterator = result.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run(get_next)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def test_assert_wrong_element_shape_on_unknown_shape_dataset(self):
+
+ def create_unknown_shape_dataset(x):
+ return script_ops.py_func(
+ lambda _: ( # pylint: disable=g-long-lambda
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
+ [x],
+ [dtypes.float32, dtypes.int32])
+
+ dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
+ unknown_shapes = (tensor_shape.TensorShape(None),
+ tensor_shape.TensorShape(None))
+ self.assertEqual(unknown_shapes, dataset.output_shapes)
+
+ wrong_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 10)))
+ iterator = (
+ dataset.apply(batching.assert_element_shape(wrong_shapes))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(get_next)
+
+ def test_assert_partial_element_shape(self):
+
+ def create_dataset(_):
+ return (array_ops.ones(2, dtype=dtypes.float32),
+ array_ops.zeros((3, 4), dtype=dtypes.int32))
+
+ dataset = dataset_ops.Dataset.range(5).map(create_dataset)
+ partial_expected_shape = (
+ tensor_shape.TensorShape(None), # Unknown shape
+ tensor_shape.TensorShape((None, 4))) # Partial shape
+ result = dataset.apply(
+ batching.assert_element_shape(partial_expected_shape))
+ # Partial shapes are merged with actual shapes:
+ actual_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 4)))
+ self.assertEqual(actual_shapes, result.output_shapes)
+
+ iterator = result.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run(get_next)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def test_assert_wrong_partial_element_shape(self):
+
+ def create_dataset(_):
+ return (array_ops.ones(2, dtype=dtypes.float32),
+ array_ops.zeros((3, 4), dtype=dtypes.int32))
+
+ dataset = dataset_ops.Dataset.range(3).map(create_dataset)
+ wrong_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((None, 10)))
+ with self.assertRaises(ValueError):
+ dataset.apply(batching.assert_element_shape(wrong_shapes))
+
+ def test_assert_partial_element_shape_on_unknown_shape_dataset(self):
+
+ def create_unknown_shape_dataset(x):
+ return script_ops.py_func(
+ lambda _: ( # pylint: disable=g-long-lambda
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
+ [x],
+ [dtypes.float32, dtypes.int32])
+
+ dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset)
+ unknown_shapes = (tensor_shape.TensorShape(None),
+ tensor_shape.TensorShape(None))
+ self.assertEqual(unknown_shapes, dataset.output_shapes)
+
+ expected_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((None, 4)))
+ result = dataset.apply(batching.assert_element_shape(expected_shapes))
+ self.assertEqual(expected_shapes, result.output_shapes)
+
+ iterator = result.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run(get_next)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def test_assert_wrong_partial_element_shape_on_unknown_shape_dataset(self):
+
+ def create_unknown_shape_dataset(x):
+ return script_ops.py_func(
+ lambda _: ( # pylint: disable=g-long-lambda
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
+ [x],
+ [dtypes.float32, dtypes.int32])
+
+ dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
+ unknown_shapes = (tensor_shape.TensorShape(None),
+ tensor_shape.TensorShape(None))
+ self.assertEqual(unknown_shapes, dataset.output_shapes)
+
+ wrong_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((None, 10)))
+ iterator = (
+ dataset.apply(batching.assert_element_shape(wrong_shapes))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
deleted file mode 100644
index fed7de5f2b..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ /dev/null
@@ -1,987 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import math
-import time
-
-from absl.testing import parameterized
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.python.client import session
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import script_ops
-from tensorflow.python.ops import string_ops
-from tensorflow.python.platform import test
-from tensorflow.python.util import compat
-
-
-class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
-
-
- def testDenseToSparseBatchDataset(self):
- components = np.random.randint(12, size=(100,)).astype(np.int32)
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components)
- .map(lambda x: array_ops.fill([x], x)).apply(
- batching.dense_to_sparse_batch(4, [12]))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
-
- for start in range(0, len(components), 4):
- results = sess.run(get_next)
- self.assertAllEqual([[i, j]
- for i, c in enumerate(components[start:start + 4])
- for j in range(c)], results.indices)
- self.assertAllEqual(
- [c for c in components[start:start + 4] for _ in range(c)],
- results.values)
- self.assertAllEqual([min(4,
- len(components) - start), 12],
- results.dense_shape)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testDenseToSparseBatchDatasetWithUnknownShape(self):
- components = np.random.randint(5, size=(40,)).astype(np.int32)
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components)
- .map(lambda x: array_ops.fill([x, x], x)).apply(
- batching.dense_to_sparse_batch(
- 4, [5, None])).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
-
- for start in range(0, len(components), 4):
- results = sess.run(get_next)
- self.assertAllEqual([[i, j, z]
- for i, c in enumerate(components[start:start + 4])
- for j in range(c)
- for z in range(c)], results.indices)
- self.assertAllEqual([
- c
- for c in components[start:start + 4] for _ in range(c)
- for _ in range(c)
- ], results.values)
- self.assertAllEqual([
- min(4,
- len(components) - start), 5,
- np.max(components[start:start + 4])
- ], results.dense_shape)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testDenseToSparseBatchDatasetWithInvalidShape(self):
- input_tensor = array_ops.constant([[1]])
- with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"):
- dataset_ops.Dataset.from_tensors(input_tensor).apply(
- batching.dense_to_sparse_batch(4, [-2])).make_initializable_iterator()
-
- def testDenseToSparseBatchDatasetShapeErrors(self):
- input_tensor = array_ops.placeholder(dtypes.int32)
- iterator = (
- dataset_ops.Dataset.from_tensors(input_tensor).apply(
- batching.dense_to_sparse_batch(4, [12]))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- # Initialize with an input tensor of incompatible rank.
- sess.run(init_op, feed_dict={input_tensor: [[1]]})
- with self.assertRaisesRegexp(errors.InvalidArgumentError,
- "incompatible with the row shape"):
- sess.run(get_next)
-
- # Initialize with an input tensor that is larger than `row_shape`.
- sess.run(init_op, feed_dict={input_tensor: range(13)})
- with self.assertRaisesRegexp(errors.DataLossError,
- "larger than the row shape"):
- sess.run(get_next)
-
- def testUnbatchScalarDataset(self):
- data = tuple([math_ops.range(10) for _ in range(3)])
- data = dataset_ops.Dataset.from_tensor_slices(data)
- expected_types = (dtypes.int32,) * 3
- data = data.batch(2)
- self.assertEqual(expected_types, data.output_types)
- data = data.apply(batching.unbatch())
- self.assertEqual(expected_types, data.output_types)
-
- iterator = data.make_one_shot_iterator()
- op = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual((i,) * 3, sess.run(op))
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(op)
-
- def testUnbatchDatasetWithStrings(self):
- data = tuple([math_ops.range(10) for _ in range(3)])
- data = dataset_ops.Dataset.from_tensor_slices(data)
- data = data.map(lambda x, y, z: (x, string_ops.as_string(y), z))
- expected_types = (dtypes.int32, dtypes.string, dtypes.int32)
- data = data.batch(2)
- self.assertEqual(expected_types, data.output_types)
- data = data.apply(batching.unbatch())
- self.assertEqual(expected_types, data.output_types)
-
- iterator = data.make_one_shot_iterator()
- op = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(op)
-
- def testUnbatchDatasetWithSparseTensor(self):
- st = sparse_tensor.SparseTensorValue(
- indices=[[i, i] for i in range(10)],
- values=list(range(10)),
- dense_shape=[10, 10])
- data = dataset_ops.Dataset.from_tensors(st)
- data = data.apply(batching.unbatch())
- data = data.batch(5)
- data = data.apply(batching.unbatch())
- iterator = data.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- st_row = sess.run(next_element)
- self.assertEqual([i], st_row.indices)
- self.assertEqual([i], st_row.values)
- self.assertEqual([10], st_row.dense_shape)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testUnbatchDatasetWithDenseAndSparseTensor(self):
- st = sparse_tensor.SparseTensorValue(
- indices=[[i, i] for i in range(10)],
- values=list(range(10)),
- dense_shape=[10, 10])
- data = dataset_ops.Dataset.from_tensors((list(range(10)), st))
- data = data.apply(batching.unbatch())
- data = data.batch(5)
- data = data.apply(batching.unbatch())
- iterator = data.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- dense_elem, st_row = sess.run(next_element)
- self.assertEqual(i, dense_elem)
- self.assertEqual([i], st_row.indices)
- self.assertEqual([i], st_row.values)
- self.assertEqual([10], st_row.dense_shape)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testUnbatchSingleElementTupleDataset(self):
- data = tuple([(math_ops.range(10),) for _ in range(3)])
- data = dataset_ops.Dataset.from_tensor_slices(data)
- expected_types = ((dtypes.int32,),) * 3
- data = data.batch(2)
- self.assertEqual(expected_types, data.output_types)
- data = data.apply(batching.unbatch())
- self.assertEqual(expected_types, data.output_types)
-
- iterator = data.make_one_shot_iterator()
- op = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual(((i,),) * 3, sess.run(op))
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(op)
-
- def testUnbatchMultiElementTupleDataset(self):
- data = tuple([(math_ops.range(10 * i, 10 * i + 10),
- array_ops.fill([10], "hi")) for i in range(3)])
- data = dataset_ops.Dataset.from_tensor_slices(data)
- expected_types = ((dtypes.int32, dtypes.string),) * 3
- data = data.batch(2)
- self.assertAllEqual(expected_types, data.output_types)
- data = data.apply(batching.unbatch())
- self.assertAllEqual(expected_types, data.output_types)
-
- iterator = data.make_one_shot_iterator()
- op = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")),
- sess.run(op))
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(op)
-
- def testUnbatchEmpty(self):
- data = dataset_ops.Dataset.from_tensors(
- (constant_op.constant([]), constant_op.constant([], shape=[0, 4]),
- constant_op.constant([], shape=[0, 4, 0])))
- data = data.apply(batching.unbatch())
- iterator = data.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testUnbatchStaticShapeMismatch(self):
- data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8),
- np.arange(9)))
- with self.assertRaises(ValueError):
- data.apply(batching.unbatch())
-
- def testUnbatchDynamicShapeMismatch(self):
- ph1 = array_ops.placeholder(dtypes.int32, shape=[None])
- ph2 = array_ops.placeholder(dtypes.int32, shape=None)
- data = dataset_ops.Dataset.from_tensors((ph1, ph2))
- data = data.apply(batching.unbatch())
- iterator = data.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- # Mismatch in the 0th dimension.
- sess.run(
- iterator.initializer,
- feed_dict={
- ph1: np.arange(7).astype(np.int32),
- ph2: np.arange(8).astype(np.int32)
- })
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(next_element)
-
- # No 0th dimension (i.e. scalar value) for one component.
- sess.run(
- iterator.initializer,
- feed_dict={
- ph1: np.arange(7).astype(np.int32),
- ph2: 7
- })
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(next_element)
-
- def testBatchAndDropRemainder(self):
- components = (np.arange(7),
- np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
- np.array(37.0) * np.arange(7))
-
- batch_size = array_ops.placeholder(dtypes.int64, shape=[])
-
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components).apply(
- batching.batch_and_drop_remainder(batch_size))
- .make_initializable_iterator())
-
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for test_batch_size in [1, 3, 7, 10]:
- sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size})
- num_batches = 7 // test_batch_size
- for i in range(num_batches):
- result = sess.run(next_element)
- for component, result_component in zip(components, result):
- for j in range(test_batch_size):
- self.assertAllEqual(component[(i * test_batch_size + j)],
- result_component[j])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testBatchAndDropRemainderSparse(self):
-
- def _sparse(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0]], values=(i * [1]), dense_shape=[1])
-
- iterator = dataset_ops.Dataset.range(12).map(_sparse).apply(
- batching.batch_and_drop_remainder(5)).make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- for i in range(2):
- actual = sess.run(get_next)
- expected = sparse_tensor.SparseTensorValue(
- indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
- values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
- dense_shape=[5, 1])
- self.assertTrue(sparse_tensor.is_sparse(actual))
- self.assertSparseValuesEqual(actual, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testPaddedBatchAndDropRemainder(self):
- els = []
- for length in [3, 6, 9, 4, 12, 10, 2]:
- els.append((np.array(length), np.arange(length) + 1,
- np.array(length * 2)))
-
- dataset = dataset_ops.Dataset.from_tensors(els[0])
- for el in els[1:]:
- dataset = dataset.concatenate(dataset_ops.Dataset.from_tensors(el))
-
- batch_size = array_ops.placeholder(dtypes.int64, shape=[])
- iterator = (
- dataset.apply(
- batching.padded_batch_and_drop_remainder(
- batch_size, ([], [None], []))).make_initializable_iterator())
-
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for test_batch_size in [1, 3, 7, 10]:
- sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size})
- num_batches = 7 // test_batch_size
- for i in range(num_batches):
- result = sess.run(next_element)
- for component_idx, result_component in enumerate(result):
- for j in range(test_batch_size):
- data_idx = i * test_batch_size + j
- comp = result_component[j]
- unpadded = comp[comp > 0]
- if np.isscalar(comp):
- # The boolean mask indexing above adds a dim back. Rm it.
- unpadded = unpadded[0]
- self.assertAllEqual(els[data_idx][component_idx], unpadded)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPaddedBatchAndDropRemainderSparseError(self):
-
- def _map_fn(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i
-
- with self.assertRaises(TypeError):
- _ = dataset_ops.Dataset.range(10).map(_map_fn).apply(
- batching.padded_batch_and_drop_remainder(5))
-
- def testBatchAndDropRemainderShapeInference(self):
- components = (array_ops.placeholder(dtypes.int32),
- (array_ops.placeholder(dtypes.int32, shape=[None]),
- array_ops.placeholder(dtypes.int32, shape=[20, 30])))
-
- # Test with a statically known batch size.
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(components).apply(
- batching.batch_and_drop_remainder(128)))
-
- self.assertIs(None, dataset.output_shapes[0].ndims)
- self.assertEqual([128], dataset.output_shapes[1][0].as_list())
- self.assertEqual([128, 30], dataset.output_shapes[1][1].as_list())
-
- # Test with a dynamic batch size: the static shape will be unknown, because
- # `batch_size` is a placeholder.
- batch_size = array_ops.placeholder(dtypes.int64)
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(components).apply(
- batching.batch_and_drop_remainder(batch_size)))
-
- self.assertIs(None, dataset.output_shapes[0].ndims)
- self.assertEqual([None], dataset.output_shapes[1][0].as_list())
- self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list())
-
- @parameterized.named_parameters(
- ("Default", None, None),
- ("SequentialCalls", 1, None),
- ("ParallelCalls", 2, None),
- ("ParallelBatches", None, 10),
- )
- def testMapAndBatch(self, num_parallel_calls, num_parallel_batches):
- """Test a dataset that maps a TF function across its input elements."""
- # The pipeline is TensorSliceDataset ->
- # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size).
- components = (np.arange(7),
- np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
- np.array(37.0) * np.arange(7))
-
- count = array_ops.placeholder(dtypes.int64, shape=[])
- batch_size = array_ops.placeholder(dtypes.int64, shape=[])
-
- def _map_fn(x, y, z):
- return math_ops.square(x), math_ops.square(y), math_ops.square(z)
-
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components).repeat(count).apply(
- batching.map_and_batch(
- map_func=_map_fn,
- batch_size=batch_size,
- num_parallel_calls=num_parallel_calls,
- num_parallel_batches=num_parallel_batches))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- self.assertEqual([[None] + list(c.shape[1:]) for c in components],
- [t.shape.as_list() for t in get_next])
-
- with self.cached_session() as sess:
- # Batch of a finite input, where the batch_size divides the
- # total number of elements.
- sess.run(init_op, feed_dict={count: 28, batch_size: 14})
- num_batches = (28 * 7) // 14
- for i in range(num_batches):
- result = sess.run(get_next)
- for component, result_component in zip(components, result):
- for j in range(14):
- self.assertAllEqual(component[(i * 14 + j) % 7]**2,
- result_component[j])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Batch of a finite input, where the batch_size does not
- # divide the total number of elements.
- sess.run(init_op, feed_dict={count: 14, batch_size: 8})
-
- # We expect (num_batches - 1) full-sized batches.
- num_batches = int(math.ceil((14 * 7) / 8))
- for i in range(num_batches - 1):
- result = sess.run(get_next)
- for component, result_component in zip(components, result):
- for j in range(8):
- self.assertAllEqual(component[(i * 8 + j) % 7]**2,
- result_component[j])
- result = sess.run(get_next)
- for component, result_component in zip(components, result):
- for j in range((14 * 7) % 8):
- self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
- result_component[j])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Batch of an empty input should fail straight away.
- sess.run(init_op, feed_dict={count: 0, batch_size: 8})
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Empty batch should be an initialization time error.
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(init_op, feed_dict={count: 14, batch_size: 0})
-
- @parameterized.named_parameters(
- ("Even", False),
- ("Uneven", True),
- )
- def testMapAndBatchPartialBatch(self, drop_remainder):
- iterator = (
- dataset_ops.Dataset.range(10).apply(
- batching.map_and_batch(
- lambda x: array_ops.reshape(x * x, [1]),
- batch_size=4,
- drop_remainder=drop_remainder)).make_one_shot_iterator())
- if drop_remainder:
- self.assertEqual([4, 1], iterator.output_shapes.as_list())
- else:
- self.assertEqual([None, 1], iterator.output_shapes.as_list())
- next_element = iterator.get_next()
- with self.cached_session() as sess:
- self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
- self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
- if not drop_remainder:
- self.assertAllEqual([[64], [81]], sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testMapAndBatchYieldsPartialBatch(self):
- iterator = (dataset_ops.Dataset.range(10)
- .apply(batching.map_and_batch(
- lambda x: array_ops.reshape(x * x, [1]), 4))
- .make_one_shot_iterator())
- self.assertEqual([None, 1], iterator.output_shapes.as_list())
- next_element = iterator.get_next()
- with self.cached_session() as sess:
- self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
- self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
- self.assertAllEqual([[64], [81]], sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testMapAndBatchParallelGetNext(self):
- iterator = (dataset_ops.Dataset.range(50000)
- .apply(batching.map_and_batch(lambda x: x, batch_size=100))
- .make_one_shot_iterator())
- elements = []
- for _ in range(100):
- elements.append(iterator.get_next())
- with self.cached_session() as sess:
- for i in range(5):
- got = sess.run(elements)
- got.sort(key=lambda x: x[0])
- expected = []
- for j in range(100):
- expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
- self.assertAllEqual(got, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elements)
-
- def testMapAndBatchParallelGetNextDropRemainder(self):
- iterator = (
- dataset_ops.Dataset.range(49999).apply(
- batching.map_and_batch(
- lambda x: x, batch_size=100, drop_remainder=True))
- .make_one_shot_iterator())
- elements = []
- for _ in range(100):
- elements.append(iterator.get_next())
- with self.cached_session() as sess:
- for i in range(4):
- got = sess.run(elements)
- got.sort(key=lambda x: x[0])
- expected = []
- for j in range(100):
- expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
- self.assertAllEqual(got, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elements)
-
- def testMapAndBatchSparse(self):
-
- def _sparse(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0]], values=(i * [1]), dense_shape=[1])
-
- iterator = dataset_ops.Dataset.range(10).apply(
- batching.map_and_batch(_sparse, 5)).make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- for i in range(2):
- actual = sess.run(get_next)
- expected = sparse_tensor.SparseTensorValue(
- indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
- values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
- dense_shape=[5, 1])
- self.assertTrue(sparse_tensor.is_sparse(actual))
- self.assertSparseValuesEqual(actual, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testMapAndBatchFails(self):
- """Test a dataset that maps a TF function across its input elements."""
- dataset = dataset_ops.Dataset.from_tensors(
- array_ops.check_numerics(
- constant_op.constant(1.0) / constant_op.constant(0.0), "oops"))
- batch_size = array_ops.placeholder(dtypes.int64, shape=[])
- iterator = (
- dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
- .make_initializable_iterator())
- init_op = iterator.initializer
- with self.cached_session() as sess:
- with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
- sess.run(init_op, feed_dict={batch_size: 14})
-
- def testMapAndBatchShapeMismatch(self):
- """Test a dataset that maps a TF function across its input elements."""
-
- def generator():
- yield [1]
- yield [2]
- yield [3]
- yield [[4, 5, 6]]
-
- dataset = dataset_ops.Dataset.from_generator(
- generator, output_types=dtypes.int32)
- batch_size = 4
- iterator = (
- dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- with self.assertRaisesRegexp(errors.InvalidArgumentError,
- "number of elements does not match"):
- sess.run(get_next)
-
- def testMapAndBatchImplicitDispose(self):
- # Tests whether a map and batch dataset will be cleaned up correctly when
- # the pipeline does not run it until exhaustion.
- # The pipeline is TensorSliceDataset -> RepeatDataset(1000) ->
- # MapAndBatchDataset(f=square_3, batch_size=100).
- components = (np.arange(1000),
- np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
- np.array(37.0) * np.arange(1000))
-
- def _map_fn(x, y, z):
- return math_ops.square(x), math_ops.square(y), math_ops.square(z)
-
- dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat(
- 1000).apply(batching.map_and_batch(_map_fn, batch_size=100))
- dataset = dataset.prefetch(5)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- for _ in range(3):
- sess.run(get_next)
-
- @parameterized.named_parameters(
- ("1", 0),
- ("2", 5),
- ("3", 10),
- ("4", 90),
- ("5", 95),
- ("6", 99),
- )
- def testMapAndBatchOutOfRangeError(self, threshold):
-
- def raising_py_fn(i):
- if i >= threshold:
- raise StopIteration()
- else:
- return i
-
- iterator = (
- dataset_ops.Dataset.range(100).apply(
- batching.map_and_batch(
- lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64),
- batch_size=10)).make_one_shot_iterator())
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(threshold // 10):
- self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
- if threshold % 10 != 0:
- self.assertAllEqual(
- [threshold // 10 * 10 + j for j in range(threshold % 10)],
- sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- @parameterized.named_parameters(
- ("1", False, dtypes.bool),
- ("2", -42, dtypes.int8),
- ("3", -42, dtypes.int16),
- ("4", -42, dtypes.int32),
- ("5", -42, dtypes.int64),
- ("6", 42, dtypes.uint8),
- ("7", 42, dtypes.uint16),
- ("8", 42.0, dtypes.float16),
- ("9", 42.0, dtypes.float32),
- ("10", 42.0, dtypes.float64),
- ("11", b"hello", dtypes.string),
- )
- def testMapAndBatchTypes(self, element, dtype):
- def gen():
- yield element
-
- dataset = dataset_ops.Dataset.from_generator(gen, dtype).repeat(100).apply(
- batching.map_and_batch(lambda x: x, batch_size=10))
-
- get_next = dataset.make_one_shot_iterator().get_next()
-
- with self.cached_session() as sess:
- for _ in range(10):
- self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
-
-
-class RestructuredDatasetTest(test_base.DatasetTestBase):
-
- def test_assert_element_shape(self):
-
- def create_dataset(_):
- return (array_ops.ones(2, dtype=dtypes.float32),
- array_ops.zeros((3, 4), dtype=dtypes.int32))
-
- dataset = dataset_ops.Dataset.range(5).map(create_dataset)
- expected_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((3, 4)))
- self.assertEqual(expected_shapes, dataset.output_shapes)
-
- result = dataset.apply(batching.assert_element_shape(expected_shapes))
- self.assertEqual(expected_shapes, result.output_shapes)
-
- iterator = result.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- for _ in range(5):
- sess.run(get_next)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def test_assert_wrong_element_shape(self):
-
- def create_dataset(_):
- return (array_ops.ones(2, dtype=dtypes.float32),
- array_ops.zeros((3, 4), dtype=dtypes.int32))
-
- dataset = dataset_ops.Dataset.range(3).map(create_dataset)
- wrong_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((3, 10)))
- with self.assertRaises(ValueError):
- dataset.apply(batching.assert_element_shape(wrong_shapes))
-
- def test_assert_element_shape_on_unknown_shape_dataset(self):
-
- def create_unknown_shape_dataset(x):
- return script_ops.py_func(
- lambda _: ( # pylint: disable=g-long-lambda
- np.ones(2, dtype=np.float32),
- np.zeros((3, 4), dtype=np.int32)),
- [x],
- [dtypes.float32, dtypes.int32])
-
- dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset)
- unknown_shapes = (tensor_shape.TensorShape(None),
- tensor_shape.TensorShape(None))
- self.assertEqual(unknown_shapes, dataset.output_shapes)
-
- expected_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((3, 4)))
- result = dataset.apply(batching.assert_element_shape(expected_shapes))
- self.assertEqual(expected_shapes, result.output_shapes)
-
- iterator = result.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- for _ in range(5):
- sess.run(get_next)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def test_assert_wrong_element_shape_on_unknown_shape_dataset(self):
-
- def create_unknown_shape_dataset(x):
- return script_ops.py_func(
- lambda _: ( # pylint: disable=g-long-lambda
- np.ones(2, dtype=np.float32),
- np.zeros((3, 4), dtype=np.int32)),
- [x],
- [dtypes.float32, dtypes.int32])
-
- dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
- unknown_shapes = (tensor_shape.TensorShape(None),
- tensor_shape.TensorShape(None))
- self.assertEqual(unknown_shapes, dataset.output_shapes)
-
- wrong_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((3, 10)))
- iterator = (
- dataset.apply(batching.assert_element_shape(wrong_shapes))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
-
- def test_assert_partial_element_shape(self):
-
- def create_dataset(_):
- return (array_ops.ones(2, dtype=dtypes.float32),
- array_ops.zeros((3, 4), dtype=dtypes.int32))
-
- dataset = dataset_ops.Dataset.range(5).map(create_dataset)
- partial_expected_shape = (tensor_shape.TensorShape(None), # Unknown shape
- tensor_shape.TensorShape((None, 4))) # Partial shape
- result = dataset.apply(
- batching.assert_element_shape(partial_expected_shape))
- # Partial shapes are merged with actual shapes:
- actual_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((3, 4)))
- self.assertEqual(actual_shapes, result.output_shapes)
-
- iterator = result.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- for _ in range(5):
- sess.run(get_next)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def test_assert_wrong_partial_element_shape(self):
-
- def create_dataset(_):
- return (array_ops.ones(2, dtype=dtypes.float32),
- array_ops.zeros((3, 4), dtype=dtypes.int32))
-
- dataset = dataset_ops.Dataset.range(3).map(create_dataset)
- wrong_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((None, 10)))
- with self.assertRaises(ValueError):
- dataset.apply(batching.assert_element_shape(wrong_shapes))
-
- def test_assert_partial_element_shape_on_unknown_shape_dataset(self):
-
- def create_unknown_shape_dataset(x):
- return script_ops.py_func(
- lambda _: ( # pylint: disable=g-long-lambda
- np.ones(2, dtype=np.float32),
- np.zeros((3, 4), dtype=np.int32)),
- [x],
- [dtypes.float32, dtypes.int32])
-
- dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset)
- unknown_shapes = (tensor_shape.TensorShape(None),
- tensor_shape.TensorShape(None))
- self.assertEqual(unknown_shapes, dataset.output_shapes)
-
- expected_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((None, 4)))
- result = dataset.apply(batching.assert_element_shape(expected_shapes))
- self.assertEqual(expected_shapes, result.output_shapes)
-
- iterator = result.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- for _ in range(5):
- sess.run(get_next)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def test_assert_wrong_partial_element_shape_on_unknown_shape_dataset(self):
-
- def create_unknown_shape_dataset(x):
- return script_ops.py_func(
- lambda _: ( # pylint: disable=g-long-lambda
- np.ones(2, dtype=np.float32),
- np.zeros((3, 4), dtype=np.int32)),
- [x],
- [dtypes.float32, dtypes.int32])
-
- dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
- unknown_shapes = (tensor_shape.TensorShape(None),
- tensor_shape.TensorShape(None))
- self.assertEqual(unknown_shapes, dataset.output_shapes)
-
- wrong_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((None, 10)))
- iterator = (
- dataset.apply(batching.assert_element_shape(wrong_shapes))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
-
-
-class UnbatchDatasetBenchmark(test.Benchmark):
-
- def benchmarkNativeUnbatch(self):
- batch_sizes = [1, 2, 5, 10, 20, 50]
- elems_per_trial = 10000
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
- batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
- dataset = dataset.batch(batch_size_placeholder)
- dataset = dataset.apply(batching.unbatch())
- dataset = dataset.skip(elems_per_trial)
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for batch_size in batch_sizes:
- deltas = []
- for _ in range(5):
- sess.run(
- iterator.initializer,
- feed_dict={batch_size_placeholder: batch_size})
- start = time.time()
- sess.run(next_element.op)
- end = time.time()
- deltas.append((end - start) / elems_per_trial)
-
- median_wall_time = np.median(deltas)
- print("Unbatch (native) batch size: %d Median wall time per element:"
- " %f microseconds" % (batch_size, median_wall_time * 1e6))
- self.report_benchmark(
- iters=10000,
- wall_time=median_wall_time,
- name="benchmark_unbatch_dataset_native_batch_size_%d" %
- batch_size)
-
- # Include a benchmark of the previous `unbatch()` implementation that uses
- # a composition of more primitive ops. Eventually we'd hope to generate code
- # that is as good in both cases.
- def benchmarkOldUnbatchImplementation(self):
- batch_sizes = [1, 2, 5, 10, 20, 50]
- elems_per_trial = 10000
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
- batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
- dataset = dataset.batch(batch_size_placeholder)
- dataset = dataset.flat_map(dataset_ops.Dataset.from_tensor_slices)
- dataset = dataset.skip(elems_per_trial)
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for batch_size in batch_sizes:
- deltas = []
- for _ in range(5):
- sess.run(
- iterator.initializer,
- feed_dict={batch_size_placeholder: batch_size})
- start = time.time()
- sess.run(next_element.op)
- end = time.time()
- deltas.append((end - start) / elems_per_trial)
-
- median_wall_time = np.median(deltas)
- print("Unbatch (unfused) batch size: %d Median wall time per element:"
- " %f microseconds" % (batch_size, median_wall_time * 1e6))
- self.report_benchmark(
- iters=10000,
- wall_time=median_wall_time,
- name="benchmark_unbatch_dataset_unfused_batch_size_%d" %
- batch_size)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
deleted file mode 100644
index ae401f786c..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
+++ /dev/null
@@ -1,824 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import random
-
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import grouping
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import string_ops
-from tensorflow.python.platform import test
-
-
-class GroupByReducerTest(test_base.DatasetTestBase):
-
- def checkResults(self, dataset, shapes, values):
- self.assertEqual(shapes, dataset.output_shapes)
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- for expected in values:
- got = sess.run(get_next)
- self.assertEqual(got, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testSum(self):
- reducer = grouping.Reducer(
- init_func=lambda _: np.int64(0),
- reduce_func=lambda x, y: x + y,
- finalize_func=lambda x: x)
- for i in range(1, 11):
- dataset = dataset_ops.Dataset.range(2 * i).apply(
- grouping.group_by_reducer(lambda x: x % 2, reducer))
- self.checkResults(
- dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
-
- def testAverage(self):
-
- def reduce_fn(x, y):
- return (x[0] * x[1] + math_ops.cast(y, dtypes.float32)) / (
- x[1] + 1), x[1] + 1
-
- reducer = grouping.Reducer(
- init_func=lambda _: (0.0, 0.0),
- reduce_func=reduce_fn,
- finalize_func=lambda x, _: x)
- for i in range(1, 11):
- dataset = dataset_ops.Dataset.range(2 * i).apply(
- grouping.group_by_reducer(
- lambda x: math_ops.cast(x, dtypes.int64) % 2, reducer))
- self.checkResults(
- dataset, shapes=tensor_shape.scalar(), values=[i - 1, i])
-
- def testConcat(self):
- components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray)
- reducer = grouping.Reducer(
- init_func=lambda x: "",
- reduce_func=lambda x, y: x + y[0],
- finalize_func=lambda x: x)
- for i in range(1, 11):
- dataset = dataset_ops.Dataset.zip(
- (dataset_ops.Dataset.from_tensor_slices(components),
- dataset_ops.Dataset.range(2 * i))).apply(
- grouping.group_by_reducer(lambda x, y: y % 2, reducer))
- self.checkResults(
- dataset,
- shapes=tensor_shape.scalar(),
- values=[b"acegikmoqs" [:i], b"bdfhjlnprt" [:i]])
-
- def testSparseSum(self):
- def _sparse(i):
- return sparse_tensor.SparseTensorValue(
- indices=np.array([[0, 0]]),
- values=(i * np.array([1], dtype=np.int64)),
- dense_shape=np.array([1, 1]))
-
- reducer = grouping.Reducer(
- init_func=lambda _: _sparse(np.int64(0)),
- reduce_func=lambda x, y: _sparse(x.values[0] + y.values[0]),
- finalize_func=lambda x: x.values[0])
- for i in range(1, 11):
- dataset = dataset_ops.Dataset.range(2 * i).map(_sparse).apply(
- grouping.group_by_reducer(lambda x: x.values[0] % 2, reducer))
- self.checkResults(
- dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
-
- def testChangingStateShape(self):
-
- def reduce_fn(x, _):
- # Statically known rank, but dynamic length.
- larger_dim = array_ops.concat([x[0], x[0]], 0)
- # Statically unknown rank.
- larger_rank = array_ops.expand_dims(x[1], 0)
- return larger_dim, larger_rank
-
- reducer = grouping.Reducer(
- init_func=lambda x: ([0], 1),
- reduce_func=reduce_fn,
- finalize_func=lambda x, y: (x, y))
-
- for i in range(1, 11):
- dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply(
- grouping.group_by_reducer(lambda x: x, reducer))
- self.assertEqual([None], dataset.output_shapes[0].as_list())
- self.assertIs(None, dataset.output_shapes[1].ndims)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- x, y = sess.run(get_next)
- self.assertAllEqual([0] * (2**i), x)
- self.assertAllEqual(np.array(1, ndmin=i), y)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testTypeMismatch(self):
- reducer = grouping.Reducer(
- init_func=lambda x: constant_op.constant(1, dtype=dtypes.int32),
- reduce_func=lambda x, y: constant_op.constant(1, dtype=dtypes.int64),
- finalize_func=lambda x: x)
-
- dataset = dataset_ops.Dataset.range(10)
- with self.assertRaisesRegexp(
- TypeError,
- "The element types for the new state must match the initial state."):
- dataset.apply(
- grouping.group_by_reducer(lambda _: np.int64(0), reducer))
-
- # TODO(b/78665031): Remove once non-scalar keys are supported.
- def testInvalidKeyShape(self):
- reducer = grouping.Reducer(
- init_func=lambda x: np.int64(0),
- reduce_func=lambda x, y: x + y,
- finalize_func=lambda x: x)
-
- dataset = dataset_ops.Dataset.range(10)
- with self.assertRaisesRegexp(
- ValueError, "`key_func` must return a single tf.int64 tensor."):
- dataset.apply(
- grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer))
-
- # TODO(b/78665031): Remove once non-int64 keys are supported.
- def testInvalidKeyType(self):
- reducer = grouping.Reducer(
- init_func=lambda x: np.int64(0),
- reduce_func=lambda x, y: x + y,
- finalize_func=lambda x: x)
-
- dataset = dataset_ops.Dataset.range(10)
- with self.assertRaisesRegexp(
- ValueError, "`key_func` must return a single tf.int64 tensor."):
- dataset.apply(
- grouping.group_by_reducer(lambda _: "wrong", reducer))
-
- def testTuple(self):
- def init_fn(_):
- return np.array([], dtype=np.int64), np.int64(0)
-
- def reduce_fn(state, value):
- s1, s2 = state
- v1, v2 = value
- return array_ops.concat([s1, [v1]], 0), s2 + v2
-
- def finalize_fn(s1, s2):
- return s1, s2
-
- reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn)
- dataset = dataset_ops.Dataset.zip(
- (dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply(
- grouping.group_by_reducer(lambda x, y: np.int64(0), reducer))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- x, y = sess.run(get_next)
- self.assertAllEqual(x, np.asarray([x for x in range(10)]))
- self.assertEqual(y, 45)
-
-
-class GroupByWindowTest(test_base.DatasetTestBase):
-
- def testSimple(self):
- components = np.random.randint(100, size=(200,)).astype(np.int64)
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components).map(lambda x: x * x)
- .apply(
- grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4),
- 4)).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- counts = []
- with self.assertRaises(errors.OutOfRangeError):
- while True:
- result = sess.run(get_next)
- self.assertTrue(
- all(x % 2 == 0
- for x in result) or all(x % 2 == 1)
- for x in result)
- counts.append(result.shape[0])
-
- self.assertEqual(len(components), sum(counts))
- num_full_batches = len([c for c in counts if c == 4])
- self.assertGreaterEqual(num_full_batches, 24)
- self.assertTrue(all(c == 4 for c in counts[:num_full_batches]))
-
- def testImmediateOutput(self):
- components = np.array(
- [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply(
- grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4),
- 4)).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- # The input is infinite, so this test demonstrates that:
- # 1. We produce output without having to consume the entire input,
- # 2. Different buckets can produce output at different rates, and
- # 3. For deterministic input, the output is deterministic.
- for _ in range(3):
- self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
- self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
- self.assertAllEqual([2, 2, 2, 2], sess.run(get_next))
- self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
-
- def testSmallGroups(self):
- components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components).apply(
- grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4),
- 4)).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
- self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
- # The small outputs at the end are deterministically produced in key
- # order.
- self.assertAllEqual([0, 0, 0], sess.run(get_next))
- self.assertAllEqual([1], sess.run(get_next))
-
- def testEmpty(self):
- iterator = (
- dataset_ops.Dataset.range(4).apply(
- grouping.group_by_window(lambda _: 0, lambda _, xs: xs, 0))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- "Window size must be greater than zero, but got 0."):
- print(sess.run(get_next))
-
- def testReduceFuncError(self):
- components = np.random.randint(100, size=(200,)).astype(np.int64)
-
- def reduce_func(_, xs):
- # Introduce an incorrect padded shape that cannot (currently) be
- # detected at graph construction time.
- return xs.padded_batch(
- 4,
- padded_shapes=(tensor_shape.TensorShape([]),
- constant_op.constant([5], dtype=dtypes.int64) * -1))
-
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components)
- .map(lambda x: (x, ops.convert_to_tensor([x * x]))).apply(
- grouping.group_by_window(lambda x, _: x % 2, reduce_func,
- 32)).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
-
- def testConsumeWindowDatasetMoreThanOnce(self):
- components = np.random.randint(50, size=(200,)).astype(np.int64)
-
- def reduce_func(key, window):
- # Apply two different kinds of padding to the input: tight
- # padding, and quantized (to a multiple of 10) padding.
- return dataset_ops.Dataset.zip((
- window.padded_batch(
- 4, padded_shapes=tensor_shape.TensorShape([None])),
- window.padded_batch(
- 4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])),
- ))
-
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components)
- .map(lambda x: array_ops.fill([math_ops.cast(x, dtypes.int32)], x))
- .apply(grouping.group_by_window(
- lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64),
- reduce_func, 4))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- counts = []
- with self.assertRaises(errors.OutOfRangeError):
- while True:
- tight_result, multiple_of_10_result = sess.run(get_next)
- self.assertEqual(0, multiple_of_10_result.shape[1] % 10)
- self.assertAllEqual(tight_result,
- multiple_of_10_result[:, :tight_result.shape[1]])
- counts.append(tight_result.shape[0])
- self.assertEqual(len(components), sum(counts))
-
-
-# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py.
-# Currently, they use a constant batch size, though should be made to use a
-# different batch size per key.
-class BucketTest(test_base.DatasetTestBase):
-
- def _dynamicPad(self, bucket, window, window_size):
- # TODO(mrry): To match `tf.contrib.training.bucket()`, implement a
- # generic form of padded_batch that pads every component
- # dynamically and does not rely on static shape information about
- # the arguments.
- return dataset_ops.Dataset.zip(
- (dataset_ops.Dataset.from_tensors(bucket),
- window.padded_batch(
- 32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape(
- [None]), tensor_shape.TensorShape([3])))))
-
- def testSingleBucket(self):
-
- def _map_fn(v):
- return (v, array_ops.fill([v], v),
- array_ops.fill([3], string_ops.as_string(v)))
-
- input_dataset = (
- dataset_ops.Dataset.from_tensor_slices(math_ops.range(32)).map(_map_fn))
-
- bucketed_dataset = input_dataset.apply(
- grouping.group_by_window(
- lambda x, y, z: 0,
- lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
-
- iterator = bucketed_dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
-
- which_bucket, bucketed_values = sess.run(get_next)
-
- self.assertEqual(0, which_bucket)
-
- expected_scalar_int = np.arange(32, dtype=np.int64)
- expected_unk_int64 = np.zeros((32, 31)).astype(np.int64)
- for i in range(32):
- expected_unk_int64[i, :i] = i
- expected_vec3_str = np.vstack(3 * [np.arange(32).astype(bytes)]).T
-
- self.assertAllEqual(expected_scalar_int, bucketed_values[0])
- self.assertAllEqual(expected_unk_int64, bucketed_values[1])
- self.assertAllEqual(expected_vec3_str, bucketed_values[2])
-
- def testEvenOddBuckets(self):
-
- def _map_fn(v):
- return (v, array_ops.fill([v], v),
- array_ops.fill([3], string_ops.as_string(v)))
-
- input_dataset = (
- dataset_ops.Dataset.from_tensor_slices(math_ops.range(64)).map(_map_fn))
-
- bucketed_dataset = input_dataset.apply(
- grouping.group_by_window(
- lambda x, y, z: math_ops.cast(x % 2, dtypes.int64),
- lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
-
- iterator = bucketed_dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
-
- # Get two minibatches (one containing even values, one containing odds)
- which_bucket_even, bucketed_values_even = sess.run(get_next)
- which_bucket_odd, bucketed_values_odd = sess.run(get_next)
-
- # Count number of bucket_tensors.
- self.assertEqual(3, len(bucketed_values_even))
- self.assertEqual(3, len(bucketed_values_odd))
-
- # Ensure bucket 0 was used for all minibatch entries.
- self.assertAllEqual(0, which_bucket_even)
- self.assertAllEqual(1, which_bucket_odd)
-
- # Test the first bucket outputted, the events starting at 0
- expected_scalar_int = np.arange(0, 32 * 2, 2, dtype=np.int64)
- expected_unk_int64 = np.zeros((32, 31 * 2)).astype(np.int64)
- for i in range(0, 32):
- expected_unk_int64[i, :2 * i] = 2 * i
- expected_vec3_str = np.vstack(
- 3 * [np.arange(0, 32 * 2, 2).astype(bytes)]).T
-
- self.assertAllEqual(expected_scalar_int, bucketed_values_even[0])
- self.assertAllEqual(expected_unk_int64, bucketed_values_even[1])
- self.assertAllEqual(expected_vec3_str, bucketed_values_even[2])
-
- # Test the second bucket outputted, the odds starting at 1
- expected_scalar_int = np.arange(1, 32 * 2 + 1, 2, dtype=np.int64)
- expected_unk_int64 = np.zeros((32, 31 * 2 + 1)).astype(np.int64)
- for i in range(0, 32):
- expected_unk_int64[i, :2 * i + 1] = 2 * i + 1
- expected_vec3_str = np.vstack(
- 3 * [np.arange(1, 32 * 2 + 1, 2).astype(bytes)]).T
-
- self.assertAllEqual(expected_scalar_int, bucketed_values_odd[0])
- self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1])
- self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2])
-
- def testEvenOddBucketsFilterOutAllOdd(self):
-
- def _map_fn(v):
- return {
- "x": v,
- "y": array_ops.fill([v], v),
- "z": array_ops.fill([3], string_ops.as_string(v))
- }
-
- def _dynamic_pad_fn(bucket, window, _):
- return dataset_ops.Dataset.zip(
- (dataset_ops.Dataset.from_tensors(bucket),
- window.padded_batch(
- 32, {
- "x": tensor_shape.TensorShape([]),
- "y": tensor_shape.TensorShape([None]),
- "z": tensor_shape.TensorShape([3])
- })))
-
- input_dataset = (
- dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn)
- .filter(lambda d: math_ops.equal(d["x"] % 2, 0)))
-
- bucketed_dataset = input_dataset.apply(
- grouping.group_by_window(
- lambda d: math_ops.cast(d["x"] % 2, dtypes.int64),
- lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32))
-
- iterator = bucketed_dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
-
- # Get two minibatches ([0, 2, ...] and [64, 66, ...])
- which_bucket0, bucketed_values_even0 = sess.run(get_next)
- which_bucket1, bucketed_values_even1 = sess.run(get_next)
-
- # Ensure that bucket 1 was completely filtered out
- self.assertAllEqual(0, which_bucket0)
- self.assertAllEqual(0, which_bucket1)
- self.assertAllEqual(
- np.arange(0, 64, 2, dtype=np.int64), bucketed_values_even0["x"])
- self.assertAllEqual(
- np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"])
-
- def testDynamicWindowSize(self):
- components = np.arange(100).astype(np.int64)
-
- # Key fn: even/odd
- # Reduce fn: batches of 5
- # Window size fn: even=5, odd=10
-
- def window_size_func(key):
- window_sizes = constant_op.constant([5, 10], dtype=dtypes.int64)
- return window_sizes[key]
-
- dataset = dataset_ops.Dataset.from_tensor_slices(components).apply(
- grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(20),
- None, window_size_func))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- with self.assertRaises(errors.OutOfRangeError):
- batches = 0
- while True:
- result = sess.run(get_next)
- is_even = all(x % 2 == 0 for x in result)
- is_odd = all(x % 2 == 1 for x in result)
- self.assertTrue(is_even or is_odd)
- expected_batch_size = 5 if is_even else 10
- self.assertEqual(expected_batch_size, result.shape[0])
- batches += 1
-
- self.assertEqual(batches, 15)
-
-
-def _element_length_fn(x, y=None):
- del y
- return array_ops.shape(x)[0]
-
-
-def _to_sparse_tensor(record):
- return sparse_tensor.SparseTensor(**record)
-
-
-def _format_record(array, sparse):
- if sparse:
- return {
- "values": array,
- "indices": [[i] for i in range(len(array))],
- "dense_shape": (len(array),)
- }
- return array
-
-
-def _get_record_type(sparse):
- if sparse:
- return {
- "values": dtypes.int64,
- "indices": dtypes.int64,
- "dense_shape": dtypes.int64
- }
- return dtypes.int32
-
-
-def _get_record_shape(sparse):
- if sparse:
- return {
- "values": tensor_shape.TensorShape([None,]),
- "indices": tensor_shape.TensorShape([None, 1]),
- "dense_shape": tensor_shape.TensorShape([1,])
- }
- return tensor_shape.TensorShape([None])
-
-
-class BucketBySequenceLength(test_base.DatasetTestBase):
-
- def testBucket(self):
-
- boundaries = [10, 20, 30]
- batch_sizes = [10, 8, 4, 2]
- lengths = [8, 13, 25, 35]
-
- def build_dataset(sparse):
- def _generator():
- # Produce 1 batch for each bucket
- elements = []
- for batch_size, length in zip(batch_sizes, lengths):
- record_len = length - 1
- for _ in range(batch_size):
- elements.append([1] * record_len)
- record_len = length
- random.shuffle(elements)
- for el in elements:
- yield (_format_record(el, sparse),)
- dataset = dataset_ops.Dataset.from_generator(
- _generator,
- (_get_record_type(sparse),),
- (_get_record_shape(sparse),))
- if sparse:
- dataset = dataset.map(lambda x: (_to_sparse_tensor(x),))
- return dataset
-
- def _test_bucket_by_padding(no_padding):
- dataset = build_dataset(sparse=no_padding)
- dataset = dataset.apply(
- grouping.bucket_by_sequence_length(
- _element_length_fn,
- boundaries,
- batch_sizes,
- no_padding=no_padding))
- batch, = dataset.make_one_shot_iterator().get_next()
-
- with self.cached_session() as sess:
- batches = []
- for _ in range(4):
- batches.append(sess.run(batch))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(batch)
- batch_sizes_val = []
- lengths_val = []
- for batch in batches:
- shape = batch.dense_shape if no_padding else batch.shape
- batch_size = shape[0]
- length = shape[1]
- batch_sizes_val.append(batch_size)
- lengths_val.append(length)
- sum_check = batch.values.sum() if no_padding else batch.sum()
- self.assertEqual(sum_check, batch_size * length - 1)
- self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
- self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
- self.assertEqual(sorted(lengths), sorted(lengths_val))
-
- for no_padding in (True, False):
- _test_bucket_by_padding(no_padding)
-
- def testPadToBoundary(self):
-
- boundaries = [10, 20, 30]
- batch_sizes = [10, 8, 4, 2]
- lengths = [8, 13, 25]
-
- def element_gen():
- # Produce 1 batch for each bucket
- elements = []
- for batch_size, length in zip(batch_sizes[:-1], lengths):
- for _ in range(batch_size):
- elements.append([1] * length)
- random.shuffle(elements)
- for el in elements:
- yield (el,)
- for _ in range(batch_sizes[-1]):
- el = [1] * (boundaries[-1] + 5)
- yield (el,)
-
- element_len = lambda el: array_ops.shape(el)[0]
- dataset = dataset_ops.Dataset.from_generator(
- element_gen, (dtypes.int64,), ([None],)).apply(
- grouping.bucket_by_sequence_length(
- element_len, boundaries, batch_sizes,
- pad_to_bucket_boundary=True))
- batch, = dataset.make_one_shot_iterator().get_next()
-
- with self.cached_session() as sess:
- batches = []
- for _ in range(3):
- batches.append(sess.run(batch))
- with self.assertRaisesOpError("bucket_boundaries"):
- sess.run(batch)
- batch_sizes_val = []
- lengths_val = []
- for batch in batches:
- batch_size = batch.shape[0]
- length = batch.shape[1]
- batch_sizes_val.append(batch_size)
- lengths_val.append(length)
- batch_sizes = batch_sizes[:-1]
- self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
- self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
- self.assertEqual([boundary - 1 for boundary in sorted(boundaries)],
- sorted(lengths_val))
-
- def testPadToBoundaryNoExtraneousPadding(self):
-
- boundaries = [3, 7, 11]
- batch_sizes = [2, 2, 2, 2]
- lengths = range(1, 11)
-
- def element_gen():
- for length in lengths:
- yield ([1] * length,)
-
- element_len = lambda element: array_ops.shape(element)[0]
- dataset = dataset_ops.Dataset.from_generator(
- element_gen, (dtypes.int64,), ([None],)).apply(
- grouping.bucket_by_sequence_length(
- element_len, boundaries, batch_sizes,
- pad_to_bucket_boundary=True))
- batch, = dataset.make_one_shot_iterator().get_next()
-
- with self.cached_session() as sess:
- batches = []
- for _ in range(5):
- batches.append(sess.run(batch))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(batch)
-
- self.assertAllEqual(batches[0], [[1, 0],
- [1, 1]])
- self.assertAllEqual(batches[1], [[1, 1, 1, 0, 0, 0],
- [1, 1, 1, 1, 0, 0]])
- self.assertAllEqual(batches[2], [[1, 1, 1, 1, 1, 0],
- [1, 1, 1, 1, 1, 1]])
- self.assertAllEqual(batches[3], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
- self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
-
- def testTupleElements(self):
-
- def build_dataset(sparse):
- def _generator():
- text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]]
- label = [1, 2, 1, 2]
- for x, y in zip(text, label):
- yield (_format_record(x, sparse), y)
- dataset = dataset_ops.Dataset.from_generator(
- generator=_generator,
- output_types=(_get_record_type(sparse), dtypes.int32),
- output_shapes=(_get_record_shape(sparse),
- tensor_shape.TensorShape([])))
- if sparse:
- dataset = dataset.map(lambda x, y: (_to_sparse_tensor(x), y))
- return dataset
-
- def _test_tuple_elements_by_padding(no_padding):
- dataset = build_dataset(sparse=no_padding)
- dataset = dataset.apply(grouping.bucket_by_sequence_length(
- element_length_func=_element_length_fn,
- bucket_batch_sizes=[2, 2, 2],
- bucket_boundaries=[0, 8],
- no_padding=no_padding))
- shapes = dataset.output_shapes
- self.assertEqual([None, None], shapes[0].as_list())
- self.assertEqual([None], shapes[1].as_list())
-
- for no_padding in (True, False):
- _test_tuple_elements_by_padding(no_padding)
-
- def testBucketSparse(self):
- """Tests bucketing of sparse tensors (case where `no_padding` == True).
-
- Test runs on following dataset:
- [
- [0],
- [0, 1],
- [0, 1, 2]
- ...
- [0, ..., max_len - 1]
- ]
- Sequences are bucketed by length and batched with
- `batch_size` < `bucket_size`.
- """
-
- min_len = 0
- max_len = 100
- batch_size = 7
- bucket_size = 10
-
- def _build_dataset():
- input_data = [range(i+1) for i in range(min_len, max_len)]
- def generator_fn():
- for record in input_data:
- yield _format_record(record, sparse=True)
- dataset = dataset_ops.Dataset.from_generator(
- generator=generator_fn,
- output_types=_get_record_type(sparse=True))
- dataset = dataset.map(_to_sparse_tensor)
- return dataset
-
- def _compute_expected_batches():
- """Computes expected batch outputs and stores in a set."""
- all_expected_sparse_tensors = set()
- for bucket_start_len in range(min_len, max_len, bucket_size):
- for batch_offset in range(0, bucket_size, batch_size):
- batch_start_len = bucket_start_len + batch_offset
- batch_end_len = min(batch_start_len + batch_size,
- bucket_start_len + bucket_size)
- expected_indices = []
- expected_values = []
- for length in range(batch_start_len, batch_end_len):
- for val in range(length + 1):
- expected_indices.append((length - batch_start_len, val))
- expected_values.append(val)
- expected_sprs_tensor = (tuple(expected_indices),
- tuple(expected_values))
- all_expected_sparse_tensors.add(expected_sprs_tensor)
- return all_expected_sparse_tensors
-
- def _compute_batches(dataset):
- """Computes actual batch outputs of dataset and stores in a set."""
- batch = dataset.make_one_shot_iterator().get_next()
- all_sparse_tensors = set()
- with self.cached_session() as sess:
- with self.assertRaises(errors.OutOfRangeError):
- while True:
- output = sess.run(batch)
- sprs_tensor = (tuple([tuple(idx) for idx in output.indices]),
- tuple(output.values))
- all_sparse_tensors.add(sprs_tensor)
- return all_sparse_tensors
-
- dataset = _build_dataset()
- boundaries = range(min_len + bucket_size + 1, max_len, bucket_size)
- dataset = dataset.apply(grouping.bucket_by_sequence_length(
- _element_length_fn,
- boundaries,
- [batch_size] * (len(boundaries) + 1),
- no_padding=True))
- batches = _compute_batches(dataset)
- expected_batches = _compute_expected_batches()
- self.assertEqual(batches, expected_batches)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
deleted file mode 100644
index 5b3c512b64..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
+++ /dev/null
@@ -1,632 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for CsvDatasetOp."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import gzip
-import os
-import string
-import tempfile
-import time
-import zlib
-
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import error_ops
-from tensorflow.contrib.data.python.ops import readers
-from tensorflow.python.client import session
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import readers as core_readers
-from tensorflow.python.eager import context
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import test_util
-from tensorflow.python.ops import parsing_ops
-from tensorflow.python.platform import gfile
-from tensorflow.python.platform import googletest
-from tensorflow.python.platform import test
-
-
-@test_util.run_all_in_graph_and_eager_modes
-class CsvDatasetOpTest(test_base.DatasetTestBase):
-
- def _setup_files(self, inputs, linebreak='\n', compression_type=None):
- filenames = []
- for i, ip in enumerate(inputs):
- fn = os.path.join(self.get_temp_dir(), 'temp_%d.csv' % i)
- contents = linebreak.join(ip).encode('utf-8')
- if compression_type is None:
- with open(fn, 'wb') as f:
- f.write(contents)
- elif compression_type == 'GZIP':
- with gzip.GzipFile(fn, 'wb') as f:
- f.write(contents)
- elif compression_type == 'ZLIB':
- contents = zlib.compress(contents)
- with open(fn, 'wb') as f:
- f.write(contents)
- else:
- raise ValueError('Unsupported compression_type', compression_type)
- filenames.append(fn)
- return filenames
-
- def _make_test_datasets(self, inputs, **kwargs):
- # Test by comparing its output to what we could get with map->decode_csv
- filenames = self._setup_files(inputs)
- dataset_expected = core_readers.TextLineDataset(filenames)
- dataset_expected = dataset_expected.map(
- lambda l: parsing_ops.decode_csv(l, **kwargs))
- dataset_actual = readers.CsvDataset(filenames, **kwargs)
- return (dataset_actual, dataset_expected)
-
- def _test_by_comparison(self, inputs, **kwargs):
- """Checks that CsvDataset is equiv to TextLineDataset->map(decode_csv)."""
- dataset_actual, dataset_expected = self._make_test_datasets(
- inputs, **kwargs)
- self.assertDatasetsEqual(dataset_actual, dataset_expected)
-
- def _verify_output_or_err(self,
- dataset,
- expected_output=None,
- expected_err_re=None):
- if expected_err_re is None:
- # Verify that output is expected, without errors
- nxt = self.getNext(dataset)
- expected_output = [[
- v.encode('utf-8') if isinstance(v, str) else v for v in op
- ] for op in expected_output]
- for value in expected_output:
- op = self.evaluate(nxt())
- self.assertAllEqual(op, value)
- with self.assertRaises(errors.OutOfRangeError):
- self.evaluate(nxt())
- else:
- # Verify that OpError is produced as expected
- with self.assertRaisesOpError(expected_err_re):
- nxt = self.getNext(dataset)
- while True:
- try:
- self.evaluate(nxt())
- except errors.OutOfRangeError:
- break
-
- def _test_dataset(
- self,
- inputs,
- expected_output=None,
- expected_err_re=None,
- linebreak='\n',
- compression_type=None, # Used for both setup and parsing
- **kwargs):
- """Checks that elements produced by CsvDataset match expected output."""
- # Convert str type because py3 tf strings are bytestrings
- filenames = self._setup_files(inputs, linebreak, compression_type)
- kwargs['compression_type'] = compression_type
- dataset = readers.CsvDataset(filenames, **kwargs)
- self._verify_output_or_err(dataset, expected_output, expected_err_re)
-
- def testCsvDataset_requiredFields(self):
- record_defaults = [[]] * 4
- inputs = [['1,2,3,4']]
- self._test_by_comparison(inputs, record_defaults=record_defaults)
-
- def testCsvDataset_int(self):
- record_defaults = [[0]] * 4
- inputs = [['1,2,3,4', '5,6,7,8']]
- self._test_by_comparison(inputs, record_defaults=record_defaults)
-
- def testCsvDataset_float(self):
- record_defaults = [[0.0]] * 4
- inputs = [['1.0,2.1,3.2,4.3', '5.4,6.5,7.6,8.7']]
- self._test_by_comparison(inputs, record_defaults=record_defaults)
-
- def testCsvDataset_string(self):
- record_defaults = [['']] * 4
- inputs = [['1.0,2.1,hello,4.3', '5.4,6.5,goodbye,8.7']]
- self._test_by_comparison(inputs, record_defaults=record_defaults)
-
- def testCsvDataset_withEmptyFields(self):
- record_defaults = [[0]] * 4
- inputs = [[',,,', '1,1,1,', ',2,2,2']]
- self._test_dataset(
- inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
- record_defaults=record_defaults)
-
- def testCsvDataset_errWithUnquotedQuotes(self):
- record_defaults = [['']] * 3
- inputs = [['1,2"3,4']]
- self._test_dataset(
- inputs,
- expected_err_re='Unquoted fields cannot have quotes inside',
- record_defaults=record_defaults)
-
- def testCsvDataset_errWithUnescapedQuotes(self):
- record_defaults = [['']] * 3
- inputs = [['"a"b","c","d"']]
- self._test_dataset(
- inputs,
- expected_err_re=
- 'Quote inside a string has to be escaped by another quote',
- record_defaults=record_defaults)
-
- def testCsvDataset_ignoreErrWithUnescapedQuotes(self):
- record_defaults = [['']] * 3
- inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']]
- filenames = self._setup_files(inputs)
- dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
- dataset = dataset.apply(error_ops.ignore_errors())
- self._verify_output_or_err(dataset, [['e', 'f', 'g']])
-
- def testCsvDataset_ignoreErrWithUnquotedQuotes(self):
- record_defaults = [['']] * 3
- inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']]
- filenames = self._setup_files(inputs)
- dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
- dataset = dataset.apply(error_ops.ignore_errors())
- self._verify_output_or_err(dataset, [['e', 'f', 'g']])
-
- def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self):
- record_defaults = [['']] * 3
- inputs = [['1,2"3,4']]
- self._test_by_comparison(
- inputs, record_defaults=record_defaults, use_quote_delim=False)
-
- def testCsvDataset_mixedTypes(self):
- record_defaults = [
- constant_op.constant([], dtype=dtypes.int32),
- constant_op.constant([], dtype=dtypes.float32),
- constant_op.constant([], dtype=dtypes.string),
- constant_op.constant([], dtype=dtypes.float64)
- ]
- inputs = [['1,2.1,3.2,4.3', '5,6.5,7.6,8.7']]
- self._test_by_comparison(inputs, record_defaults=record_defaults)
-
- def testCsvDataset_withUseQuoteDelimFalse(self):
- record_defaults = [['']] * 4
- inputs = [['1,2,"3,4"', '"5,6",7,8']]
- self._test_by_comparison(
- inputs, record_defaults=record_defaults, use_quote_delim=False)
-
- def testCsvDataset_withFieldDelim(self):
- record_defaults = [[0]] * 4
- inputs = [['1:2:3:4', '5:6:7:8']]
- self._test_by_comparison(
- inputs, record_defaults=record_defaults, field_delim=':')
-
- def testCsvDataset_withNaValue(self):
- record_defaults = [[0]] * 4
- inputs = [['1,NA,3,4', 'NA,6,7,8']]
- self._test_by_comparison(
- inputs, record_defaults=record_defaults, na_value='NA')
-
- def testCsvDataset_withSelectCols(self):
- record_defaults = [['']] * 2
- inputs = [['1,2,3,4', '"5","6","7","8"']]
- self._test_by_comparison(
- inputs, record_defaults=record_defaults, select_cols=[1, 2])
-
- def testCsvDataset_withSelectColsTooHigh(self):
- record_defaults = [[0]] * 2
- inputs = [['1,2,3,4', '5,6,7,8']]
- self._test_dataset(
- inputs,
- expected_err_re='Expect 2 fields but have 1 in record',
- record_defaults=record_defaults,
- select_cols=[3, 4])
-
- def testCsvDataset_withOneCol(self):
- record_defaults = [['NA']]
- inputs = [['0', '', '2']]
- self._test_dataset(
- inputs, [['0'], ['NA'], ['2']], record_defaults=record_defaults)
-
- def testCsvDataset_withMultipleFiles(self):
- record_defaults = [[0]] * 4
- inputs = [['1,2,3,4', '5,6,7,8'], ['5,6,7,8']]
- self._test_by_comparison(inputs, record_defaults=record_defaults)
-
- def testCsvDataset_withLeadingAndTrailingSpaces(self):
- record_defaults = [[0.0]] * 4
- inputs = [['0, 1, 2, 3']]
- expected = [[0.0, 1.0, 2.0, 3.0]]
- self._test_dataset(inputs, expected, record_defaults=record_defaults)
-
- def testCsvDataset_errorWithMissingDefault(self):
- record_defaults = [[]] * 2
- inputs = [['0,']]
- self._test_dataset(
- inputs,
- expected_err_re='Field 1 is required but missing in record!',
- record_defaults=record_defaults)
-
- def testCsvDataset_errorWithFewerDefaultsThanFields(self):
- record_defaults = [[0.0]] * 2
- inputs = [['0,1,2,3']]
- self._test_dataset(
- inputs,
- expected_err_re='Expect 2 fields but have more in record',
- record_defaults=record_defaults)
-
- def testCsvDataset_errorWithMoreDefaultsThanFields(self):
- record_defaults = [[0.0]] * 5
- inputs = [['0,1,2,3']]
- self._test_dataset(
- inputs,
- expected_err_re='Expect 5 fields but have 4 in record',
- record_defaults=record_defaults)
-
- def testCsvDataset_withHeader(self):
- record_defaults = [[0]] * 2
- inputs = [['col1,col2', '1,2']]
- expected = [[1, 2]]
- self._test_dataset(
- inputs,
- expected,
- record_defaults=record_defaults,
- header=True,
- )
-
- def testCsvDataset_withHeaderAndNoRecords(self):
- record_defaults = [[0]] * 2
- inputs = [['col1,col2']]
- expected = []
- self._test_dataset(
- inputs,
- expected,
- record_defaults=record_defaults,
- header=True,
- )
-
- def testCsvDataset_errorWithHeaderEmptyFile(self):
- record_defaults = [[0]] * 2
- inputs = [[]]
- expected_err_re = "Can't read header of file"
- self._test_dataset(
- inputs,
- expected_err_re=expected_err_re,
- record_defaults=record_defaults,
- header=True,
- )
-
- def testCsvDataset_withEmptyFile(self):
- record_defaults = [['']] * 2
- inputs = [['']] # Empty file
- self._test_dataset(
- inputs, expected_output=[], record_defaults=record_defaults)
-
- def testCsvDataset_errorWithEmptyRecord(self):
- record_defaults = [['']] * 2
- inputs = [['', '1,2']] # First record is empty
- self._test_dataset(
- inputs,
- expected_err_re='Expect 2 fields but have 1 in record',
- record_defaults=record_defaults)
-
- def testCsvDataset_withChainedOps(self):
- # Testing that one dataset can create multiple iterators fine.
- # `repeat` creates multiple iterators from the same C++ Dataset.
- record_defaults = [[0]] * 4
- inputs = [['1,,3,4', '5,6,,8']]
- ds_actual, ds_expected = self._make_test_datasets(
- inputs, record_defaults=record_defaults)
- self.assertDatasetsEqual(
- ds_actual.repeat(5).prefetch(1),
- ds_expected.repeat(5).prefetch(1))
-
- def testCsvDataset_withTypeDefaults(self):
- # Testing using dtypes as record_defaults for required fields
- record_defaults = [dtypes.float32, [0.0]]
- inputs = [['1.0,2.0', '3.0,4.0']]
- self._test_dataset(
- inputs,
- [[1.0, 2.0], [3.0, 4.0]],
- record_defaults=record_defaults,
- )
-
- def testMakeCsvDataset_fieldOrder(self):
- data = [[
- '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19',
- '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19'
- ]]
- file_path = self._setup_files(data)
-
- ds = readers.make_csv_dataset(
- file_path, batch_size=1, shuffle=False, num_epochs=1)
- nxt = self.getNext(ds)
-
- result = list(self.evaluate(nxt()).values())
-
- self.assertEqual(result, sorted(result))
-
-## The following tests exercise parsing logic for quoted fields
-
- def testCsvDataset_withQuoted(self):
- record_defaults = [['']] * 4
- inputs = [['"a","b","c :)","d"', '"e","f","g :(","h"']]
- self._test_by_comparison(inputs, record_defaults=record_defaults)
-
- def testCsvDataset_withOneColAndQuotes(self):
- record_defaults = [['']]
- inputs = [['"0"', '"1"', '"2"']]
- self._test_dataset(
- inputs, [['0'], ['1'], ['2']], record_defaults=record_defaults)
-
- def testCsvDataset_withNewLine(self):
- # In this case, we expect it to behave differently from
- # TextLineDataset->map(decode_csv) since that flow has bugs
- record_defaults = [['']] * 4
- inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']]
- expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']]
- self._test_dataset(inputs, expected, record_defaults=record_defaults)
-
- def testCsvDataset_withNewLineInUnselectedCol(self):
- record_defaults = [['']]
- inputs = [['1,"2\n3",4', '5,6,7']]
- self._test_dataset(
- inputs,
- expected_output=[['1'], ['5']],
- record_defaults=record_defaults,
- select_cols=[0])
-
- def testCsvDataset_withMultipleNewLines(self):
- # In this case, we expect it to behave differently from
- # TextLineDataset->map(decode_csv) since that flow has bugs
- record_defaults = [['']] * 4
- inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']]
- expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']]
- self._test_dataset(inputs, expected, record_defaults=record_defaults)
-
- def testCsvDataset_errorWithTerminateMidRecord(self):
- record_defaults = [['']] * 4
- inputs = [['a,b,c,"a']]
- self._test_dataset(
- inputs,
- expected_err_re=
- 'Reached end of file without closing quoted field in record',
- record_defaults=record_defaults)
-
- def testCsvDataset_withEscapedQuotes(self):
- record_defaults = [['']] * 4
- inputs = [['1.0,2.1,"she said: ""hello""",4.3', '5.4,6.5,goodbye,8.7']]
- self._test_by_comparison(inputs, record_defaults=record_defaults)
-
-
-## Testing that parsing works with all buffer sizes, quoted/unquoted fields,
-## and different types of line breaks
-
- def testCsvDataset_withInvalidBufferSize(self):
- record_defaults = [['']] * 4
- inputs = [['a,b,c,d']]
- self._test_dataset(
- inputs,
- expected_err_re='buffer_size should be positive',
- record_defaults=record_defaults,
- buffer_size=0)
-
- def _test_dataset_on_buffer_sizes(self,
- inputs,
- expected,
- linebreak,
- record_defaults,
- compression_type=None,
- num_sizes_to_test=20):
- # Testing reading with a range of buffer sizes that should all work.
- for i in list(range(1, 1 + num_sizes_to_test)) + [None]:
- self._test_dataset(
- inputs,
- expected,
- linebreak=linebreak,
- compression_type=compression_type,
- record_defaults=record_defaults,
- buffer_size=i)
-
- def testCsvDataset_withLF(self):
- record_defaults = [['NA']] * 3
- inputs = [['abc,def,ghi', '0,1,2', ',,']]
- expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']]
- self._test_dataset_on_buffer_sizes(
- inputs, expected, linebreak='\n', record_defaults=record_defaults)
-
- def testCsvDataset_withCR(self):
- # Test that when the line separator is '\r', parsing works with all buffer
- # sizes
- record_defaults = [['NA']] * 3
- inputs = [['abc,def,ghi', '0,1,2', ',,']]
- expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']]
- self._test_dataset_on_buffer_sizes(
- inputs, expected, linebreak='\r', record_defaults=record_defaults)
-
- def testCsvDataset_withCRLF(self):
- # Test that when the line separator is '\r\n', parsing works with all buffer
- # sizes
- record_defaults = [['NA']] * 3
- inputs = [['abc,def,ghi', '0,1,2', ',,']]
- expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']]
- self._test_dataset_on_buffer_sizes(
- inputs, expected, linebreak='\r\n', record_defaults=record_defaults)
-
- def testCsvDataset_withBufferSizeAndQuoted(self):
- record_defaults = [['NA']] * 3
- inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
- expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
- ['NA', 'NA', 'NA']]
- self._test_dataset_on_buffer_sizes(
- inputs, expected, linebreak='\n', record_defaults=record_defaults)
-
- def testCsvDataset_withCRAndQuoted(self):
- # Test that when the line separator is '\r', parsing works with all buffer
- # sizes
- record_defaults = [['NA']] * 3
- inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
- expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
- ['NA', 'NA', 'NA']]
- self._test_dataset_on_buffer_sizes(
- inputs, expected, linebreak='\r', record_defaults=record_defaults)
-
- def testCsvDataset_withCRLFAndQuoted(self):
- # Test that when the line separator is '\r\n', parsing works with all buffer
- # sizes
- record_defaults = [['NA']] * 3
- inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
- expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
- ['NA', 'NA', 'NA']]
- self._test_dataset_on_buffer_sizes(
- inputs, expected, linebreak='\r\n', record_defaults=record_defaults)
-
- def testCsvDataset_withGzipCompressionType(self):
- record_defaults = [['NA']] * 3
- inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
- expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
- ['NA', 'NA', 'NA']]
- self._test_dataset_on_buffer_sizes(
- inputs,
- expected,
- linebreak='\r\n',
- compression_type='GZIP',
- record_defaults=record_defaults)
-
- def testCsvDataset_withZlibCompressionType(self):
- record_defaults = [['NA']] * 3
- inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
- expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
- ['NA', 'NA', 'NA']]
- self._test_dataset_on_buffer_sizes(
- inputs,
- expected,
- linebreak='\r\n',
- compression_type='ZLIB',
- record_defaults=record_defaults)
-
- def testCsvDataset_withScalarDefaults(self):
- record_defaults = [constant_op.constant(0, dtype=dtypes.int64)] * 4
- inputs = [[',,,', '1,1,1,', ',2,2,2']]
- self._test_dataset(
- inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
- record_defaults=record_defaults)
-
- def testCsvDataset_with2DDefaults(self):
- record_defaults = [constant_op.constant([[0]], dtype=dtypes.int64)] * 4
- inputs = [[',,,', '1,1,1,', ',2,2,2']]
-
- if context.executing_eagerly():
- err_spec = errors.InvalidArgumentError, (
- 'Each record default should be at '
- 'most rank 1.')
- else:
- err_spec = ValueError, 'Shape must be at most rank 1 but is rank 2'
-
- with self.assertRaisesWithPredicateMatch(*err_spec):
- self._test_dataset(
- inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
- record_defaults=record_defaults)
-
-
-class CsvDatasetBenchmark(test.Benchmark):
- """Benchmarks for the various ways of creating a dataset from CSV files.
- """
- FLOAT_VAL = '1.23456E12'
- STR_VAL = string.ascii_letters * 10
-
- def _setUp(self, str_val):
- # Since this isn't test.TestCase, have to manually create a test dir
- gfile.MakeDirs(googletest.GetTempDir())
- self._temp_dir = tempfile.mkdtemp(dir=googletest.GetTempDir())
-
- self._num_cols = [4, 64, 256]
- self._num_per_iter = 5000
- self._filenames = []
- for n in self._num_cols:
- fn = os.path.join(self._temp_dir, 'file%d.csv' % n)
- with open(fn, 'wb') as f:
- # Just write 100 rows and use `repeat`... Assumes the cost
- # of creating an iterator is not significant
- row = ','.join([str_val for _ in range(n)])
- f.write('\n'.join([row for _ in range(100)]))
- self._filenames.append(fn)
-
- def _tearDown(self):
- gfile.DeleteRecursively(self._temp_dir)
-
- def _runBenchmark(self, dataset, num_cols, prefix):
- dataset = dataset.skip(self._num_per_iter - 1)
- deltas = []
- for _ in range(10):
- next_element = dataset.make_one_shot_iterator().get_next()
- with session.Session() as sess:
- start = time.time()
- # NOTE: This depends on the underlying implementation of skip, to have
- # the net effect of calling `GetNext` num_per_iter times on the
- # input dataset. We do it this way (instead of a python for loop, or
- # batching N inputs in one iter) so that the overhead from session.run
- # or batch doesn't dominate. If we eventually optimize skip, this has
- # to change.
- sess.run(next_element)
- end = time.time()
- deltas.append(end - start)
- # Median wall time per CSV record read and decoded
- median_wall_time = np.median(deltas) / self._num_per_iter
- print('%s num_cols: %d Median wall time: %f' % (prefix, num_cols,
- median_wall_time))
- self.report_benchmark(
- iters=self._num_per_iter,
- wall_time=median_wall_time,
- name='%s_with_cols_%d' % (prefix, num_cols))
-
- def benchmarkMapWithFloats(self):
- self._setUp(self.FLOAT_VAL)
- for i in range(len(self._filenames)):
- num_cols = self._num_cols[i]
- kwargs = {'record_defaults': [[0.0]] * num_cols}
- dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
- dataset = dataset.map(lambda l: parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop
- self._runBenchmark(dataset, num_cols, 'csv_float_map_decode_csv')
- self._tearDown()
-
- def benchmarkMapWithStrings(self):
- self._setUp(self.STR_VAL)
- for i in range(len(self._filenames)):
- num_cols = self._num_cols[i]
- kwargs = {'record_defaults': [['']] * num_cols}
- dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
- dataset = dataset.map(lambda l: parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop
- self._runBenchmark(dataset, num_cols, 'csv_strings_map_decode_csv')
- self._tearDown()
-
- def benchmarkCsvDatasetWithFloats(self):
- self._setUp(self.FLOAT_VAL)
- for i in range(len(self._filenames)):
- num_cols = self._num_cols[i]
- kwargs = {'record_defaults': [[0.0]] * num_cols}
- dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
- dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop
- self._runBenchmark(dataset, num_cols, 'csv_float_fused_dataset')
- self._tearDown()
-
- def benchmarkCsvDatasetWithStrings(self):
- self._setUp(self.STR_VAL)
- for i in range(len(self._filenames)):
- num_cols = self._num_cols[i]
- kwargs = {'record_defaults': [['']] * num_cols}
- dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
- dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop
- self._runBenchmark(dataset, num_cols, 'csv_strings_fused_dataset')
- self._tearDown()
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py
deleted file mode 100644
index 722e87e555..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py
+++ /dev/null
@@ -1,71 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.framework import dtypes
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-
-
-class DatasetConstructorTest(test_base.DatasetTestBase):
-
- def testRestructureDataset(self):
- components = (array_ops.placeholder(dtypes.int32),
- (array_ops.placeholder(dtypes.int32, shape=[None]),
- array_ops.placeholder(dtypes.int32, shape=[20, 30])))
- dataset = dataset_ops.Dataset.from_tensors(components)
-
- i32 = dtypes.int32
-
- test_cases = [((i32, i32, i32), None),
- (((i32, i32), i32), None),
- ((i32, i32, i32), (None, None, None)),
- ((i32, i32, i32), ([17], [17], [20, 30]))]
-
- for new_types, new_shape_lists in test_cases:
- # pylint: disable=protected-access
- new = batching._RestructuredDataset(dataset, new_types, new_shape_lists)
- # pylint: enable=protected-access
- self.assertEqual(new_types, new.output_types)
- if new_shape_lists is not None:
- for expected_shape_list, shape in zip(
- nest.flatten(new_shape_lists), nest.flatten(new.output_shapes)):
- if expected_shape_list is None:
- self.assertIs(None, shape.ndims)
- else:
- self.assertEqual(expected_shape_list, shape.as_list())
-
- fail_cases = [((i32, dtypes.int64, i32), None),
- ((i32, i32, i32, i32), None),
- ((i32, i32, i32), ((None, None), None)),
- ((i32, i32, i32), (None, None, None, None)),
- ((i32, i32, i32), (None, [None], [21, 30]))]
-
- for new_types, new_shape_lists in fail_cases:
- with self.assertRaises(ValueError):
- # pylint: disable=protected-access
- new = batching._RestructuredDataset(dataset, new_types, new_shape_lists)
- # pylint: enable=protected-access
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
deleted file mode 100644
index bc10c21472..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
+++ /dev/null
@@ -1,148 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import interleave_ops
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import random_seed
-from tensorflow.python.platform import test
-
-
-class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
-
- def testBasic(self):
- selector_dataset = dataset_ops.Dataset.range(10).repeat(100)
- input_datasets = [
- dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10)
- ]
- dataset = interleave_ops._DirectedInterleaveDataset(selector_dataset,
- input_datasets)
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for _ in range(100):
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def _normalize(self, vec):
- return vec / vec.sum()
-
- def _chi2(self, expected, actual):
- actual = np.asarray(actual)
- expected = np.asarray(expected)
- diff = actual - expected
- chi2 = np.sum(diff * diff / expected, axis=0)
- return chi2
-
- def _testSampleFromDatasetsHelper(self, weights, num_datasets, num_samples):
- # Create a dataset that samples each integer in `[0, num_datasets)`
- # with probability given by `weights[i]`.
- dataset = interleave_ops.sample_from_datasets([
- dataset_ops.Dataset.from_tensors(i).repeat(None)
- for i in range(num_datasets)
- ], weights)
- dataset = dataset.take(num_samples)
- iterator = dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- freqs = np.zeros([num_datasets])
- for _ in range(num_samples):
- freqs[sess.run(next_element)] += 1
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- return freqs
-
- def testSampleFromDatasets(self):
- random_seed.set_random_seed(1619)
- num_samples = 5000
- rand_probs = self._normalize(np.random.random_sample((15,)))
-
- # Use chi-squared test to assert that the observed distribution matches the
- # expected distribution. Based on the implementation in
- # "tensorflow/python/kernel_tests/multinomial_op_test.py".
- for probs in [[.85, .05, .1], rand_probs, [1.]]:
- probs = np.asarray(probs)
- classes = len(probs)
- freqs = self._testSampleFromDatasetsHelper(probs, classes, num_samples)
- self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2)
-
- # Also check that `weights` as a dataset samples correctly.
- probs_ds = dataset_ops.Dataset.from_tensors(probs).repeat()
- freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples)
- self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2)
-
- def testSelectFromDatasets(self):
- words = [b"foo", b"bar", b"baz"]
- datasets = [dataset_ops.Dataset.from_tensors(w).repeat() for w in words]
- choice_array = np.random.randint(3, size=(15,), dtype=np.int64)
- choice_dataset = dataset_ops.Dataset.from_tensor_slices(choice_array)
- dataset = interleave_ops.choose_from_datasets(datasets, choice_dataset)
- iterator = dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in choice_array:
- self.assertEqual(words[i], sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testErrors(self):
- with self.assertRaisesRegexp(ValueError,
- r"vector of length `len\(datasets\)`"):
- interleave_ops.sample_from_datasets(
- [dataset_ops.Dataset.range(10),
- dataset_ops.Dataset.range(20)],
- weights=[0.25, 0.25, 0.25, 0.25])
-
- with self.assertRaisesRegexp(TypeError, "`tf.float32` or `tf.float64`"):
- interleave_ops.sample_from_datasets(
- [dataset_ops.Dataset.range(10),
- dataset_ops.Dataset.range(20)],
- weights=[1, 1])
-
- with self.assertRaisesRegexp(TypeError, "must have the same type"):
- interleave_ops.sample_from_datasets([
- dataset_ops.Dataset.from_tensors(0),
- dataset_ops.Dataset.from_tensors(0.0)
- ])
-
- with self.assertRaisesRegexp(TypeError, "tf.int64"):
- interleave_ops.choose_from_datasets([
- dataset_ops.Dataset.from_tensors(0),
- dataset_ops.Dataset.from_tensors(1)
- ], choice_dataset=dataset_ops.Dataset.from_tensors(1.0))
-
- with self.assertRaisesRegexp(TypeError, "scalar"):
- interleave_ops.choose_from_datasets([
- dataset_ops.Dataset.from_tensors(0),
- dataset_ops.Dataset.from_tensors(1)
- ], choice_dataset=dataset_ops.Dataset.from_tensors([1.0]))
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py
deleted file mode 100644
index 6d01bf585c..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py
+++ /dev/null
@@ -1,76 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Benchmarks FilterDataset input pipeline op."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import time
-
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.python.client import session
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class FilterBenchmark(test.Benchmark):
-
- # This benchmark compares the performance of pipeline with multiple chained
- # filter with and without filter fusion.
- def benchmarkFilters(self):
- chain_lengths = [0, 1, 2, 5, 10, 20, 50]
- for chain_length in chain_lengths:
- self._benchmarkFilters(chain_length, False)
- self._benchmarkFilters(chain_length, True)
-
- def _benchmarkFilters(self, chain_length, optimize_dataset):
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors(5).repeat(None)
- for _ in range(chain_length):
- dataset = dataset.filter(lambda x: math_ops.greater_equal(x - 5, 0))
- if optimize_dataset:
- dataset = dataset.apply(optimization.optimize(["filter_fusion"]))
-
- iterator = dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for _ in range(10):
- sess.run(next_element.op)
- deltas = []
- for _ in range(100):
- start = time.time()
- for _ in range(100):
- sess.run(next_element.op)
- end = time.time()
- deltas.append(end - start)
-
- median_wall_time = np.median(deltas) / 100
- opt_mark = "opt" if optimize_dataset else "no-opt"
- print("Filter dataset {} chain length: {} Median wall time: {}".format(
- opt_mark, chain_length, median_wall_time))
- self.report_benchmark(
- iters=1000,
- wall_time=median_wall_time,
- name="benchmark_filter_dataset_chain_latency_{}_{}".format(
- opt_mark, chain_length))
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
deleted file mode 100644
index d4d3d4adb2..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
+++ /dev/null
@@ -1,79 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for experimental indexed dataset ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import unittest
-
-from tensorflow.contrib.data.python.ops import indexed_dataset_ops
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
-from tensorflow.python.platform import test
-
-
-class IndexedDatasetOpsTest(test_base.DatasetTestBase):
-
- def testLowLevelIndexedDatasetOps(self):
- identity = ged_ops.experimental_identity_indexed_dataset(
- ops.convert_to_tensor(16, dtype=dtypes.uint64))
- handle = ged_ops.experimental_materialized_index_dataset_handle(
- container="",
- shared_name="",
- output_types=[dtypes.uint64],
- output_shapes=[[]])
- materialize = ged_ops.experimental_indexed_dataset_materialize(
- identity, handle)
- index = array_ops.placeholder(dtypes.uint64)
- get_op = ged_ops.experimental_indexed_dataset_get(
- handle, index, output_types=[dtypes.uint64], output_shapes=[[]])
-
- with self.cached_session() as sess:
- sess.run(materialize)
- self.assertEqual([3], sess.run(get_op, feed_dict={index: 3}))
-
- def testIdentityIndexedDataset(self):
- ds = indexed_dataset_ops.IdentityIndexedDataset(16)
- materialized = ds.materialize()
- with self.cached_session() as sess:
- sess.run(materialized.initializer)
- placeholder = array_ops.placeholder(dtypes.uint64, shape=[])
- for i in range(16):
- output = sess.run(
- materialized.get(placeholder), feed_dict={placeholder: i})
- self.assertEqual([i], output)
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(materialized.get(placeholder), feed_dict={placeholder: 16})
-
- @unittest.skip("Requisite functionality currently unimplemented.")
- def testIdentityIndexedDatasetIterator(self):
- ds = indexed_dataset_ops.IdentityIndexedDataset(16)
- itr = ds.make_initializable_iterator()
- n = itr.get_next()
- with self.cached_session() as sess:
- sess.run(itr.initializer)
- for i in range(16):
- output = sess.run(n)
- self.assertEqual(i, output)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(n)
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
deleted file mode 100644
index 28bd670ab5..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
+++ /dev/null
@@ -1,811 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import itertools
-import math
-import threading
-import time
-
-from six.moves import zip_longest
-
-from tensorflow.contrib.data.python.ops import interleave_ops
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import script_ops
-from tensorflow.python.ops import sparse_ops
-from tensorflow.python.platform import test
-
-
-class ParallelInterleaveDatasetTest(test_base.DatasetTestBase):
-
- def setUp(self):
-
- self.input_values = array_ops.placeholder(dtypes.int64, shape=[None])
- self.cycle_length = array_ops.placeholder(dtypes.int64, shape=[])
- self.block_length = array_ops.placeholder(dtypes.int64, shape=[])
- self.sloppy = array_ops.placeholder(dtypes.bool, shape=[])
- self.buffer_output_elements = array_ops.placeholder(dtypes.int64, shape=[])
- self.prefetch_input_elements = array_ops.placeholder(dtypes.int64, shape=[])
-
- self.error = None
- self.repeat_count = 2
-
- # Set up threading events used to sequence when items are produced that
- # are subsequently interleaved. These events allow us to deterministically
- # simulate slowdowns and force sloppiness.
- self.read_coordination_events = {}
- self.write_coordination_events = {}
- # input values [4, 5, 6] are the common case for the tests; set defaults
- for i in range(4, 7):
- self.read_coordination_events[i] = threading.Semaphore(0)
- self.write_coordination_events[i] = threading.Event()
-
- def map_py_fn(x):
- self.write_coordination_events[x].wait()
- self.write_coordination_events[x].clear()
- self.read_coordination_events[x].release()
- if self.error:
- err = self.error
- self.error = None
- raise err # pylint: disable=raising-bad-type
- return x * x
-
- def map_fn(x):
- return script_ops.py_func(map_py_fn, [x], x.dtype)
-
- def interleave_fn(x):
- dataset = dataset_ops.Dataset.from_tensors(x)
- dataset = dataset.repeat(x)
- return dataset.map(map_fn)
-
- self.dataset = (
- dataset_ops.Dataset.from_tensor_slices(self.input_values)
- .repeat(self.repeat_count).apply(
- interleave_ops.parallel_interleave(interleave_fn, self.cycle_length,
- self.block_length, self.sloppy,
- self.buffer_output_elements,
- self.prefetch_input_elements)))
- self.iterator = self.dataset.make_initializable_iterator()
- self.init_op = self.iterator.initializer
- self.next_element = self.iterator.get_next()
-
- def _interleave(self, lists, cycle_length, block_length):
- """Python implementation of interleave used for testing."""
- num_open = 0
-
- # `all_iterators` acts as a queue of iterators over each element of `lists`.
- all_iterators = [iter(l) for l in lists]
-
- # `open_iterators` are the iterators whose elements are currently being
- # interleaved.
- open_iterators = []
- for i in range(cycle_length):
- if all_iterators:
- open_iterators.append(all_iterators.pop(0))
- num_open += 1
- else:
- open_iterators.append(None)
-
- while num_open or all_iterators:
- for i in range(cycle_length):
- if open_iterators[i] is None:
- if all_iterators:
- open_iterators[i] = all_iterators.pop(0)
- num_open += 1
- else:
- continue
- for _ in range(block_length):
- try:
- yield next(open_iterators[i])
- except StopIteration:
- open_iterators[i] = None
- num_open -= 1
- break
-
- def testPythonImplementation(self):
- input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6],
- [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]]
-
- # Cycle length 1 acts like `Dataset.flat_map()`.
- expected_elements = itertools.chain(*input_lists)
- for expected, produced in zip(expected_elements,
- self._interleave(input_lists, 1, 1)):
- self.assertEqual(expected, produced)
-
- # Cycle length > 1.
- expected_elements = [
- 4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5,
- 6, 5, 6, 5, 6, 6
- ]
- for index, (expected, produced) in enumerate(
- zip_longest(expected_elements, self._interleave(input_lists, 2, 1))):
- self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
- (index, expected, produced))
-
- def testPythonImplementationBlockLength(self):
- input_lists = [[4] * 4, [5] * 5, [6] * 6] * 2
- expected_elements = [
- 4, 4, 5, 5, 4, 4, 5, 5, 5, 6, 6, 4, 4, 6, 6, 4, 4, 6, 6, 5, 5, 6, 6, 5,
- 5, 6, 6, 5, 6, 6
- ]
- for index, (expected, produced) in enumerate(
- zip_longest(expected_elements, self._interleave(input_lists, 2, 2))):
- self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
- (index, expected, produced))
-
- def testPythonImplementationEmptyLists(self):
- input_lists = [[4, 4, 4, 4], [], [6, 6, 6, 6, 6, 6], [4, 4, 4, 4], [],
- [6, 6, 6, 6, 6, 6]]
-
- expected_elements = [
- 4, 4, 6, 4, 6, 4, 6, 6, 4, 6, 4, 6, 4, 4, 6, 6, 6, 6, 6, 6
- ]
- for index, (expected, produced) in enumerate(
- zip_longest(expected_elements, self._interleave(input_lists, 2, 1))):
- self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
- (index, expected, produced))
-
- def _clear_coordination_events(self):
- for i in range(4, 7):
- self.read_coordination_events[i] = threading.Semaphore(0)
- self.write_coordination_events[i].clear()
-
- def _allow_all_map_threads(self):
- for i in range(4, 7):
- self.write_coordination_events[i].set()
-
- def _testSingleThreaded(self, sloppy=False, prefetch_input_elements=0):
- # cycle_length=1,block_length=1 acts like `Dataset.interleave()` and
- # `Dataset.flat_map()` and is single-threaded. No synchronization required.
- with self.cached_session() as sess:
- self._clear_coordination_events()
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [4, 5, 6],
- self.cycle_length: 1,
- self.block_length: 1,
- self.sloppy: sloppy,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: prefetch_input_elements,
- })
-
- for expected_element in self._interleave(
- [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 1):
- self.write_coordination_events[expected_element].set()
- self.assertEqual(expected_element * expected_element,
- sess.run(self.next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
-
- def testSingleThreaded(self):
- self._testSingleThreaded()
-
- def testSingleThreadedSloppy(self):
- self._testSingleThreaded(sloppy=True)
-
- def testSingleThreadedPrefetch1Itr(self):
- self._testSingleThreaded(prefetch_input_elements=1)
-
- def testSingleThreadedPrefetch1ItrSloppy(self):
- self._testSingleThreaded(prefetch_input_elements=1, sloppy=True)
-
- def testSingleThreadedRagged(self):
- # Tests a sequence with wildly different elements per iterator.
- with self.cached_session() as sess:
- self._clear_coordination_events()
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [3, 7, 4],
- self.cycle_length: 2,
- self.block_length: 1,
- self.sloppy: False,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: 1,
- })
-
- # Add coordination values for 3 and 7
- self.read_coordination_events[3] = threading.Semaphore(0)
- self.write_coordination_events[3] = threading.Event()
- self.read_coordination_events[7] = threading.Semaphore(0)
- self.write_coordination_events[7] = threading.Event()
-
- for expected_element in self._interleave(
- [[3] * 3, [7] * 7, [4] * 4] * self.repeat_count, 2, 1):
- self.write_coordination_events[expected_element].set()
- output = sess.run(self.next_element)
- self.assertEqual(expected_element * expected_element, output)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
-
- def _testTwoThreadsNoContention(self, sloppy=False):
- # num_threads > 1.
- # Explicit coordination should result in `Dataset.interleave()` behavior
- with self.cached_session() as sess:
- self._clear_coordination_events()
- done_first_event = False
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [4, 5, 6],
- self.cycle_length: 2,
- self.block_length: 1,
- self.sloppy: sloppy,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: 1,
- })
- for i, expected_element in enumerate(
- self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
- 1)):
- self.write_coordination_events[expected_element].set()
- if done_first_event: # First event starts the worker threads.
- self.read_coordination_events[expected_element].acquire()
- actual_element = sess.run(self.next_element)
- if not done_first_event:
- self.read_coordination_events[expected_element].acquire()
- done_first_event = True
- self.assertEqual(expected_element * expected_element, actual_element,
- "At index %s: %s expected, got: %s" %
- (i, expected_element, actual_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
-
- def testTwoThreadsNoContention(self):
- self._testTwoThreadsNoContention()
-
- def testTwoThreadsNoContentionSloppy(self):
- self._testTwoThreadsNoContention(sloppy=True)
-
- def _testTwoThreadsNoContentionWithRaces(self, sloppy=False):
- """Tests where all the workers race in producing elements.
-
- Note: this is in contrast with the previous test which carefully sequences
- the execution of the map functions.
-
- Args:
- sloppy: Whether to be sloppy or not.
- """
- with self.cached_session() as sess:
- self._clear_coordination_events()
- done_first_event = False
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [4, 5, 6],
- self.cycle_length: 2,
- self.block_length: 1,
- self.sloppy: sloppy,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: 1,
- })
- for i, expected_element in enumerate(
- self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
- 1)):
- if done_first_event: # First event starts the worker threads.
- self._allow_all_map_threads()
- self.read_coordination_events[expected_element].acquire()
- else:
- self.write_coordination_events[expected_element].set()
- time.sleep(0.5) # Sleep to consistently "avoid" the race condition.
- actual_element = sess.run(self.next_element)
- if not done_first_event:
- done_first_event = True
- self.assertTrue(
- self.read_coordination_events[expected_element].acquire(False))
- self.assertEqual(expected_element * expected_element, actual_element,
- "At index %s: %s expected, got: %s" %
- (i, expected_element, actual_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
-
- def testTwoThreadsNoContentionWithRaces(self):
- self._testTwoThreadsNoContentionWithRaces()
-
- def testTwoThreadsNoContentionWithRacesSloppy(self):
- self._testTwoThreadsNoContentionWithRaces(sloppy=True)
-
- def _testTwoThreadsNoContentionBlockLength(self, sloppy=False):
- # num_threads > 1.
- # Explicit coordination should result in `Dataset.interleave()` behavior
- with self.cached_session() as sess:
- self._clear_coordination_events()
- done_first_event = False
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [4, 5, 6],
- self.cycle_length: 2,
- self.block_length: 2,
- self.sloppy: sloppy,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: 1,
- })
- for i, expected_element in enumerate(
- self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
- 2)):
- self.write_coordination_events[expected_element].set()
- if done_first_event: # First event starts the worker threads.
- self.read_coordination_events[expected_element].acquire()
- actual_element = sess.run(self.next_element)
- if not done_first_event:
- done_first_event = True
- self.read_coordination_events[expected_element].acquire()
- self.assertEqual(expected_element * expected_element, actual_element,
- "At index %s: %s expected, got: %s" %
- (i, expected_element, actual_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
-
- def testTwoThreadsNoContentionBlockLength(self):
- self._testTwoThreadsNoContentionBlockLength()
-
- def testTwoThreadsNoContentionBlockLengthSloppy(self):
- self._testTwoThreadsNoContentionBlockLength(sloppy=True)
-
- def _testTwoThreadsNoContentionWithRacesAndBlocking(self, sloppy=False):
- """Tests where all the workers race in producing elements.
-
- Note: this is in contrast with the previous test which carefully sequences
- the execution of the map functions.
-
-
- Args:
- sloppy: Whether to be sloppy or not.
- """
- with self.cached_session() as sess:
- self._clear_coordination_events()
- done_first_event = False
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [4, 5, 6],
- self.cycle_length: 2,
- self.block_length: 2,
- self.sloppy: sloppy,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: 1,
- })
- for i, expected_element in enumerate(
- self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
- 2)):
- if done_first_event: # First event starts the worker threads.
- self._allow_all_map_threads()
- self.read_coordination_events[expected_element].acquire()
- else:
- self.write_coordination_events[expected_element].set()
- time.sleep(0.5) # Sleep to consistently "avoid" the race condition.
- actual_element = sess.run(self.next_element)
- if not done_first_event:
- done_first_event = True
- self.assertTrue(
- self.read_coordination_events[expected_element].acquire(False))
- self.assertEqual(expected_element * expected_element, actual_element,
- "At index %s: %s expected, got: %s" %
- (i, expected_element, actual_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
-
- def testTwoThreadsNoContentionWithRacesAndBlocking(self):
- self._testTwoThreadsNoContentionWithRacesAndBlocking()
-
- def testTwoThreadsNoContentionWithRacesAndBlockingSloppy(self):
- self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True)
-
- def _testEmptyInput(self, sloppy=False):
- with self.cached_session() as sess:
- # Empty input.
- self._clear_coordination_events()
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [],
- self.cycle_length: 2,
- self.block_length: 3,
- self.sloppy: sloppy,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: 0,
- })
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
-
- def testEmptyInput(self):
- self._testEmptyInput()
-
- def testEmptyInputSloppy(self):
- self._testEmptyInput(sloppy=True)
-
- def _testNonEmptyInputIntoEmptyOutputs(self, sloppy=False):
- # Non-empty input leading to empty output.
- with self.cached_session() as sess:
- self._clear_coordination_events()
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [0, 0, 0],
- self.cycle_length: 2,
- self.block_length: 3,
- self.sloppy: sloppy,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: 0,
- })
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
-
- def testNonEmptyInputIntoEmptyOutputs(self):
- self._testNonEmptyInputIntoEmptyOutputs()
-
- def testNonEmptyInputIntoEmptyOutputsSloppy(self):
- self._testNonEmptyInputIntoEmptyOutputs(sloppy=True)
-
- def _testPartiallyEmptyOutputs(self, sloppy=False, prefetch_input_elements=1):
- race_indices = {2, 8, 14} # Sequence points when sloppy mode has race conds
- # Mixture of non-empty and empty interleaved datasets.
- with self.cached_session() as sess:
- self._clear_coordination_events()
- done_first_event = False
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [4, 0, 6],
- self.cycle_length: 2,
- self.block_length: 1,
- self.sloppy: sloppy,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: prefetch_input_elements,
- })
- for i, expected_element in enumerate(
- self._interleave([[4] * 4, [], [6] * 6] * self.repeat_count, 2, 1)):
- self.write_coordination_events[expected_element].set()
- # First event starts the worker threads. Additionally, when running the
- # sloppy case with prefetch_input_elements=0, we get stuck if we wait
- # for the read coordination event for certain event orderings in the
- # presence of finishing iterators.
- if done_first_event and not (sloppy and (i in race_indices)):
- self.read_coordination_events[expected_element].acquire()
- actual_element = sess.run(self.next_element)
- if not done_first_event or (sloppy and (i in race_indices)):
- done_first_event = True
- self.read_coordination_events[expected_element].acquire()
- self.assertEqual(expected_element * expected_element, actual_element,
- "At index %s: %s expected, got: %s" %
- (i, expected_element, actual_element))
-
- def testPartiallyEmptyOutputs(self):
- self._testPartiallyEmptyOutputs()
-
- def testPartiallyEmptyOutputsSloppy(self):
- self._testPartiallyEmptyOutputs(sloppy=True, prefetch_input_elements=0)
-
- def testDelayedOutputSloppy(self):
- # Explicitly control the sequence of events to ensure we correctly avoid
- # head-of-line blocking.
- with self.cached_session() as sess:
- self._clear_coordination_events()
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [4, 5, 6],
- self.cycle_length: 2,
- self.block_length: 1,
- self.sloppy: True,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: 0,
- })
-
- mis_ordering = [
- 4, 4, 5, 4, 5, 5, 4, 5, 6, 6, 6, 5, 4, 4, 6, 6, 4, 4, 6, 5, 6, 6, 6,
- 6, 5, 5, 5, 5, 6, 6
- ]
- for element in mis_ordering:
- self.write_coordination_events[element].set()
- self.assertEqual(element * element, sess.run(self.next_element))
- self.assertTrue(self.read_coordination_events[element].acquire(False))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
-
- def testBlockLengthWithContentionSloppy(self):
- with self.cached_session() as sess:
- self._clear_coordination_events()
- done_first_event = False
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [4, 5, 6],
- self.cycle_length: 2,
- self.block_length: 1,
- self.sloppy: True,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: 1,
- })
- # Test against a generating sequence that differs from the uncontended
- # case, in order to prove sloppy correctness.
- for i, expected_element in enumerate(
- self._interleave(
- [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count,
- cycle_length=2,
- block_length=3)):
- self.write_coordination_events[expected_element].set()
- if done_first_event: # First event starts the worker threads.
- self.read_coordination_events[expected_element].acquire()
- actual_element = sess.run(self.next_element)
- if not done_first_event:
- self.read_coordination_events[expected_element].acquire()
- done_first_event = True
- self.assertEqual(expected_element * expected_element, actual_element,
- "At index %s: %s expected, got: %s" %
- (i, expected_element, actual_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
-
- def _testEarlyExit(self, sloppy=False):
- # Exiting without consuming all input should not block
- with self.cached_session() as sess:
- self._clear_coordination_events()
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [4, 5, 6],
- self.cycle_length: 3,
- self.block_length: 2,
- self.sloppy: sloppy,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: 0,
- })
- for i in range(4, 7):
- self.write_coordination_events[i].set()
- elem = sess.run(self.next_element) # Start all workers
- # Allow the one successful worker to progress beyond the py_func again.
- elem = int(math.sqrt(elem))
- self.write_coordination_events[elem].set()
- self.read_coordination_events[elem].acquire()
- # Allow the prefetch to succeed
- for i in range(4, 7):
- self.read_coordination_events[i].acquire()
- self.write_coordination_events[i].set()
-
- def testEarlyExit(self):
- self._testEarlyExit()
-
- def testEarlyExitSloppy(self):
- self._testEarlyExit(sloppy=True)
-
- def _testTooManyReaders(self, sloppy=False):
-
- def interleave_fn(x):
- dataset = dataset_ops.Dataset.from_tensors(x)
- dataset = dataset.repeat(math_ops.cast(x, dtype=dtypes.int64))
- return dataset
-
- dataset = dataset_ops.Dataset.from_tensor_slices([4, 5, 6])
- dataset = dataset.repeat(self.repeat_count)
- dataset = dataset.apply(
- interleave_ops.parallel_interleave(
- interleave_fn, cycle_length=16, block_length=2, sloppy=sloppy))
- iterator = dataset.make_one_shot_iterator()
-
- with self.cached_session() as sess:
- output_values = []
- for _ in range(30):
- output_values.append(sess.run(iterator.get_next()))
-
- expected_values = self._interleave(
- [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2)
- self.assertItemsEqual(output_values, expected_values)
-
- def testTooManyReaders(self):
- self._testTooManyReaders()
-
- def testTooManyReadersSloppy(self):
- self._testTooManyReaders(sloppy=True)
-
- def testSparse(self):
- def _map_fn(i):
- return sparse_tensor.SparseTensor(
- indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
-
- def _interleave_fn(x):
- return dataset_ops.Dataset.from_tensor_slices(
- sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
-
- dataset = dataset_ops.Dataset.range(10).map(_map_fn)
- iterator = dataset.apply(
- interleave_ops.parallel_interleave(
- _interleave_fn, cycle_length=1)).make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- for i in range(10):
- for j in range(2):
- expected = [i, 0] if j % 2 == 0 else [0, -i]
- self.assertAllEqual(expected, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testErrorsInOutputFn(self):
- with self.cached_session() as sess:
- self._clear_coordination_events()
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [4, 5, 6],
- self.cycle_length: 2,
- self.block_length: 1,
- self.sloppy: False,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: 0,
- })
-
- except_on_element_indices = set([3])
-
- for i, expected_element in enumerate(
- self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
- 1)):
- if i in except_on_element_indices:
- self.error = ValueError()
- self.write_coordination_events[expected_element].set()
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(self.next_element)
- else:
- self.write_coordination_events[expected_element].set()
- actual_element = sess.run(self.next_element)
- self.assertEqual(expected_element * expected_element, actual_element,
- "At index %s: %s expected, got: %s" %
- (i, expected_element, actual_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
-
- def testErrorsInInputFn(self):
-
- def map_py_fn(x):
- if x == 5:
- raise ValueError()
- return x
-
- def map_fn(x):
- return script_ops.py_func(map_py_fn, [x], x.dtype)
-
- def interleave_fn(x):
- dataset = dataset_ops.Dataset.from_tensors(x)
- dataset = dataset.repeat(x)
- return dataset
-
- self.dataset = (
- dataset_ops.Dataset.from_tensor_slices(self.input_values).map(map_fn)
- .repeat(self.repeat_count).apply(
- interleave_ops.parallel_interleave(interleave_fn, self.cycle_length,
- self.block_length, self.sloppy,
- self.buffer_output_elements,
- self.prefetch_input_elements)))
-
- self.iterator = self.dataset.make_initializable_iterator()
- self.init_op = self.iterator.initializer
- self.next_element = self.iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [4, 5, 6],
- self.cycle_length: 2,
- self.block_length: 1,
- self.sloppy: False,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: 0,
- })
- for i, expected_element in enumerate(
- self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)):
- if expected_element == 5:
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(self.next_element)
- else:
- actual_element = sess.run(self.next_element)
- self.assertEqual(expected_element, actual_element,
- "At index %s: %s expected, got: %s" %
- (i, expected_element, actual_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
-
- def testErrorsInInterleaveFn(self):
-
- def map_py_fn(x):
- if x == 5:
- raise ValueError()
- return x
-
- def interleave_fn(x):
- dataset = dataset_ops.Dataset.from_tensors(x)
- y = script_ops.py_func(map_py_fn, [x], x.dtype)
- dataset = dataset.repeat(y)
- return dataset
-
- self.dataset = (
- dataset_ops.Dataset.from_tensor_slices(self.input_values)
- .repeat(self.repeat_count).apply(
- interleave_ops.parallel_interleave(interleave_fn, self.cycle_length,
- self.block_length, self.sloppy,
- self.buffer_output_elements,
- self.prefetch_input_elements)))
-
- self.iterator = self.dataset.make_initializable_iterator()
- self.init_op = self.iterator.initializer
- self.next_element = self.iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(
- self.init_op,
- feed_dict={
- self.input_values: [4, 5, 6],
- self.cycle_length: 2,
- self.block_length: 1,
- self.sloppy: False,
- self.buffer_output_elements: 1,
- self.prefetch_input_elements: 0,
- })
- for i, expected_element in enumerate(
- self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)):
- if expected_element == 5:
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(self.next_element)
- else:
- actual_element = sess.run(self.next_element)
- self.assertEqual(expected_element, actual_element,
- "At index %s: %s expected, got: %s" %
- (i, expected_element, actual_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(self.next_element)
-
- def testShutdownRace(self):
- dataset = dataset_ops.Dataset.range(20)
- map_fn = lambda x: dataset_ops.Dataset.range(20 * x, 20 * (x + 1))
- dataset = dataset.apply(
- interleave_ops.parallel_interleave(
- map_fn,
- cycle_length=3,
- sloppy=False,
- buffer_output_elements=1,
- prefetch_input_elements=0))
- dataset = dataset.batch(32)
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- results = []
- with self.cached_session() as sess:
- for _ in range(2):
- elements = []
- sess.run(iterator.initializer)
- try:
- while True:
- elements.extend(sess.run(next_element))
- except errors.OutOfRangeError:
- pass
- results.append(elements)
-
- self.assertAllEqual(results[0], results[1])
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
deleted file mode 100644
index 58a1d7c93b..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
+++ /dev/null
@@ -1,125 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for experimental iterator_ops."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.ops import iterator_ops
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.estimator import estimator
-from tensorflow.python.estimator import model_fn
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import test
-from tensorflow.python.training import checkpoint_management
-from tensorflow.python.training import saver as saver_lib
-from tensorflow.python.training import training_util
-
-
-class CheckpointInputPipelineHookTest(test_base.DatasetTestBase):
-
- @staticmethod
- def _model_fn(features, labels, mode, config):
- del labels
- del mode
- del config
- global_step = training_util.get_or_create_global_step()
- update_global_step_op = global_step.assign_add(1)
- latest_feature = variables.VariableV1(
- 0, name='latest_feature', dtype=dtypes.int64)
- store_latest_feature_op = latest_feature.assign(features)
- ops.add_to_collection('my_vars', global_step)
- ops.add_to_collection('my_vars', latest_feature)
- return model_fn.EstimatorSpec(
- mode='train',
- train_op=control_flow_ops.group(
- [update_global_step_op, store_latest_feature_op]),
- loss=constant_op.constant(2.0))
-
- def _read_vars(self, model_dir):
- """Returns (global_step, latest_feature)."""
- with ops.Graph().as_default() as g:
- ckpt_path = checkpoint_management.latest_checkpoint(model_dir)
- meta_filename = ckpt_path + '.meta'
- saver_lib.import_meta_graph(meta_filename)
- saver = saver_lib.Saver()
- with self.session(graph=g) as sess:
- saver.restore(sess, ckpt_path)
- return sess.run(ops.get_collection('my_vars'))
-
- def _build_iterator_saver_hook(self, est):
- return iterator_ops.CheckpointInputPipelineHook(est)
-
- def testReturnDatasetFromInputFn(self):
-
- def _input_fn():
- return dataset_ops.Dataset.range(10)
-
- est = estimator.Estimator(model_fn=self._model_fn)
-
- est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
- self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
- est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
- self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
-
- def testBuildIteratorInInputFn(self):
-
- def _input_fn():
- ds = dataset_ops.Dataset.range(10)
- iterator = ds.make_one_shot_iterator()
- return iterator.get_next()
-
- est = estimator.Estimator(model_fn=self._model_fn)
-
- est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
- self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
- est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
- self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
-
- def testDoNotRestore(self):
-
- def _input_fn():
- return dataset_ops.Dataset.range(10)
-
- est = estimator.Estimator(model_fn=self._model_fn)
-
- est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
- self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
- est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
- self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
- # Hook not provided, input pipeline was not restored.
- est.train(_input_fn, steps=2)
- self.assertSequenceEqual(self._read_vars(est.model_dir), (6, 1))
-
- def testRaiseErrorIfNoIterator(self):
-
- def _input_fn():
- return constant_op.constant(1, dtype=dtypes.int64)
-
- est = estimator.Estimator(model_fn=self._model_fn)
-
- with self.assertRaises(ValueError):
- est.train(
- _input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
deleted file mode 100644
index 385c4ef6ea..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
+++ /dev/null
@@ -1,359 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import hashlib
-import itertools
-import os
-import time
-
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import error_ops
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.core.protobuf import config_pb2
-from tensorflow.python.client import session
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import io_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-from tensorflow.python.util import compat
-
-_NUMPY_RANDOM_SEED = 42
-
-
-class MapDatasetTest(test_base.DatasetTestBase):
-
- def testMapIgnoreError(self):
- components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
-
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(components)
- .map(lambda x: array_ops.check_numerics(x, "message")).apply(
- error_ops.ignore_errors()))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- for x in [1., 2., 3., 5.]:
- self.assertEqual(x, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testParallelMapIgnoreError(self):
- components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
-
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(components).map(
- lambda x: array_ops.check_numerics(x, "message"),
- num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors()))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- for x in [1., 2., 3., 5.]:
- self.assertEqual(x, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testReadFileIgnoreError(self):
-
- def write_string_to_file(value, filename):
- with open(filename, "w") as f:
- f.write(value)
-
- filenames = [
- os.path.join(self.get_temp_dir(), "file_%d.txt" % i) for i in range(5)
- ]
- for filename in filenames:
- write_string_to_file(filename, filename)
-
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(filenames).map(
- io_ops.read_file,
- num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors()))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- # All of the files are present.
- sess.run(init_op)
- for filename in filenames:
- self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Delete one of the files.
- os.remove(filenames[0])
-
- # Attempting to read filenames[0] will fail, but ignore_errors()
- # will catch the error.
- sess.run(init_op)
- for filename in filenames[1:]:
- self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testCaptureResourceInMapFn(self):
-
- def _build_ds(iterator):
-
- def _map_fn(x):
- get_next = iterator.get_next()
- return x * get_next
-
- return dataset_ops.Dataset.range(10).map(_map_fn)
-
- def _build_graph():
- captured_iterator = dataset_ops.Dataset.range(
- 10).make_initializable_iterator()
- ds = _build_ds(captured_iterator)
- iterator = ds.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- return captured_iterator.initializer, init_op, get_next
-
- with ops.Graph().as_default() as g:
- captured_init_op, init_op, get_next = _build_graph()
- with self.session(graph=g) as sess:
- sess.run(captured_init_op)
- sess.run(init_op)
- for i in range(10):
- self.assertEquals(i * i, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
-
-class MapDatasetBenchmark(test.Benchmark):
-
- # The purpose of this benchmark is to compare the performance of chaining vs
- # fusing of the map and batch transformations across various configurations.
- #
- # NOTE: It is recommended to build the benchmark with
- # `-c opt --copt=-mavx --copt=-mavx2 --copt=-mfma --copt=-gmlt`
- # and execute it on a machine with at least 32 CPU cores.
- def benchmarkMapAndBatch(self):
-
- # Sequential pipeline configurations.
- seq_elem_size_series = itertools.product([1], [1], [1, 2, 4, 8], [16])
- seq_batch_size_series = itertools.product([1], [1], [1], [8, 16, 32, 64])
-
- # Parallel pipeline configuration.
- par_elem_size_series = itertools.product([32], [32], [1, 2, 4, 8], [256])
- par_batch_size_series = itertools.product([32], [32], [1],
- [128, 256, 512, 1024])
- par_num_calls_series = itertools.product([8, 16, 32, 64], [32], [1], [512])
- par_inter_op_series = itertools.product([32], [8, 16, 32, 64], [1], [512])
-
- def name(method, label, num_calls, inter_op, element_size, batch_size):
- return ("%s_id_%s_num_calls_%d_inter_op_%d_elem_size_%d_batch_size_%d" % (
- method,
- hashlib.sha1(label).hexdigest(),
- num_calls,
- inter_op,
- element_size,
- batch_size,
- ))
-
- def benchmark(label, series):
-
- print("%s:" % label)
- for num_calls, inter_op, element_size, batch_size in series:
-
- num_iters = 1024 // (
- (element_size * batch_size) // min(num_calls, inter_op))
- k = 1024 * 1024
- dataset = dataset_ops.Dataset.from_tensors((np.random.rand(
- element_size, 4 * k), np.random.rand(4 * k, 1))).repeat()
-
- chained_dataset = dataset.map(
- math_ops.matmul,
- num_parallel_calls=num_calls).batch(batch_size=batch_size)
- chained_iterator = chained_dataset.make_one_shot_iterator()
- chained_get_next = chained_iterator.get_next()
-
- chained_deltas = []
- with session.Session(
- config=config_pb2.ConfigProto(
- inter_op_parallelism_threads=inter_op,
- use_per_session_threads=True)) as sess:
- for _ in range(5):
- sess.run(chained_get_next.op)
- for _ in range(num_iters):
- start = time.time()
- sess.run(chained_get_next.op)
- end = time.time()
- chained_deltas.append(end - start)
-
- fused_dataset = dataset.apply(
- batching.map_and_batch(
- math_ops.matmul,
- num_parallel_calls=num_calls,
- batch_size=batch_size))
- fused_iterator = fused_dataset.make_one_shot_iterator()
- fused_get_next = fused_iterator.get_next()
-
- fused_deltas = []
- with session.Session(
- config=config_pb2.ConfigProto(
- inter_op_parallelism_threads=inter_op,
- use_per_session_threads=True)) as sess:
-
- for _ in range(5):
- sess.run(fused_get_next.op)
- for _ in range(num_iters):
- start = time.time()
- sess.run(fused_get_next.op)
- end = time.time()
- fused_deltas.append(end - start)
-
- print(
- "batch size: %d, num parallel calls: %d, inter-op parallelism: %d, "
- "element size: %d, num iters: %d\nchained wall time: %f (median), "
- "%f (mean), %f (stddev), %f (min), %f (max)\n fused wall time: "
- "%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n "
- "chained/fused: %.2fx (median), %.2fx (mean)" %
- (batch_size, num_calls, inter_op, element_size, num_iters,
- np.median(chained_deltas), np.mean(chained_deltas),
- np.std(chained_deltas), np.min(chained_deltas),
- np.max(chained_deltas), np.median(fused_deltas),
- np.mean(fused_deltas), np.std(fused_deltas), np.min(fused_deltas),
- np.max(fused_deltas),
- np.median(chained_deltas) / np.median(fused_deltas),
- np.mean(chained_deltas) / np.mean(fused_deltas)))
-
- self.report_benchmark(
- iters=num_iters,
- wall_time=np.median(chained_deltas),
- name=name("chained", label, num_calls, inter_op, element_size,
- batch_size))
-
- self.report_benchmark(
- iters=num_iters,
- wall_time=np.median(fused_deltas),
- name=name("fused", label, num_calls, inter_op, element_size,
- batch_size))
-
- print("")
-
- np.random.seed(_NUMPY_RANDOM_SEED)
- benchmark("Sequential element size evaluation", seq_elem_size_series)
- benchmark("Sequential batch size evaluation", seq_batch_size_series)
- benchmark("Parallel element size evaluation", par_elem_size_series)
- benchmark("Parallel batch size evaluation", par_batch_size_series)
- benchmark("Transformation parallelism evaluation", par_num_calls_series)
- benchmark("Threadpool size evaluation", par_inter_op_series)
-
- # This benchmark compares the performance of pipeline with multiple chained
- # maps with and without map fusion.
- def benchmarkChainOfMaps(self):
- chain_lengths = [0, 1, 2, 5, 10, 20, 50]
- for chain_length in chain_lengths:
- self._benchmarkChainOfMaps(chain_length, False)
- self._benchmarkChainOfMaps(chain_length, True)
-
- def _benchmarkChainOfMaps(self, chain_length, optimize_dataset):
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
- for _ in range(chain_length):
- dataset = dataset.map(lambda x: x)
- if optimize_dataset:
- dataset = dataset.apply(optimization.optimize(["map_fusion"]))
-
- iterator = dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for _ in range(5):
- sess.run(next_element.op)
- deltas = []
- for _ in range(100):
- start = time.time()
- for _ in range(100):
- sess.run(next_element.op)
- end = time.time()
- deltas.append(end - start)
-
- median_wall_time = np.median(deltas) / 100
- opt_mark = "opt" if optimize_dataset else "no-opt"
- print("Map dataset {} chain length: {} Median wall time: {}".format(
- opt_mark, chain_length, median_wall_time))
- self.report_benchmark(
- iters=1000,
- wall_time=median_wall_time,
- name="benchmark_map_dataset_chain_latency_{}_{}".format(
- opt_mark, chain_length))
-
-
-class MapAndFilterBenchmark(test.Benchmark):
-
- # This benchmark compares the performance of pipeline with multiple chained
- # map + filter with and without map fusion.
- def benchmarkMapAndFilter(self):
- chain_lengths = [0, 1, 2, 5, 10, 20, 50]
- for chain_length in chain_lengths:
- self._benchmarkMapAndFilter(chain_length, False)
- self._benchmarkMapAndFilter(chain_length, True)
-
- def _benchmarkMapAndFilter(self, chain_length, optimize_dataset):
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
- for _ in range(chain_length):
- dataset = dataset.map(lambda x: x + 5).filter(
- lambda x: math_ops.greater_equal(x - 5, 0))
- if optimize_dataset:
- dataset = dataset.apply(
- optimization.optimize(["map_and_filter_fusion"]))
-
- iterator = dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for _ in range(10):
- sess.run(next_element.op)
- deltas = []
- for _ in range(100):
- start = time.time()
- for _ in range(100):
- sess.run(next_element.op)
- end = time.time()
- deltas.append(end - start)
-
- median_wall_time = np.median(deltas) / 100
- opt_mark = "opt" if optimize_dataset else "no-opt"
- print("Map and filter dataset {} chain length: {} Median wall time: {}".
- format(opt_mark, chain_length, median_wall_time))
- self.report_benchmark(
- iters=1000,
- wall_time=median_wall_time,
- name="benchmark_map_and_filter_dataset_chain_latency_{}_{}".format(
- opt_mark, chain_length))
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
deleted file mode 100644
index 751e6d5b30..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
+++ /dev/null
@@ -1,281 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for MapDefunOp."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import time
-
-from tensorflow.contrib.data.python.ops import map_defun
-from tensorflow.python.client import session
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import function
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import data_flow_ops
-from tensorflow.python.ops import functional_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class MapDefunTest(test_base.DatasetTestBase):
-
- def testMapDefunSimple(self):
-
- @function.Defun(dtypes.int32)
- def simple_fn(x):
- return x * 2 + 3
-
- nums = [[1, 2], [3, 4], [5, 6]]
- elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
- r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2,)])[0]
- expected = elems * 2 + 3
- self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
-
- def testMapDefunMismatchedTypes(self):
-
- @function.Defun(dtypes.int32)
- def fn(x):
- return math_ops.cast(x, dtypes.float64)
-
- nums = [1, 2, 3, 4, 5, 6]
- elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
- r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
- with self.assertRaises(errors.InvalidArgumentError):
- self.evaluate(r)
-
- def testMapDefunReduceDim(self):
- # Tests where the output has a different rank from the input
-
- @function.Defun(dtypes.int32)
- def fn(x):
- return array_ops.gather(x, 0)
-
- nums = [[1, 2], [3, 4], [5, 6]]
- elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
- r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
- expected = constant_op.constant([1, 3, 5])
- self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
-
- def testMapDefunMultipleOutputs(self):
-
- @function.Defun(dtypes.int32)
- def fn(x):
- return (x, math_ops.cast(x * 2 + 3, dtypes.float64))
-
- nums = [[1, 2], [3, 4], [5, 6]]
- elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
- r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64], [(2,),
- (2,)])
- expected = [elems, elems * 2 + 3]
- self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
-
- def testMapDefunShapeInference(self):
-
- @function.Defun(dtypes.int32)
- def fn(x):
- return x
-
- nums = [[1, 2], [3, 4], [5, 6]]
- elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
- result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])[0]
- self.assertEqual(result.get_shape(), (3, 2))
-
- def testMapDefunPartialShapeInference(self):
-
- @function.Defun(dtypes.int32)
- def fn(x):
- return x
-
- elems = array_ops.placeholder(dtypes.int64, (None, 2))
- result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])
- self.assertEqual(result[0].get_shape().as_list(), [None, 2])
-
- def testMapDefunRaisesErrorOnRuntimeShapeMismatch(self):
-
- @function.Defun(dtypes.int32, dtypes.int32)
- def fn(x, y):
- return x, y
-
- elems1 = array_ops.placeholder(dtypes.int32)
- elems2 = array_ops.placeholder(dtypes.int32)
- result = map_defun.map_defun(fn, [elems1, elems2],
- [dtypes.int32, dtypes.int32], [(), ()])
- with self.cached_session() as sess:
- with self.assertRaisesWithPredicateMatch(
- errors.InvalidArgumentError,
- "All inputs must have the same dimension 0."):
- sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]})
-
- def testMapDefunRaisesDefunError(self):
-
- @function.Defun(dtypes.int32)
- def fn(x):
- with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
- return array_ops.identity(x)
-
- elems = constant_op.constant([0, 0, 0, 37, 0])
- result = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])
- with self.assertRaises(errors.InvalidArgumentError):
- self.evaluate(result)
-
- def testMapDefunCancelledCorrectly(self):
-
- @function.Defun(dtypes.int64)
- def defun(x):
- # x has leading dimension 5, this will raise an error
- return array_ops.gather(x, 10)
-
- c = array_ops.tile(
- array_ops.expand_dims(
- constant_op.constant([1, 2, 3, 4, 5], dtype=dtypes.int64), 0),
- [100, 1])
- map_defun_op = map_defun.map_defun(defun, [c], [dtypes.int64], [()])[0]
- with self.assertRaisesRegexp(errors.InvalidArgumentError,
- r"indices = 10 is not in \[0, 5\)"):
- self.evaluate(map_defun_op)
-
- def testMapDefunWithUnspecifiedOutputShape(self):
-
- @function.Defun(dtypes.int32)
- def simple_fn(x):
- res = x * 2 + 3
- return (res, res + 1, res + 2)
-
- nums = [[1, 2], [3, 4], [5, 6]]
- elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
- r = map_defun.map_defun(simple_fn, [elems],
- [dtypes.int32, dtypes.int32, dtypes.int32],
- [None, (None,), (2,)])
- expected = elems * 2 + 3
- self.assertAllEqual(self.evaluate(r[0]), self.evaluate(expected))
- self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1))
- self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2))
-
- def testMapDefunWithDifferentOutputShapeEachRun(self):
-
- @function.Defun(dtypes.int32)
- def simple_fn(x):
- return x * 2 + 3
-
- elems = array_ops.placeholder(dtypes.int32, name="data")
- r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [None])[0]
- with session.Session() as sess:
- self.assertAllEqual(sess.run(r, feed_dict={elems: [0]}), [3])
- self.assertAllEqual(
- sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]])
-
- def testMapDefunWithWrongOutputShape(self):
-
- @function.Defun(dtypes.int32)
- def simple_fn(x):
- return x * 2 + 3
-
- nums = [[1, 2], [3, 4], [5, 6]]
- elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
- r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(1,)])[0]
- with self.assertRaises(errors.InvalidArgumentError):
- self.evaluate(r)
-
- def testMapDefunWithInvalidInput(self):
-
- @function.Defun(dtypes.int32)
- def simple_fn(x):
- return x * 2
-
- c = constant_op.constant(2)
- with self.assertRaises(ValueError):
- # Fails at graph construction time for inputs with known shapes.
- r = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [None])[0]
- p = array_ops.placeholder(dtypes.int32)
- r = map_defun.map_defun(simple_fn, [p], [dtypes.int32], [None])[0]
- with session.Session() as sess:
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(r, feed_dict={p: 0})
-
- def _assert_op_cancelled(self, sess, map_defun_op):
- with self.assertRaisesRegexp(errors.CancelledError, "was cancelled"):
- sess.run(map_defun_op)
-
- def testMapDefunWithParentCancellation(self):
- # Checks that a cancellation of the parent graph is threaded through to
- # MapDefunOp correctly.
- @function.Defun(dtypes.int32)
- def simple_fn(x):
- del x
- queue = data_flow_ops.FIFOQueue(10, dtypes.int32, ())
- # Blocking
- return queue.dequeue_many(5)
-
- c = constant_op.constant([1, 2, 3, 4, 5])
- map_defun_op = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [()])[0]
-
- with self.cached_session() as sess:
- thread = self.checkedThread(
- self._assert_op_cancelled, args=(sess, map_defun_op))
- thread.start()
- time.sleep(0.1)
- sess.close()
- thread.join()
-
-
-class MapDefunBenchmark(test.Benchmark):
-
- def _run(self, op, name=None, num_iters=3000):
- with session.Session() as sess:
- # Warm up the session
- for _ in range(5):
- sess.run(op)
- start = time.time()
- for _ in range(num_iters):
- sess.run(op)
- end = time.time()
- mean_us = (end - start) * 1e6 / num_iters
- self.report_benchmark(
- name=name,
- iters=num_iters,
- wall_time=mean_us,
- extras={"examples_per_sec": num_iters / (end - start)})
-
- def benchmarkDefunVsMapFn(self):
- """Benchmarks to compare the performance of MapDefun vs tf.map_fn."""
-
- @function.Defun(dtypes.int32)
- def defun(x):
- return array_ops.identity(x)
-
- def map_fn(x):
- return array_ops.identity(x)
-
- base = math_ops.range(100)
- for input_size in [10, 100, 1000, 10000]:
- num_iters = 100000 // input_size
- map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()])
- map_fn_op = functional_ops.map_fn(map_fn, base)
-
- self._run(
- map_defun_op,
- "benchmarkMapDefun_size_%d" % input_size,
- num_iters=num_iters)
- self._run(
- map_fn_op, "benchmarkMapFn_size_%d" % input_size, num_iters=num_iters)
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
deleted file mode 100644
index d7b5edcd9a..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
+++ /dev/null
@@ -1,164 +0,0 @@
-package(default_visibility = ["//tensorflow:internal"])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-load("//tensorflow:tensorflow.bzl", "py_test")
-
-py_test(
- name = "assert_next_dataset_op_test",
- size = "medium",
- srcs = ["assert_next_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "hoist_random_uniform_test",
- size = "small",
- srcs = ["hoist_random_uniform_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:math_ops",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "latency_all_edges_test",
- size = "small",
- srcs = ["latency_all_edges_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base",
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/contrib/data/python/ops:stats_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "map_vectorization_test",
- size = "small",
- srcs = ["map_vectorization_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:session",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "map_and_filter_fusion_test",
- size = "medium",
- srcs = ["map_and_filter_fusion_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:math_ops",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "map_parallelization_test",
- size = "small",
- srcs = ["map_parallelization_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:math_ops",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "model_dataset_op_test",
- size = "medium",
- srcs = ["model_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "optonly",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/contrib/data/python/ops:interleave_ops",
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "noop_elimination_test",
- size = "small",
- srcs = ["noop_elimination_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/contrib/data/python/ops:interleave_ops",
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "optimize_dataset_op_test",
- size = "small",
- srcs = ["optimize_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
deleted file mode 100644
index fe1b5280ba..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
+++ /dev/null
@@ -1,65 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import errors
-from tensorflow.python.platform import test
-
-
-class AssertNextDatasetTest(test_base.DatasetTestBase):
-
- def testAssertNext(self):
- dataset = dataset_ops.Dataset.from_tensors(0).apply(
- optimization.assert_next(["Map"])).map(lambda x: x)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.assertEqual(0, sess.run(get_next))
-
- def testAssertNextInvalid(self):
- dataset = dataset_ops.Dataset.from_tensors(0).apply(
- optimization.assert_next(["Whoops"])).map(lambda x: x)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- "Asserted Whoops transformation at offset 0 but encountered "
- "Map transformation instead."):
- sess.run(get_next)
-
- def testAssertNextShort(self):
- dataset = dataset_ops.Dataset.from_tensors(0).apply(
- optimization.assert_next(["Map", "Whoops"])).map(lambda x: x)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- "Asserted next 2 transformations but encountered only 1."):
- sess.run(get_next)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py
deleted file mode 100644
index b43efb5c7c..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py
+++ /dev/null
@@ -1,103 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for HostState optimization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from absl.testing import parameterized
-
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.platform import test
-
-
-class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
-
- @staticmethod
- def map_functions():
- plus_one = lambda x: x + 1
-
- def random(_):
- return random_ops.random_uniform([],
- minval=1,
- maxval=10,
- dtype=dtypes.float32,
- seed=42)
-
- def random_with_assert(x):
- y = random(x)
- assert_op = control_flow_ops.Assert(math_ops.greater_equal(y, 1), [y])
- with ops.control_dependencies([assert_op]):
- return y
-
- twice_random = lambda x: (random(x) + random(x)) / 2.
-
- tests = [("PlusOne", plus_one, False), ("RandomUniform", random, True),
- ("RandomWithAssert", random_with_assert, True),
- ("TwiceRandom", twice_random, False)]
- return tuple(tests)
-
- @parameterized.named_parameters(*map_functions.__func__())
- def testHoisting(self, function, will_optimize):
- dataset = dataset_ops.Dataset.range(5).apply(
- optimization.assert_next(
- ["Zip[0]", "Map"] if will_optimize else ["Map"])).map(function)
-
- dataset = dataset.apply(optimization.optimize(["hoist_random_uniform"]))
- self._testDataset(dataset)
-
- def testAdditionalInputs(self):
- a = constant_op.constant(1, dtype=dtypes.float32)
- b = constant_op.constant(0, dtype=dtypes.float32)
- some_tensor = math_ops.mul(a, b)
-
- def random_with_capture(_):
- return some_tensor + random_ops.random_uniform(
- [], minval=1, maxval=10, dtype=dtypes.float32, seed=42)
-
- dataset = dataset_ops.Dataset.range(5).apply(
- optimization.assert_next(
- ["Zip[0]", "Map"])).map(random_with_capture).apply(
- optimization.optimize(["hoist_random_uniform"]))
- self._testDataset(dataset)
-
- def _testDataset(self, dataset):
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
- previous_result = 0
- with self.cached_session() as sess:
- for _ in range(5):
- result = sess.run(get_next)
- self.assertLessEqual(1, result)
- self.assertLessEqual(result, 10)
- # This checks if the result is somehow random by checking if we are not
- # generating the same values.
- self.assertNotEqual(previous_result, result)
- previous_result = result
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py
deleted file mode 100644
index e4f18222fd..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the LatencyAllEdges optimization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.contrib.data.python.ops import stats_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import errors
-from tensorflow.python.platform import test
-
-
-class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
-
- def testLatencyStatsOptimization(self):
-
- stats_aggregator = stats_ops.StatsAggregator()
- dataset = dataset_ops.Dataset.from_tensors(1).apply(
- optimization.assert_next(
- ["LatencyStats", "Map", "LatencyStats", "Prefetch",
- "LatencyStats"])).map(lambda x: x * x).prefetch(1).apply(
- stats_ops.set_stats_aggregator(stats_aggregator)).apply(
- optimization.optimize(["latency_all_edges"]))
- iterator = dataset.make_initializable_iterator()
- get_next = iterator.get_next()
- summary_t = stats_aggregator.get_summary()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- self.assertEqual(1 * 1, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
- summary_str = sess.run(summary_t)
- self._assertSummaryHasCount(summary_str,
- "record_latency_TensorDataset/_1", 1)
- self._assertSummaryHasCount(summary_str, "record_latency_MapDataset/_4",
- 1)
- self._assertSummaryHasCount(summary_str,
- "record_latency_PrefetchDataset/_6", 1)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
deleted file mode 100644
index e9e3fc81e5..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
+++ /dev/null
@@ -1,225 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the MapAndFilterFusion optimization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from absl.testing import parameterized
-
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
-
- @staticmethod
- def map_functions():
- identity = lambda x: x
- increment = lambda x: x + 1
-
- def increment_and_square(x):
- y = x + 1
- return y * y
-
- functions = [identity, increment, increment_and_square]
- tests = []
- for i, fun1 in enumerate(functions):
- for j, fun2 in enumerate(functions):
- tests.append((
- "Test{}{}".format(i, j),
- [fun1, fun2],
- ))
- for k, fun3 in enumerate(functions):
- tests.append((
- "Test{}{}{}".format(i, j, k),
- [fun1, fun2, fun3],
- ))
-
- swap = lambda x, n: (n, x)
- tests.append((
- "Swap1",
- [lambda x: (x, 42), swap],
- ))
- tests.append((
- "Swap2",
- [lambda x: (x, 42), swap, swap],
- ))
- return tuple(tests)
-
- @parameterized.named_parameters(*map_functions.__func__())
- def testMapFusion(self, functions):
- dataset = dataset_ops.Dataset.range(5).apply(
- optimization.assert_next(["Map", "Prefetch"]))
- for function in functions:
- dataset = dataset.map(function)
-
- dataset = dataset.prefetch(0).apply(optimization.optimize(["map_fusion"]))
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- for x in range(5):
- result = sess.run(get_next)
- r = x
- for function in functions:
- if isinstance(r, tuple):
- r = function(*r) # Pass tuple as multiple arguments.
- else:
- r = function(r)
- self.assertAllEqual(r, result)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- @staticmethod
- def map_and_filter_functions():
- identity = lambda x: x
- increment = lambda x: x + 1
- minus_five = lambda x: x - 5
-
- def increment_and_square(x):
- y = x + 1
- return y * y
-
- take_all = lambda x: constant_op.constant(True)
- is_zero = lambda x: math_ops.equal(x, 0)
- is_odd = lambda x: math_ops.equal(x % 2, 0)
- greater = lambda x: math_ops.greater(x + 5, 0)
-
- functions = [identity, increment, minus_five, increment_and_square]
- filters = [take_all, is_zero, is_odd, greater]
- tests = []
-
- for x, fun in enumerate(functions):
- for y, predicate in enumerate(filters):
- tests.append(("Mixed{}{}".format(x, y), fun, predicate))
-
- # Multi output
- tests.append(("Multi1", lambda x: (x, x),
- lambda x, y: constant_op.constant(True)))
- tests.append(
- ("Multi2", lambda x: (x, 2),
- lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)))
- return tuple(tests)
-
- @parameterized.named_parameters(*map_and_filter_functions.__func__())
- def testMapFilterFusion(self, function, predicate):
- dataset = dataset_ops.Dataset.range(10).apply(
- optimization.assert_next(
- ["Map",
- "FilterByLastComponent"])).map(function).filter(predicate).apply(
- optimization.optimize(["map_and_filter_fusion"]))
- self._testMapAndFilter(dataset, function, predicate)
-
- def _testMapAndFilter(self, dataset, function, predicate):
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- for x in range(10):
- r = function(x)
- if isinstance(r, tuple):
- b = predicate(*r) # Pass tuple as multiple arguments.
- else:
- b = predicate(r)
- if sess.run(b):
- result = sess.run(get_next)
- self.assertAllEqual(r, result)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testAdditionalInputs(self):
- a = constant_op.constant(3, dtype=dtypes.int64)
- b = constant_op.constant(4, dtype=dtypes.int64)
- some_tensor = math_ops.mul(a, b)
- function = lambda x: x * x
-
- def predicate(y):
- return math_ops.less(math_ops.cast(y, dtypes.int64), some_tensor)
-
- # We are currently not supporting functions with additional inputs.
- dataset = dataset_ops.Dataset.range(10).apply(
- optimization.assert_next(
- ["Map", "Filter"])).map(function).filter(predicate).apply(
- optimization.optimize(["map_and_filter_fusion"]))
-
- self._testMapAndFilter(dataset, function, predicate)
-
- @staticmethod
- def filter_functions():
- take_all = lambda x: constant_op.constant(True)
- is_zero = lambda x: math_ops.equal(x, 0)
- greater = lambda x: math_ops.greater(x + 5, 0)
-
- tests = []
- filters = [take_all, is_zero, greater]
- identity = lambda x: x
- for x, predicate_1 in enumerate(filters):
- for y, predicate_2 in enumerate(filters):
- tests.append(("Mixed{}{}".format(x, y), identity,
- [predicate_1, predicate_2]))
- for z, predicate_3 in enumerate(filters):
- tests.append(("Mixed{}{}{}".format(x, y, z), identity,
- [predicate_1, predicate_2, predicate_3]))
-
- take_all_multiple = lambda x, y: constant_op.constant(True)
- # Multi output
- tests.append(("Multi1", lambda x: (x, x),
- [take_all_multiple, take_all_multiple]))
- tests.append(("Multi2", lambda x: (x, 2), [
- take_all_multiple,
- lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
- ]))
- return tuple(tests)
-
- @parameterized.named_parameters(*filter_functions.__func__())
- def testFilterFusion(self, map_function, predicates):
- dataset = dataset_ops.Dataset.range(5).apply(
- optimization.assert_next(["Map", "Filter",
- "Prefetch"])).map(map_function)
- for predicate in predicates:
- dataset = dataset.filter(predicate)
-
- dataset = dataset.prefetch(0).apply(
- optimization.optimize(["filter_fusion"]))
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- for x in range(5):
- r = map_function(x)
- filtered = False
- for predicate in predicates:
- if isinstance(r, tuple):
- b = predicate(*r) # Pass tuple as multiple arguments.
- else:
- b = predicate(r)
- if not sess.run(b):
- filtered = True
- break
-
- if not filtered:
- result = sess.run(get_next)
- self.assertAllEqual(r, result)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py
deleted file mode 100644
index f7907eb890..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py
+++ /dev/null
@@ -1,85 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the MapParallelization optimization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from absl.testing import parameterized
-
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.platform import test
-
-
-class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
-
- @staticmethod
- def map_functions():
- identity = lambda x: x
- increment = lambda x: x + 1
-
- def assert_greater(x):
- assert_op = control_flow_ops.Assert(math_ops.greater(x, -1), [x])
- with ops.control_dependencies([assert_op]):
- return x
-
- def random(_):
- return random_ops.random_uniform([],
- minval=0,
- maxval=10,
- dtype=dtypes.int64,
- seed=42)
-
- def assert_with_random(x):
- x = assert_greater(x)
- return random(x)
-
- return (("Identity", identity, True), ("Increment", increment, True),
- ("AssertGreater", assert_greater, True), ("Random", random, False),
- ("AssertWithRandom", assert_with_random, False))
-
- @parameterized.named_parameters(*map_functions.__func__())
- def testMapParallelization(self, function, should_optimize):
- next_nodes = ["ParallelMap"] if should_optimize else ["Map"]
- dataset = dataset_ops.Dataset.range(5).apply(
- optimization.assert_next(next_nodes)).map(function).apply(
- optimization.optimize(["map_parallelization"]))
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- for x in range(5):
- result = sess.run(get_next)
- # No need to run the pipeline if it was not optimized. Also the results
- # might be hard to check because of random.
- if not should_optimize:
- return
- r = function(x)
- self.assertAllEqual(r, result)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
deleted file mode 100644
index a5ea85f454..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
+++ /dev/null
@@ -1,223 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the MapVectorization optimization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import time
-
-from absl.testing import parameterized
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.python.client import session
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
-
- def _get_test_datasets(self,
- base_dataset,
- map_fn,
- num_parallel_calls=None,
- expect_optimized=True):
- """Given base dataset and map fn, creates test datasets.
-
- Returns a tuple of (unoptimized, dataset, optimized dataset). The
- unoptimized dataset has the assertion that Batch follows Map. The optimized
- dataset has the assertion that Map follows Batch, and has the
- "map_vectorization" optimization applied.
-
- Args:
- base_dataset: Input dataset to map->batch
- map_fn: Map function to use
- num_parallel_calls: (Optional.) num_parallel_calls argument for map
- expect_optimized: (Optional.) Whether we expect the optimization to take
- place, in which case we will assert that Batch is followed by Map,
- otherwise Map followed by Batch. Defaults to True.
-
- Returns:
- Tuple of (unoptimized dataset, optimized dataset).
- """
- map_node_name = "Map" if num_parallel_calls is None else "ParallelMap"
- batch_size = 100
-
- def _make_dataset(node_names):
- return base_dataset.apply(optimization.assert_next(node_names)).map(
- map_fn, num_parallel_calls=num_parallel_calls).batch(batch_size)
-
- unoptimized = _make_dataset([map_node_name, "Batch"])
- optimized = _make_dataset(["Batch", map_node_name] if expect_optimized else
- [map_node_name, "Batch"]).apply(
- optimization.optimize(["map_vectorization"]))
-
- return unoptimized, optimized
-
- @parameterized.named_parameters(
- ("Basic", lambda x: (x, x + 1), None),
- ("Parallel", lambda x: (x, x + 1), 12),
- ("Gather", lambda x: array_ops.gather(x, 0), 12),
- )
- def testOptimization(self, map_fn, num_parallel_calls):
- base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
- [3, 4]]).repeat(5)
- unoptimized, optimized = self._get_test_datasets(base_dataset, map_fn,
- num_parallel_calls)
- self.assertDatasetsEqual(unoptimized, optimized)
-
- def testOptimizationBadMapFn(self):
- # Test map functions that give an error
- def map_fn(x):
- # x has leading dimension 5, this will raise an error
- return array_ops.gather(x, 10)
-
- base_dataset = dataset_ops.Dataset.range(5).repeat(5).batch(
- 5, drop_remainder=True)
- _, optimized = self._get_test_datasets(base_dataset, map_fn)
- nxt = optimized.make_one_shot_iterator().get_next()
- with self.assertRaisesRegexp(errors.InvalidArgumentError,
- r"indices = 10 is not in \[0, 5\)"):
- self.evaluate(nxt)
-
- def testOptimizationWithCapturedInputs(self):
- # Tests that vectorization works with captured inputs
- def map_fn(x):
- return x + y
-
- y = constant_op.constant(1, shape=(2,))
- base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
- [3, 4]]).repeat(5)
- # TODO(rachelim): when this optimization works, turn on expect_optimized
- unoptimized, optimized = self._get_test_datasets(
- base_dataset, map_fn, expect_optimized=False)
- self.assertDatasetsEqual(optimized, unoptimized)
-
- def testOptimizationIgnoreStateful(self):
-
- def map_fn(x):
- with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
- return array_ops.identity(x)
-
- base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
- [3, 4]]).repeat(5)
- unoptimized, optimized = self._get_test_datasets(
- base_dataset, map_fn, expect_optimized=False)
- self.assertDatasetsRaiseSameError(
- unoptimized, optimized, errors.InvalidArgumentError,
- [("OneShotIterator", "OneShotIterator_1", 1),
- ("IteratorGetNext", "IteratorGetNext_1", 1)])
-
- def testOptimizationIgnoreRagged(self):
- # Make sure we ignore inputs that might not be uniformly sized
- def map_fn(x):
- return array_ops.gather(x, 0)
-
- # output_shape = (?,)
- base_dataset = dataset_ops.Dataset.range(20).batch(3, drop_remainder=False)
- unoptimized, optimized = self._get_test_datasets(
- base_dataset, map_fn, expect_optimized=False)
- self.assertDatasetsEqual(unoptimized, optimized)
-
- def testOptimizationIgnoreRaggedMap(self):
- # Don't optimize when the output of the map fn shapes are unknown.
- def map_fn(x):
- return array_ops.tile(x, x)
-
- base_dataset = dataset_ops.Dataset.range(20).batch(1, drop_remainder=True)
- unoptimized, optimized = self._get_test_datasets(
- base_dataset, map_fn, expect_optimized=False)
- self.assertDatasetsRaiseSameError(
- unoptimized, optimized, errors.InvalidArgumentError,
- [("OneShotIterator", "OneShotIterator_1", 1),
- ("IteratorGetNext", "IteratorGetNext_1", 1)])
-
-
-class MapVectorizationBenchmark(test.Benchmark):
- # TODO(rachelim): Add a benchmark for more expensive transformations, such as
- # vgg_preprocessing.
-
- def _run(self, x, num_iters=100, name=None):
- deltas = []
- with session.Session() as sess:
- for _ in range(5):
- # Warm up session...
- sess.run(x)
- for _ in range(num_iters):
- start = time.time()
- sess.run(x)
- end = time.time()
- deltas.append(end - start)
- median_time = np.median(deltas)
- self.report_benchmark(iters=num_iters, wall_time=median_time, name=name)
- return median_time
-
- def _compare(self, input_dataset, map_fn, batch_size, input_size, str_id):
- num_elems = np.prod(input_size)
- name_template = "{}__batch_size_{}_input_size_{}_{}"
- unoptimized = input_dataset.map(map_fn).batch(batch_size)
- unoptimized_op = unoptimized.make_one_shot_iterator().get_next()
-
- optimized = unoptimized.apply(optimization.optimize(["map_vectorization"]))
- optimized_op = optimized.make_one_shot_iterator().get_next()
-
- unoptimized_time = self._run(
- unoptimized_op,
- name=name_template.format(str_id, batch_size, num_elems, "unoptimized"))
- optimized_time = self._run(
- optimized_op,
- name=name_template.format(str_id, batch_size, num_elems, "optimized"))
-
- print("Batch size: {}\n"
- "Input size: {}\n"
- "Transformation: {}\n"
- "Speedup: {}\n".format(batch_size, input_size, str_id,
- (unoptimized_time / optimized_time)))
-
- # Known cheap functions
- def benchmarkIdentity(self):
- self._benchmark_helper(lambda *args: [array_ops.identity(x) for x in args],
- "identity")
-
- def benchmarkAddConst(self):
- self._benchmark_helper(lambda *args: [x + 1 for x in args], "add_const")
-
- def benchmarkSelect(self):
- self._benchmark_helper(lambda *args: args[0], "select")
-
- def benchmarkCast(self):
- self._benchmark_helper(
- lambda *args: [math_ops.cast(x, dtypes.float64) for x in args], "cast")
-
- def _benchmark_helper(self, map_fn, str_id):
- input_sizes = [(10, 10, 3), (10, 100, 300)]
- batch_size = 1000
- for input_size in input_sizes:
- input_dataset = dataset_ops.Dataset.from_tensor_slices(
- (np.random.rand(*input_size), np.random.rand(*input_size))).repeat()
- self._compare(input_dataset, map_fn, batch_size, input_size, str_id)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
deleted file mode 100644
index 33c250ab2a..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
+++ /dev/null
@@ -1,183 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import time
-
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class ModelDatasetTest(test_base.DatasetTestBase):
-
- def testModelMap(self):
- k = 1024 * 1024
- dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
- np.random.rand(4 * k,
- 1))).repeat()
- dataset = dataset.map(math_ops.matmul)
- iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
- get_next = iterator.get_next()
-
- deltas = []
- with self.cached_session() as sess:
- for _ in range(5):
- sess.run(get_next.op)
- for _ in range(100):
- start = time.time()
- sess.run(get_next.op)
- end = time.time()
- deltas.append(end - start)
-
- print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
- (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
- np.max(deltas)))
-
- def testModelParallelMap(self):
- k = 1024 * 1024
- dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
- np.random.rand(4 * k,
- 1))).repeat()
- dataset = dataset.map(
- math_ops.matmul, num_parallel_calls=optimization.AUTOTUNE)
- iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
- get_next = iterator.get_next()
-
- deltas = []
- with self.cached_session() as sess:
- for _ in range(5):
- sess.run(get_next.op)
- for _ in range(1000):
- start = time.time()
- sess.run(get_next.op)
- end = time.time()
- deltas.append(end - start)
-
- print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
- (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
- np.max(deltas)))
-
- def testModelMapAndBatch(self):
- batch_size = 16
- k = 1024 * 1024
- dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
- np.random.rand(4 * k,
- 1))).repeat()
- dataset = dataset.apply(
- batching.map_and_batch(
- math_ops.matmul,
- num_parallel_calls=optimization.AUTOTUNE,
- batch_size=batch_size))
- iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
- get_next = iterator.get_next()
-
- deltas = []
- with self.cached_session() as sess:
- for _ in range(5):
- sess.run(get_next.op)
- for _ in range(10):
- start = time.time()
- sess.run(get_next.op)
- end = time.time()
- deltas.append(end - start)
-
- print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
- (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
- np.max(deltas)))
-
- def testModelParallelInterleave(self):
- k = 1024 * 1024
- dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
- np.random.rand(4 * k,
- 1))).repeat()
- dataset = dataset.map(math_ops.matmul)
- dataset = dataset_ops.Dataset.range(1).repeat().interleave(
- lambda _: dataset,
- cycle_length=10,
- num_parallel_calls=optimization.AUTOTUNE)
- iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
- get_next = iterator.get_next()
-
- deltas = []
- with self.cached_session() as sess:
- for _ in range(5):
- sess.run(get_next.op)
- for _ in range(1000):
- start = time.time()
- sess.run(get_next.op)
- end = time.time()
- deltas.append(end - start)
-
- print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
- (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
- np.max(deltas)))
-
- def testModelNested(self):
- k = 1024 * 1024
- a = (np.random.rand(1, 8 * k), np.random.rand(8 * k, 1))
- b = (np.random.rand(1, 4 * k), np.random.rand(4 * k, 1))
- c = (np.random.rand(1, 2 * k), np.random.rand(2 * k, 1))
- dataset = dataset_ops.Dataset.from_tensors((a, b, c)).repeat()
-
- def f1(a, b, c):
- x, y = a
- return math_ops.matmul(x, y), b, c
-
- def f2(a, b, c):
- x, y = b
- return a, math_ops.matmul(x, y), c
-
- def f3(a, b, c):
- x, y = c
- return a, b, math_ops.matmul(x, y)
-
- dataset = dataset.map(f1, num_parallel_calls=optimization.AUTOTUNE)
- dataset = dataset_ops.Dataset.range(1).repeat().interleave(
- lambda _: dataset, cycle_length=2)
-
- dataset = dataset.map(f2, num_parallel_calls=optimization.AUTOTUNE)
- dataset = dataset_ops.Dataset.range(1).repeat().interleave(
- lambda _: dataset, cycle_length=2)
-
- dataset = dataset.map(f3, num_parallel_calls=optimization.AUTOTUNE)
- iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
- get_next = iterator.get_next()
-
- deltas = []
- with self.cached_session() as sess:
- for _ in range(5):
- sess.run(get_next)
- for _ in range(100):
- start = time.time()
- sess.run(get_next)
- end = time.time()
- deltas.append(end - start)
-
- print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
- (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
- np.max(deltas)))
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py
deleted file mode 100644
index b9e60cfa4e..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the MapParallelization optimization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class NoopEliminationTest(test_base.DatasetTestBase):
-
- def testNoopElimination(self):
- a = constant_op.constant(1, dtype=dtypes.int64)
- b = constant_op.constant(2, dtype=dtypes.int64)
- some_tensor = math_ops.mul(a, b)
-
- dataset = dataset_ops.Dataset.range(5)
- dataset = dataset.apply(
- optimization.assert_next(
- ["FiniteRepeat", "FiniteSkip", "Prefetch", "Prefetch"]))
- dataset = dataset.repeat(some_tensor).skip(5).prefetch(0).take(-1).skip(
- 0).repeat(1).prefetch(0)
- dataset = dataset.apply(optimization.optimize(["noop_elimination"]))
-
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.test_session() as sess:
- for x in range(5):
- result = sess.run(get_next)
- self.assertAllEqual(result, x)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
deleted file mode 100644
index 04f499f8c5..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
+++ /dev/null
@@ -1,109 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.platform import test
-
-
-class OptimizeDatasetTest(test_base.DatasetTestBase):
-
- def testOptimizationDefault(self):
- dataset = dataset_ops.Dataset.range(10).apply(
- optimization.assert_next(
- ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply(
- optimization.optimize())
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testOptimizationEmpty(self):
- dataset = dataset_ops.Dataset.range(10).apply(
- optimization.assert_next(
- ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply(
- optimization.optimize([]))
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testOptimizationFusion(self):
- dataset = dataset_ops.Dataset.range(10).apply(
- optimization.assert_next(
- ["MapAndBatch"])).map(lambda x: x * x).batch(10).apply(
- optimization.optimize(["map_and_batch_fusion"]))
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testOptimizationStatefulFunction(self):
- dataset = dataset_ops.Dataset.range(10).map(
- lambda _: random_ops.random_uniform([])).batch(10).apply(
- optimization.optimize(["map_and_batch_fusion"]))
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(get_next)
-
- def testOptimizationLargeInputFromTensor(self):
- input_t = array_ops.placeholder(dtypes.int32, (None, None, None))
- dataset = dataset_ops.Dataset.from_tensors(input_t).apply(
- optimization.optimize())
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)})
- sess.run(get_next)
-
- def testOptimizationLargeInputFromTensorSlices(self):
- input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None))
- dataset = dataset_ops.Dataset.from_tensor_slices(input_t).apply(
- optimization.optimize())
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)})
- sess.run(get_next)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
deleted file mode 100644
index 66ccaceea5..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
+++ /dev/null
@@ -1,851 +0,0 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for tensorflow.ops.parsing_ops."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import copy
-
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import parsing_ops as contrib_parsing_ops
-from tensorflow.core.example import example_pb2
-from tensorflow.core.example import feature_pb2
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors_impl
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import parsing_ops
-from tensorflow.python.platform import test
-from tensorflow.python.platform import tf_logging
-
-# Helpers for creating Example objects
-example = example_pb2.Example
-feature = feature_pb2.Feature
-features = lambda d: feature_pb2.Features(feature=d)
-bytes_feature = lambda v: feature(bytes_list=feature_pb2.BytesList(value=v))
-int64_feature = lambda v: feature(int64_list=feature_pb2.Int64List(value=v))
-float_feature = lambda v: feature(float_list=feature_pb2.FloatList(value=v))
-# Helpers for creating SequenceExample objects
-feature_list = lambda l: feature_pb2.FeatureList(feature=l)
-feature_lists = lambda d: feature_pb2.FeatureLists(feature_list=d)
-sequence_example = example_pb2.SequenceExample
-
-
-def _compare_output_to_expected(tester, dict_tensors, expected_tensors,
- flat_output):
- tester.assertEqual(set(dict_tensors.keys()), set(expected_tensors.keys()))
-
- i = 0 # Index into the flattened output of session.run()
- for k, v in sorted(dict_tensors.items()):
- # TODO(shivaniagrawal): flat_output is same as v.
- expected_v = expected_tensors[k]
- tf_logging.info("Comparing key: %s", k)
- print("i", i, "flat_output", flat_output[i], "expected_v", expected_v)
- if sparse_tensor.is_sparse(v):
- # Three outputs for SparseTensor : indices, values, shape.
- tester.assertEqual([k, len(expected_v)], [k, 3])
- print("i", i, "flat_output", flat_output[i].indices, "expected_v",
- expected_v[0])
- tester.assertAllEqual(expected_v[0], flat_output[i].indices)
- tester.assertAllEqual(expected_v[1], flat_output[i].values)
- tester.assertAllEqual(expected_v[2], flat_output[i].dense_shape)
- else:
- # One output for standard Tensor.
- tester.assertAllEqual(expected_v, flat_output[i])
- i += 1
-
-
-class ParseExampleTest(test_base.DatasetTestBase):
-
- def _test(self,
- input_tensor,
- feature_val,
- expected_values=None,
- expected_err=None):
-
- with self.cached_session() as sess:
- if expected_err:
- with self.assertRaisesWithPredicateMatch(expected_err[0],
- expected_err[1]):
- dataset = dataset_ops.Dataset.from_tensors(input_tensor).apply(
- contrib_parsing_ops.parse_example_dataset(feature_val))
- get_next = dataset.make_one_shot_iterator().get_next()
- sess.run(get_next)
- return
- else:
- # Returns dict w/ Tensors and SparseTensors.
- # Check values.
- dataset = dataset_ops.Dataset.from_tensors(input_tensor).apply(
- contrib_parsing_ops.parse_example_dataset(feature_val))
- get_next = dataset.make_one_shot_iterator().get_next()
- result = sess.run(get_next)
- flattened = nest.flatten(result)
- print("result", result, "expected_values", expected_values)
- _compare_output_to_expected(self, result, expected_values, flattened)
-
- # Check shapes; if serialized is a Tensor we need its size to
- # properly check.
- batch_size = (
- input_tensor.eval().size if isinstance(input_tensor, ops.Tensor) else
- np.asarray(input_tensor).size)
- for k, f in feature_val.items():
- print("output_shapes as list ",
- tuple(dataset.output_shapes[k].as_list()))
- if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None:
- self.assertEqual(dataset.output_shapes[k].as_list()[0], batch_size)
- elif isinstance(f, parsing_ops.VarLenFeature):
- self.assertEqual(dataset.output_shapes[k].as_list()[1], None)
-
- def testEmptySerializedWithAllDefaults(self):
- sparse_name = "st_a"
- a_name = "a"
- b_name = "b"
- c_name = "c:has_a_tricky_name"
- a_default = [0, 42, 0]
- b_default = np.random.rand(3, 3).astype(bytes)
- c_default = np.random.rand(2).astype(np.float32)
-
- expected_st_a = ( # indices, values, shape
- np.empty(
- (0, 2), dtype=np.int64), # indices
- np.empty(
- (0,), dtype=np.int64), # sp_a is DT_INT64
- np.array(
- [2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
-
- expected_output = {
- sparse_name: expected_st_a,
- a_name: np.array(2 * [[a_default]]),
- b_name: np.array(2 * [b_default]),
- c_name: np.array(2 * [c_default]),
- }
-
- self._test(
- ops.convert_to_tensor(["", ""]), {
- sparse_name:
- parsing_ops.VarLenFeature(dtypes.int64),
- a_name:
- parsing_ops.FixedLenFeature(
- (1, 3), dtypes.int64, default_value=a_default),
- b_name:
- parsing_ops.FixedLenFeature(
- (3, 3), dtypes.string, default_value=b_default),
- c_name:
- parsing_ops.FixedLenFeature(
- (2,), dtypes.float32, default_value=c_default),
- },
- expected_values=expected_output)
-
- def testEmptySerializedWithoutDefaultsShouldFail(self):
- input_features = {
- "st_a":
- parsing_ops.VarLenFeature(dtypes.int64),
- "a":
- parsing_ops.FixedLenFeature(
- (1, 3), dtypes.int64, default_value=[0, 42, 0]),
- "b":
- parsing_ops.FixedLenFeature(
- (3, 3),
- dtypes.string,
- default_value=np.random.rand(3, 3).astype(bytes)),
- # Feature "c" is missing a default, this gap will cause failure.
- "c":
- parsing_ops.FixedLenFeature(
- (2,), dtype=dtypes.float32),
- }
-
- # Edge case where the key is there but the feature value is empty
- original = example(features=features({"c": feature()}))
- self._test(
- [original.SerializeToString()],
- input_features,
- expected_err=(errors_impl.InvalidArgumentError,
- "Feature: c \\(data type: float\\) is required"))
-
- # Standard case of missing key and value.
- self._test(
- ["", ""],
- input_features,
- expected_err=(errors_impl.InvalidArgumentError,
- "Feature: c \\(data type: float\\) is required"))
-
- def testDenseNotMatchingShapeShouldFail(self):
- original = [
- example(features=features({
- "a": float_feature([1, 1, 3]),
- })), example(features=features({
- "a": float_feature([-1, -1]),
- }))
- ]
-
- serialized = [m.SerializeToString() for m in original]
-
- self._test(
- ops.convert_to_tensor(serialized),
- {"a": parsing_ops.FixedLenFeature((1, 3), dtypes.float32)},
- expected_err=(errors_impl.InvalidArgumentError,
- "Key: a, Index: 1. Number of float values"))
-
- def testDenseDefaultNoShapeShouldFail(self):
- original = [example(features=features({"a": float_feature([1, 1, 3]),})),]
-
- serialized = [m.SerializeToString() for m in original]
-
- self._test(
- ops.convert_to_tensor(serialized),
- {"a": parsing_ops.FixedLenFeature(None, dtypes.float32)},
- expected_err=(ValueError, "Missing shape for feature a"))
-
- def testSerializedContainingSparse(self):
- original = [
- example(features=features({
- "st_c": float_feature([3, 4])
- })),
- example(features=features({
- "st_c": float_feature([]), # empty float list
- })),
- example(features=features({
- "st_d": feature(), # feature with nothing in it
- })),
- example(features=features({
- "st_c": float_feature([1, 2, -1]),
- "st_d": bytes_feature([b"hi"])
- }))
- ]
-
- serialized = [m.SerializeToString() for m in original]
-
- expected_st_c = ( # indices, values, shape
- np.array(
- [[0, 0], [0, 1], [3, 0], [3, 1], [3, 2]], dtype=np.int64), np.array(
- [3.0, 4.0, 1.0, 2.0, -1.0], dtype=np.float32), np.array(
- [4, 3], dtype=np.int64)) # batch == 2, max_elems = 3
-
- expected_st_d = ( # indices, values, shape
- np.array(
- [[3, 0]], dtype=np.int64), np.array(
- ["hi"], dtype=bytes), np.array(
- [4, 1], dtype=np.int64)) # batch == 2, max_elems = 1
-
- expected_output = {
- "st_c": expected_st_c,
- "st_d": expected_st_d,
- }
-
- self._test(
- ops.convert_to_tensor(serialized), {
- "st_c": parsing_ops.VarLenFeature(dtypes.float32),
- "st_d": parsing_ops.VarLenFeature(dtypes.string)
- },
- expected_values=expected_output)
-
- def testSerializedContainingSparseFeature(self):
- original = [
- example(features=features({
- "val": float_feature([3, 4]),
- "idx": int64_feature([5, 10])
- })),
- example(features=features({
- "val": float_feature([]), # empty float list
- "idx": int64_feature([])
- })),
- example(features=features({
- "val": feature(), # feature with nothing in it
- # missing idx feature
- })),
- example(features=features({
- "val": float_feature([1, 2, -1]),
- "idx":
- int64_feature([0, 9, 3]) # unsorted
- }))
- ]
-
- serialized = [m.SerializeToString() for m in original]
-
- expected_sp = ( # indices, values, shape
- np.array(
- [[0, 5], [0, 10], [3, 0], [3, 3], [3, 9]], dtype=np.int64),
- np.array(
- [3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32), np.array(
- [4, 13], dtype=np.int64)) # batch == 4, max_elems = 13
-
- expected_output = {"sp": expected_sp,}
-
- self._test(
- ops.convert_to_tensor(serialized),
- {"sp": parsing_ops.SparseFeature(["idx"], "val", dtypes.float32, [13])},
- expected_values=expected_output)
-
- def testSerializedContainingSparseFeatureReuse(self):
- original = [
- example(features=features({
- "val1": float_feature([3, 4]),
- "val2": float_feature([5, 6]),
- "idx": int64_feature([5, 10])
- })),
- example(features=features({
- "val1": float_feature([]), # empty float list
- "idx": int64_feature([])
- })),
- ]
-
- serialized = [m.SerializeToString() for m in original]
-
- expected_sp1 = ( # indices, values, shape
- np.array(
- [[0, 5], [0, 10]], dtype=np.int64), np.array(
- [3.0, 4.0], dtype=np.float32), np.array(
- [2, 13], dtype=np.int64)) # batch == 2, max_elems = 13
-
- expected_sp2 = ( # indices, values, shape
- np.array(
- [[0, 5], [0, 10]], dtype=np.int64), np.array(
- [5.0, 6.0], dtype=np.float32), np.array(
- [2, 7], dtype=np.int64)) # batch == 2, max_elems = 13
-
- expected_output = {
- "sp1": expected_sp1,
- "sp2": expected_sp2,
- }
-
- self._test(
- ops.convert_to_tensor(serialized), {
- "sp1":
- parsing_ops.SparseFeature("idx", "val1", dtypes.float32, 13),
- "sp2":
- parsing_ops.SparseFeature(
- "idx", "val2", dtypes.float32, size=7, already_sorted=True)
- },
- expected_values=expected_output)
-
- def testSerializedContaining3DSparseFeature(self):
- original = [
- example(features=features({
- "val": float_feature([3, 4]),
- "idx0": int64_feature([5, 10]),
- "idx1": int64_feature([0, 2]),
- })),
- example(features=features({
- "val": float_feature([]), # empty float list
- "idx0": int64_feature([]),
- "idx1": int64_feature([]),
- })),
- example(features=features({
- "val": feature(), # feature with nothing in it
- # missing idx feature
- })),
- example(features=features({
- "val": float_feature([1, 2, -1]),
- "idx0": int64_feature([0, 9, 3]), # unsorted
- "idx1": int64_feature([1, 0, 2]),
- }))
- ]
-
- serialized = [m.SerializeToString() for m in original]
-
- expected_sp = (
- # indices
- np.array(
- [[0, 5, 0], [0, 10, 2], [3, 0, 1], [3, 3, 2], [3, 9, 0]],
- dtype=np.int64),
- # values
- np.array([3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32),
- # shape batch == 4, max_elems = 13
- np.array([4, 13, 3], dtype=np.int64))
-
- expected_output = {"sp": expected_sp,}
-
- self._test(
- ops.convert_to_tensor(serialized), {
- "sp":
- parsing_ops.SparseFeature(["idx0", "idx1"], "val",
- dtypes.float32, [13, 3])
- },
- expected_values=expected_output)
-
- def testSerializedContainingDense(self):
- aname = "a"
- bname = "b*has+a:tricky_name"
- original = [
- example(features=features({
- aname: float_feature([1, 1]),
- bname: bytes_feature([b"b0_str"]),
- })), example(features=features({
- aname: float_feature([-1, -1]),
- bname: bytes_feature([b""]),
- }))
- ]
-
- serialized = [m.SerializeToString() for m in original]
-
- expected_output = {
- aname:
- np.array(
- [[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1),
- bname:
- np.array(
- ["b0_str", ""], dtype=bytes).reshape(2, 1, 1, 1, 1),
- }
-
- # No defaults, values required
- self._test(
- ops.convert_to_tensor(serialized), {
- aname:
- parsing_ops.FixedLenFeature((1, 2, 1), dtype=dtypes.float32),
- bname:
- parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string),
- },
- expected_values=expected_output)
-
- # This test is identical as the previous one except
- # for the creation of 'serialized'.
- def testSerializedContainingDenseWithConcat(self):
- aname = "a"
- bname = "b*has+a:tricky_name"
- # TODO(lew): Feature appearing twice should be an error in future.
- original = [
- (example(features=features({
- aname: float_feature([10, 10]),
- })), example(features=features({
- aname: float_feature([1, 1]),
- bname: bytes_feature([b"b0_str"]),
- }))),
- (
- example(features=features({
- bname: bytes_feature([b"b100"]),
- })),
- example(features=features({
- aname: float_feature([-1, -1]),
- bname: bytes_feature([b"b1"]),
- })),),
- ]
-
- serialized = [
- m.SerializeToString() + n.SerializeToString() for (m, n) in original
- ]
-
- expected_output = {
- aname:
- np.array(
- [[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1),
- bname:
- np.array(
- ["b0_str", "b1"], dtype=bytes).reshape(2, 1, 1, 1, 1),
- }
-
- # No defaults, values required
- self._test(
- ops.convert_to_tensor(serialized), {
- aname:
- parsing_ops.FixedLenFeature((1, 2, 1), dtype=dtypes.float32),
- bname:
- parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string),
- },
- expected_values=expected_output)
-
- def testSerializedContainingDenseScalar(self):
- original = [
- example(features=features({
- "a": float_feature([1]),
- })), example(features=features({}))
- ]
-
- serialized = [m.SerializeToString() for m in original]
-
- expected_output = {
- "a":
- np.array(
- [[1], [-1]], dtype=np.float32) # 2x1 (column vector)
- }
-
- self._test(
- ops.convert_to_tensor(serialized), {
- "a":
- parsing_ops.FixedLenFeature(
- (1,), dtype=dtypes.float32, default_value=-1),
- },
- expected_values=expected_output)
-
- def testSerializedContainingDenseWithDefaults(self):
- original = [
- example(features=features({
- "a": float_feature([1, 1]),
- })),
- example(features=features({
- "b": bytes_feature([b"b1"]),
- })),
- example(features=features({
- "b": feature()
- })),
- ]
-
- serialized = [m.SerializeToString() for m in original]
-
- expected_output = {
- "a":
- np.array(
- [[1, 1], [3, -3], [3, -3]], dtype=np.float32).reshape(3, 1, 2,
- 1),
- "b":
- np.array(
- ["tmp_str", "b1", "tmp_str"], dtype=bytes).reshape(3, 1, 1, 1,
- 1),
- }
-
- self._test(
- ops.convert_to_tensor(serialized), {
- "a":
- parsing_ops.FixedLenFeature(
- (1, 2, 1), dtype=dtypes.float32, default_value=[3.0, -3.0]),
- "b":
- parsing_ops.FixedLenFeature(
- (1, 1, 1, 1), dtype=dtypes.string, default_value="tmp_str"),
- },
- expected_values=expected_output)
-
- def testSerializedContainingSparseAndSparseFeatureAndDenseWithNoDefault(self):
- expected_st_a = ( # indices, values, shape
- np.empty(
- (0, 2), dtype=np.int64), # indices
- np.empty(
- (0,), dtype=np.int64), # sp_a is DT_INT64
- np.array(
- [2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
- expected_sp = ( # indices, values, shape
- np.array(
- [[0, 0], [0, 3], [1, 7]], dtype=np.int64), np.array(
- ["a", "b", "c"], dtype="|S"), np.array(
- [2, 13], dtype=np.int64)) # batch == 4, max_elems = 13
-
- original = [
- example(features=features({
- "c": float_feature([3, 4]),
- "val": bytes_feature([b"a", b"b"]),
- "idx": int64_feature([0, 3])
- })), example(features=features({
- "c": float_feature([1, 2]),
- "val": bytes_feature([b"c"]),
- "idx": int64_feature([7])
- }))
- ]
-
- serialized = [m.SerializeToString() for m in original]
-
- a_default = [1, 2, 3]
- b_default = np.random.rand(3, 3).astype(bytes)
- expected_output = {
- "st_a": expected_st_a,
- "sp": expected_sp,
- "a": np.array(2 * [[a_default]]),
- "b": np.array(2 * [b_default]),
- "c": np.array(
- [[3, 4], [1, 2]], dtype=np.float32),
- }
-
- self._test(
- ops.convert_to_tensor(serialized),
- {
- "st_a":
- parsing_ops.VarLenFeature(dtypes.int64),
- "sp":
- parsing_ops.SparseFeature("idx", "val", dtypes.string, 13),
- "a":
- parsing_ops.FixedLenFeature(
- (1, 3), dtypes.int64, default_value=a_default),
- "b":
- parsing_ops.FixedLenFeature(
- (3, 3), dtypes.string, default_value=b_default),
- # Feature "c" must be provided, since it has no default_value.
- "c":
- parsing_ops.FixedLenFeature((2,), dtypes.float32),
- },
- expected_values=expected_output)
-
- def testSerializedContainingSparseAndSparseFeatureWithReuse(self):
- expected_idx = ( # indices, values, shape
- np.array(
- [[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.int64),
- np.array([0, 3, 7, 1]), np.array(
- [2, 2], dtype=np.int64)) # batch == 4, max_elems = 2
-
- expected_sp = ( # indices, values, shape
- np.array(
- [[0, 0], [0, 3], [1, 1], [1, 7]], dtype=np.int64), np.array(
- ["a", "b", "d", "c"], dtype="|S"), np.array(
- [2, 13], dtype=np.int64)) # batch == 4, max_elems = 13
-
- original = [
- example(features=features({
- "val": bytes_feature([b"a", b"b"]),
- "idx": int64_feature([0, 3])
- })), example(features=features({
- "val": bytes_feature([b"c", b"d"]),
- "idx": int64_feature([7, 1])
- }))
- ]
-
- serialized = [m.SerializeToString() for m in original]
-
- expected_output = {
- "idx": expected_idx,
- "sp": expected_sp,
- }
-
- self._test(
- ops.convert_to_tensor(serialized), {
- "idx":
- parsing_ops.VarLenFeature(dtypes.int64),
- "sp":
- parsing_ops.SparseFeature(["idx"], "val", dtypes.string, [13]),
- },
- expected_values=expected_output)
-
- def _testSerializedContainingVarLenDenseLargerBatch(self, batch_size):
- # During parsing, data read from the serialized proto is stored in buffers.
- # For small batch sizes, a buffer will contain one minibatch entry.
- # For larger batch sizes, a buffer may contain several minibatch
- # entries. This test identified a bug where the code that copied
- # data out of the buffers and into the output tensors assumed each
- # buffer only contained one minibatch entry. The bug has since been fixed.
- truth_int = [i for i in range(batch_size)]
- truth_str = [[("foo%d" % i).encode(), ("bar%d" % i).encode()]
- for i in range(batch_size)]
-
- expected_str = copy.deepcopy(truth_str)
-
- # Delete some intermediate entries
- for i in range(batch_size):
- col = 1
- if np.random.rand() < 0.25:
- # w.p. 25%, drop out the second entry
- expected_str[i][col] = b"default"
- col -= 1
- truth_str[i].pop()
- if np.random.rand() < 0.25:
- # w.p. 25%, drop out the second entry (possibly again)
- expected_str[i][col] = b"default"
- truth_str[i].pop()
-
- expected_output = {
- # Batch size batch_size, 1 time step.
- "a": np.array(truth_int, dtype=np.int64).reshape(batch_size, 1),
- # Batch size batch_size, 2 time steps.
- "b": np.array(expected_str, dtype="|S").reshape(batch_size, 2),
- }
-
- original = [
- example(features=features(
- {"a": int64_feature([truth_int[i]]),
- "b": bytes_feature(truth_str[i])}))
- for i in range(batch_size)
- ]
-
- serialized = [m.SerializeToString() for m in original]
-
- self._test(
- ops.convert_to_tensor(serialized, dtype=dtypes.string), {
- "a":
- parsing_ops.FixedLenSequenceFeature(
- shape=(),
- dtype=dtypes.int64,
- allow_missing=True,
- default_value=-1),
- "b":
- parsing_ops.FixedLenSequenceFeature(
- shape=[],
- dtype=dtypes.string,
- allow_missing=True,
- default_value="default"),
- },
- expected_values=expected_output)
-
- def testSerializedContainingVarLenDenseLargerBatch(self):
- np.random.seed(3456)
- for batch_size in (1, 10, 20, 100, 256):
- self._testSerializedContainingVarLenDenseLargerBatch(batch_size)
-
- def testSerializedContainingVarLenDense(self):
- aname = "a"
- bname = "b"
- cname = "c"
- dname = "d"
- original = [
- example(features=features({
- cname: int64_feature([2]),
- })),
- example(features=features({
- aname: float_feature([1, 1]),
- bname: bytes_feature([b"b0_str", b"b1_str"]),
- })),
- example(features=features({
- aname: float_feature([-1, -1, 2, 2]),
- bname: bytes_feature([b"b1"]),
- })),
- example(features=features({
- aname: float_feature([]),
- cname: int64_feature([3]),
- })),
- ]
-
- serialized = [m.SerializeToString() for m in original]
-
- expected_output = {
- aname:
- np.array(
- [
- [0, 0, 0, 0],
- [1, 1, 0, 0],
- [-1, -1, 2, 2],
- [0, 0, 0, 0],
- ],
- dtype=np.float32).reshape(4, 2, 2, 1),
- bname:
- np.array(
- [["", ""], ["b0_str", "b1_str"], ["b1", ""], ["", ""]],
- dtype=bytes).reshape(4, 2, 1, 1, 1),
- cname:
- np.array([2, 0, 0, 3], dtype=np.int64).reshape(4, 1),
- dname:
- np.empty(shape=(4, 0), dtype=bytes),
- }
-
- self._test(
- ops.convert_to_tensor(serialized), {
- aname:
- parsing_ops.FixedLenSequenceFeature(
- (2, 1), dtype=dtypes.float32, allow_missing=True),
- bname:
- parsing_ops.FixedLenSequenceFeature(
- (1, 1, 1), dtype=dtypes.string, allow_missing=True),
- cname:
- parsing_ops.FixedLenSequenceFeature(
- shape=[], dtype=dtypes.int64, allow_missing=True),
- dname:
- parsing_ops.FixedLenSequenceFeature(
- shape=[], dtype=dtypes.string, allow_missing=True),
- },
- expected_values=expected_output)
-
- # Test with padding values.
- expected_output_custom_padding = dict(expected_output)
- expected_output_custom_padding[aname] = np.array(
- [
- [-2, -2, -2, -2],
- [1, 1, -2, -2],
- [-1, -1, 2, 2],
- [-2, -2, -2, -2],
- ],
- dtype=np.float32).reshape(4, 2, 2, 1)
-
- self._test(
- ops.convert_to_tensor(serialized), {
- aname:
- parsing_ops.FixedLenSequenceFeature(
- (2, 1),
- dtype=dtypes.float32,
- allow_missing=True,
- default_value=-2.0),
- bname:
- parsing_ops.FixedLenSequenceFeature(
- (1, 1, 1), dtype=dtypes.string, allow_missing=True),
- cname:
- parsing_ops.FixedLenSequenceFeature(
- shape=[], dtype=dtypes.int64, allow_missing=True),
- dname:
- parsing_ops.FixedLenSequenceFeature(
- shape=[], dtype=dtypes.string, allow_missing=True),
- }, expected_output_custom_padding)
-
- # Change number of required values so the inputs are not a
- # multiple of this size.
- self._test(
- ops.convert_to_tensor(serialized), {
- aname:
- parsing_ops.FixedLenSequenceFeature(
- (2, 1), dtype=dtypes.float32, allow_missing=True),
- bname:
- parsing_ops.FixedLenSequenceFeature(
- (2, 1, 1), dtype=dtypes.string, allow_missing=True),
- },
- expected_err=(
- errors_impl.OpError, "Key: b, Index: 2. "
- "Number of bytes values is not a multiple of stride length."))
-
- self._test(
- ops.convert_to_tensor(serialized), {
- aname:
- parsing_ops.FixedLenSequenceFeature(
- (2, 1),
- dtype=dtypes.float32,
- allow_missing=True,
- default_value=[]),
- bname:
- parsing_ops.FixedLenSequenceFeature(
- (2, 1, 1), dtype=dtypes.string, allow_missing=True),
- },
- expected_err=(ValueError,
- "Cannot reshape a tensor with 0 elements to shape"))
-
- self._test(
- ops.convert_to_tensor(serialized), {
- aname:
- parsing_ops.FixedLenFeature((None, 2, 1), dtype=dtypes.float32),
- bname:
- parsing_ops.FixedLenSequenceFeature(
- (2, 1, 1), dtype=dtypes.string, allow_missing=True),
- },
- expected_err=(ValueError,
- "First dimension of shape for feature a unknown. "
- "Consider using FixedLenSequenceFeature."))
-
- self._test(
- ops.convert_to_tensor(serialized), {
- cname:
- parsing_ops.FixedLenFeature(
- (1, None), dtype=dtypes.int64, default_value=[[1]]),
- },
- expected_err=(ValueError,
- "All dimensions of shape for feature c need to be known "
- r"but received \(1, None\)."))
-
- self._test(
- ops.convert_to_tensor(serialized), {
- aname:
- parsing_ops.FixedLenSequenceFeature(
- (2, 1), dtype=dtypes.float32, allow_missing=True),
- bname:
- parsing_ops.FixedLenSequenceFeature(
- (1, 1, 1), dtype=dtypes.string, allow_missing=True),
- cname:
- parsing_ops.FixedLenSequenceFeature(
- shape=[], dtype=dtypes.int64, allow_missing=False),
- dname:
- parsing_ops.FixedLenSequenceFeature(
- shape=[], dtype=dtypes.string, allow_missing=True),
- },
- expected_err=(ValueError,
- "Unsupported: FixedLenSequenceFeature requires "
- "allow_missing to be True."))
-
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
deleted file mode 100644
index 7a6a7a709a..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
+++ /dev/null
@@ -1,948 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for prefetching_ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import threading
-
-from tensorflow.contrib.data.python.ops import prefetching_ops
-from tensorflow.core.protobuf import config_pb2
-from tensorflow.python.compat import compat
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import function
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.framework import test_util
-from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.platform import test
-
-
-class PrefetchingKernelsOpsTest(test_base.DatasetTestBase):
-
- def setUp(self):
- self._event = threading.Event()
-
- def _create_ds_and_iterator(self, device0, initializable=False):
-
- def gen():
- for i in range(1, 10):
- yield [float(i)]
- if i == 6:
- self._event.set()
-
- with ops.device(device0):
- ds = dataset_ops.Dataset.from_generator(gen, (dtypes.float32))
- if initializable:
- ds_iterator = ds.make_initializable_iterator()
- else:
- ds_iterator = ds.make_one_shot_iterator()
- return (ds, ds_iterator)
-
- def _create_ops(self, ds, ds_iterator, buffer_name, device0, device1):
- ds_iterator_handle = ds_iterator.string_handle()
-
- @function.Defun(dtypes.string)
- def _remote_fn(h):
- remote_iterator = iterator_ops.Iterator.from_string_handle(
- h, ds.output_types, ds.output_shapes)
- return remote_iterator.get_next()
-
- target = constant_op.constant(device0)
- with ops.device(device1):
- buffer_resource_handle = prefetching_ops.function_buffering_resource(
- f=_remote_fn,
- output_types=[dtypes.float32],
- target_device=target,
- string_arg=ds_iterator_handle,
- buffer_size=3,
- shared_name=buffer_name)
-
- with ops.device(device1):
- prefetch_op = prefetching_ops.function_buffering_resource_get_next(
- function_buffer_resource=buffer_resource_handle,
- output_types=[dtypes.float32])
- reset_op = prefetching_ops.function_buffering_resource_reset(
- function_buffer_resource=buffer_resource_handle)
- destroy_op = resource_variable_ops.destroy_resource_op(
- buffer_resource_handle, ignore_lookup_error=True)
-
- return (prefetch_op, reset_op, destroy_op)
-
- def _prefetch_fn_helper_one_shot(self, buffer_name, device0, device1):
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
-
- ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=False)
- prefetch_op, _, destroy_op = self._create_ops(ds, ds_iterator, buffer_name,
- device0, device1)
-
- with self.test_session(config=worker_config) as sess:
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [1.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [2.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [3.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [4.0])
- self._event.wait()
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [5.0])
- sess.run(destroy_op)
-
- def testSameDeviceCPU(self):
- self._prefetch_fn_helper_one_shot("same_device_cpu",
- "/job:localhost/replica:0/task:0/cpu:0",
- "/job:localhost/replica:0/task:0/cpu:0")
-
- def testDifferentDeviceCPU(self):
- self._prefetch_fn_helper_one_shot("diff_device_cpu",
- "/job:localhost/replica:0/task:0/cpu:0",
- "/job:localhost/replica:0/task:0/cpu:1")
-
- def testDifferentDeviceCPUGPU(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- self._prefetch_fn_helper_one_shot("cpu_gpu",
- "/job:localhost/replica:0/task:0/cpu:0",
- "/job:localhost/replica:0/task:0/gpu:0")
-
- def testReinitialization(self):
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
-
- device0 = "/job:localhost/replica:0/task:0/cpu:0"
- device1 = "/job:localhost/replica:0/task:0/cpu:1"
- ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True)
- prefetch_op, reset_op, destroy_op = self._create_ops(
- ds, ds_iterator, "reinit", device0, device1)
-
- with self.test_session(config=worker_config) as sess:
- sess.run(ds_iterator.initializer)
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [1.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [2.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [3.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [4.0])
- self._event.wait()
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [5.0])
- # Lets reset the function buffering resource and reinitialize the
- # iterator. Should be able to go through this again.
- self._event.clear()
- sess.run(reset_op)
- sess.run(ds_iterator.initializer)
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [1.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [2.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [3.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [4.0])
- self._event.wait()
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [5.0])
- sess.run(destroy_op)
-
- def testReinitializationOutOfRange(self):
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
-
- device0 = "/job:localhost/replica:0/task:0/cpu:0"
- device1 = "/job:localhost/replica:0/task:0/cpu:1"
- ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True)
- prefetch_op, reset_op, destroy_op = self._create_ops(
- ds, ds_iterator, "reinit", device0, device1)
-
- with self.test_session(config=worker_config) as sess:
- sess.run(ds_iterator.initializer)
- for i in range(1, 10):
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [float(i)])
- # Try fetching after its over twice to test out end of sequence.
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(prefetch_op)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(prefetch_op)
-
- # Now reset everything and try it out again.
- self._event.clear()
- sess.run(reset_op)
- sess.run(ds_iterator.initializer)
- for i in range(1, 10):
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [float(i)])
- # Try fetching after its over twice to test out end of sequence.
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(prefetch_op)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(prefetch_op)
-
- sess.run(destroy_op)
-
- def testStringsGPU(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- device0 = "/job:localhost/replica:0/task:0/cpu:0"
- device1 = "/job:localhost/replica:0/task:0/gpu:0"
-
- ds = dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"])
- ds_iterator = ds.make_one_shot_iterator()
- ds_iterator_handle = ds_iterator.string_handle()
-
- @function.Defun(dtypes.string)
- def _remote_fn(h):
- remote_iterator = iterator_ops.Iterator.from_string_handle(
- h, ds.output_types, ds.output_shapes)
- return remote_iterator.get_next()
-
- target = constant_op.constant(device0)
- with ops.device(device1):
- buffer_resource_handle = prefetching_ops.function_buffering_resource(
- f=_remote_fn,
- output_types=[dtypes.string],
- target_device=target,
- string_arg=ds_iterator_handle,
- buffer_size=3,
- shared_name="strings")
-
- with ops.device(device1):
- prefetch_op = prefetching_ops.function_buffering_resource_get_next(
- function_buffer_resource=buffer_resource_handle,
- output_types=[dtypes.string])
- destroy_op = resource_variable_ops.destroy_resource_op(
- buffer_resource_handle, ignore_lookup_error=True)
-
- with self.cached_session() as sess:
- self.assertEqual([b"a"], sess.run(prefetch_op))
- self.assertEqual([b"b"], sess.run(prefetch_op))
- self.assertEqual([b"c"], sess.run(prefetch_op))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(prefetch_op)
-
- sess.run(destroy_op)
-
-
-class PrefetchToDeviceTest(test_base.DatasetTestBase):
-
- def testPrefetchToDevice(self):
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device("/cpu:1"))
-
- # NOTE(mrry): This device block creates the "host" dataset and iterator on
- # /cpu:0, and ensures that the prefetching is across devices. In typical use
- # this would not be necessary, because the GPU device would not support any
- # of the dataset-related ops.
- with ops.device("/cpu:0"):
- iterator = device_dataset.make_one_shot_iterator()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- next_element = iterator.get_next()
- self.assertEqual(dtypes.int64, next_element.dtype)
- self.assertEqual([], next_element.shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchToSameDevice(self):
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device(
- "/job:localhost/replica:0/task:0/device:CPU:0"))
-
- # NOTE(mrry): This device block creates the "host" dataset and iterator on
- # /cpu:0, and ensures that the prefetching is across devices. In typical use
- # this would not be necessary, because the GPU device would not support any
- # of the dataset-related ops.
- with ops.device("/cpu:0"):
- iterator = device_dataset.make_one_shot_iterator()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- next_element = iterator.get_next()
- self.assertEqual(dtypes.int64, next_element.dtype)
- self.assertEqual([], next_element.shape)
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchDictToDevice(self):
- host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device("/cpu:1"))
-
- # NOTE(mrry): This device block creates the "host" dataset and iterator on
- # /cpu:0, and ensures that the prefetching is across devices. In typical use
- # this would not be necessary, because the GPU device would not support any
- # of the dataset-related ops.
- with ops.device("/cpu:0"):
- iterator = device_dataset.make_one_shot_iterator()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- next_element = iterator.get_next()
- self.assertEqual(dtypes.int64, next_element["a"].dtype)
- self.assertEqual([], next_element["a"].shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- for i in range(10):
- self.assertEqual({"a": i}, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchSparseTensorsToDevice(self):
- def make_tensor(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0, 0]], values=(i*[1]), dense_shape=[2, 2])
- host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)
-
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device("/cpu:1"))
-
- # NOTE(mrry): This device block creates the "host" dataset and iterator on
- # /cpu:0, and ensures that the prefetching is across devices. In typical use
- # this would not be necessary, because the GPU device would not support any
- # of the dataset-related ops.
- with ops.device("/cpu:0"):
- iterator = device_dataset.make_one_shot_iterator()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- next_element = iterator.get_next()
- self.assertEqual(dtypes.int64, next_element.dtype)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- for i in range(10):
- actual = sess.run(next_element)
- self.assertAllEqual([i], actual.values)
- self.assertAllEqual([[0, 0]], actual.indices)
- self.assertAllEqual([2, 2], actual.dense_shape)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchToDeviceGpu(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device("/gpu:0"))
-
- iterator = device_dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchToDeviceWithReInit(self):
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device("/cpu:1"))
-
- # NOTE(mrry): This device block creates the "host" dataset and iterator on
- # /cpu:0, and ensures that the prefetching is across devices. In typical use
- # this would not be necessary, because the GPU device would not support any
- # of the dataset-related ops.
- with ops.device("/cpu:0"):
- iterator = device_dataset.make_initializable_iterator()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- next_element = iterator.get_next()
- self.assertEqual(dtypes.int64, next_element.dtype)
- self.assertEqual([], next_element.shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- sess.run(iterator.initializer)
- for i in range(5):
- self.assertEqual(i, sess.run(next_element))
- sess.run(iterator.initializer)
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchToDeviceGpuWithReInit(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device("/gpu:0"))
-
- iterator = device_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for i in range(5):
- self.assertEqual(i, sess.run(next_element))
- sess.run(iterator.initializer)
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
-
-class CopyToDeviceTest(test_base.DatasetTestBase):
-
- def testCopyToDevice(self):
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/cpu:1"))
-
- with ops.device("/cpu:1"):
- iterator = device_dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- self.assertEqual(dtypes.int64, next_element.dtype)
- self.assertEqual([], next_element.shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToDeviceInt32(self):
- host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/cpu:1"))
-
- with ops.device("/cpu:1"):
- iterator = device_dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- self.assertEqual(dtypes.int32, next_element.dtype)
- self.assertEqual((4,), next_element.shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToSameDevice(self):
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/cpu:0"))
-
- with ops.device("/cpu:0"):
- iterator = device_dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- self.assertEqual(dtypes.int64, next_element.dtype)
- self.assertEqual([], next_element.shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToDeviceWithPrefetch(self):
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
-
- with ops.device("/cpu:1"):
- iterator = device_dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- self.assertEqual(dtypes.int64, next_element.dtype)
- self.assertEqual([], next_element.shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyDictToDevice(self):
- host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/cpu:1"))
-
- with ops.device("/cpu:1"):
- iterator = device_dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- self.assertEqual(dtypes.int64, next_element["a"].dtype)
- self.assertEqual([], next_element["a"].shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- for i in range(10):
- self.assertEqual({"a": i}, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyDictToDeviceWithPrefetch(self):
- host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
-
- with ops.device("/cpu:1"):
- iterator = device_dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- self.assertEqual(dtypes.int64, next_element["a"].dtype)
- self.assertEqual([], next_element["a"].shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- for i in range(10):
- self.assertEqual({"a": i}, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopySparseTensorsToDevice(self):
-
- def make_tensor(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0, 0]], values=(i * [1]), dense_shape=[2, 2])
-
- host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)
-
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/cpu:1"))
-
- with ops.device("/cpu:1"):
- iterator = device_dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- self.assertEqual(dtypes.int64, next_element.dtype)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- for i in range(10):
- actual = sess.run(next_element)
- self.assertAllEqual([i], actual.values)
- self.assertAllEqual([[0, 0]], actual.indices)
- self.assertAllEqual([2, 2], actual.dense_shape)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopySparseTensorsToDeviceWithPrefetch(self):
-
- def make_tensor(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0, 0]], values=(i * [1]), dense_shape=[2, 2])
-
- host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)
-
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
-
- with ops.device("/cpu:1"):
- iterator = device_dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- self.assertEqual(dtypes.int64, next_element.dtype)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- for i in range(10):
- actual = sess.run(next_element)
- self.assertAllEqual([i], actual.values)
- self.assertAllEqual([[0, 0]], actual.indices)
- self.assertAllEqual([2, 2], actual.dense_shape)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToDeviceGpu(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/gpu:0"))
-
- with ops.device("/gpu:0"):
- iterator = device_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToDeviceGpuWithPrefetch(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/gpu:0")).prefetch(1)
-
- with ops.device("/gpu:0"):
- iterator = device_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToDeviceGpuInt32(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/gpu:0"))
-
- with ops.device("/gpu:0"):
- iterator = device_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToDeviceGpuInt32AndPrefetch(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/gpu:0")).prefetch(1)
-
- with ops.device("/gpu:0"):
- iterator = device_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToDeviceGpuStrings(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.from_tensors(["a", "b", "c"])
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/gpu:0"))
-
- with ops.device("/gpu:0"):
- iterator = device_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToDeviceGpuStringsAndPrefetch(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.from_tensors(["a", "b", "c"])
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/gpu:0"))
-
- with ops.device("/gpu:0"):
- iterator = device_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToDevicePingPongCPUGPU(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- with compat.forward_compatibility_horizon(2018, 8, 4):
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/gpu:0", source_device="/cpu:0"))
- back_to_cpu_dataset = device_dataset.apply(
- prefetching_ops.copy_to_device("/cpu:0", source_device="/gpu:0"))
-
- with ops.device("/cpu:0"):
- iterator = back_to_cpu_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToDeviceWithReInit(self):
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/cpu:1"))
-
- with ops.device("/cpu:1"):
- iterator = device_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- self.assertEqual(dtypes.int64, next_element.dtype)
- self.assertEqual([], next_element.shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- sess.run(iterator.initializer)
- for i in range(5):
- self.assertEqual(i, sess.run(next_element))
- sess.run(iterator.initializer)
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToDeviceWithReInitAndPrefetch(self):
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
-
- with ops.device("/cpu:1"):
- iterator = device_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- self.assertEqual(dtypes.int64, next_element.dtype)
- self.assertEqual([], next_element.shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- sess.run(iterator.initializer)
- for i in range(5):
- self.assertEqual(i, sess.run(next_element))
- sess.run(iterator.initializer)
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToDeviceGpuWithReInit(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/gpu:0"))
-
- with ops.device("/gpu:0"):
- iterator = device_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for i in range(5):
- self.assertEqual(i, sess.run(next_element))
- sess.run(iterator.initializer)
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testCopyToDeviceGpuWithReInitAndPrefetch(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/gpu:0")).prefetch(1)
-
- with ops.device("/gpu:0"):
- iterator = device_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for i in range(5):
- self.assertEqual(i, sess.run(next_element))
- sess.run(iterator.initializer)
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testIteratorGetNextAsOptionalOnGPU(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.range(3)
- device_dataset = host_dataset.apply(
- prefetching_ops.copy_to_device("/gpu:0"))
- with ops.device("/gpu:0"):
- iterator = device_dataset.make_initializable_iterator()
- next_elem = iterator_ops.get_next_as_optional(iterator)
- elem_has_value_t = next_elem.has_value()
- elem_value_t = next_elem.get_value()
-
- with self.cached_session() as sess:
- # Before initializing the iterator, evaluating the optional fails with
- # a FailedPreconditionError.
- with self.assertRaises(errors.FailedPreconditionError):
- sess.run(elem_has_value_t)
- with self.assertRaises(errors.FailedPreconditionError):
- sess.run(elem_value_t)
-
- # For each element of the dataset, assert that the optional evaluates to
- # the expected value.
- sess.run(iterator.initializer)
- for i in range(3):
- elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
- self.assertTrue(elem_has_value)
- self.assertEqual(i, elem_value)
-
- # After exhausting the iterator, `next_elem.has_value()` will evaluate to
- # false, and attempting to get the value will fail.
- for _ in range(2):
- self.assertFalse(sess.run(elem_has_value_t))
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(elem_value_t)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
deleted file mode 100644
index 2e901587f4..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
+++ /dev/null
@@ -1,78 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Test RangeDataset."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.ops import counter
-from tensorflow.contrib.data.python.ops import enumerate_ops
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.platform import test
-
-
-class RangeDatasetTest(test_base.DatasetTestBase):
-
- def testEnumerateDataset(self):
- components = (["a", "b"], [1, 2], [37.0, 38])
- start = constant_op.constant(20, dtype=dtypes.int64)
-
- iterator = (dataset_ops.Dataset.from_tensor_slices(components).apply(
- enumerate_ops.enumerate_dataset(start)).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- self.assertEqual(dtypes.int64, get_next[0].dtype)
- self.assertEqual((), get_next[0].shape)
- self.assertEqual([tensor_shape.TensorShape([])] * 3,
- [t.shape for t in get_next[1]])
-
- with self.cached_session() as sess:
- sess.run(init_op)
- self.assertEqual((20, (b"a", 1, 37.0)), sess.run(get_next))
- self.assertEqual((21, (b"b", 2, 38.0)), sess.run(get_next))
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testCounter(self):
- """Test dataset construction using `count`."""
- iterator = (counter.Counter(start=3, step=4)
- .make_one_shot_iterator())
- get_next = iterator.get_next()
- self.assertEqual([], get_next.shape.as_list())
- self.assertEqual(dtypes.int64, get_next.dtype)
-
- negative_iterator = (counter.Counter(start=0, step=-1)
- .make_one_shot_iterator())
- negative_get_next = negative_iterator.get_next()
-
- with self.cached_session() as sess:
- self.assertEqual(3, sess.run(get_next))
- self.assertEqual(3 + 4, sess.run(get_next))
- self.assertEqual(3 + 2 * 4, sess.run(get_next))
-
- self.assertEqual(0, sess.run(negative_get_next))
- self.assertEqual(-1, sess.run(negative_get_next))
- self.assertEqual(-2, sess.run(negative_get_next))
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
deleted file mode 100644
index 66ed547b6d..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ /dev/null
@@ -1,1083 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import gzip
-import os
-import zlib
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
-from tensorflow.contrib.data.python.ops import readers
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import readers as core_readers
-from tensorflow.python.data.util import nest
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import parsing_ops
-from tensorflow.python.ops import string_ops
-from tensorflow.python.platform import test
-
-
-class ReadBatchFeaturesTest(
- reader_dataset_ops_test_base.ReadBatchFeaturesTestBase):
-
- def testRead(self):
- for batch_size in [1, 2]:
- for num_epochs in [1, 10]:
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- # Basic test: read from file 0.
- self.outputs = self.make_batch_feature(
- filenames=self.test_filenames[0],
- label_key="label",
- num_epochs=num_epochs,
- batch_size=batch_size).make_one_shot_iterator().get_next()
- self.verify_records(
- sess,
- batch_size,
- 0,
- num_epochs=num_epochs,
- label_key_provided=True)
- with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess, label_key_provided=True)
-
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- # Basic test: read from file 1.
- self.outputs = self.make_batch_feature(
- filenames=self.test_filenames[1],
- label_key="label",
- num_epochs=num_epochs,
- batch_size=batch_size).make_one_shot_iterator().get_next()
- self.verify_records(
- sess,
- batch_size,
- 1,
- num_epochs=num_epochs,
- label_key_provided=True)
- with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess, label_key_provided=True)
-
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- # Basic test: read from both files.
- self.outputs = self.make_batch_feature(
- filenames=self.test_filenames,
- label_key="label",
- num_epochs=num_epochs,
- batch_size=batch_size).make_one_shot_iterator().get_next()
- self.verify_records(
- sess,
- batch_size,
- num_epochs=num_epochs,
- label_key_provided=True)
- with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess, label_key_provided=True)
-
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- # Basic test: read from both files.
- self.outputs = self.make_batch_feature(
- filenames=self.test_filenames,
- num_epochs=num_epochs,
- batch_size=batch_size).make_one_shot_iterator().get_next()
- self.verify_records(sess, batch_size, num_epochs=num_epochs)
- with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess)
-
- def testReadWithEquivalentDataset(self):
- features = {
- "file": parsing_ops.FixedLenFeature([], dtypes.int64),
- "record": parsing_ops.FixedLenFeature([], dtypes.int64),
- }
- dataset = (
- core_readers.TFRecordDataset(self.test_filenames)
- .map(lambda x: parsing_ops.parse_single_example(x, features))
- .repeat(10).batch(2))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
- range(self._num_files), 2, 10):
- actual_batch = sess.run(next_element)
- self.assertAllEqual(file_batch, actual_batch["file"])
- self.assertAllEqual(record_batch, actual_batch["record"])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testReadWithFusedShuffleRepeatDataset(self):
- num_epochs = 5
- total_records = num_epochs * self._num_records
- for batch_size in [1, 2]:
- # Test that shuffling with same seed produces the same result.
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- outputs1 = self.make_batch_feature(
- filenames=self.test_filenames[0],
- num_epochs=num_epochs,
- batch_size=batch_size,
- shuffle=True,
- shuffle_seed=5).make_one_shot_iterator().get_next()
- outputs2 = self.make_batch_feature(
- filenames=self.test_filenames[0],
- num_epochs=num_epochs,
- batch_size=batch_size,
- shuffle=True,
- shuffle_seed=5).make_one_shot_iterator().get_next()
- for _ in range(total_records // batch_size):
- batch1 = self._run_actual_batch(outputs1, sess)
- batch2 = self._run_actual_batch(outputs2, sess)
- for i in range(len(batch1)):
- self.assertAllEqual(batch1[i], batch2[i])
-
- # Test that shuffling with different seeds produces a different order.
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- outputs1 = self.make_batch_feature(
- filenames=self.test_filenames[0],
- num_epochs=num_epochs,
- batch_size=batch_size,
- shuffle=True,
- shuffle_seed=5).make_one_shot_iterator().get_next()
- outputs2 = self.make_batch_feature(
- filenames=self.test_filenames[0],
- num_epochs=num_epochs,
- batch_size=batch_size,
- shuffle=True,
- shuffle_seed=15).make_one_shot_iterator().get_next()
- all_equal = True
- for _ in range(total_records // batch_size):
- batch1 = self._run_actual_batch(outputs1, sess)
- batch2 = self._run_actual_batch(outputs2, sess)
- for i in range(len(batch1)):
- all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
- self.assertFalse(all_equal)
-
- def testParallelReadersAndParsers(self):
- num_epochs = 5
- for batch_size in [1, 2]:
- for reader_num_threads in [2, 4]:
- for parser_num_threads in [2, 4]:
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- self.outputs = self.make_batch_feature(
- filenames=self.test_filenames,
- label_key="label",
- num_epochs=num_epochs,
- batch_size=batch_size,
- reader_num_threads=reader_num_threads,
- parser_num_threads=parser_num_threads).make_one_shot_iterator(
- ).get_next()
- self.verify_records(
- sess,
- batch_size,
- num_epochs=num_epochs,
- label_key_provided=True,
- interleave_cycle_length=reader_num_threads)
- with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess, label_key_provided=True)
-
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- self.outputs = self.make_batch_feature(
- filenames=self.test_filenames,
- num_epochs=num_epochs,
- batch_size=batch_size,
- reader_num_threads=reader_num_threads,
- parser_num_threads=parser_num_threads).make_one_shot_iterator(
- ).get_next()
- self.verify_records(
- sess,
- batch_size,
- num_epochs=num_epochs,
- interleave_cycle_length=reader_num_threads)
- with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess)
-
- def testDropFinalBatch(self):
- for batch_size in [1, 2]:
- for num_epochs in [1, 10]:
- with ops.Graph().as_default():
- # Basic test: read from file 0.
- outputs = self.make_batch_feature(
- filenames=self.test_filenames[0],
- label_key="label",
- num_epochs=num_epochs,
- batch_size=batch_size,
- drop_final_batch=True).make_one_shot_iterator().get_next()
- for tensor in nest.flatten(outputs):
- if isinstance(tensor, ops.Tensor): # Guard against SparseTensor.
- self.assertEqual(tensor.shape[0], batch_size)
-
- def testIndefiniteRepeatShapeInference(self):
- dataset = self.make_batch_feature(
- filenames=self.test_filenames[0],
- label_key="label",
- num_epochs=None,
- batch_size=32)
- for shape, clazz in zip(nest.flatten(dataset.output_shapes),
- nest.flatten(dataset.output_classes)):
- if issubclass(clazz, ops.Tensor):
- self.assertEqual(32, shape[0])
-
-
-class MakeCsvDatasetTest(test_base.DatasetTestBase):
-
- def _make_csv_dataset(self, filenames, batch_size, num_epochs=1, **kwargs):
- return readers.make_csv_dataset(
- filenames, batch_size=batch_size, num_epochs=num_epochs, **kwargs)
-
- def _setup_files(self, inputs, linebreak="\n", compression_type=None):
- filenames = []
- for i, ip in enumerate(inputs):
- fn = os.path.join(self.get_temp_dir(), "temp_%d.csv" % i)
- contents = linebreak.join(ip).encode("utf-8")
- if compression_type is None:
- with open(fn, "wb") as f:
- f.write(contents)
- elif compression_type == "GZIP":
- with gzip.GzipFile(fn, "wb") as f:
- f.write(contents)
- elif compression_type == "ZLIB":
- contents = zlib.compress(contents)
- with open(fn, "wb") as f:
- f.write(contents)
- else:
- raise ValueError("Unsupported compression_type", compression_type)
- filenames.append(fn)
- return filenames
-
- def _next_expected_batch(self, expected_output, expected_keys, batch_size,
- num_epochs):
- features = {k: [] for k in expected_keys}
- for _ in range(num_epochs):
- for values in expected_output:
- for n, key in enumerate(expected_keys):
- features[key].append(values[n])
- if len(features[expected_keys[0]]) == batch_size:
- yield features
- features = {k: [] for k in expected_keys}
- if features[expected_keys[0]]: # Leftover from the last batch
- yield features
-
- def _verify_output(
- self,
- sess,
- dataset,
- batch_size,
- num_epochs,
- label_name,
- expected_output,
- expected_keys,
- ):
- nxt = dataset.make_one_shot_iterator().get_next()
-
- for expected_features in self._next_expected_batch(
- expected_output,
- expected_keys,
- batch_size,
- num_epochs,
- ):
- actual_features = sess.run(nxt)
-
- if label_name is not None:
- expected_labels = expected_features.pop(label_name)
- self.assertAllEqual(expected_labels, actual_features[1])
- actual_features = actual_features[0]
-
- for k in expected_features.keys():
- # Compare features
- self.assertAllEqual(expected_features[k], actual_features[k])
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(nxt)
-
- def _test_dataset(self,
- inputs,
- expected_output,
- expected_keys,
- batch_size=1,
- num_epochs=1,
- label_name=None,
- **kwargs):
- """Checks that elements produced by CsvDataset match expected output."""
- # Convert str type because py3 tf strings are bytestrings
- filenames = self._setup_files(
- inputs, compression_type=kwargs.get("compression_type", None))
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- dataset = self._make_csv_dataset(
- filenames,
- batch_size=batch_size,
- num_epochs=num_epochs,
- label_name=label_name,
- **kwargs)
- self._verify_output(sess, dataset, batch_size, num_epochs, label_name,
- expected_output, expected_keys)
-
- def testMakeCSVDataset(self):
- """Tests making a CSV dataset with keys and defaults provided."""
- record_defaults = [
- constant_op.constant([], dtypes.int32),
- constant_op.constant([], dtypes.int64),
- constant_op.constant([], dtypes.float32),
- constant_op.constant([], dtypes.float64),
- constant_op.constant([], dtypes.string)
- ]
-
- column_names = ["col%d" % i for i in range(5)]
- inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
- ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
- ]]
- expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
- [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
- label = "col0"
-
- self._test_dataset(
- inputs,
- expected_output=expected_output,
- expected_keys=column_names,
- column_names=column_names,
- label_name=label,
- batch_size=1,
- num_epochs=1,
- shuffle=False,
- header=True,
- column_defaults=record_defaults,
- )
-
- def testMakeCSVDataset_withBatchSizeAndEpochs(self):
- """Tests making a CSV dataset with keys and defaults provided."""
- record_defaults = [
- constant_op.constant([], dtypes.int32),
- constant_op.constant([], dtypes.int64),
- constant_op.constant([], dtypes.float32),
- constant_op.constant([], dtypes.float64),
- constant_op.constant([], dtypes.string)
- ]
-
- column_names = ["col%d" % i for i in range(5)]
- inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
- ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
- ]]
- expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
- [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
- label = "col0"
-
- self._test_dataset(
- inputs,
- expected_output=expected_output,
- expected_keys=column_names,
- column_names=column_names,
- label_name=label,
- batch_size=3,
- num_epochs=10,
- shuffle=False,
- header=True,
- column_defaults=record_defaults,
- )
-
- def testMakeCSVDataset_withCompressionType(self):
- """Tests `compression_type` argument."""
- record_defaults = [
- constant_op.constant([], dtypes.int32),
- constant_op.constant([], dtypes.int64),
- constant_op.constant([], dtypes.float32),
- constant_op.constant([], dtypes.float64),
- constant_op.constant([], dtypes.string)
- ]
-
- column_names = ["col%d" % i for i in range(5)]
- inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
- ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
- ]]
- expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
- [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
- label = "col0"
-
- for compression_type in ("GZIP", "ZLIB"):
- self._test_dataset(
- inputs,
- expected_output=expected_output,
- expected_keys=column_names,
- column_names=column_names,
- label_name=label,
- batch_size=1,
- num_epochs=1,
- shuffle=False,
- header=True,
- column_defaults=record_defaults,
- compression_type=compression_type,
- )
-
- def testMakeCSVDataset_withBadInputs(self):
- """Tests that exception is raised when input is malformed.
- """
- record_defaults = [
- constant_op.constant([], dtypes.int32),
- constant_op.constant([], dtypes.int64),
- constant_op.constant([], dtypes.float32),
- constant_op.constant([], dtypes.float64),
- constant_op.constant([], dtypes.string)
- ]
-
- column_names = ["col%d" % i for i in range(5)]
- inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
- ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
- ]]
- filenames = self._setup_files(inputs)
-
- # Duplicate column names
- with self.assertRaises(ValueError):
- self._make_csv_dataset(
- filenames,
- batch_size=1,
- column_defaults=record_defaults,
- label_name="col0",
- column_names=column_names * 2)
-
- # Label key not one of column names
- with self.assertRaises(ValueError):
- self._make_csv_dataset(
- filenames,
- batch_size=1,
- column_defaults=record_defaults,
- label_name="not_a_real_label",
- column_names=column_names)
-
- def testMakeCSVDataset_withNoLabel(self):
- """Tests making a CSV dataset with no label provided."""
- record_defaults = [
- constant_op.constant([], dtypes.int32),
- constant_op.constant([], dtypes.int64),
- constant_op.constant([], dtypes.float32),
- constant_op.constant([], dtypes.float64),
- constant_op.constant([], dtypes.string)
- ]
-
- column_names = ["col%d" % i for i in range(5)]
- inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
- ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
- ]]
- expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
- [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
-
- self._test_dataset(
- inputs,
- expected_output=expected_output,
- expected_keys=column_names,
- column_names=column_names,
- batch_size=1,
- num_epochs=1,
- shuffle=False,
- header=True,
- column_defaults=record_defaults,
- )
-
- def testMakeCSVDataset_withNoHeader(self):
- """Tests that datasets can be created from CSV files with no header line.
- """
- record_defaults = [
- constant_op.constant([], dtypes.int32),
- constant_op.constant([], dtypes.int64),
- constant_op.constant([], dtypes.float32),
- constant_op.constant([], dtypes.float64),
- constant_op.constant([], dtypes.string)
- ]
-
- column_names = ["col%d" % i for i in range(5)]
- inputs = [["0,1,2,3,4", "5,6,7,8,9"], ["10,11,12,13,14", "15,16,17,18,19"]]
- expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
- [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
- label = "col0"
-
- self._test_dataset(
- inputs,
- expected_output=expected_output,
- expected_keys=column_names,
- column_names=column_names,
- label_name=label,
- batch_size=1,
- num_epochs=1,
- shuffle=False,
- header=False,
- column_defaults=record_defaults,
- )
-
- def testMakeCSVDataset_withTypes(self):
- """Tests that defaults can be a dtype instead of a Tensor for required vals.
- """
- record_defaults = [
- dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64,
- dtypes.string
- ]
-
- column_names = ["col%d" % i for i in range(5)]
- inputs = [[",".join(x[0] for x in column_names), "0,1,2,3,4", "5,6,7,8,9"],
- [
- ",".join(x[0] for x in column_names), "10,11,12,13,14",
- "15,16,17,18,19"
- ]]
- expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
- [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
- label = "col0"
-
- self._test_dataset(
- inputs,
- expected_output=expected_output,
- expected_keys=column_names,
- column_names=column_names,
- label_name=label,
- batch_size=1,
- num_epochs=1,
- shuffle=False,
- header=True,
- column_defaults=record_defaults,
- )
-
- def testMakeCSVDataset_withNoColNames(self):
- """Tests that datasets can be created when column names are not specified.
-
- In that case, we should infer the column names from the header lines.
- """
- record_defaults = [
- constant_op.constant([], dtypes.int32),
- constant_op.constant([], dtypes.int64),
- constant_op.constant([], dtypes.float32),
- constant_op.constant([], dtypes.float64),
- constant_op.constant([], dtypes.string)
- ]
-
- column_names = ["col%d" % i for i in range(5)]
- inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
- ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
- ]]
- expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
- [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
- label = "col0"
-
- self._test_dataset(
- inputs,
- expected_output=expected_output,
- expected_keys=column_names,
- label_name=label,
- batch_size=1,
- num_epochs=1,
- shuffle=False,
- header=True,
- column_defaults=record_defaults,
- )
-
- def testMakeCSVDataset_withTypeInferenceMismatch(self):
- # Test that error is thrown when num fields doesn't match columns
- column_names = ["col%d" % i for i in range(5)]
- inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
- ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
- ]]
- filenames = self._setup_files(inputs)
- with self.assertRaises(ValueError):
- self._make_csv_dataset(
- filenames,
- column_names=column_names + ["extra_name"],
- column_defaults=None,
- batch_size=2,
- num_epochs=10)
-
- def testMakeCSVDataset_withTypeInference(self):
- """Tests that datasets can be created when no defaults are specified.
-
- In that case, we should infer the types from the first N records.
- """
- column_names = ["col%d" % i for i in range(5)]
- str_int32_max = str(2**33)
- inputs = [[
- ",".join(x for x in column_names),
- "0,%s,2.0,3e50,rabbit" % str_int32_max
- ]]
- expected_output = [[0, 2**33, 2.0, 3e50, b"rabbit"]]
- label = "col0"
-
- self._test_dataset(
- inputs,
- expected_output=expected_output,
- expected_keys=column_names,
- column_names=column_names,
- label_name=label,
- batch_size=1,
- num_epochs=1,
- shuffle=False,
- header=True,
- )
-
- def testMakeCSVDataset_withTypeInferenceFallthrough(self):
- """Tests that datasets can be created when no defaults are specified.
-
- Tests on a deliberately tricky file.
- """
- column_names = ["col%d" % i for i in range(5)]
- str_int32_max = str(2**33)
- inputs = [[
- ",".join(x for x in column_names),
- ",,,,",
- "0,0,0.0,0.0,0.0",
- "0,%s,2.0,3e50,rabbit" % str_int32_max,
- ",,,,",
- ]]
- expected_output = [[0, 0, 0, 0, b""], [0, 0, 0, 0, b"0.0"],
- [0, 2**33, 2.0, 3e50, b"rabbit"], [0, 0, 0, 0, b""]]
- label = "col0"
-
- self._test_dataset(
- inputs,
- expected_output=expected_output,
- expected_keys=column_names,
- column_names=column_names,
- label_name=label,
- batch_size=1,
- num_epochs=1,
- shuffle=False,
- header=True,
- )
-
- def testMakeCSVDataset_withSelectCols(self):
- record_defaults = [
- constant_op.constant([], dtypes.int32),
- constant_op.constant([], dtypes.int64),
- constant_op.constant([], dtypes.float32),
- constant_op.constant([], dtypes.float64),
- constant_op.constant([], dtypes.string)
- ]
- column_names = ["col%d" % i for i in range(5)]
- str_int32_max = str(2**33)
- inputs = [[
- ",".join(x for x in column_names),
- "0,%s,2.0,3e50,rabbit" % str_int32_max
- ]]
- expected_output = [[0, 2**33, 2.0, 3e50, b"rabbit"]]
-
- select_cols = [1, 3, 4]
- self._test_dataset(
- inputs,
- expected_output=[[x[i] for i in select_cols] for x in expected_output],
- expected_keys=[column_names[i] for i in select_cols],
- column_names=column_names,
- column_defaults=[record_defaults[i] for i in select_cols],
- batch_size=1,
- num_epochs=1,
- shuffle=False,
- header=True,
- select_columns=select_cols,
- )
-
- # Can still do inference without provided defaults
- self._test_dataset(
- inputs,
- expected_output=[[x[i] for i in select_cols] for x in expected_output],
- expected_keys=[column_names[i] for i in select_cols],
- column_names=column_names,
- batch_size=1,
- num_epochs=1,
- shuffle=False,
- header=True,
- select_columns=select_cols,
- )
-
- # Can still do column name inference
- self._test_dataset(
- inputs,
- expected_output=[[x[i] for i in select_cols] for x in expected_output],
- expected_keys=[column_names[i] for i in select_cols],
- batch_size=1,
- num_epochs=1,
- shuffle=False,
- header=True,
- select_columns=select_cols,
- )
-
- # Can specify column names instead of indices
- self._test_dataset(
- inputs,
- expected_output=[[x[i] for i in select_cols] for x in expected_output],
- expected_keys=[column_names[i] for i in select_cols],
- column_names=column_names,
- batch_size=1,
- num_epochs=1,
- shuffle=False,
- header=True,
- select_columns=[column_names[i] for i in select_cols],
- )
-
- def testMakeCSVDataset_withSelectColsError(self):
- record_defaults = [
- constant_op.constant([], dtypes.int32),
- constant_op.constant([], dtypes.int64),
- constant_op.constant([], dtypes.float32),
- constant_op.constant([], dtypes.float64),
- constant_op.constant([], dtypes.string)
- ]
- column_names = ["col%d" % i for i in range(5)]
- str_int32_max = str(2**33)
- inputs = [[
- ",".join(x for x in column_names),
- "0,%s,2.0,3e50,rabbit" % str_int32_max
- ]]
-
- select_cols = [1, 3, 4]
- filenames = self._setup_files(inputs)
-
- with self.assertRaises(ValueError):
- # Mismatch in number of defaults and number of columns selected,
- # should raise an error
- self._make_csv_dataset(
- filenames,
- batch_size=1,
- column_defaults=record_defaults,
- column_names=column_names,
- select_columns=select_cols)
-
- with self.assertRaises(ValueError):
- # Invalid column name should raise an error
- self._make_csv_dataset(
- filenames,
- batch_size=1,
- column_defaults=[[0]],
- column_names=column_names,
- label_name=None,
- select_columns=["invalid_col_name"])
-
- def testMakeCSVDataset_withShuffle(self):
- record_defaults = [
- constant_op.constant([], dtypes.int32),
- constant_op.constant([], dtypes.int64),
- constant_op.constant([], dtypes.float32),
- constant_op.constant([], dtypes.float64),
- constant_op.constant([], dtypes.string)
- ]
-
- def str_series(st):
- return ",".join(str(i) for i in range(st, st + 5))
-
- column_names = ["col%d" % i for i in range(5)]
- inputs = [
- [",".join(x for x in column_names)
- ] + [str_series(5 * i) for i in range(15)],
- [",".join(x for x in column_names)] +
- [str_series(5 * i) for i in range(15, 20)],
- ]
-
- filenames = self._setup_files(inputs)
-
- total_records = 20
- for batch_size in [1, 2]:
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- # Test that shuffling with the same seed produces the same result
- dataset1 = self._make_csv_dataset(
- filenames,
- column_defaults=record_defaults,
- column_names=column_names,
- batch_size=batch_size,
- header=True,
- shuffle=True,
- shuffle_seed=5,
- num_epochs=2,
- )
- dataset2 = self._make_csv_dataset(
- filenames,
- column_defaults=record_defaults,
- column_names=column_names,
- batch_size=batch_size,
- header=True,
- shuffle=True,
- shuffle_seed=5,
- num_epochs=2,
- )
- outputs1 = dataset1.make_one_shot_iterator().get_next()
- outputs2 = dataset2.make_one_shot_iterator().get_next()
- for _ in range(total_records // batch_size):
- batch1 = nest.flatten(sess.run(outputs1))
- batch2 = nest.flatten(sess.run(outputs2))
- for i in range(len(batch1)):
- self.assertAllEqual(batch1[i], batch2[i])
-
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- # Test that shuffling with a different seed produces different results
- dataset1 = self._make_csv_dataset(
- filenames,
- column_defaults=record_defaults,
- column_names=column_names,
- batch_size=batch_size,
- header=True,
- shuffle=True,
- shuffle_seed=5,
- num_epochs=2,
- )
- dataset2 = self._make_csv_dataset(
- filenames,
- column_defaults=record_defaults,
- column_names=column_names,
- batch_size=batch_size,
- header=True,
- shuffle=True,
- shuffle_seed=6,
- num_epochs=2,
- )
- outputs1 = dataset1.make_one_shot_iterator().get_next()
- outputs2 = dataset2.make_one_shot_iterator().get_next()
- all_equal = False
- for _ in range(total_records // batch_size):
- batch1 = nest.flatten(sess.run(outputs1))
- batch2 = nest.flatten(sess.run(outputs2))
- for i in range(len(batch1)):
- all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
- self.assertFalse(all_equal)
-
- def testIndefiniteRepeatShapeInference(self):
- column_names = ["col%d" % i for i in range(5)]
- inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
- ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
- ]]
- filenames = self._setup_files(inputs)
- dataset = self._make_csv_dataset(filenames, batch_size=32, num_epochs=None)
- for shape in nest.flatten(dataset.output_shapes):
- self.assertEqual(32, shape[0])
-
-
-class MakeTFRecordDatasetTest(
- reader_dataset_ops_test_base.TFRecordDatasetTestBase):
-
- def _interleave(self, iterators, cycle_length):
- pending_iterators = iterators
- open_iterators = []
- num_open = 0
- for i in range(cycle_length):
- if pending_iterators:
- open_iterators.append(pending_iterators.pop(0))
- num_open += 1
-
- while num_open:
- for i in range(min(cycle_length, len(open_iterators))):
- if open_iterators[i] is None:
- continue
- try:
- yield next(open_iterators[i])
- except StopIteration:
- if pending_iterators:
- open_iterators[i] = pending_iterators.pop(0)
- else:
- open_iterators[i] = None
- num_open -= 1
-
- def _next_expected_batch(self,
- file_indices,
- batch_size,
- num_epochs,
- cycle_length,
- drop_final_batch,
- use_parser_fn):
-
- def _next_record(file_indices):
- for j in file_indices:
- for i in range(self._num_records):
- yield j, i
-
- def _next_record_interleaved(file_indices, cycle_length):
- return self._interleave([_next_record([i]) for i in file_indices],
- cycle_length)
-
- record_batch = []
- batch_index = 0
- for _ in range(num_epochs):
- if cycle_length == 1:
- next_records = _next_record(file_indices)
- else:
- next_records = _next_record_interleaved(file_indices, cycle_length)
- for f, r in next_records:
- record = self._record(f, r)
- if use_parser_fn:
- record = record[1:]
- record_batch.append(record)
- batch_index += 1
- if len(record_batch) == batch_size:
- yield record_batch
- record_batch = []
- batch_index = 0
- if record_batch and not drop_final_batch:
- yield record_batch
-
- def _verify_records(self,
- sess,
- outputs,
- batch_size,
- file_index,
- num_epochs,
- interleave_cycle_length,
- drop_final_batch,
- use_parser_fn):
- if file_index is not None:
- file_indices = [file_index]
- else:
- file_indices = range(self._num_files)
-
- for expected_batch in self._next_expected_batch(
- file_indices, batch_size, num_epochs, interleave_cycle_length,
- drop_final_batch, use_parser_fn):
- actual_batch = sess.run(outputs)
- self.assertAllEqual(expected_batch, actual_batch)
-
- def _read_test(self, batch_size, num_epochs, file_index=None,
- num_parallel_reads=1, drop_final_batch=False, parser_fn=False):
- if file_index is None:
- file_pattern = self.test_filenames
- else:
- file_pattern = self.test_filenames[file_index]
-
- if parser_fn:
- fn = lambda x: string_ops.substr(x, 1, 999)
- else:
- fn = None
-
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- outputs = readers.make_tf_record_dataset(
- file_pattern=file_pattern,
- num_epochs=num_epochs,
- batch_size=batch_size,
- parser_fn=fn,
- num_parallel_reads=num_parallel_reads,
- drop_final_batch=drop_final_batch,
- shuffle=False).make_one_shot_iterator().get_next()
- self._verify_records(
- sess, outputs, batch_size, file_index, num_epochs=num_epochs,
- interleave_cycle_length=num_parallel_reads,
- drop_final_batch=drop_final_batch, use_parser_fn=parser_fn)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(outputs)
-
- def testRead(self):
- for batch_size in [1, 2]:
- for num_epochs in [1, 3]:
- # Basic test: read from file 0.
- self._read_test(batch_size, num_epochs, 0)
-
- # Basic test: read from file 1.
- self._read_test(batch_size, num_epochs, 1)
-
- # Basic test: read from both files.
- self._read_test(batch_size, num_epochs)
-
- # Basic test: read from both files, with parallel reads.
- self._read_test(batch_size, num_epochs, num_parallel_reads=8)
-
- def testDropFinalBatch(self):
- for batch_size in [1, 2, 10]:
- for num_epochs in [1, 3]:
- # Read from file 0.
- self._read_test(batch_size, num_epochs, 0, drop_final_batch=True)
-
- # Read from both files.
- self._read_test(batch_size, num_epochs, drop_final_batch=True)
-
- # Read from both files, with parallel reads.
- self._read_test(batch_size, num_epochs, num_parallel_reads=8,
- drop_final_batch=True)
-
- def testParserFn(self):
- for batch_size in [1, 2]:
- for num_epochs in [1, 3]:
- for drop_final_batch in [False, True]:
- self._read_test(batch_size, num_epochs, parser_fn=True,
- drop_final_batch=drop_final_batch)
- self._read_test(batch_size, num_epochs, num_parallel_reads=8,
- parser_fn=True, drop_final_batch=drop_final_batch)
-
- def _shuffle_test(self, batch_size, num_epochs, num_parallel_reads=1,
- seed=None):
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- dataset = readers.make_tf_record_dataset(
- file_pattern=self.test_filenames,
- num_epochs=num_epochs,
- batch_size=batch_size,
- num_parallel_reads=num_parallel_reads,
- shuffle=True,
- shuffle_seed=seed)
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- sess.run(iterator.initializer)
- first_batches = []
- try:
- while True:
- first_batches.append(sess.run(next_element))
- except errors.OutOfRangeError:
- pass
-
- sess.run(iterator.initializer)
- second_batches = []
- try:
- while True:
- second_batches.append(sess.run(next_element))
- except errors.OutOfRangeError:
- pass
-
- self.assertEqual(len(first_batches), len(second_batches))
- if seed is not None:
- # if you set a seed, should get the same results
- for i in range(len(first_batches)):
- self.assertAllEqual(first_batches[i], second_batches[i])
-
- expected = []
- for f in range(self._num_files):
- for r in range(self._num_records):
- expected.extend([self._record(f, r)] * num_epochs)
-
- for batches in (first_batches, second_batches):
- actual = []
- for b in batches:
- actual.extend(b)
- self.assertAllEqual(sorted(expected), sorted(actual))
-
- def testShuffle(self):
- for batch_size in [1, 2]:
- for num_epochs in [1, 3]:
- for num_parallel_reads in [1, 2]:
- # Test that all expected elements are produced
- self._shuffle_test(batch_size, num_epochs, num_parallel_reads)
- # Test that elements are produced in a consistent order if
- # you specify a seed.
- self._shuffle_test(batch_size, num_epochs, num_parallel_reads,
- seed=21345)
-
- def testIndefiniteRepeatShapeInference(self):
- dataset = readers.make_tf_record_dataset(
- file_pattern=self.test_filenames, num_epochs=None, batch_size=32)
- for shape in nest.flatten(dataset.output_shapes):
- self.assertEqual(32, shape[0])
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
deleted file mode 100644
index f443b5501b..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
+++ /dev/null
@@ -1,353 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Base class for testing reader datasets."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import gzip
-import os
-import zlib
-
-from tensorflow.contrib.data.python.ops import readers
-from tensorflow.core.example import example_pb2
-from tensorflow.core.example import feature_pb2
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.data.ops import readers as core_readers
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.lib.io import python_io
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import parsing_ops
-from tensorflow.python.util import compat
-
-
-class FixedLengthRecordDatasetTestBase(test_base.DatasetTestBase):
- """Base class for setting up and testing FixedLengthRecordDataset."""
-
- def setUp(self):
- super(FixedLengthRecordDatasetTestBase, self).setUp()
- self._num_files = 2
- self._num_records = 7
- self._header_bytes = 5
- self._record_bytes = 3
- self._footer_bytes = 2
-
- def _record(self, f, r):
- return compat.as_bytes(str(f * 2 + r) * self._record_bytes)
-
- def _createFiles(self):
- filenames = []
- for i in range(self._num_files):
- fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
- filenames.append(fn)
- with open(fn, "wb") as f:
- f.write(b"H" * self._header_bytes)
- for j in range(self._num_records):
- f.write(self._record(i, j))
- f.write(b"F" * self._footer_bytes)
- return filenames
-
-
-class ReadBatchFeaturesTestBase(test_base.DatasetTestBase):
- """Base class for setting up and testing `make_batched_feature_dataset`."""
-
- def setUp(self):
- super(ReadBatchFeaturesTestBase, self).setUp()
- self._num_files = 2
- self._num_records = 7
- self.test_filenames = self._createFiles()
-
- def make_batch_feature(self,
- filenames,
- num_epochs,
- batch_size,
- label_key=None,
- reader_num_threads=1,
- parser_num_threads=1,
- shuffle=False,
- shuffle_seed=None,
- drop_final_batch=False):
- self.filenames = filenames
- self.num_epochs = num_epochs
- self.batch_size = batch_size
-
- return readers.make_batched_features_dataset(
- file_pattern=self.filenames,
- batch_size=self.batch_size,
- features={
- "file": parsing_ops.FixedLenFeature([], dtypes.int64),
- "record": parsing_ops.FixedLenFeature([], dtypes.int64),
- "keywords": parsing_ops.VarLenFeature(dtypes.string),
- "label": parsing_ops.FixedLenFeature([], dtypes.string),
- },
- label_key=label_key,
- reader=core_readers.TFRecordDataset,
- num_epochs=self.num_epochs,
- shuffle=shuffle,
- shuffle_seed=shuffle_seed,
- reader_num_threads=reader_num_threads,
- parser_num_threads=parser_num_threads,
- drop_final_batch=drop_final_batch)
-
- def _record(self, f, r, l):
- example = example_pb2.Example(
- features=feature_pb2.Features(
- feature={
- "file":
- feature_pb2.Feature(
- int64_list=feature_pb2.Int64List(value=[f])),
- "record":
- feature_pb2.Feature(
- int64_list=feature_pb2.Int64List(value=[r])),
- "keywords":
- feature_pb2.Feature(
- bytes_list=feature_pb2.BytesList(
- value=self._get_keywords(f, r))),
- "label":
- feature_pb2.Feature(
- bytes_list=feature_pb2.BytesList(
- value=[compat.as_bytes(l)]))
- }))
- return example.SerializeToString()
-
- def _get_keywords(self, f, r):
- num_keywords = 1 + (f + r) % 2
- keywords = []
- for index in range(num_keywords):
- keywords.append(compat.as_bytes("keyword%d" % index))
- return keywords
-
- def _sum_keywords(self, num_files):
- sum_keywords = 0
- for i in range(num_files):
- for j in range(self._num_records):
- sum_keywords += 1 + (i + j) % 2
- return sum_keywords
-
- def _createFiles(self):
- filenames = []
- for i in range(self._num_files):
- fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
- filenames.append(fn)
- writer = python_io.TFRecordWriter(fn)
- for j in range(self._num_records):
- writer.write(self._record(i, j, "fake-label"))
- writer.close()
- return filenames
-
- def _run_actual_batch(self, outputs, sess, label_key_provided=False):
- if label_key_provided:
- # outputs would be a tuple of (feature dict, label)
- label_op = outputs[1]
- features_op = outputs[0]
- else:
- features_op = outputs
- label_op = features_op["label"]
- file_op = features_op["file"]
- keywords_indices_op = features_op["keywords"].indices
- keywords_values_op = features_op["keywords"].values
- keywords_dense_shape_op = features_op["keywords"].dense_shape
- record_op = features_op["record"]
- return sess.run([
- file_op, keywords_indices_op, keywords_values_op,
- keywords_dense_shape_op, record_op, label_op
- ])
-
- def _next_actual_batch(self, sess, label_key_provided=False):
- return self._run_actual_batch(self.outputs, sess, label_key_provided)
-
- def _interleave(self, iterators, cycle_length):
- pending_iterators = iterators
- open_iterators = []
- num_open = 0
- for i in range(cycle_length):
- if pending_iterators:
- open_iterators.append(pending_iterators.pop(0))
- num_open += 1
-
- while num_open:
- for i in range(min(cycle_length, len(open_iterators))):
- if open_iterators[i] is None:
- continue
- try:
- yield next(open_iterators[i])
- except StopIteration:
- if pending_iterators:
- open_iterators[i] = pending_iterators.pop(0)
- else:
- open_iterators[i] = None
- num_open -= 1
-
- def _next_expected_batch(self,
- file_indices,
- batch_size,
- num_epochs,
- cycle_length=1):
-
- def _next_record(file_indices):
- for j in file_indices:
- for i in range(self._num_records):
- yield j, i, compat.as_bytes("fake-label")
-
- def _next_record_interleaved(file_indices, cycle_length):
- return self._interleave([_next_record([i]) for i in file_indices],
- cycle_length)
-
- file_batch = []
- keywords_batch_indices = []
- keywords_batch_values = []
- keywords_batch_max_len = 0
- record_batch = []
- batch_index = 0
- label_batch = []
- for _ in range(num_epochs):
- if cycle_length == 1:
- next_records = _next_record(file_indices)
- else:
- next_records = _next_record_interleaved(file_indices, cycle_length)
- for record in next_records:
- f = record[0]
- r = record[1]
- label_batch.append(record[2])
- file_batch.append(f)
- record_batch.append(r)
- keywords = self._get_keywords(f, r)
- keywords_batch_values.extend(keywords)
- keywords_batch_indices.extend(
- [[batch_index, i] for i in range(len(keywords))])
- batch_index += 1
- keywords_batch_max_len = max(keywords_batch_max_len, len(keywords))
- if len(file_batch) == batch_size:
- yield [
- file_batch, keywords_batch_indices, keywords_batch_values,
- [batch_size, keywords_batch_max_len], record_batch, label_batch
- ]
- file_batch = []
- keywords_batch_indices = []
- keywords_batch_values = []
- keywords_batch_max_len = 0
- record_batch = []
- batch_index = 0
- label_batch = []
- if file_batch:
- yield [
- file_batch, keywords_batch_indices, keywords_batch_values,
- [len(file_batch), keywords_batch_max_len], record_batch, label_batch
- ]
-
- def verify_records(self,
- sess,
- batch_size,
- file_index=None,
- num_epochs=1,
- label_key_provided=False,
- interleave_cycle_length=1):
- if file_index is not None:
- file_indices = [file_index]
- else:
- file_indices = range(self._num_files)
-
- for expected_batch in self._next_expected_batch(
- file_indices,
- batch_size,
- num_epochs,
- cycle_length=interleave_cycle_length):
- actual_batch = self._next_actual_batch(
- sess, label_key_provided=label_key_provided)
- for i in range(len(expected_batch)):
- self.assertAllEqual(expected_batch[i], actual_batch[i])
-
-
-class TextLineDatasetTestBase(test_base.DatasetTestBase):
- """Base class for setting up and testing TextLineDataset."""
-
- def _lineText(self, f, l):
- return compat.as_bytes("%d: %d" % (f, l))
-
- def _createFiles(self,
- num_files,
- num_lines,
- crlf=False,
- compression_type=None):
- filenames = []
- for i in range(num_files):
- fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i)
- filenames.append(fn)
- contents = []
- for j in range(num_lines):
- contents.append(self._lineText(i, j))
- # Always include a newline after the record unless it is
- # at the end of the file, in which case we include it
- if j + 1 != num_lines or i == 0:
- contents.append(b"\r\n" if crlf else b"\n")
- contents = b"".join(contents)
-
- if not compression_type:
- with open(fn, "wb") as f:
- f.write(contents)
- elif compression_type == "GZIP":
- with gzip.GzipFile(fn, "wb") as f:
- f.write(contents)
- elif compression_type == "ZLIB":
- contents = zlib.compress(contents)
- with open(fn, "wb") as f:
- f.write(contents)
- else:
- raise ValueError("Unsupported compression_type", compression_type)
-
- return filenames
-
-
-class TFRecordDatasetTestBase(test_base.DatasetTestBase):
- """Base class for setting up and testing TFRecordDataset."""
-
- def setUp(self):
- super(TFRecordDatasetTestBase, self).setUp()
- self._num_files = 2
- self._num_records = 7
-
- self.test_filenames = self._createFiles()
-
- self.filenames = array_ops.placeholder(dtypes.string, shape=[None])
- self.num_epochs = array_ops.placeholder_with_default(
- constant_op.constant(1, dtypes.int64), shape=[])
- self.compression_type = array_ops.placeholder_with_default("", shape=[])
- self.batch_size = array_ops.placeholder(dtypes.int64, shape=[])
-
- repeat_dataset = core_readers.TFRecordDataset(
- self.filenames, self.compression_type).repeat(self.num_epochs)
- batch_dataset = repeat_dataset.batch(self.batch_size)
-
- iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
- self.init_op = iterator.make_initializer(repeat_dataset)
- self.init_batch_op = iterator.make_initializer(batch_dataset)
- self.get_next = iterator.get_next()
-
- def _record(self, f, r):
- return compat.as_bytes("Record %d of file %d" % (r, f))
-
- def _createFiles(self):
- filenames = []
- for i in range(self._num_files):
- fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
- filenames.append(fn)
- writer = python_io.TFRecordWriter(fn)
- for j in range(self._num_records):
- writer.write(self._record(i, j))
- writer.close()
- return filenames
diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py
index cc22ea1df7..e7281d5318 100644
--- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py
@@ -25,49 +25,11 @@ from tensorflow.contrib.data.python.ops import grouping
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
-
- @parameterized.named_parameters(
- ("Zero", 0, 1),
- ("Five", 5, 1),
- ("Ten", 10, 1),
- ("Empty", 100, 1, errors.InvalidArgumentError, "Dataset was empty."),
- ("MoreThanOne", 0, 2, errors.InvalidArgumentError,
- "Dataset had more than one element."),
- )
- def testGetSingleElement(self, skip, take, error=None, error_msg=None):
- skip_t = array_ops.placeholder(dtypes.int64, shape=[])
- take_t = array_ops.placeholder(dtypes.int64, shape=[])
-
- def make_sparse(x):
- x_1d = array_ops.reshape(x, [1])
- x_2d = array_ops.reshape(x, [1, 1])
- return sparse_tensor.SparseTensor(x_2d, x_1d, x_1d)
-
- dataset = dataset_ops.Dataset.range(100).skip(skip_t).map(
- lambda x: (x * x, make_sparse(x))).take(take_t)
- element = get_single_element.get_single_element(dataset)
-
- with self.cached_session() as sess:
- if error is None:
- dense_val, sparse_val = sess.run(
- element, feed_dict={
- skip_t: skip,
- take_t: take
- })
- self.assertEqual(skip * skip, dense_val)
- self.assertAllEqual([[skip]], sparse_val.indices)
- self.assertAllEqual([skip], sparse_val.values)
- self.assertAllEqual([skip], sparse_val.dense_shape)
- else:
- with self.assertRaisesRegexp(error, error_msg):
- sess.run(element, feed_dict={skip_t: skip, take_t: take})
+class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("SumZero", 0),
diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py
deleted file mode 100644
index 32474bd411..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py
+++ /dev/null
@@ -1,182 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import time
-
-from absl.testing import parameterized
-import numpy as np
-from six.moves import xrange # pylint: disable=redefined-builtin
-
-from tensorflow.contrib.data.python.ops import resampling
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import string_ops
-from tensorflow.python.platform import test
-from tensorflow.python.util import compat
-
-
-def _time_resampling(
- test_obj, data_np, target_dist, init_dist, num_to_sample):
- dataset = dataset_ops.Dataset.from_tensor_slices(data_np).repeat()
-
- # Reshape distribution via rejection sampling.
- dataset = dataset.apply(
- resampling.rejection_resample(
- class_func=lambda x: x,
- target_dist=target_dist,
- initial_dist=init_dist,
- seed=142))
-
- get_next = dataset.make_one_shot_iterator().get_next()
-
- with test_obj.test_session() as sess:
- start_time = time.time()
- for _ in xrange(num_to_sample):
- sess.run(get_next)
- end_time = time.time()
-
- return end_time - start_time
-
-
-class ResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
-
- @parameterized.named_parameters(
- ("InitialDistributionKnown", True),
- ("InitialDistributionUnknown", False))
- def testDistribution(self, initial_known):
- classes = np.random.randint(5, size=(20000,)) # Uniformly sampled
- target_dist = [0.9, 0.05, 0.05, 0.0, 0.0]
- initial_dist = [0.2] * 5 if initial_known else None
- classes = math_ops.to_int64(classes) # needed for Windows build.
- dataset = dataset_ops.Dataset.from_tensor_slices(classes).shuffle(
- 200, seed=21).map(lambda c: (c, string_ops.as_string(c))).repeat()
-
- get_next = dataset.apply(
- resampling.rejection_resample(
- target_dist=target_dist,
- initial_dist=initial_dist,
- class_func=lambda c, _: c,
- seed=27)).make_one_shot_iterator().get_next()
-
- with self.cached_session() as sess:
- returned = []
- while len(returned) < 4000:
- returned.append(sess.run(get_next))
-
- returned_classes, returned_classes_and_data = zip(*returned)
- _, returned_data = zip(*returned_classes_and_data)
- self.assertAllEqual([compat.as_bytes(str(c))
- for c in returned_classes], returned_data)
- total_returned = len(returned_classes)
- class_counts = np.array([
- len([True for v in returned_classes if v == c])
- for c in range(5)])
- returned_dist = class_counts / total_returned
- self.assertAllClose(target_dist, returned_dist, atol=1e-2)
-
- @parameterized.named_parameters(
- ("OnlyInitial", True),
- ("NotInitial", False))
- def testEdgeCasesSampleFromInitialDataset(self, only_initial_dist):
- init_dist = [0.5, 0.5]
- target_dist = [0.5, 0.5] if only_initial_dist else [0.0, 1.0]
- num_classes = len(init_dist)
- # We don't need many samples to test that this works.
- num_samples = 100
- data_np = np.random.choice(num_classes, num_samples, p=init_dist)
-
- dataset = dataset_ops.Dataset.from_tensor_slices(data_np)
-
- # Reshape distribution.
- dataset = dataset.apply(
- resampling.rejection_resample(
- class_func=lambda x: x,
- target_dist=target_dist,
- initial_dist=init_dist))
-
- get_next = dataset.make_one_shot_iterator().get_next()
-
- with self.cached_session() as sess:
- returned = []
- with self.assertRaises(errors.OutOfRangeError):
- while True:
- returned.append(sess.run(get_next))
-
- def testRandomClasses(self):
- init_dist = [0.25, 0.25, 0.25, 0.25]
- target_dist = [0.0, 0.0, 0.0, 1.0]
- num_classes = len(init_dist)
- # We don't need many samples to test a dirac-delta target distribution.
- num_samples = 100
- data_np = np.random.choice(num_classes, num_samples, p=init_dist)
-
- dataset = dataset_ops.Dataset.from_tensor_slices(data_np)
-
- # Apply a random mapping that preserves the data distribution.
- def _remap_fn(_):
- return math_ops.cast(random_ops.random_uniform([1]) * num_classes,
- dtypes.int32)[0]
- dataset = dataset.map(_remap_fn)
-
- # Reshape distribution.
- dataset = dataset.apply(
- resampling.rejection_resample(
- class_func=lambda x: x,
- target_dist=target_dist,
- initial_dist=init_dist))
-
- get_next = dataset.make_one_shot_iterator().get_next()
-
- with self.cached_session() as sess:
- returned = []
- with self.assertRaises(errors.OutOfRangeError):
- while True:
- returned.append(sess.run(get_next))
-
- classes, _ = zip(*returned)
- bincount = np.bincount(
- np.array(classes),
- minlength=num_classes).astype(np.float32) / len(classes)
-
- self.assertAllClose(target_dist, bincount, atol=1e-2)
-
-
-class ResampleDatasetBenchmark(test.Benchmark):
-
- def benchmarkResamplePerformance(self):
- init_dist = [0.25, 0.25, 0.25, 0.25]
- target_dist = [0.0, 0.0, 0.0, 1.0]
- num_classes = len(init_dist)
- # We don't need many samples to test a dirac-delta target distribution
- num_samples = 1000
- data_np = np.random.choice(num_classes, num_samples, p=init_dist)
-
- resample_time = _time_resampling(
- self, data_np, target_dist, init_dist, num_to_sample=1000)
-
- self.report_benchmark(
- iters=1000, wall_time=resample_time, name="benchmark_resample")
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
deleted file mode 100644
index bdf80eae4e..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
+++ /dev/null
@@ -1,172 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import itertools
-
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import scan_ops
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.eager import context
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.framework import test_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-
-
-class ScanDatasetTest(test_base.DatasetTestBase):
-
- def _counting_dataset(self, start, scan_fn):
- return dataset_ops.Dataset.from_tensors(0).repeat().apply(
- scan_ops.scan(start, scan_fn))
-
- def testCount(self):
- def make_scan_fn(step):
- return lambda state, _: (state + step, state)
-
- start = array_ops.placeholder(dtypes.int32, shape=[])
- step = array_ops.placeholder(dtypes.int32, shape=[])
- take = array_ops.placeholder(dtypes.int64, shape=[])
- iterator = self._counting_dataset(
- start, make_scan_fn(step)).take(take).make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
-
- for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10),
- (10, 2, 10), (10, -1, 10),
- (10, -2, 10)]:
- sess.run(iterator.initializer,
- feed_dict={start: start_val, step: step_val, take: take_val})
- for expected, _ in zip(
- itertools.count(start_val, step_val), range(take_val)):
- self.assertEqual(expected, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- @test_util.run_in_graph_and_eager_modes
- def testFibonacci(self):
- iterator = dataset_ops.Dataset.from_tensors(1).repeat(None).apply(
- scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1]))
- ).make_one_shot_iterator()
-
- if context.executing_eagerly():
- next_element = iterator.get_next
- else:
- get_next = iterator.get_next()
- next_element = lambda: get_next
-
- self.assertEqual(1, self.evaluate(next_element()))
- self.assertEqual(1, self.evaluate(next_element()))
- self.assertEqual(2, self.evaluate(next_element()))
- self.assertEqual(3, self.evaluate(next_element()))
- self.assertEqual(5, self.evaluate(next_element()))
- self.assertEqual(8, self.evaluate(next_element()))
-
- def testSparseCount(self):
- def _sparse(i):
- return sparse_tensor.SparseTensorValue(
- indices=np.array([[0, 0]]),
- values=(i * np.array([1])),
- dense_shape=np.array([1, 1]))
-
- def make_scan_fn(step):
- return lambda state, _: (_sparse(state.values[0] + step), state)
-
- start = array_ops.placeholder(dtypes.int32, shape=[])
- step = array_ops.placeholder(dtypes.int32, shape=[])
- take = array_ops.placeholder(dtypes.int64, shape=[])
- iterator = self._counting_dataset(
- _sparse(start),
- make_scan_fn(step)).take(take).make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
-
- for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10),
- (10, 2, 10), (10, -1, 10),
- (10, -2, 10)]:
- sess.run(iterator.initializer,
- feed_dict={start: start_val, step: step_val, take: take_val})
- for expected, _ in zip(
- itertools.count(start_val, step_val), range(take_val)):
- self.assertEqual(expected, sess.run(next_element).values[0])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testChangingStateShape(self):
- # Test the fixed-point shape invariant calculations: start with
- # initial values with known shapes, and use a scan function that
- # changes the size of the state on each element.
- def _scan_fn(state, input_value):
- # Statically known rank, but dynamic length.
- ret_longer_vector = array_ops.concat([state[0], state[0]], 0)
- # Statically unknown rank.
- ret_larger_rank = array_ops.expand_dims(state[1], 0)
- return (ret_longer_vector, ret_larger_rank), (state, input_value)
-
- dataset = dataset_ops.Dataset.from_tensors(0).repeat(5).apply(
- scan_ops.scan(([0], 1), _scan_fn))
- self.assertEqual([None], dataset.output_shapes[0][0].as_list())
- self.assertIs(None, dataset.output_shapes[0][1].ndims)
- self.assertEqual([], dataset.output_shapes[1].as_list())
-
- iterator = dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(5):
- (longer_vector_val, larger_rank_val), _ = sess.run(next_element)
- self.assertAllEqual([0] * (2**i), longer_vector_val)
- self.assertAllEqual(np.array(1, ndmin=i), larger_rank_val)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testIncorrectStateType(self):
-
- def _scan_fn(state, _):
- return constant_op.constant(1, dtype=dtypes.int64), state
-
- dataset = dataset_ops.Dataset.range(10)
- with self.assertRaisesRegexp(
- TypeError,
- "The element types for the new state must match the initial state."):
- dataset.apply(
- scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn))
-
- def testIncorrectReturnType(self):
-
- def _scan_fn(unused_state, unused_input_value):
- return constant_op.constant(1, dtype=dtypes.int64)
-
- dataset = dataset_ops.Dataset.range(10)
- with self.assertRaisesRegexp(
- TypeError,
- "The scan function must return a pair comprising the new state and the "
- "output value."):
- dataset.apply(
- scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn))
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
deleted file mode 100644
index aa89674c6e..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
+++ /dev/null
@@ -1,555 +0,0 @@
-package(default_visibility = ["//tensorflow:internal"])
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-load("//tensorflow:tensorflow.bzl", "py_test")
-
-py_library(
- name = "dataset_serialization_test_base",
- srcs = [
- "dataset_serialization_test_base.py",
- ],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:iterator_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:lookup_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:training",
- "//tensorflow/python:util",
- "//tensorflow/python:variables",
- "//tensorflow/python/data/ops:iterator_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "batch_dataset_serialization_test",
- size = "medium",
- srcs = ["batch_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "cache_dataset_serialization_test",
- size = "small",
- srcs = ["cache_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python/data/ops:dataset_ops",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "concatenate_dataset_serialization_test",
- size = "small",
- srcs = ["concatenate_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "csv_dataset_serialization_test",
- size = "small",
- srcs = ["csv_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:readers",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- ],
-)
-
-py_test(
- name = "dataset_constructor_serialization_test",
- size = "medium",
- srcs = ["dataset_constructor_serialization_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "filter_dataset_serialization_test",
- size = "medium",
- srcs = ["filter_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "fixed_length_record_dataset_serialization_test",
- size = "medium",
- srcs = ["fixed_length_record_dataset_serialization_test.py"],
- shard_count = 4,
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:readers",
- ],
-)
-
-py_test(
- name = "flat_map_dataset_serialization_test",
- size = "medium",
- srcs = ["flat_map_dataset_serialization_test.py"],
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:function",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:sparse_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "group_by_reducer_serialization_test",
- size = "medium",
- srcs = ["group_by_reducer_serialization_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:grouping",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "group_by_window_serialization_test",
- size = "medium",
- srcs = ["group_by_window_serialization_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:grouping",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "ignore_errors_serialization_test",
- size = "small",
- srcs = ["ignore_errors_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:error_ops",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "interleave_dataset_serialization_test",
- size = "medium",
- srcs = ["interleave_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:sparse_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "map_and_batch_dataset_serialization_test",
- size = "medium",
- srcs = ["map_and_batch_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:math_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "map_dataset_serialization_test",
- size = "medium",
- srcs = ["map_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:function",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "optimize_dataset_serialization_test",
- size = "small",
- srcs = ["optimize_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "padded_batch_dataset_serialization_test",
- size = "medium",
- srcs = ["padded_batch_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:string_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "parallel_interleave_dataset_serialization_test",
- size = "medium",
- srcs = ["parallel_interleave_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:interleave_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:sparse_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "parallel_map_dataset_serialization_test",
- size = "medium",
- srcs = ["parallel_map_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:function",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "parse_example_dataset_serialization_test",
- size = "medium",
- srcs = ["parse_example_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base",
- "//tensorflow/python:client_testlib",
- ],
-)
-
-py_test(
- name = "prefetch_dataset_serialization_test",
- size = "small",
- srcs = ["prefetch_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "range_dataset_serialization_test",
- size = "small",
- srcs = ["range_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:io_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:variables",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "sample_from_datasets_serialization_test",
- size = "medium",
- srcs = ["sample_from_datasets_serialization_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:interleave_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "scan_dataset_serialization_test",
- size = "small",
- srcs = ["scan_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:scan_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "sequence_dataset_serialization_test",
- size = "medium",
- srcs = ["sequence_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "serialization_integration_test",
- size = "small",
- srcs = ["serialization_integration_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- "//tensorflow/contrib/data/python/ops:iterator_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:training",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "shuffle_and_repeat_dataset_serialization_test",
- size = "medium",
- srcs = ["shuffle_and_repeat_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:shuffle_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "shuffle_dataset_serialization_test",
- size = "medium",
- srcs = ["shuffle_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:iterator_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:training",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "sql_dataset_serialization_test",
- size = "small",
- srcs = ["sql_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/kernel_tests:sql_dataset_op_test_base",
- "//tensorflow/contrib/data/python/ops:readers",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- ],
-)
-
-py_test(
- name = "stats_dataset_serialization_test",
- size = "medium",
- srcs = ["stats_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:stats_ops",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "textline_dataset_serialization_test",
- size = "medium",
- srcs = ["textline_dataset_serialization_test.py"],
- shard_count = 4,
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:readers",
- ],
-)
-
-py_test(
- name = "tf_record_dataset_serialization_test",
- size = "medium",
- srcs = ["tf_record_dataset_serialization_test.py"],
- shard_count = 4,
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:readers",
- ],
-)
-
-py_test(
- name = "unbatch_dataset_serialization_test",
- size = "medium",
- srcs = ["unbatch_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "unique_dataset_serialization_test",
- size = "small",
- srcs = ["unique_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:unique",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "zip_dataset_serialization_test",
- size = "small",
- srcs = ["zip_dataset_serialization_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":dataset_serialization_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py
deleted file mode 100644
index af87d8b608..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py
+++ /dev/null
@@ -1,83 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the BatchDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-
-
-class BatchDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2):
- components = (
- np.arange(tensor_slice_len),
- np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis],
- np.array(multiplier) * np.arange(tensor_slice_len))
-
- return dataset_ops.Dataset.from_tensor_slices(components).batch(batch_size)
-
- def testCore(self):
- tensor_slice_len = 8
- batch_size = 2
- num_outputs = tensor_slice_len // batch_size
- self.run_core_tests(
- lambda: self.build_dataset(15.0, tensor_slice_len, batch_size),
- lambda: self.build_dataset(20.0, tensor_slice_len, batch_size),
- num_outputs)
-
- def _build_dataset_dense_to_sparse(self, components):
- return dataset_ops.Dataset.from_tensor_slices(components).map(
- lambda x: array_ops.fill([x], x)).apply(
- batching.dense_to_sparse_batch(4, [12]))
-
- def testDenseToSparseBatchDatasetCore(self):
- components = np.random.randint(5, size=(40,)).astype(np.int32)
- diff_comp = np.random.randint(2, size=(100,)).astype(np.int32)
-
- num_outputs = len(components) // 4
- self.run_core_tests(lambda: self._build_dataset_dense_to_sparse(components),
- lambda: self._build_dataset_dense_to_sparse(diff_comp),
- num_outputs)
-
- def _sparse(self, i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0]], values=(i * [1]), dense_shape=[1])
-
- def _build_dataset_sparse(self, batch_size=5):
- return dataset_ops.Dataset.range(10).map(self._sparse).batch(batch_size)
-
- def testSparseCore(self):
- self.run_core_tests(self._build_dataset_sparse,
- lambda: self._build_dataset_sparse(2), 2)
-
- def _build_dataset_nested_sparse(self):
- return dataset_ops.Dataset.range(10).map(self._sparse).batch(5).batch(2)
-
- def testNestedSparseCore(self):
- self.run_core_tests(self._build_dataset_nested_sparse, None, 1)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py
deleted file mode 100644
index 1b6059ccbc..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py
+++ /dev/null
@@ -1,253 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the CacheDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-
-from absl.testing import parameterized
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import errors
-from tensorflow.python.platform import test
-
-
-class CacheDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase,
- parameterized.TestCase):
-
- def setUp(self):
- self.range_size = 10
- self.num_repeats = 3
- self.num_outputs = self.range_size * self.num_repeats
- self.cache_file_prefix = 'test'
-
- def make_dataset_fn(self, is_memory):
- if is_memory:
- filename = ''
- else:
- filename = os.path.join(self.get_temp_dir(), self.cache_file_prefix)
-
- def ds_fn():
- return dataset_ops.Dataset.range(self.range_size).cache(filename).repeat(
- self.num_repeats)
-
- return ds_fn
-
- def expected_outputs(self):
- return list(range(self.range_size)) * self.num_repeats
-
- @parameterized.named_parameters(
- ('Memory', True),
- ('File', False),
- )
- def testCheckpointBeforeOneEpoch(self, is_memory):
- ds_fn = self.make_dataset_fn(is_memory)
-
- # Generate 5 entries from iterator and save checkpoint.
- outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False)
- self.assertSequenceEqual(outputs, range(5))
-
- # Restore from checkpoint and produce the rest of the elements from the
- # iterator.
- outputs.extend(
- self.gen_outputs(
- ds_fn, [],
- self.num_outputs - 5,
- ckpt_saved=True,
- verify_exhausted=False))
- self.assertSequenceEqual(outputs, self.expected_outputs())
-
- @parameterized.named_parameters(
- ('Memory', True),
- ('File', False),
- )
- def testCheckpointBeforeOneEpochThenRunFewSteps(self, is_memory):
- ds_fn = self.make_dataset_fn(is_memory)
-
- # Generate 8 entries from iterator but save checkpoint after producing 5.
- outputs = self.gen_outputs(
- ds_fn, [5], 8, verify_exhausted=False, save_checkpoint_at_end=False)
- self.assertSequenceEqual(outputs, range(8))
-
- if is_memory:
- outputs = outputs[:5]
- outputs.extend(
- self.gen_outputs(
- ds_fn, [],
- self.num_outputs - 5,
- ckpt_saved=True,
- verify_exhausted=False))
- self.assertSequenceEqual(outputs, self.expected_outputs())
- else:
- # Restoring from checkpoint and running GetNext should return
- # `AlreadExistsError` now because the lockfile already exists.
- with self.assertRaises(errors.AlreadyExistsError):
- self.gen_outputs(
- ds_fn, [],
- self.num_outputs - 5,
- ckpt_saved=True,
- verify_exhausted=False)
-
- @parameterized.named_parameters(
- ('Memory', True),
- ('File', False),
- )
- def testCheckpointAfterOneEpoch(self, is_memory):
- ds_fn = self.make_dataset_fn(is_memory)
-
- # Generate 15 entries from iterator and save checkpoint.
- outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False)
- self.assertSequenceEqual(outputs, list(range(10)) + list(range(5)))
-
- # Restore from checkpoint and produce the rest of the elements from the
- # iterator.
- outputs.extend(
- self.gen_outputs(
- ds_fn, [],
- self.num_outputs - 15,
- ckpt_saved=True,
- verify_exhausted=False))
- self.assertSequenceEqual(outputs, self.expected_outputs())
-
- @parameterized.named_parameters(
- ('Memory', True),
- ('File', False),
- )
- def testCheckpointAfterOneEpochThenRunFewSteps(self, is_memory):
- ds_fn = self.make_dataset_fn(is_memory)
-
- # Generate 18 entries from iterator but save checkpoint after producing 15.
- outputs = self.gen_outputs(
- ds_fn, [15], 18, verify_exhausted=False, save_checkpoint_at_end=False)
- self.assertSequenceEqual(outputs, list(range(10)) + list(range(8)))
-
- outputs = list(range(10)) + list(range(5)) + self.gen_outputs(
- ds_fn, [],
- self.num_outputs - 15,
- ckpt_saved=True,
- verify_exhausted=False)
- self.assertSequenceEqual(outputs, list(range(10)) * 3)
-
- @parameterized.named_parameters(
- ('Memory', True),
- ('File', False),
- )
- def testCheckpointBeforeOneEpochButRunCompleteEpoch(self, is_memory):
- ds_fn = self.make_dataset_fn(is_memory)
-
- # Generate 13 entries from iterator but save checkpoint after producing 5.
- outputs = self.gen_outputs(
- ds_fn, [5], 13, verify_exhausted=False, save_checkpoint_at_end=False)
- self.assertSequenceEqual(outputs, list(range(10)) + list(range(3)))
-
- # Since we ran for more than one epoch, the cache was completely written.
- # The ckpt was saved when the iterator was in cache-write mode. Test that
- # the iterator falls back to read mode after restoring if the cache has
- # been completely written.
-
- outputs = list(range(5)) + self.gen_outputs(
- ds_fn, [],
- self.num_outputs - 5,
- ckpt_saved=True,
- verify_exhausted=False)
- self.assertSequenceEqual(outputs, list(range(10)) * 3)
-
- @parameterized.named_parameters(
- ('Memory', True),
- ('File', False),
- )
- def testCheckpointUnusedWriterIterator(self, is_memory):
- ds_fn = self.make_dataset_fn(is_memory)
-
- # Checkpoint before get_next is called even once.
- outputs = self.gen_outputs(ds_fn, [], 0, verify_exhausted=False)
- self.assertSequenceEqual(outputs, [])
-
- outputs = self.gen_outputs(
- ds_fn, [], self.num_outputs, ckpt_saved=True, verify_exhausted=False)
- self.assertSequenceEqual(outputs, list(range(10)) * 3)
-
- @parameterized.named_parameters(
- ('Memory', True),
- ('File', False),
- )
- def testCheckpointUnusedMidwayWriterIterator(self, is_memory):
- ds_fn = self.make_dataset_fn(is_memory)
-
- # Produce 5 elements and checkpoint.
- outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False)
- self.assertSequenceEqual(outputs, range(5))
-
- # Restore from checkpoint, then produce no elements and checkpoint.
- outputs.extend(
- self.gen_outputs(ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False))
- self.assertSequenceEqual(outputs, range(5))
-
- # Restore from checkpoint and produce rest of the elements.
- outputs.extend(
- self.gen_outputs(
- ds_fn, [],
- self.num_outputs - 5,
- ckpt_saved=True,
- verify_exhausted=False))
- self.assertSequenceEqual(outputs, list(range(10)) * 3)
-
- @parameterized.named_parameters(
- ('Memory', True),
- ('File', False),
- )
- def testUnusedCheckpointError(self, is_memory):
- ds_fn = self.make_dataset_fn(is_memory)
-
- # Produce 5 elements and save ckpt.
- outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False)
- self.assertSequenceEqual(outputs, range(5))
-
- if is_memory:
- outputs = self.gen_outputs(
- ds_fn, [], self.num_outputs, verify_exhausted=False)
- self.assertSequenceEqual(outputs, self.expected_outputs())
- else:
- # Since the complete cache has not been written, a new iterator which does
- # not restore the checkpoint will throw an error since there is a partial
- # cache shard.
- with self.assertRaises(errors.AlreadyExistsError):
- outputs = self.gen_outputs(
- ds_fn, [], self.num_outputs, verify_exhausted=False)
-
- @parameterized.named_parameters(
- ('Memory', True),
- ('File', False),
- )
- def testIgnoreCheckpointIfCacheWritten(self, is_memory):
- ds_fn = self.make_dataset_fn(is_memory)
-
- # Produce 15 elements and save ckpt. This will write the complete cache.
- outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False)
- self.assertSequenceEqual(outputs, list(range(10)) + list(range(5)))
-
- # Build the iterator again but do not restore from ckpt. Since the cache
- # has already been written we should be able to use it.
- outputs = self.gen_outputs(
- ds_fn, [], self.num_outputs, verify_exhausted=False)
- self.assertSequenceEqual(outputs, list(range(10)) * 3)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py
deleted file mode 100644
index 96f13d75a3..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py
+++ /dev/null
@@ -1,49 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the ConcatenateDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.platform import test
-
-
-class ConcatenateDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_concatenate_dataset(self, var_array):
- input_components = (np.tile(np.array([[1], [2], [3], [4]]), 20),
- np.tile(np.array([[12], [13], [14], [15]]), 4))
- to_concatenate_components = (np.tile(
- np.array([[5], [6], [7], [8], [9]]), 20), var_array)
-
- return dataset_ops.Dataset.from_tensor_slices(input_components).concatenate(
- dataset_ops.Dataset.from_tensor_slices(to_concatenate_components))
-
- def testConcatenateCore(self):
- num_outputs = 9
- array = np.tile(np.array([[16], [17], [18], [19], [20]]), 15)
- diff_array = np.array([[1], [2], [3], [4], [5]])
- self.run_core_tests(lambda: self._build_concatenate_dataset(array),
- lambda: self._build_concatenate_dataset(diff_array),
- num_outputs)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py
deleted file mode 100644
index 247f2046ea..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the CsvDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import gzip
-import os
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import readers
-from tensorflow.python.platform import test
-
-
-class CsvDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def setUp(self):
- self._num_cols = 7
- self._num_rows = 10
- self._num_epochs = 14
- self._num_outputs = self._num_rows * self._num_epochs
-
- inputs = [
- ",".join(str(self._num_cols * j + i)
- for i in range(self._num_cols))
- for j in range(self._num_rows)
- ]
- contents = "\n".join(inputs).encode("utf-8")
-
- self._filename = os.path.join(self.get_temp_dir(), "file.csv")
- self._compressed = os.path.join(self.get_temp_dir(),
- "comp.csv") # GZip compressed
-
- with open(self._filename, "wb") as f:
- f.write(contents)
- with gzip.GzipFile(self._compressed, "wb") as f:
- f.write(contents)
-
- def ds_func(self, **kwargs):
- compression_type = kwargs.get("compression_type", None)
- if compression_type == "GZIP":
- filename = self._compressed
- elif compression_type is None:
- filename = self._filename
- else:
- raise ValueError("Invalid compression type:", compression_type)
-
- return readers.CsvDataset(filename, **kwargs).repeat(self._num_epochs)
-
- def testSerializationCore(self):
- defs = [[0]] * self._num_cols
- self.run_core_tests(
- lambda: self.ds_func(record_defaults=defs, buffer_size=2),
- lambda: self.ds_func(record_defaults=defs, buffer_size=12),
- self._num_outputs)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py
deleted file mode 100644
index 2139b5c33d..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py
+++ /dev/null
@@ -1,95 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the dataset constructors serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.platform import test
-
-
-class FromTensorsSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_tensor_dataset(self, variable_array):
- components = (variable_array, np.array([1, 2, 3]), np.array(37.0))
-
- return dataset_ops.Dataset.from_tensors(components)
-
- def testFromTensorsCore(self):
- # Equal length components
- arr = np.array(1)
- num_outputs = 1
- diff_arr = np.array(2)
- self.run_core_tests(lambda: self._build_tensor_dataset(arr),
- lambda: self._build_tensor_dataset(diff_arr),
- num_outputs)
-
-
-class FromTensorSlicesSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_tensor_slices_dataset(self, components):
- return dataset_ops.Dataset.from_tensor_slices(components)
-
- def testFromTensorSlicesCore(self):
- # Equal length components
- components = (np.tile(np.array([[1], [2], [3], [4]]), 20),
- np.tile(np.array([[12], [13], [14], [15]]), 22),
- np.array([37.0, 38.0, 39.0, 40.0]))
-
- diff_comp = (np.tile(np.array([[1], [2], [3], [4]]), 20),
- np.tile(np.array([[5], [6], [7], [8]]), 22),
- np.array([1.0, 2.0, 3.0, 4.0]))
-
- dict_components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]}
-
- self.run_core_tests(lambda: self._build_tensor_slices_dataset(components),
- lambda: self._build_tensor_slices_dataset(diff_comp), 4)
- self.run_core_tests(
- lambda: self._build_tensor_slices_dataset(dict_components), None, 3)
-
-
-class FromSparseTensorSlicesSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_sparse_tensor_slice_dataset(self, slices):
- indices = np.array(
- [[i, j] for i in range(len(slices)) for j in range(len(slices[i]))],
- dtype=np.int64)
- values = np.array([val for s in slices for val in s], dtype=np.float64)
- dense_shape = np.array(
- [len(slices), max(len(s) for s in slices) + 1], dtype=np.int64)
- sparse_components = sparse_tensor.SparseTensor(indices, values, dense_shape)
- return dataset_ops.Dataset.from_sparse_tensor_slices(sparse_components)
-
- def testFromSparseTensorSlicesCore(self):
- slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []]
- diff_slices = [[1., 2.], [2.], [2., 3., 4.], [], [], []]
-
- self.run_core_tests(
- lambda: self._build_sparse_tensor_slice_dataset(slices),
- lambda: self._build_sparse_tensor_slice_dataset(diff_slices),
- 9,
- sparse_tensors=True)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py
deleted file mode 100644
index 595cecef4d..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py
+++ /dev/null
@@ -1,692 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Base class for testing serializable datasets."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
-from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import lookup_ops
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import gfile
-from tensorflow.python.platform import test
-from tensorflow.python.training import checkpoint_management
-from tensorflow.python.training import saver as saver_lib
-from tensorflow.python.util import nest
-
-
-def remove_variants(get_next_op):
- # TODO(b/72408568): Remove this once session.run can get
- # variant tensors.
- """Remove variants from a nest structure, so sess.run will execute."""
-
- def _remove_variant(x):
- if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant:
- return ()
- else:
- return x
-
- return nest.map_structure(_remove_variant, get_next_op)
-
-
-class DatasetSerializationTestBase(test.TestCase):
- """Base class for testing serializable datasets."""
-
- def tearDown(self):
- self._delete_ckpt()
-
- # TODO(b/72657739): Remove sparse_tensor argument, which is to test the
- # (deprecated) saveable `SparseTensorSliceDataset`, once the API
- # `from_sparse_tensor_slices()`and related tests are deleted.
- def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False):
- """Runs the core tests.
-
- Args:
- ds_fn1: 0-argument function that returns a Dataset.
- ds_fn2: 0-argument function that returns a Dataset different from
- ds_fn1. If None, verify_restore_in_modified_graph test is not run.
- num_outputs: Total number of outputs expected from this Dataset.
- sparse_tensors: Whether dataset is built from SparseTensor(s).
-
- Raises:
- AssertionError if any test fails.
- """
- self.verify_unused_iterator(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- self.verify_fully_used_iterator(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- self.verify_exhausted_iterator(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- self.verify_init_before_restore(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- self.verify_multiple_breaks(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- self.verify_reset_restored_iterator(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- self.verify_restore_in_empty_graph(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- if ds_fn2:
- self.verify_restore_in_modified_graph(
- ds_fn1, ds_fn2, num_outputs, sparse_tensors=sparse_tensors)
-
- def verify_unused_iterator(self,
- ds_fn,
- num_outputs,
- sparse_tensors=False,
- verify_exhausted=True):
- """Verifies that saving and restoring an unused iterator works.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- self.verify_run_with_breaks(
- ds_fn, [0],
- num_outputs,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- def verify_fully_used_iterator(self, ds_fn, num_outputs,
- sparse_tensors=False):
- """Verifies that saving and restoring a fully used iterator works.
-
- Note that this only checks saving and restoring an iterator from which
- `num_outputs` items have been produced but does not check for an
- exhausted iterator, i.e., one from which an OutOfRange error has been
- returned.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- sparse_tensors: See `run_core_tests`.
-
- Raises:
- AssertionError if test fails.
- """
- self.verify_run_with_breaks(
- ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors)
-
- def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False):
- """Verifies that saving and restoring an exhausted iterator works.
-
- An exhausted iterator is one which has returned an OutOfRange error.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- sparse_tensors: See `run_core_tests`.
-
- Raises:
- AssertionError if any test fails.
- """
- self.gen_outputs(
- ds_fn, [],
- num_outputs,
- verify_exhausted=True,
- sparse_tensors=sparse_tensors)
- actual = self.gen_outputs(
- ds_fn, [],
- 0,
- ckpt_saved=True,
- verify_exhausted=True,
- sparse_tensors=sparse_tensors)
- self.assertEqual(len(actual), 0)
-
- def verify_init_before_restore(self,
- ds_fn,
- num_outputs,
- sparse_tensors=False,
- verify_exhausted=True):
- """Verifies that restoring into an already initialized iterator works.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- self.verify_run_with_breaks(
- ds_fn,
- self.gen_break_points(num_outputs),
- num_outputs,
- init_before_restore=True,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- def verify_multiple_breaks(self,
- ds_fn,
- num_outputs,
- num_breaks=10,
- sparse_tensors=False,
- verify_exhausted=True):
- """Attempts to save/restore at multiple break points.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- num_breaks: The number of break points. These are uniformly spread in
- [0, num_outputs] both inclusive.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- self.verify_run_with_breaks(
- ds_fn,
- self.gen_break_points(num_outputs, num_breaks),
- num_outputs,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- def verify_reset_restored_iterator(self,
- ds_fn,
- num_outputs,
- break_point=None,
- sparse_tensors=False,
- verify_exhausted=True):
- """Attempts to re-initialize a restored iterator.
-
- This is useful when restoring a training checkpoint during validation.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- break_point: Break point. Optional. Defaults to num_outputs/2.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- break_point = num_outputs // 2 if not break_point else break_point
-
- # Collect ground truth containing all outputs.
- expected = self.gen_outputs(
- ds_fn, [],
- num_outputs,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- # Skip some items and save checkpoint.
- self.gen_outputs(
- ds_fn, [],
- break_point,
- sparse_tensors=sparse_tensors,
- verify_exhausted=False)
-
- actual = []
- # Restore from checkpoint and then run init_op.
- with ops.Graph().as_default() as g:
- saver = self._import_meta_graph()
- init_op, get_next_op = self._get_iterator_ops_from_collection(
- ds_fn, sparse_tensors=sparse_tensors)
- get_next_op = remove_variants(get_next_op)
- with self.session(graph=g) as sess:
- self._restore(saver, sess)
- self._initialize(init_op, sess)
- for _ in range(num_outputs):
- actual.append(sess.run(get_next_op))
- if verify_exhausted:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
- self.match(expected, actual)
-
- def verify_restore_in_modified_graph(self,
- ds_fn1,
- ds_fn2,
- num_outputs,
- break_point=None,
- sparse_tensors=False,
- verify_exhausted=True):
- """Attempts to restore an iterator in a modified graph.
-
- Builds an input pipeline using ds_fn1, runs it for `break_point` steps
- and saves a checkpoint. Then builds a new graph using ds_fn2, restores
- the checkpoint from ds_fn1 and verifies that the restore is successful.
-
- Args:
- ds_fn1: See `run_core_tests`.
- ds_fn2: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- break_point: Break point. Optional. Defaults to num_outputs/2.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- break_point = num_outputs // 2 if not break_point else break_point
-
- # Skip `break_point` items and store the remaining produced from ds_fn1
- # in `expected`.
- self.gen_outputs(
- ds_fn1, [],
- break_point,
- sparse_tensors=sparse_tensors,
- verify_exhausted=False)
- expected = self.gen_outputs(
- ds_fn1, [],
- num_outputs - break_point,
- ckpt_saved=True,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- # Generate `break_point` items from ds_fn1 and save checkpoint.
- self.gen_outputs(
- ds_fn1, [],
- break_point,
- sparse_tensors=sparse_tensors,
- verify_exhausted=False)
-
- actual = []
- # Build graph for ds_fn2 but load checkpoint for ds_fn1.
- with ops.Graph().as_default() as g:
- _, get_next_op, saver = self._build_graph(
- ds_fn2, sparse_tensors=sparse_tensors)
- get_next_op = remove_variants(get_next_op)
- with self.session(graph=g) as sess:
- self._restore(saver, sess)
- for _ in range(num_outputs - break_point):
- actual.append(sess.run(get_next_op))
- if verify_exhausted:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
- self.match(expected, actual)
-
- def verify_restore_in_empty_graph(self,
- ds_fn,
- num_outputs,
- break_point=None,
- sparse_tensors=False,
- verify_exhausted=True):
- """Attempts to restore an iterator in an empty graph.
-
- Builds an input pipeline using ds_fn, runs it for `break_point` steps
- and saves a checkpoint. Then builds a new empty graph, restores
- the checkpoint from ds_fn and verifies that the restore is successful.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- break_point: Break point. Optional. Defaults to num_outputs/2.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- break_point = num_outputs // 2 if not break_point else break_point
-
- # Skip `break_point` items and store the remaining produced from ds_fn
- # in `expected`.
- self.gen_outputs(
- ds_fn, [],
- break_point,
- sparse_tensors=sparse_tensors,
- verify_exhausted=False)
- expected = self.gen_outputs(
- ds_fn, [],
- num_outputs - break_point,
- ckpt_saved=True,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- # Generate `break_point` items from ds_fn and save checkpoint.
- self.gen_outputs(
- ds_fn, [],
- break_point,
- sparse_tensors=sparse_tensors,
- verify_exhausted=False)
-
- actual = []
- # Build an empty graph but load checkpoint for ds_fn.
- with ops.Graph().as_default() as g:
- get_next_op, saver = self._build_empty_graph(
- ds_fn, sparse_tensors=sparse_tensors)
- get_next_op = remove_variants(get_next_op)
- with self.session(graph=g) as sess:
- self._restore(saver, sess)
- for _ in range(num_outputs - break_point):
- actual.append(sess.run(get_next_op))
- if verify_exhausted:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
- self.match(expected, actual)
-
- def verify_error_on_save(self,
- ds_fn,
- num_outputs,
- error,
- break_point=None,
- sparse_tensors=False):
- """Attempts to save a non-saveable iterator.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- error: Declared error when trying to save iterator.
- break_point: Break point. Optional. Defaults to num_outputs/2.
- sparse_tensors: See `run_core_tests`.
-
- Raises:
- AssertionError if any test fails.
- """
-
- break_point = num_outputs // 2 if not break_point else break_point
- with ops.Graph().as_default() as g:
- init_op, get_next_op, saver = self._build_graph(
- ds_fn, sparse_tensors=sparse_tensors)
- get_next_op = remove_variants(get_next_op)
- with self.session(graph=g) as sess:
- self._initialize(init_op, sess)
- for _ in range(break_point):
- sess.run(get_next_op)
- with self.assertRaises(error):
- self._save(sess, saver)
-
- def verify_run_with_breaks(self,
- ds_fn,
- break_points,
- num_outputs,
- init_before_restore=False,
- sparse_tensors=False,
- verify_exhausted=True):
- """Verifies that ds_fn() produces the same outputs with and without breaks.
-
- 1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
- *without* stopping at break points.
- 2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
- with stopping at break points.
-
- Deep matches outputs from 1 and 2.
-
- Args:
- ds_fn: See `gen_outputs`.
- break_points: See `gen_outputs`.
- num_outputs: See `gen_outputs`.
- init_before_restore: See `gen_outputs`.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- expected = self.gen_outputs(
- ds_fn, [],
- num_outputs,
- init_before_restore=init_before_restore,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- actual = self.gen_outputs(
- ds_fn,
- break_points,
- num_outputs,
- init_before_restore=init_before_restore,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- self.match(expected, actual)
-
- def gen_outputs(self,
- ds_fn,
- break_points,
- num_outputs,
- ckpt_saved=False,
- init_before_restore=False,
- sparse_tensors=False,
- verify_exhausted=True,
- save_checkpoint_at_end=True):
- """Generates elements from input dataset while stopping at break points.
-
- Produces `num_outputs` outputs and saves the state of the iterator in the
- Saver checkpoint.
-
- Args:
- ds_fn: 0-argument function that returns the dataset.
- break_points: A list of integers. For each `break_point` in
- `break_points`, we produce outputs till `break_point` number of items
- have been produced and then checkpoint the state. The current graph
- and session are destroyed and a new graph and session are used to
- produce outputs till next checkpoint or till `num_outputs` elements
- have been produced. `break_point` must be <= `num_outputs`.
- num_outputs: The total number of outputs to produce from the iterator.
- ckpt_saved: Whether a checkpoint already exists. If False, we build the
- graph from ds_fn.
- init_before_restore: Whether init should be called before saver.restore.
- This is just so that we can verify that restoring an already initialized
- iterator works.
- sparse_tensors: Whether dataset is built from SparseTensor(s).
- verify_exhausted: Whether to verify that the iterator has been exhausted
- after producing `num_outputs` elements.
- save_checkpoint_at_end: Whether to save a checkpoint after producing all
- outputs. If False, checkpoints are saved each break point but not at the
- end. Note that checkpoints overwrite each other so there is always only
- a single checkpoint available. Defaults to True.
-
- Returns:
- A list of `num_outputs` items.
- """
- outputs = []
-
- def get_ops():
- if ckpt_saved:
- saver = self._import_meta_graph()
- init_op, get_next_op = self._get_iterator_ops_from_collection(
- ds_fn, sparse_tensors=sparse_tensors)
- else:
- init_op, get_next_op, saver = self._build_graph(
- ds_fn, sparse_tensors=sparse_tensors)
- return init_op, get_next_op, saver
-
- for i in range(len(break_points) + 1):
- with ops.Graph().as_default() as g:
- init_op, get_next_op, saver = get_ops()
- get_next_op = remove_variants(get_next_op)
- with self.session(graph=g) as sess:
- if ckpt_saved:
- if init_before_restore:
- self._initialize(init_op, sess)
- self._restore(saver, sess)
- else:
- self._initialize(init_op, sess)
- start = break_points[i - 1] if i > 0 else 0
- end = break_points[i] if i < len(break_points) else num_outputs
- num_iters = end - start
- for _ in range(num_iters):
- outputs.append(sess.run(get_next_op))
- if i == len(break_points) and verify_exhausted:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
- if save_checkpoint_at_end or i < len(break_points):
- self._save(sess, saver)
- ckpt_saved = True
-
- return outputs
-
- def match(self, expected, actual):
- """Matches nested structures.
-
- Recursively matches shape and values of `expected` and `actual`.
- Handles scalars, numpy arrays and other python sequence containers
- e.g. list, dict.
-
- Args:
- expected: Nested structure 1.
- actual: Nested structure 2.
-
- Raises:
- AssertionError if matching fails.
- """
- if isinstance(expected, np.ndarray):
- expected = expected.tolist()
- if isinstance(actual, np.ndarray):
- actual = actual.tolist()
- self.assertEqual(type(expected), type(actual))
-
- if nest.is_sequence(expected):
- self.assertEqual(len(expected), len(actual))
- if isinstance(expected, dict):
- for key1, key2 in zip(sorted(expected), sorted(actual)):
- self.assertEqual(key1, key2)
- self.match(expected[key1], actual[key2])
- else:
- for item1, item2 in zip(expected, actual):
- self.match(item1, item2)
- else:
- self.assertEqual(expected, actual)
-
- def does_not_match(self, expected, actual):
- with self.assertRaises(AssertionError):
- self.match(expected, actual)
-
- def gen_break_points(self, num_outputs, num_samples=10):
- """Generates `num_samples` breaks points in [0, num_outputs]."""
- return np.linspace(0, num_outputs, num_samples, dtype=int)
-
- def _build_graph(self, ds_fn, sparse_tensors=False):
- iterator = ds_fn().make_initializable_iterator()
-
- saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
- ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
- init_op = iterator.initializer
- if sparse_tensors:
- get_next = sparse_tensor.SparseTensor(*iterator.get_next())
- else:
- get_next = iterator.get_next()
- self._add_iterator_ops_to_collection(init_op, get_next, ds_fn,
- sparse_tensors)
- saver = saver_lib.Saver(allow_empty=True)
- return init_op, get_next, saver
-
- def _build_empty_graph(self, ds_fn, sparse_tensors=False):
- iterator = iterator_ops.Iterator.from_structure(
- self._get_output_types(ds_fn),
- output_shapes=self._get_output_shapes(ds_fn),
- output_classes=self._get_output_classes(ds_fn))
- saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
- ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
- if sparse_tensors:
- get_next = sparse_tensor.SparseTensor(*iterator.get_next())
- else:
- get_next = iterator.get_next()
- saver = saver_lib.Saver(allow_empty=True)
- return get_next, saver
-
- def _add_iterator_ops_to_collection(self,
- init_op,
- get_next,
- ds_fn,
- sparse_tensors=False):
- ops.add_to_collection("iterator_ops", init_op)
- # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections
- # do not support tuples we flatten the tensors and restore the shape in
- # `_get_iterator_ops_from_collection`.
- if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
- ops.add_to_collection("iterator_ops", get_next.indices)
- ops.add_to_collection("iterator_ops", get_next.values)
- ops.add_to_collection("iterator_ops", get_next.dense_shape)
- return
-
- get_next_list = nest.flatten(get_next)
- for i, output_class in enumerate(
- nest.flatten(self._get_output_classes(ds_fn))):
- if output_class is sparse_tensor.SparseTensor:
- ops.add_to_collection("iterator_ops", get_next_list[i].indices)
- ops.add_to_collection("iterator_ops", get_next_list[i].values)
- ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape)
- else:
- ops.add_to_collection("iterator_ops", get_next_list[i])
-
- def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False):
- all_ops = ops.get_collection("iterator_ops")
- if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
- init_op, indices, values, dense_shape = all_ops
- return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape)
- get_next_list = []
- i = 1
- for output_class in nest.flatten(self._get_output_classes(ds_fn)):
- if output_class is sparse_tensor.SparseTensor:
- indices, values, dense_shape = all_ops[i:i + 3]
- i += 3
- get_next_list.append(
- sparse_tensor.SparseTensor(indices, values, dense_shape))
- else:
- get_next_list.append(all_ops[i])
- i += 1
- return all_ops[0], nest.pack_sequence_as(
- self._get_output_types(ds_fn), get_next_list)
-
- def _get_output_types(self, ds_fn):
- with ops.Graph().as_default():
- return ds_fn().output_types
-
- def _get_output_shapes(self, ds_fn):
- with ops.Graph().as_default():
- return ds_fn().output_shapes
-
- def _get_output_classes(self, ds_fn):
- with ops.Graph().as_default():
- return ds_fn().output_classes
-
- def _ckpt_path(self):
- return os.path.join(self.get_temp_dir(), "iterator")
-
- def _latest_ckpt(self):
- return checkpoint_management.latest_checkpoint(self.get_temp_dir())
-
- def _save(self, sess, saver):
- saver.save(sess, self._ckpt_path())
-
- def _restore(self, saver, sess):
- sess.run(lookup_ops.tables_initializer())
- saver.restore(sess, self._latest_ckpt())
-
- def _initialize(self, init_op, sess):
- sess.run(variables.global_variables_initializer())
- sess.run(lookup_ops.tables_initializer())
- sess.run(init_op)
-
- def _import_meta_graph(self):
- meta_file_path = self._ckpt_path() + ".meta"
- return saver_lib.import_meta_graph(meta_file_path)
-
- def _delete_ckpt(self):
- # Remove all checkpoint files.
- prefix = self._ckpt_path()
- pattern = prefix + "*"
- files = gfile.Glob(pattern)
- map(gfile.Remove, files)
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py
deleted file mode 100644
index 7c170078a1..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py
+++ /dev/null
@@ -1,71 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the FilterDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class FilterDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_filter_range_graph(self, div):
- return dataset_ops.Dataset.range(100).filter(
- lambda x: math_ops.not_equal(math_ops.mod(x, div), 2))
-
- def testFilterCore(self):
- div = 3
- num_outputs = np.sum([x % 3 != 2 for x in range(100)])
- self.run_core_tests(lambda: self._build_filter_range_graph(div),
- lambda: self._build_filter_range_graph(div * 2),
- num_outputs)
-
- def _build_filter_dict_graph(self):
- return dataset_ops.Dataset.range(10).map(
- lambda x: {"foo": x * 2, "bar": x ** 2}).filter(
- lambda d: math_ops.equal(d["bar"] % 2, 0)).map(
- lambda d: d["foo"] + d["bar"])
-
- def testFilterDictCore(self):
- num_outputs = np.sum([(x**2) % 2 == 0 for x in range(10)])
- self.run_core_tests(self._build_filter_dict_graph, None, num_outputs)
-
- def _build_sparse_filter(self):
-
- def _map_fn(i):
- return sparse_tensor.SparseTensor(
- indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i
-
- def _filter_fn(_, i):
- return math_ops.equal(i % 2, 0)
-
- return dataset_ops.Dataset.range(10).map(_map_fn).filter(_filter_fn).map(
- lambda x, i: x)
-
- def testSparseCore(self):
- num_outputs = 5
- self.run_core_tests(self._build_sparse_filter, None, num_outputs)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py
deleted file mode 100644
index 34392d88d4..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py
+++ /dev/null
@@ -1,45 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the FixedLengthRecordDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import readers as core_readers
-from tensorflow.python.platform import test
-
-
-class FixedLengthRecordDatasetSerializationTest(
- reader_dataset_ops_test_base.FixedLengthRecordDatasetTestBase,
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_iterator_graph(self, num_epochs, compression_type=None):
- filenames = self._createFiles()
- return core_readers.FixedLengthRecordDataset(
- filenames, self._record_bytes, self._header_bytes,
- self._footer_bytes).repeat(num_epochs)
-
- def testFixedLengthRecordCore(self):
- num_epochs = 5
- num_outputs = num_epochs * self._num_files * self._num_records
- self.run_core_tests(lambda: self._build_iterator_graph(num_epochs),
- lambda: self._build_iterator_graph(num_epochs * 2),
- num_outputs)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py
deleted file mode 100644
index 16051ffd3f..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py
+++ /dev/null
@@ -1,122 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the FlatMapDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import function
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import sparse_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.platform import test
-
-
-class FlatMapDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def testCore(self):
- # Complicated way of saying range(start, start+25).
- def build_ds(start):
-
- def map_fn(x):
- return dataset_ops.Dataset.range(x, x + 5)
-
- return dataset_ops.Dataset.range(start, start + 5 * 5, 5).flat_map(map_fn)
-
- self.run_core_tests(lambda: build_ds(0), lambda: build_ds(10), 25)
-
- def testMapThenFlatMap(self):
-
- def build_ds():
-
- def flat_map_fn(_):
-
- def map_fn(y):
- return 10 * math_ops.to_int32(y)
-
- return dataset_ops.Dataset.range(100).map(map_fn)
-
- return dataset_ops.Dataset.range(5).flat_map(flat_map_fn)
-
- self.run_core_tests(build_ds, None, 500)
-
- def testCaptureDefunInMapFn(self):
-
- def build_ds():
-
- def map_fn(x):
-
- @function.Defun(dtypes.int64)
- def defun_fn(x):
- return constant_op.constant(1000) + math_ops.to_int32(x)
-
- return dataset_ops.Dataset.from_tensor_slices([defun_fn(x)])
-
- return dataset_ops.Dataset.range(100).flat_map(map_fn)
-
- self.run_core_tests(build_ds, None, 100)
-
- def testDisallowVariableCapture(self):
-
- def build_ds():
- test_var = variable_scope.get_variable(
- name="test_var", shape=(), use_resource=True)
- return dataset_ops.Dataset.range(5).flat_map(
- lambda _: dataset_ops.Dataset.from_tensor_slices([test_var]))
-
- self.verify_error_on_save(build_ds, 5, errors.InvalidArgumentError)
-
- def testDisallowCapturingStatefulOps(self):
-
- def build_ds():
-
- def flat_map_fn(_):
-
- def map_fn(x):
- return random_ops.random_uniform(
- (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x)
-
- return dataset_ops.Dataset.range(100).map(map_fn)
-
- return dataset_ops.Dataset.range(5).flat_map(flat_map_fn)
-
- self.verify_error_on_save(build_ds, 500, errors.InvalidArgumentError)
-
- def testSparseCore(self):
-
- def _map_fn(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
-
- def _flat_map_fn(x):
- return dataset_ops.Dataset.from_tensor_slices(
- sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
-
- def _build_ds():
- return dataset_ops.Dataset.range(10).map(_map_fn).flat_map(_flat_map_fn)
-
- self.run_core_tests(_build_ds, None, 20)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py
deleted file mode 100644
index 571e0899bb..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py
+++ /dev/null
@@ -1,61 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the GroupByReducer serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import grouping
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.platform import test
-
-
-class GroupByReducerSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_dataset(self, components):
- reducer = grouping.Reducer(
- init_func=lambda _: np.int64(0),
- reduce_func=lambda x, y: x + y,
- finalize_func=lambda x: x)
-
- return dataset_ops.Dataset.from_tensor_slices(components).apply(
- grouping.group_by_reducer(lambda x: x % 5, reducer))
-
- def testCoreGroupByReducer(self):
- components = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64)
- self.verify_unused_iterator(
- lambda: self._build_dataset(components), 5, verify_exhausted=True)
- self.verify_init_before_restore(
- lambda: self._build_dataset(components), 5, verify_exhausted=True)
- self.verify_multiple_breaks(
- lambda: self._build_dataset(components), 5, verify_exhausted=True)
- self.verify_reset_restored_iterator(
- lambda: self._build_dataset(components), 5, verify_exhausted=True)
- self.verify_restore_in_empty_graph(
- lambda: self._build_dataset(components), 5, verify_exhausted=True)
- diff_components = np.array([5, 4, 3, 2, 1, 0], dtype=np.int64)
- self.verify_restore_in_modified_graph(
- lambda: self._build_dataset(components),
- lambda: self._build_dataset(diff_components),
- 5,
- verify_exhausted=True)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py
deleted file mode 100644
index f86af4084e..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py
+++ /dev/null
@@ -1,57 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the GroupByWindow serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import grouping
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.platform import test
-
-
-class GroupByWindowSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_dataset(self, components):
- return dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply(
- grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), 4))
-
- def testCoreGroupByWindow(self):
- components = np.array(
- [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
- self.verify_unused_iterator(
- lambda: self._build_dataset(components), 12, verify_exhausted=False)
- self.verify_init_before_restore(
- lambda: self._build_dataset(components), 12, verify_exhausted=False)
- self.verify_multiple_breaks(
- lambda: self._build_dataset(components), 12, verify_exhausted=False)
- self.verify_reset_restored_iterator(
- lambda: self._build_dataset(components), 12, verify_exhausted=False)
- self.verify_restore_in_empty_graph(
- lambda: self._build_dataset(components), 12, verify_exhausted=False)
- diff_components = np.array([0, 0, 0, 1, 1, 1], dtype=np.int64)
- self.verify_restore_in_modified_graph(
- lambda: self._build_dataset(components),
- lambda: self._build_dataset(diff_components),
- 12,
- verify_exhausted=False)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py
deleted file mode 100644
index 65ae9923b8..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py
+++ /dev/null
@@ -1,46 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the IgnoreErrors input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import error_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-
-
-class IgnoreErrorsSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_ds(self, components):
- return dataset_ops.Dataset.from_tensor_slices(components).map(
- lambda x: array_ops.check_numerics(x, "message")).apply(
- error_ops.ignore_errors())
-
- def testIgnoreErrorsCore(self):
- components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
- diff_components = np.array([1., 2., 3., np.nan]).astype(np.float32)
- num_outputs = 4
- self.run_core_tests(lambda: self._build_ds(components),
- lambda: self._build_ds(diff_components), num_outputs)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
deleted file mode 100644
index 243f6405a1..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
+++ /dev/null
@@ -1,83 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the InterleaveDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from absl.testing import parameterized
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import sparse_ops
-from tensorflow.python.platform import test
-
-
-class InterleaveDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase,
- parameterized.TestCase):
-
- def _build_iterator_graph(self, input_values, cycle_length, block_length,
- num_parallel_calls):
- repeat_count = 2
- return dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
- repeat_count).interleave(
- lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
- cycle_length, block_length, num_parallel_calls)
-
- @parameterized.named_parameters(
- ("1", 2, 3, None),
- ("2", 2, 3, 1),
- ("3", 2, 3, 2),
- ("4", 1, 3, None),
- ("5", 1, 3, 1),
- ("6", 2, 1, None),
- ("7", 2, 1, 1),
- ("8", 2, 1, 2),
- )
- def testSerializationCore(self, cycle_length, block_length,
- num_parallel_calls):
- input_values = np.array([4, 5, 6], dtype=np.int64)
- num_outputs = np.sum(input_values) * 2
- # pylint: disable=g-long-lambda
- self.run_core_tests(
- lambda: self._build_iterator_graph(
- input_values, cycle_length, block_length, num_parallel_calls),
- lambda: self._build_iterator_graph(
- input_values, cycle_length * 2, block_length, num_parallel_calls),
- num_outputs)
- # pylint: enable=g-long-lambda
-
- def testSparseCore(self):
-
- def _map_fn(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
-
- def _interleave_fn(x):
- return dataset_ops.Dataset.from_tensor_slices(
- sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
-
- def _build_dataset():
- return dataset_ops.Dataset.range(10).map(_map_fn).interleave(
- _interleave_fn, cycle_length=1)
-
- self.run_core_tests(_build_dataset, None, 20)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py
deleted file mode 100644
index c9cd211328..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py
+++ /dev/null
@@ -1,88 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the MapAndBatchDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import math
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class MapAndBatchDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def testNumParallelBatches(self):
- range_size = 11
- num_repeats = 2
- batch_size = 5
- total_outputs = range_size * num_repeats
- num_outputs_drop_remainder = total_outputs // batch_size
- num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size))
- num_parallel_batches = 2
-
- def build_ds(range_start, drop_remainder=False):
-
- def _map_fn(x):
- return math_ops.square(x)
-
- return dataset_ops.Dataset.range(
- range_start, range_start + range_size).repeat(num_repeats).apply(
- batching.map_and_batch(
- map_func=_map_fn,
- batch_size=batch_size,
- num_parallel_batches=num_parallel_batches,
- drop_remainder=drop_remainder))
-
- self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15),
- num_outputs_keep_remainder)
- self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True),
- num_outputs_drop_remainder)
-
- def testNumParallelCalls(self):
- range_size = 11
- num_repeats = 2
- batch_size = 5
- total_outputs = range_size * num_repeats
- num_outputs_drop_remainder = total_outputs // batch_size
- num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size))
- num_parallel_calls = 7
-
- def build_ds(range_start, drop_remainder=False):
-
- def _map_fn(x):
- return math_ops.square(x)
-
- return dataset_ops.Dataset.range(
- range_start, range_start + range_size).repeat(num_repeats).apply(
- batching.map_and_batch(
- map_func=_map_fn,
- batch_size=batch_size,
- num_parallel_calls=num_parallel_calls,
- drop_remainder=drop_remainder))
-
- self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15),
- num_outputs_keep_remainder)
- self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True),
- num_outputs_drop_remainder)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py
deleted file mode 100644
index ab783e5cce..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py
+++ /dev/null
@@ -1,140 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the MapDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import function
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.platform import test
-
-
-class MapDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def setUp(self):
- self._tensor_slice_len = 7
- self._num_epochs = 14
- self._num_outputs = self._tensor_slice_len * self._num_epochs
-
- def _build_ds(self, multiplier=37.0):
- components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) *
- np.arange(self._tensor_slice_len)[:, np.newaxis],
- np.array(multiplier) * np.arange(self._tensor_slice_len))
-
- def _map_fn(x, y, z):
- return math_ops.square(x), math_ops.square(y), math_ops.square(z)
-
- return (
- dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
- .repeat(self._num_epochs))
-
- def testSaveRestoreCore(self):
- self.run_core_tests(
- self._build_ds,
- lambda: self._build_ds(multiplier=15.0),
- self._num_outputs)
-
- def testSaveStatefulFunction(self):
-
- def _build_ds():
-
- def _map_fn(x):
- return random_ops.random_uniform(
- (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x)
-
- return dataset_ops.Dataset.range(100).map(_map_fn)
-
- self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
-
- def testCaptureVariableInMapFn(self):
-
- def _build_ds():
- counter_var = variable_scope.get_variable(
- "counter", (), dtypes.int32, use_resource=True)
- return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
- lambda _: counter_var.assign_add(1)))
-
- self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
-
- def testCaptureConstantInMapFn(self):
-
- def _build_ds():
- constant_var = constant_op.constant(5)
- return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
- lambda x: x + constant_var))
-
- self.run_core_tests(_build_ds, None, 10)
-
- def testCaptureDefunInMapFn(self):
- num_outputs = 100
-
- def _build_ds():
-
- @function.Defun(dtypes.int64)
- def defun_fn(x):
- return constant_op.constant(1000) + math_ops.to_int32(x)
-
- return dataset_ops.Dataset.range(num_outputs).map(defun_fn)
-
- self.run_core_tests(_build_ds, None, num_outputs)
-
- def testBuildDefunInMapFn(self):
- num_outputs = 100
-
- def _build_ds():
-
- @function.Defun(dtypes.int64)
- def defun_fn(x):
-
- @function.Defun(dtypes.int32)
- def defun_fn_deep(x):
- return constant_op.constant(1000) + math_ops.to_int32(x)
-
- return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x))
-
- return dataset_ops.Dataset.range(num_outputs).map(defun_fn)
-
- self.run_core_tests(_build_ds, None, num_outputs)
-
- def testSparseCore(self):
-
- def _sparse(i):
- return sparse_tensor.SparseTensorValue(
- indices=np.array([[0, 0]]),
- values=(i * np.array([1])),
- dense_shape=np.array([1, 1]))
-
- def _build_ds(num_outputs):
- return dataset_ops.Dataset.range(num_outputs).map(_sparse)
-
- num_outputs = 10
- self.run_core_tests(lambda: _build_ds(num_outputs),
- lambda: _build_ds(int(num_outputs / 2)), num_outputs)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py
deleted file mode 100644
index 9ac42a461a..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the PaddedBatchDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import string_ops
-from tensorflow.python.platform import test
-
-
-class PaddedBatchDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def testPaddedBatch(self):
-
- def build_dataset(seq_lens):
- return dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
- lambda x: array_ops.fill([x], x)).padded_batch(
- 4, padded_shapes=[-1])
-
- seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32)
- seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32)
- self.run_core_tests(lambda: build_dataset(seq_lens1),
- lambda: build_dataset(seq_lens2), 8)
-
- def testPaddedBatchNonDefaultPadding(self):
-
- def build_dataset(seq_lens):
-
- def fill_tuple(x):
- filled = array_ops.fill([x], x)
- return (filled, string_ops.as_string(filled))
-
- padded_shape = [-1]
- return dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
- fill_tuple).padded_batch(
- 4,
- padded_shapes=(padded_shape, padded_shape),
- padding_values=(-1, "<end>"))
-
- seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32)
- seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32)
- self.run_core_tests(lambda: build_dataset(seq_lens1),
- lambda: build_dataset(seq_lens2), 8)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py
deleted file mode 100644
index 1f8a584df9..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py
+++ /dev/null
@@ -1,101 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the ParallelInterleaveDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import interleave_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import sparse_ops
-from tensorflow.python.platform import test
-
-
-class ParallelInterleaveDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def setUp(self):
- self.input_values = np.array([4, 5, 6], dtype=np.int64)
- self.num_repeats = 2
- self.num_outputs = np.sum(self.input_values) * 2
-
- def _build_ds(self, cycle_length, block_length, sloppy=False):
- return (dataset_ops.Dataset.from_tensor_slices(
- self.input_values).repeat(self.num_repeats).apply(
- interleave_ops.parallel_interleave(
- lambda x: dataset_ops.Dataset.range(10 * x, 11 * x),
- cycle_length, block_length, sloppy)))
-
- def testSerializationCore(self):
- # cycle_length > 1, block_length > 1
- cycle_length = 2
- block_length = 3
- self.run_core_tests(
- lambda: self._build_ds(cycle_length, block_length),
- lambda: self._build_ds(cycle_length * 2, block_length * 1),
- self.num_outputs)
- # cycle_length = 1
- cycle_length = 1
- block_length = 3
- self.run_core_tests(lambda: self._build_ds(cycle_length, block_length),
- None, self.num_outputs)
- # block_length = 1
- cycle_length = 2
- block_length = 1
- self.run_core_tests(lambda: self._build_ds(cycle_length, block_length),
- None, self.num_outputs)
-
- def testSerializationWithSloppy(self):
- break_points = self.gen_break_points(self.num_outputs, 10)
- expected_outputs = np.repeat(
- np.concatenate([np.arange(10 * x, 11 * x) for x in self.input_values]),
- self.num_repeats).tolist()
-
- def run_test(cycle_length, block_length):
- actual = self.gen_outputs(
- lambda: self._build_ds(cycle_length, block_length, True),
- break_points, self.num_outputs)
- self.assertSequenceEqual(sorted(actual), expected_outputs)
-
- # cycle_length > 1, block_length > 1
- run_test(2, 3)
- # cycle_length = 1
- run_test(1, 3)
- # block_length = 1
- run_test(2, 1)
-
- def testSparseCore(self):
-
- def _map_fn(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
-
- def _interleave_fn(x):
- return dataset_ops.Dataset.from_tensor_slices(
- sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
-
- def _build_dataset():
- return dataset_ops.Dataset.range(10).map(_map_fn).apply(
- interleave_ops.parallel_interleave(_interleave_fn, 1))
-
- self.run_core_tests(_build_dataset, None, 20)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py
deleted file mode 100644
index 3fb7605be1..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py
+++ /dev/null
@@ -1,139 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the ParallelMapDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import function
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.platform import test
-
-
-class ParallelMapDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def setUp(self):
- self._tensor_slice_len = 7
- self._num_epochs = 1
- self._num_outputs = self._tensor_slice_len * self._num_epochs
-
- def _build_ds(self, multiplier=37.0):
- components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) *
- np.arange(self._tensor_slice_len)[:, np.newaxis],
- np.array(multiplier) * np.arange(self._tensor_slice_len))
-
- def _map_fn(x, y, z):
- return math_ops.square(x), math_ops.square(y), math_ops.square(z)
-
- return (dataset_ops.Dataset.from_tensor_slices(components).map(
- _map_fn, num_parallel_calls=3).repeat(self._num_epochs))
-
- def _build_ds_with_prefetch(self, multiplier=37.0):
- components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) *
- np.arange(self._tensor_slice_len)[:, np.newaxis],
- np.array(multiplier) * np.arange(self._tensor_slice_len))
-
- def _map_fn(x, y, z):
- return math_ops.square(x), math_ops.square(y), math_ops.square(z)
-
- return (dataset_ops.Dataset.from_tensor_slices(components).map(
- _map_fn, num_parallel_calls=3).repeat(self._num_epochs).prefetch(5))
-
- def testSaveRestoreCore(self):
- for ds_fn in [self._build_ds, self._build_ds_with_prefetch]:
- self.run_core_tests(
- ds_fn,
- lambda: ds_fn(multiplier=15.0),
- self._num_outputs)
-
- def testSaveStatefulFunction(self):
-
- def _build_ds():
-
- def _map_fn(x):
- return random_ops.random_uniform(
- (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x)
-
- return dataset_ops.Dataset.range(100).map(
- _map_fn, num_parallel_calls=2).prefetch(2)
-
- self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
-
- def testCaptureVariableInMapFn(self):
-
- def _build_ds():
- counter_var = variable_scope.get_variable(
- "counter", (), dtypes.int32, use_resource=True)
- return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
- lambda _: counter_var.assign_add(1),
- num_parallel_calls=2).prefetch(2))
-
- self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
-
- def testCaptureConstantInMapFn(self):
-
- def _build_ds():
- constant_var = constant_op.constant(5)
- return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
- lambda x: x + constant_var, num_parallel_calls=2).prefetch(2))
-
- self.run_core_tests(_build_ds, None, 10)
-
- def testCaptureDefunInMapFn(self):
- num_outputs = 100
-
- def _build_ds():
-
- @function.Defun(dtypes.int64)
- def defun_fn(x):
- return constant_op.constant(1000) + math_ops.to_int32(x)
-
- return dataset_ops.Dataset.range(num_outputs).map(
- defun_fn, num_parallel_calls=2).prefetch(2)
-
- self.run_core_tests(_build_ds, None, num_outputs)
-
- def testBuildDefunInMapFn(self):
- num_outputs = 100
-
- def _build_ds():
-
- @function.Defun(dtypes.int64)
- def defun_fn(x):
-
- @function.Defun(dtypes.int32)
- def defun_fn_deep(x):
- return constant_op.constant(1000) + math_ops.to_int32(x)
-
- return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x))
-
- return dataset_ops.Dataset.range(num_outputs).map(
- defun_fn, num_parallel_calls=2).prefetch(2)
-
- self.run_core_tests(_build_ds, None, num_outputs)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py
deleted file mode 100644
index d3fa84e74c..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the ParseExampleDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.platform import test
-
-
-class ParseExampleDatasetSerializationTest(
- reader_dataset_ops_test_base.ReadBatchFeaturesTestBase,
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def ParseExampleDataset(self, num_repeat, batch_size):
- return self.make_batch_feature(
- filenames=self.test_filenames,
- num_epochs=num_repeat,
- batch_size=batch_size,
- reader_num_threads=5,
- parser_num_threads=10)
-
- def testSerializationCore(self):
- num_repeat = 5
- batch_size = 2
- num_outputs = self._num_records * self._num_files * num_repeat // batch_size
- # pylint: disable=g-long-lambda
- self.run_core_tests(
- lambda: self.ParseExampleDataset(
- num_repeat=num_repeat, batch_size=batch_size),
- lambda: self.ParseExampleDataset(num_repeat=10, batch_size=4),
- num_outputs)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py
deleted file mode 100644
index c802402461..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the PrefetchDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.platform import test
-
-
-class PrefetchDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def build_dataset(self, seed):
- return dataset_ops.Dataset.range(100).prefetch(10).shuffle(
- buffer_size=10, seed=seed, reshuffle_each_iteration=False)
-
- def testCore(self):
- num_outputs = 100
- self.run_core_tests(lambda: self.build_dataset(10),
- lambda: self.build_dataset(20), num_outputs)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py
deleted file mode 100644
index 6341190847..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py
+++ /dev/null
@@ -1,118 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the RangeDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import io_ops
-from tensorflow.python.ops import parsing_ops
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import test
-
-
-class RangeDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _iterator_checkpoint_prefix_local(self):
- return os.path.join(self.get_temp_dir(), "iterator")
-
- def _save_op(self, iterator_resource):
- iterator_state_variant = gen_dataset_ops.serialize_iterator(
- iterator_resource)
- save_op = io_ops.write_file(
- self._iterator_checkpoint_prefix_local(),
- parsing_ops.serialize_tensor(iterator_state_variant))
- return save_op
-
- def _restore_op(self, iterator_resource):
- iterator_state_variant = parsing_ops.parse_tensor(
- io_ops.read_file(self._iterator_checkpoint_prefix_local()),
- dtypes.variant)
- restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
- iterator_state_variant)
- return restore_op
-
- def testSaveRestore(self):
-
- def _build_graph(start, stop):
- iterator = dataset_ops.Dataset.range(start,
- stop).make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- save_op = self._save_op(iterator._iterator_resource)
- restore_op = self._restore_op(iterator._iterator_resource)
- return init_op, get_next, save_op, restore_op
-
- # Saving and restoring in different sessions.
- start = 2
- stop = 10
- break_point = 5
- with ops.Graph().as_default() as g:
- init_op, get_next, save_op, _ = _build_graph(start, stop)
- with self.session(graph=g) as sess:
- sess.run(variables.global_variables_initializer())
- sess.run(init_op)
- for i in range(start, break_point):
- self.assertEqual(i, sess.run(get_next))
- sess.run(save_op)
-
- with ops.Graph().as_default() as g:
- init_op, get_next, _, restore_op = _build_graph(start, stop)
- with self.session(graph=g) as sess:
- sess.run(init_op)
- sess.run(restore_op)
- for i in range(break_point, stop):
- self.assertEqual(i, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Saving and restoring in same session.
- with ops.Graph().as_default() as g:
- init_op, get_next, save_op, restore_op = _build_graph(start, stop)
- with self.session(graph=g) as sess:
- sess.run(variables.global_variables_initializer())
- sess.run(init_op)
- for i in range(start, break_point):
- self.assertEqual(i, sess.run(get_next))
- sess.run(save_op)
- sess.run(restore_op)
- for i in range(break_point, stop):
- self.assertEqual(i, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def _build_range_dataset(self, start, stop):
- return dataset_ops.Dataset.range(start, stop)
-
- def testRangeCore(self):
- start = 2
- stop = 10
- stop_1 = 8
- self.run_core_tests(lambda: self._build_range_dataset(start, stop),
- lambda: self._build_range_dataset(start, stop_1),
- stop - start)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py
deleted file mode 100644
index fdb35ea624..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py
+++ /dev/null
@@ -1,46 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the SampleFromDatasets serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import interleave_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.platform import test
-
-
-class SampleFromDatasetsSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_dataset(self, probs, num_samples):
- dataset = interleave_ops.sample_from_datasets(
- [
- dataset_ops.Dataset.from_tensors(i).repeat(None)
- for i in range(len(probs))
- ],
- probs,
- seed=1813)
- return dataset.take(num_samples)
-
- def testSerializationCore(self):
- self.run_core_tests(
- lambda: self._build_dataset([0.5, 0.5], 100),
- lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py
deleted file mode 100644
index af9ef48c0f..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py
+++ /dev/null
@@ -1,40 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the ScanDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import scan_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.platform import test
-
-
-class ScanDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_dataset(self, num_elements):
- return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply(
- scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])))
-
- def testScanCore(self):
- num_output = 5
- self.run_core_tests(lambda: self._build_dataset(num_output),
- lambda: self._build_dataset(2), num_output)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py
deleted file mode 100644
index 2afebca0f5..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py
+++ /dev/null
@@ -1,129 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the sequence datasets serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.platform import test
-
-
-class SkipDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_skip_dataset(self, count):
- components = (np.arange(10),)
- return dataset_ops.Dataset.from_tensor_slices(components).skip(count)
-
- def testSkipFewerThanInputs(self):
- count = 4
- num_outputs = 10 - count
- self.run_core_tests(lambda: self._build_skip_dataset(count),
- lambda: self._build_skip_dataset(count + 2),
- num_outputs)
-
- def testSkipVarious(self):
- # Skip more than inputs
- self.run_core_tests(lambda: self._build_skip_dataset(20), None, 0)
- # Skip exactly the input size
- self.run_core_tests(lambda: self._build_skip_dataset(10), None, 0)
- self.run_core_tests(lambda: self._build_skip_dataset(-1), None, 0)
- # Skip nothing
- self.run_core_tests(lambda: self._build_skip_dataset(0), None, 10)
-
- def testInvalidSkip(self):
- with self.assertRaisesRegexp(ValueError,
- 'Shape must be rank 0 but is rank 1'):
- self.run_core_tests(lambda: self._build_skip_dataset([1, 2]), None, 0)
-
-
-class TakeDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_take_dataset(self, count):
- components = (np.arange(10),)
- return dataset_ops.Dataset.from_tensor_slices(components).take(count)
-
- def testTakeFewerThanInputs(self):
- count = 4
- self.run_core_tests(
- lambda: self._build_take_dataset(count),
- lambda: self._build_take_dataset(count + 2),
- count,
- )
-
- def testTakeVarious(self):
- # Take more than inputs
- self.run_core_tests(lambda: self._build_take_dataset(20), None, 10)
- # Take exactly the input size
- self.run_core_tests(lambda: self._build_take_dataset(10), None, 10)
- # Take all
- self.run_core_tests(lambda: self._build_take_dataset(-1), None, 10)
- # Take nothing
- self.run_core_tests(lambda: self._build_take_dataset(0), None, 0)
-
- def testInvalidTake(self):
- with self.assertRaisesRegexp(ValueError,
- 'Shape must be rank 0 but is rank 1'):
- self.run_core_tests(lambda: self._build_take_dataset([1, 2]), None, 0)
-
-
-class RepeatDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_repeat_dataset(self, count, take_count=3):
- components = (np.arange(10),)
- return dataset_ops.Dataset.from_tensor_slices(components).take(
- take_count).repeat(count)
-
- def testFiniteRepeat(self):
- count = 10
- self.run_core_tests(lambda: self._build_repeat_dataset(count),
- lambda: self._build_repeat_dataset(count + 2),
- 3 * count)
-
- def testEmptyRepeat(self):
- self.run_core_tests(lambda: self._build_repeat_dataset(0), None, 0)
-
- def testInfiniteRepeat(self):
- self.verify_unused_iterator(
- lambda: self._build_repeat_dataset(-1), 10, verify_exhausted=False)
- self.verify_init_before_restore(
- lambda: self._build_repeat_dataset(-1), 10, verify_exhausted=False)
- self.verify_multiple_breaks(
- lambda: self._build_repeat_dataset(-1), 20, verify_exhausted=False)
- self.verify_reset_restored_iterator(
- lambda: self._build_repeat_dataset(-1), 20, verify_exhausted=False)
- self.verify_restore_in_modified_graph(
- lambda: self._build_repeat_dataset(-1),
- lambda: self._build_repeat_dataset(2),
- 20,
- verify_exhausted=False)
- # Test repeat empty dataset
- self.run_core_tests(lambda: self._build_repeat_dataset(-1, 0), None, 0)
-
- def testInvalidRepeat(self):
- with self.assertRaisesRegexp(
- ValueError, 'Shape must be rank 0 but is rank 1'):
- self.run_core_tests(lambda: self._build_repeat_dataset([1, 2], 0),
- None, 0)
-
-
-if __name__ == '__main__':
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py
deleted file mode 100644
index 6aac50ecd9..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py
+++ /dev/null
@@ -1,85 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Integration test for dataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-
-from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import ops
-from tensorflow.python.platform import test
-from tensorflow.python.training import saver as saver_lib
-
-
-class SerializationIntegrationTest(test.TestCase):
-
- def _build_input_pipeline(self, name, num_outputs):
- with ops.name_scope(name):
- ds = dataset_ops.Dataset.range(num_outputs).shuffle(
- 10, reshuffle_each_iteration=False).prefetch(10)
- iterator = ds.make_initializable_iterator()
- saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
- ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
- return iterator.initializer, iterator.get_next()
-
- def _build_graph(self, num_pipelines, num_outputs):
- init_ops = []
- get_next_ops = []
- for i in range(num_pipelines):
- name = "input_pipeline_%d" % i
- init_op, get_next_op = self._build_input_pipeline(name, num_outputs)
- init_ops.append(init_op)
- get_next_ops.append(get_next_op)
- saver = saver_lib.Saver()
- return init_ops, get_next_ops, saver
-
- def _ckpt_path(self):
- return os.path.join(self.get_temp_dir(), "iterator")
-
- def testConcurrentSaves(self):
- num_pipelines = 100
- num_outputs = 100
- break_point = 10
- all_outputs = [[] for _ in range(num_pipelines)]
- with ops.Graph().as_default() as g:
- init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
- num_outputs)
- with self.session(graph=g) as sess:
- sess.run(init_ops)
- for _ in range(break_point):
- output = sess.run(get_next_ops)
- for i in range(num_pipelines):
- all_outputs[i].append(output[i])
- saver.save(sess, self._ckpt_path())
-
- with ops.Graph().as_default() as g:
- init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
- num_outputs)
- with self.session(graph=g) as sess:
- saver.restore(sess, self._ckpt_path())
- for _ in range(num_outputs - break_point):
- output = sess.run(get_next_ops)
- for i in range(num_pipelines):
- all_outputs[i].append(output[i])
-
- for output in all_outputs:
- self.assertSequenceEqual(sorted(output), range(num_outputs))
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py
deleted file mode 100644
index f199ec835e..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the ShuffleAndRepeatDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import shuffle_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.platform import test
-
-
-class ShuffleAndRepeatSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_ds(self, seed):
- return dataset_ops.Dataset.range(20).apply(
- shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed))
-
- def testCore(self):
- self.run_core_tests(lambda: self._build_ds(10), lambda: self._build_ds(20),
- 100)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py
deleted file mode 100644
index a59fa94d66..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py
+++ /dev/null
@@ -1,148 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the ShuffleDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import ops
-from tensorflow.python.platform import test
-from tensorflow.python.training import saver as saver_lib
-
-
-class ShuffleDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_shuffle_dataset(
- self,
- range_limit=10,
- num_repeats=5,
- buffer_size=5,
- seed=None,
- reshuffle_each_iteration=None,
- ):
- return dataset_ops.Dataset.range(range_limit).shuffle(
- buffer_size,
- seed=seed,
- reshuffle_each_iteration=reshuffle_each_iteration).repeat(num_repeats)
-
- def testShuffleCore(self):
-
- seed = 55
- range_limit = 5
- num_repeats = 2
- num_outputs = range_limit * num_repeats
- buffer_sizes = [1, 3, 5, 8, 10]
- # pylint: disable=cell-var-from-loop
- # pylint: disable=g-long-lambda
- for reshuffle_each_iteration in [True, False]:
- for buffer_size in buffer_sizes:
- self.run_core_tests(
- lambda: self._build_shuffle_dataset(
- range_limit=range_limit,
- num_repeats=num_repeats,
- buffer_size=buffer_size,
- seed=seed,
- reshuffle_each_iteration=reshuffle_each_iteration),
- lambda: self._build_shuffle_dataset(
- range_limit=range_limit,
- num_repeats=num_repeats,
- buffer_size=buffer_size,
- seed=10,
- reshuffle_each_iteration=reshuffle_each_iteration),
- num_outputs)
- # pylint: enable=cell-var-from-loop
- # pylint: enable=g-long-lambda
-
- def testNonDeterministicSeeding(self):
-
- range_limit = 5
- num_repeats = 2
- num_outputs = range_limit * num_repeats
- buffer_sizes = [1, 3, 5, 8, 10]
- for reshuffle_each_iteration in [True, False]:
- for buffer_size in buffer_sizes:
-
- def ds_fn():
- # pylint: disable=cell-var-from-loop
- return self._build_shuffle_dataset(
- range_limit=range_limit,
- num_repeats=num_repeats,
- buffer_size=buffer_size,
- seed=None, # Iterator seeds are generated non-deterministically.
- reshuffle_each_iteration=reshuffle_each_iteration)
- # pylint: enable=cell-var-from-loop
-
- # We checkpoint the initial state of the Dataset so that we can restore
- # the seeds in the next run. Since the seeding is non-deterministic
- # the dataset gets initialized with different seeds each time.
- expected = self.gen_outputs(
- ds_fn,
- break_points=[0],
- num_outputs=num_outputs,
- ckpt_saved=False,
- verify_exhausted=False,
- save_checkpoint_at_end=False)
- actual = self.gen_outputs(
- ds_fn,
- break_points=self.gen_break_points(num_outputs),
- num_outputs=num_outputs,
- ckpt_saved=True,
- verify_exhausted=False)
- self.match(expected, actual)
-
- def testMultipleIterators(self):
- range_limit = 5
- num_repeats = 2
- num_outputs = range_limit * num_repeats
- buffer_sizes = [1, 3, 5, 8, 10]
-
- for reshuffle_each_iteration in [True, False]:
- for buffer_size in buffer_sizes:
-
- def ds_fn():
- # pylint: disable=cell-var-from-loop
- return self._build_shuffle_dataset(
- range_limit=range_limit,
- num_repeats=num_repeats,
- buffer_size=buffer_size,
- seed=None, # Iterator seeds are generated non-deterministically.
- reshuffle_each_iteration=reshuffle_each_iteration)
- # pylint: enable=cell-var-from-loop
-
- with ops.Graph().as_default() as g:
- ds = ds_fn()
- iterators = [ds.make_one_shot_iterator(), ds.make_one_shot_iterator()]
- get_next_ops = [it.get_next() for it in iterators]
- saveables = [
- contrib_iterator_ops.make_saveable_from_iterator(it)
- for it in iterators
- ]
- for saveable in saveables:
- ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
- saver = saver_lib.Saver(allow_empty=True)
- with self.session(graph=g) as sess:
- self._save(sess, saver)
- expected = [sess.run(get_next_ops) for _ in range(num_outputs)]
- self._restore(saver, sess)
- actual = [sess.run(get_next_ops) for _ in range(num_outputs)]
- self.match(expected, actual)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py
deleted file mode 100644
index 93b26ed58a..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the SqlDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-
-from tensorflow.contrib.data.python.kernel_tests import sql_dataset_op_test_base
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import readers
-from tensorflow.python.framework import dtypes
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-
-
-class SqlDatasetSerializationTest(
- sql_dataset_op_test_base.SqlDatasetTestBase,
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_dataset(self, num_repeats):
- data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite")
- driver_name = array_ops.placeholder_with_default(
- array_ops.constant("sqlite", dtypes.string), shape=[])
- query = ("SELECT first_name, last_name, motto FROM students ORDER BY "
- "first_name DESC")
- output_types = (dtypes.string, dtypes.string, dtypes.string)
- return readers.SqlDataset(driver_name, data_source_name, query,
- output_types).repeat(num_repeats)
-
- def testSQLSaveable(self):
- num_repeats = 4
- num_outputs = num_repeats * 2
- self.run_core_tests(lambda: self._build_dataset(num_repeats),
- lambda: self._build_dataset(num_repeats // 2),
- num_outputs)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py
deleted file mode 100644
index a10f85263a..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py
+++ /dev/null
@@ -1,106 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the StatsDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import stats_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-
-
-# TODO(shivaniagrawal): Can not checkpoint input_pipeline with the
-# transformation `stats_ops.set_stats_aggregator`, since we don't support
-# serializing StatsAggregator yet.
-class StatsDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_dataset_bytes_stats(self, num_elements):
- return dataset_ops.Dataset.range(num_elements).map(
- lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply(
- stats_ops.bytes_produced_stats("bytes_produced"))
-
- def test_bytes_produced_stats_invalid_tag_shape(self):
- with self.assertRaisesRegexp(
- ValueError, "Shape must be rank 0 but is rank 1"):
- # pylint: disable=g-long-lambda
- self.run_core_tests(
- lambda: dataset_ops.Dataset.range(100).apply(
- stats_ops.bytes_produced_stats(["bytes_produced"])),
- None, 100)
- # pylint: enable=g-long-lambda
-
- def testBytesStatsDatasetSaveableCore(self):
- num_outputs = 100
- self.run_core_tests(
- lambda: self._build_dataset_bytes_stats(num_outputs),
- lambda: self._build_dataset_bytes_stats(num_outputs // 10), num_outputs)
-
- def _build_dataset_latency_stats(self, num_elements, tag="record_latency"):
- return dataset_ops.Dataset.range(num_elements).apply(
- stats_ops.latency_stats(tag))
-
- def _build_dataset_multiple_tags(self,
- num_elements,
- tag1="record_latency",
- tag2="record_latency_2"):
- return dataset_ops.Dataset.range(num_elements).apply(
- stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2))
-
- def test_latency_stats_invalid_tag_shape(self):
- with self.assertRaisesRegexp(
- ValueError, "Shape must be rank 0 but is rank 1"):
- # pylint: disable=g-long-lambda
- self.run_core_tests(
- lambda: dataset_ops.Dataset.range(100).apply(
- stats_ops.latency_stats(["record_latency", "record_latency_2"])),
- None, 100)
- # pylint: enable=g-long-lambda
-
- def testLatencyStatsDatasetSaveableCore(self):
- num_outputs = 100
-
- self.run_core_tests(
- lambda: self._build_dataset_latency_stats(num_outputs),
- lambda: self._build_dataset_latency_stats(num_outputs // 10),
- num_outputs)
-
- self.run_core_tests(lambda: self._build_dataset_multiple_tags(num_outputs),
- None, num_outputs)
-
- tag1 = "record_latency"
- tag2 = "record_latency"
- self.run_core_tests(
- lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2),
- None, num_outputs)
-
- def _build_dataset_stats_aggregator(self):
- stats_aggregator = stats_ops.StatsAggregator()
- return dataset_ops.Dataset.range(10).apply(
- stats_ops.set_stats_aggregator(stats_aggregator))
-
- def test_set_stats_aggregator_not_support_checkpointing(self):
- with self.assertRaisesRegexp(errors.UnimplementedError,
- "does not support checkpointing"):
- self.run_core_tests(self._build_dataset_stats_aggregator, None, 10)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py
deleted file mode 100644
index 2483787f44..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the TextLineDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import readers as core_readers
-from tensorflow.python.platform import test
-
-
-class TextLineDatasetSerializationTest(
- reader_dataset_ops_test_base.TextLineDatasetTestBase,
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_iterator_graph(self, test_filenames, compression_type=None):
- return core_readers.TextLineDataset(
- test_filenames, compression_type=compression_type, buffer_size=10)
-
- def testTextLineCore(self):
- compression_types = [None, "GZIP", "ZLIB"]
- num_files = 5
- lines_per_file = 5
- num_outputs = num_files * lines_per_file
- for compression_type in compression_types:
- test_filenames = self._createFiles(
- num_files,
- lines_per_file,
- crlf=True,
- compression_type=compression_type)
- # pylint: disable=cell-var-from-loop
- self.run_core_tests(
- lambda: self._build_iterator_graph(test_filenames, compression_type),
- lambda: self._build_iterator_graph(test_filenames), num_outputs)
- # pylint: enable=cell-var-from-loop
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py
deleted file mode 100644
index 55a6257a27..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py
+++ /dev/null
@@ -1,99 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the TFRecordDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import gzip
-import os
-import zlib
-
-from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import readers as core_readers
-from tensorflow.python.platform import test
-
-
-class TFRecordDatasetSerializationTest(
- reader_dataset_ops_test_base.TFRecordDatasetTestBase,
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_iterator_graph(self,
- num_epochs,
- batch_size=1,
- compression_type=None,
- buffer_size=None):
- filenames = self._createFiles()
- if compression_type == "ZLIB":
- zlib_files = []
- for i, fn in enumerate(filenames):
- with open(fn, "rb") as f:
- cdata = zlib.compress(f.read())
- zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i)
- with open(zfn, "wb") as f:
- f.write(cdata)
- zlib_files.append(zfn)
- filenames = zlib_files
-
- elif compression_type == "GZIP":
- gzip_files = []
- for i, fn in enumerate(self.test_filenames):
- with open(fn, "rb") as f:
- gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i)
- with gzip.GzipFile(gzfn, "wb") as gzf:
- gzf.write(f.read())
- gzip_files.append(gzfn)
- filenames = gzip_files
-
- return core_readers.TFRecordDataset(
- filenames, compression_type,
- buffer_size=buffer_size).repeat(num_epochs).batch(batch_size)
-
- def testTFRecordWithoutBufferCore(self):
- num_epochs = 5
- batch_size = num_epochs
- num_outputs = num_epochs * self._num_files * self._num_records // batch_size
- # pylint: disable=g-long-lambda
- self.run_core_tests(
- lambda: self._build_iterator_graph(num_epochs, batch_size,
- buffer_size=0),
- lambda: self._build_iterator_graph(num_epochs * 2, batch_size),
- num_outputs)
- self.run_core_tests(
- lambda: self._build_iterator_graph(num_epochs, buffer_size=0), None,
- num_outputs * batch_size)
- # pylint: enable=g-long-lambda
-
- def testTFRecordWithBufferCore(self):
- num_epochs = 5
- num_outputs = num_epochs * self._num_files * self._num_records
- self.run_core_tests(lambda: self._build_iterator_graph(num_epochs),
- lambda: self._build_iterator_graph(num_epochs * 2),
- num_outputs)
-
- def testTFRecordWithCompressionCore(self):
- num_epochs = 5
- num_outputs = num_epochs * self._num_files * self._num_records
- self.run_core_tests(
- lambda: self._build_iterator_graph(num_epochs, compression_type="ZLIB"),
- lambda: self._build_iterator_graph(num_epochs * 2), num_outputs)
- self.run_core_tests(
- lambda: self._build_iterator_graph(num_epochs, compression_type="GZIP"),
- lambda: self._build_iterator_graph(num_epochs * 2), num_outputs)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py
deleted file mode 100644
index b2a5a8a20d..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py
+++ /dev/null
@@ -1,51 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the UnbatchDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.platform import test
-
-
-class UnbatchDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2):
- components = (
- np.arange(tensor_slice_len),
- np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis],
- np.array(multiplier) * np.arange(tensor_slice_len))
-
- return dataset_ops.Dataset.from_tensor_slices(components).batch(
- batch_size).apply(batching.unbatch())
-
- def testCore(self):
- tensor_slice_len = 8
- batch_size = 2
- num_outputs = tensor_slice_len
- self.run_core_tests(
- lambda: self.build_dataset(15.0, tensor_slice_len, batch_size),
- lambda: self.build_dataset(20.0, tensor_slice_len, batch_size),
- num_outputs)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py
deleted file mode 100644
index 22f15b8846..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py
+++ /dev/null
@@ -1,40 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the UniqueDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import unique
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.platform import test
-
-
-class UniqueDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def testUnique(self):
-
- def build_dataset(num_elements, unique_elem_range):
- return dataset_ops.Dataset.range(num_elements).map(
- lambda x: x % unique_elem_range).apply(unique.unique())
-
- self.run_core_tests(lambda: build_dataset(200, 100),
- lambda: build_dataset(40, 100), 100)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py
deleted file mode 100644
index 340a6ff72e..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py
+++ /dev/null
@@ -1,54 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the ZipDataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.platform import test
-
-
-class ZipDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def _build_dataset(self, arr):
- components = [
- np.tile(np.array([[1], [2], [3], [4]]), 20),
- np.tile(np.array([[12], [13], [14], [15]]), 22),
- np.array(arr)
- ]
- datasets = [
- dataset_ops.Dataset.from_tensor_slices(component)
- for component in components
- ]
- return dataset_ops.Dataset.zip((datasets[0], (datasets[1], datasets[2])))
-
- def testCore(self):
- # Equal length components
- arr = [37.0, 38.0, 39.0, 40.0]
- num_outputs = len(arr)
- self.run_core_tests(lambda: self._build_dataset(arr), None, num_outputs)
- # Variable length components
- diff_size_arr = [1.0, 2.0]
- self.run_core_tests(lambda: self._build_dataset(diff_size_arr),
- lambda: self._build_dataset(arr), 2)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
deleted file mode 100644
index c97002a255..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
+++ /dev/null
@@ -1,115 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import shuffle_ops
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.platform import test
-
-
-class ShuffleAndRepeatTest(test_base.DatasetTestBase):
-
- def _build_ds(self, seed, count=5, num_elements=20):
- return dataset_ops.Dataset.range(num_elements).apply(
- shuffle_ops.shuffle_and_repeat(buffer_size=5, count=count, seed=seed))
-
- def _gen_outputs(self, ds_fn, num_outputs, verify_exhausted=True):
- get_next = ds_fn().make_one_shot_iterator().get_next()
- outputs = []
- with self.cached_session() as sess:
- for _ in range(num_outputs):
- outputs.append(sess.run(get_next))
- if verify_exhausted:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
- return outputs
-
- def testCorrectOutput(self):
- output = self._gen_outputs(lambda: self._build_ds(10), 100)
- self.assertSequenceEqual(
- sorted(output), sorted(
- np.array([range(20) for _ in range(5)]).flatten()))
- for i in range(5):
- self.assertSequenceEqual(sorted(output[i * 20:(i + 1) * 20]), range(20))
-
- def testReshuffling(self):
- # Check that the output orders of different epochs are indeed different.
- output = self._gen_outputs(lambda: self._build_ds(10), 100)
- for i in range(4):
- epoch1 = output[i * 20:(i + 1) * 20]
- epoch2 = output[(i + 1) * 20:(i + 2) * 20]
- self.assertNotEqual(epoch1, epoch2)
-
- def testSameOrderForSameSeeds(self):
- output1 = self._gen_outputs(lambda: self._build_ds(10), 100)
- output2 = self._gen_outputs(lambda: self._build_ds(10), 100)
- self.assertEqual(output1, output2)
-
- def testDifferentOrderForDifferentSeeds(self):
- output1 = self._gen_outputs(lambda: self._build_ds(10), 100)
- output2 = self._gen_outputs(lambda: self._build_ds(20), 100)
- self.assertNotEqual(output1, output2)
- self.assertEqual(sorted(output1), sorted(output2))
-
- def testCountNone(self):
- output1 = self._gen_outputs(
- lambda: self._build_ds(10, count=None), 100, verify_exhausted=False)
- output2 = self._gen_outputs(
- lambda: self._build_ds(20, count=None), 100, verify_exhausted=False)
- self.assertNotEqual(output1, output2)
- self.assertEqual(sorted(output1), sorted(output2))
-
- def testCountMinusOne(self):
- output1 = self._gen_outputs(
- lambda: self._build_ds(10, count=-1), 100, verify_exhausted=False)
- output2 = self._gen_outputs(
- lambda: self._build_ds(20, count=-1), 100, verify_exhausted=False)
- self.assertNotEqual(output1, output2)
- self.assertEqual(sorted(output1), sorted(output2))
-
- def testInfiniteOutputs(self):
- # Asserting the iterator is exhausted after producing 100 items should fail.
- with self.assertRaises(AssertionError):
- self._gen_outputs(lambda: self._build_ds(10, count=None), 100)
- with self.assertRaises(AssertionError):
- self._gen_outputs(lambda: self._build_ds(10, count=-1), 100)
-
- def testInfiniteEmpty(self):
- with self.assertRaises(errors.OutOfRangeError):
- self._gen_outputs(lambda: self._build_ds(10, count=None, num_elements=0),
- 100)
- with self.assertRaises(errors.OutOfRangeError):
- self._gen_outputs(lambda: self._build_ds(10, count=-1, num_elements=0),
- 100)
-
- def testLargeBufferSize(self):
- with ops.Graph().as_default() as g:
- ds = dataset_ops.Dataset.range(20).apply(
- shuffle_ops.shuffle_and_repeat(buffer_size=21))
- get_next_op = ds.make_one_shot_iterator().get_next()
- with self.session(graph=g) as sess:
- sess.run(get_next_op)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
deleted file mode 100644
index 52823d3fca..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
+++ /dev/null
@@ -1,590 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for experimental sql input op."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.kernel_tests import sql_dataset_op_test_base
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.platform import test
-
-
-class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
-
- # Test that SqlDataset can read from a database table.
- def testReadResultSet(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
- dtypes.string), 2)
- with self.cached_session() as sess:
- for _ in range(2): # Run twice to verify statelessness of db operations.
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, last_name, motto FROM students "
- "ORDER BY first_name DESC"
- })
- for _ in range(2): # Dataset is repeated. See setUp.
- self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next))
- self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that SqlDataset works on a join query.
- def testReadResultSetJoinQuery(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
- dtypes.string))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query:
- "SELECT students.first_name, state, motto FROM students "
- "INNER JOIN people "
- "ON students.first_name = people.first_name "
- "AND students.last_name = people.last_name"
- })
- self.assertEqual((b"John", b"California", b"Hi!"), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that SqlDataset can read a database entry with a null-terminator
- # in the middle of the text and place the entry in a `string` tensor.
- def testReadResultSetNullTerminator(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
- dtypes.string))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query:
- "SELECT first_name, last_name, favorite_nonsense_word "
- "FROM students ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", b"Doe", b"n\0nsense"), sess.run(get_next))
- self.assertEqual((b"Jane", b"Moe", b"nonsense\0"), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that SqlDataset works when used on two different queries.
- # Because the output types of the dataset must be determined at graph-creation
- # time, the two queries must have the same number and types of columns.
- def testReadResultSetReuseSqlDataset(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
- dtypes.string))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, last_name, motto FROM students "
- "ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next))
- self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, last_name, state FROM people "
- "ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", b"Doe", b"California"), sess.run(get_next))
- self.assertEqual((b"Benjamin", b"Franklin", b"Pennsylvania"),
- sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that an `OutOfRangeError` is raised on the first call to
- # `get_next_str_only` if result set is empty.
- def testReadEmptyResultSet(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
- dtypes.string))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, last_name, motto FROM students "
- "WHERE first_name = 'Nonexistent'"
- })
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that an error is raised when `driver_name` is invalid.
- def testReadResultSetWithInvalidDriverName(self):
- init_op = self._createSqlDataset((dtypes.string, dtypes.string,
- dtypes.string))[0]
- with self.cached_session() as sess:
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(
- init_op,
- feed_dict={
- self.driver_name: "sqlfake",
- self.query: "SELECT first_name, last_name, motto FROM students "
- "ORDER BY first_name DESC"
- })
-
- # Test that an error is raised when a column name in `query` is nonexistent
- def testReadResultSetWithInvalidColumnName(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
- dtypes.string))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query:
- "SELECT first_name, last_name, fake_column FROM students "
- "ORDER BY first_name DESC"
- })
- with self.assertRaises(errors.UnknownError):
- sess.run(get_next)
-
- # Test that an error is raised when there is a syntax error in `query`.
- def testReadResultSetOfQueryWithSyntaxError(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
- dtypes.string))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query:
- "SELEmispellECT first_name, last_name, motto FROM students "
- "ORDER BY first_name DESC"
- })
- with self.assertRaises(errors.UnknownError):
- sess.run(get_next)
-
- # Test that an error is raised when the number of columns in `query`
- # does not match the length of `output_types`.
- def testReadResultSetWithMismatchBetweenColumnsAndOutputTypes(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
- dtypes.string))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, last_name FROM students "
- "ORDER BY first_name DESC"
- })
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
-
- # Test that no results are returned when `query` is an insert query rather
- # than a select query. In particular, the error refers to the number of
- # output types passed to the op not matching the number of columns in the
- # result set of the query (namely, 0 for an insert statement.)
- def testReadResultSetOfInsertQuery(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
- dtypes.string))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query:
- "INSERT INTO students (first_name, last_name, motto) "
- "VALUES ('Foo', 'Bar', 'Baz'), ('Fizz', 'Buzz', 'Fizzbuzz')"
- })
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read an integer from a SQLite database table and
- # place it in an `int8` tensor.
- def testReadResultSetInt8(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, desk_number FROM students "
- "ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", 9), sess.run(get_next))
- self.assertEqual((b"Jane", 127), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read a negative or 0-valued integer from a
- # SQLite database table and place it in an `int8` tensor.
- def testReadResultSetInt8NegativeAndZero(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8,
- dtypes.int8))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, income, favorite_negative_number "
- "FROM students "
- "WHERE first_name = 'John' ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", 0, -2), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read a large (positive or negative) integer from
- # a SQLite database table and place it in an `int8` tensor.
- def testReadResultSetInt8MaxValues(self):
- init_op, get_next = self._createSqlDataset((dtypes.int8, dtypes.int8))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query:
- "SELECT desk_number, favorite_negative_number FROM students "
- "ORDER BY first_name DESC"
- })
- self.assertEqual((9, -2), sess.run(get_next))
- # Max and min values of int8
- self.assertEqual((127, -128), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read an integer from a SQLite database table and
- # place it in an `int16` tensor.
- def testReadResultSetInt16(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, desk_number FROM students "
- "ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", 9), sess.run(get_next))
- self.assertEqual((b"Jane", 127), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read a negative or 0-valued integer from a
- # SQLite database table and place it in an `int16` tensor.
- def testReadResultSetInt16NegativeAndZero(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16,
- dtypes.int16))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, income, favorite_negative_number "
- "FROM students "
- "WHERE first_name = 'John' ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", 0, -2), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read a large (positive or negative) integer from
- # a SQLite database table and place it in an `int16` tensor.
- def testReadResultSetInt16MaxValues(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, favorite_medium_sized_number "
- "FROM students ORDER BY first_name DESC"
- })
- # Max value of int16
- self.assertEqual((b"John", 32767), sess.run(get_next))
- # Min value of int16
- self.assertEqual((b"Jane", -32768), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read an integer from a SQLite database table and
- # place it in an `int32` tensor.
- def testReadResultSetInt32(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, desk_number FROM students "
- "ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", 9), sess.run(get_next))
- self.assertEqual((b"Jane", 127), sess.run(get_next))
-
- # Test that `SqlDataset` can read a negative or 0-valued integer from a
- # SQLite database table and place it in an `int32` tensor.
- def testReadResultSetInt32NegativeAndZero(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, income FROM students "
- "ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", 0), sess.run(get_next))
- self.assertEqual((b"Jane", -20000), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read a large (positive or negative) integer from
- # a SQLite database table and place it in an `int32` tensor.
- def testReadResultSetInt32MaxValues(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, favorite_number FROM students "
- "ORDER BY first_name DESC"
- })
- # Max value of int32
- self.assertEqual((b"John", 2147483647), sess.run(get_next))
- # Min value of int32
- self.assertEqual((b"Jane", -2147483648), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read a numeric `varchar` from a SQLite database
- # table and place it in an `int32` tensor.
- def testReadResultSetInt32VarCharColumnAsInt(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, school_id FROM students "
- "ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", 123), sess.run(get_next))
- self.assertEqual((b"Jane", 1000), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read an integer from a SQLite database table
- # and place it in an `int64` tensor.
- def testReadResultSetInt64(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, desk_number FROM students "
- "ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", 9), sess.run(get_next))
- self.assertEqual((b"Jane", 127), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read a negative or 0-valued integer from a
- # SQLite database table and place it in an `int64` tensor.
- def testReadResultSetInt64NegativeAndZero(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, income FROM students "
- "ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", 0), sess.run(get_next))
- self.assertEqual((b"Jane", -20000), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read a large (positive or negative) integer from
- # a SQLite database table and place it in an `int64` tensor.
- def testReadResultSetInt64MaxValues(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query:
- "SELECT first_name, favorite_big_number FROM students "
- "ORDER BY first_name DESC"
- })
- # Max value of int64
- self.assertEqual((b"John", 9223372036854775807), sess.run(get_next))
- # Min value of int64
- self.assertEqual((b"Jane", -9223372036854775808), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read an integer from a SQLite database table and
- # place it in a `uint8` tensor.
- def testReadResultSetUInt8(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, desk_number FROM students "
- "ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", 9), sess.run(get_next))
- self.assertEqual((b"Jane", 127), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read the minimum and maximum uint8 values from a
- # SQLite database table and place them in `uint8` tensors.
- def testReadResultSetUInt8MinAndMaxValues(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, brownie_points FROM students "
- "ORDER BY first_name DESC"
- })
- # Min value of uint8
- self.assertEqual((b"John", 0), sess.run(get_next))
- # Max value of uint8
- self.assertEqual((b"Jane", 255), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read an integer from a SQLite database table
- # and place it in a `uint16` tensor.
- def testReadResultSetUInt16(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, desk_number FROM students "
- "ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", 9), sess.run(get_next))
- self.assertEqual((b"Jane", 127), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read the minimum and maximum uint16 values from a
- # SQLite database table and place them in `uint16` tensors.
- def testReadResultSetUInt16MinAndMaxValues(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, account_balance FROM students "
- "ORDER BY first_name DESC"
- })
- # Min value of uint16
- self.assertEqual((b"John", 0), sess.run(get_next))
- # Max value of uint16
- self.assertEqual((b"Jane", 65535), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read a 0-valued and 1-valued integer from a
- # SQLite database table and place them as `True` and `False` respectively
- # in `bool` tensors.
- def testReadResultSetBool(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query:
- "SELECT first_name, registration_complete FROM students "
- "ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", True), sess.run(get_next))
- self.assertEqual((b"Jane", False), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read an integer that is not 0-valued or 1-valued
- # from a SQLite database table and place it as `True` in a `bool` tensor.
- def testReadResultSetBoolNotZeroOrOne(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query: "SELECT first_name, favorite_medium_sized_number "
- "FROM students ORDER BY first_name DESC"
- })
- self.assertEqual((b"John", True), sess.run(get_next))
- self.assertEqual((b"Jane", True), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read a float from a SQLite database table
- # and place it in a `float64` tensor.
- def testReadResultSetFloat64(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
- dtypes.float64))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query:
- "SELECT first_name, last_name, victories FROM townspeople "
- "ORDER BY first_name"
- })
- self.assertEqual((b"George", b"Washington", 20.0), sess.run(get_next))
- self.assertEqual((b"John", b"Adams", -19.95), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read a float from a SQLite database table beyond
- # the precision of 64-bit IEEE, without throwing an error. Test that
- # `SqlDataset` identifies such a value as equal to itself.
- def testReadResultSetFloat64OverlyPrecise(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
- dtypes.float64))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query:
- "SELECT first_name, last_name, accolades FROM townspeople "
- "ORDER BY first_name"
- })
- self.assertEqual(
- (b"George", b"Washington",
- 1331241.321342132321324589798264627463827647382647382643874),
- sess.run(get_next))
- self.assertEqual(
- (b"John", b"Adams",
- 1331241321342132321324589798264627463827647382647382643874.0),
- sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Test that `SqlDataset` can read a float from a SQLite database table,
- # representing the largest integer representable as a 64-bit IEEE float
- # such that the previous integer is also representable as a 64-bit IEEE float.
- # Test that `SqlDataset` can distinguish these two numbers.
- def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self):
- init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
- dtypes.float64))
- with self.cached_session() as sess:
- sess.run(
- init_op,
- feed_dict={
- self.query:
- "SELECT first_name, last_name, triumphs FROM townspeople "
- "ORDER BY first_name"
- })
- self.assertNotEqual((b"George", b"Washington", 9007199254740992.0),
- sess.run(get_next))
- self.assertNotEqual((b"John", b"Adams", 9007199254740991.0),
- sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py
deleted file mode 100644
index 319a2ea263..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py
+++ /dev/null
@@ -1,95 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Base class for testing SqlDataset."""
-
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-
-import sqlite3
-
-from tensorflow.contrib.data.python.ops import readers
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.framework import dtypes
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-
-
-class SqlDatasetTestBase(test_base.DatasetTestBase):
- """Base class for setting up and testing SqlDataset."""
-
- def _createSqlDataset(self, output_types, num_repeats=1):
- dataset = readers.SqlDataset(self.driver_name, self.data_source_name,
- self.query, output_types).repeat(num_repeats)
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- return init_op, get_next
-
- def setUp(self):
- self.data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite")
- self.driver_name = array_ops.placeholder_with_default(
- array_ops.constant("sqlite", dtypes.string), shape=[])
- self.query = array_ops.placeholder(dtypes.string, shape=[])
-
- conn = sqlite3.connect(self.data_source_name)
- c = conn.cursor()
- c.execute("DROP TABLE IF EXISTS students")
- c.execute("DROP TABLE IF EXISTS people")
- c.execute("DROP TABLE IF EXISTS townspeople")
- c.execute(
- "CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY, "
- "first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100), "
- "school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), "
- "desk_number INTEGER, income INTEGER, favorite_number INTEGER, "
- "favorite_big_number INTEGER, favorite_negative_number INTEGER, "
- "favorite_medium_sized_number INTEGER, brownie_points INTEGER, "
- "account_balance INTEGER, registration_complete INTEGER)")
- c.executemany(
- "INSERT INTO students (first_name, last_name, motto, school_id, "
- "favorite_nonsense_word, desk_number, income, favorite_number, "
- "favorite_big_number, favorite_negative_number, "
- "favorite_medium_sized_number, brownie_points, account_balance, "
- "registration_complete) "
- "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
- [("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647,
- 9223372036854775807, -2, 32767, 0, 0, 1),
- ("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 127, -20000,
- -2147483648, -9223372036854775808, -128, -32768, 255, 65535, 0)])
- c.execute(
- "CREATE TABLE IF NOT EXISTS people (id INTEGER NOT NULL PRIMARY KEY, "
- "first_name VARCHAR(100), last_name VARCHAR(100), state VARCHAR(100))")
- c.executemany(
- "INSERT INTO PEOPLE (first_name, last_name, state) VALUES (?, ?, ?)",
- [("Benjamin", "Franklin", "Pennsylvania"), ("John", "Doe",
- "California")])
- c.execute(
- "CREATE TABLE IF NOT EXISTS townspeople (id INTEGER NOT NULL PRIMARY "
- "KEY, first_name VARCHAR(100), last_name VARCHAR(100), victories "
- "FLOAT, accolades FLOAT, triumphs FLOAT)")
- c.executemany(
- "INSERT INTO townspeople (first_name, last_name, victories, "
- "accolades, triumphs) VALUES (?, ?, ?, ?, ?)",
- [("George", "Washington", 20.00,
- 1331241.321342132321324589798264627463827647382647382643874,
- 9007199254740991.0),
- ("John", "Adams", -19.95,
- 1331241321342132321324589798264627463827647382647382643874.0,
- 9007199254740992.0)])
- conn.commit()
- conn.close()
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
deleted file mode 100644
index be8ae5e955..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
+++ /dev/null
@@ -1,253 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline statistics gathering ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base
-from tensorflow.contrib.data.python.ops import stats_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.platform import test
-
-
-class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
-
- def testBytesProduced(self):
- stats_aggregator = stats_ops.StatsAggregator()
- dataset = dataset_ops.Dataset.range(100).map(
- lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply(
- stats_ops.bytes_produced_stats("bytes_produced")).apply(
- stats_ops.set_stats_aggregator(stats_aggregator))
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
- summary_t = stats_aggregator.get_summary()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- expected_sum = 0.0
- for i in range(100):
- self.assertAllEqual(
- np.array([i] * i, dtype=np.int64), sess.run(next_element))
- summary_str = sess.run(summary_t)
- self._assertSummaryHasCount(summary_str, "bytes_produced", float(i + 1))
- expected_sum += i * 8.0
- self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
- summary_str = sess.run(summary_t)
- self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0)
- self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
-
- def testLatencyStats(self):
- stats_aggregator = stats_ops.StatsAggregator()
- dataset = dataset_ops.Dataset.range(100).apply(
- stats_ops.latency_stats("record_latency")).apply(
- stats_ops.set_stats_aggregator(stats_aggregator))
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
- summary_t = stats_aggregator.get_summary()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for i in range(100):
- self.assertEqual(i, sess.run(next_element))
- self._assertSummaryHasCount(
- sess.run(summary_t), "record_latency", float(i + 1))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
- self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
-
- def testPrefetchBufferUtilization(self):
- stats_aggregator = stats_ops.StatsAggregator()
- dataset = dataset_ops.Dataset.range(100).map(
- lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(
- -1).apply(stats_ops.set_stats_aggregator(stats_aggregator))
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
- summary_t = stats_aggregator.get_summary()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for i in range(100):
- self.assertAllEqual(
- np.array([i] * i, dtype=np.int64), sess.run(next_element))
- summary_str = sess.run(summary_t)
- self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
- float(i + 1))
- self._assertSummaryContains(summary_str, "Prefetch::buffer_capacity")
- self._assertSummaryContains(summary_str, "Prefetch::buffer_size")
- self._assertSummaryHasRange(summary_str, "Prefetch::buffer_utilization",
- 0, 1)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
- summary_str = sess.run(summary_t)
- self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
- 100)
-
- def testPrefetchBufferScalars(self):
- stats_aggregator = stats_ops.StatsAggregator()
- dataset = dataset_ops.Dataset.range(10).map(
- lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(
- 0).apply(stats_ops.set_stats_aggregator(stats_aggregator))
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
- summary_t = stats_aggregator.get_summary()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for i in range(10):
- self.assertAllEqual(
- np.array([i] * i, dtype=np.int64), sess.run(next_element))
- summary_str = sess.run(summary_t)
- self._assertSummaryHasScalarValue(summary_str,
- "Prefetch::buffer_capacity", 0)
- self._assertSummaryHasScalarValue(summary_str, "Prefetch::buffer_size",
- 0)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testFilteredElementsStats(self):
- stats_aggregator = stats_ops.StatsAggregator()
- dataset = dataset_ops.Dataset.range(101).filter(
- lambda x: math_ops.equal(math_ops.mod(x, 3), 0)).apply(
- stats_ops.set_stats_aggregator(stats_aggregator))
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
- summary_t = stats_aggregator.get_summary()
-
- with self.test_session() as sess:
- sess.run(iterator.initializer)
- for i in range(34):
- self.assertEqual(i * 3, sess.run(next_element))
- if i is not 0:
- self._assertSummaryHasScalarValue(
- sess.run(summary_t), "Filter::dropped_elements", float(i * 2))
- self._assertSummaryHasScalarValue(
- sess.run(summary_t), "Filter::filtered_elements", float(i + 1))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
- self._assertSummaryHasScalarValue(
- sess.run(summary_t), "Filter::dropped_elements", 67.0)
- self._assertSummaryHasScalarValue(
- sess.run(summary_t), "Filter::filtered_elements", 34.0)
-
- def testReinitialize(self):
- stats_aggregator = stats_ops.StatsAggregator()
- dataset = dataset_ops.Dataset.range(100).apply(
- stats_ops.latency_stats("record_latency")).apply(
- stats_ops.set_stats_aggregator(stats_aggregator))
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
- summary_t = stats_aggregator.get_summary()
-
- with self.cached_session() as sess:
- for j in range(5):
- sess.run(iterator.initializer)
- for i in range(100):
- self.assertEqual(i, sess.run(next_element))
- self._assertSummaryHasCount(
- sess.run(summary_t), "record_latency", float((j * 100) + i + 1))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
- self._assertSummaryHasCount(
- sess.run(summary_t), "record_latency", (j + 1) * 100.0)
-
- def testNoAggregatorRegistered(self):
- dataset = dataset_ops.Dataset.range(100).apply(
- stats_ops.latency_stats("record_latency"))
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for i in range(100):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testMultipleTags(self):
- stats_aggregator = stats_ops.StatsAggregator()
- dataset = dataset_ops.Dataset.range(100).apply(
- stats_ops.latency_stats("record_latency")).apply(
- stats_ops.latency_stats("record_latency_2")).apply(
- stats_ops.set_stats_aggregator(stats_aggregator))
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
- summary_t = stats_aggregator.get_summary()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for i in range(100):
- self.assertEqual(i, sess.run(next_element))
- self._assertSummaryHasCount(
- sess.run(summary_t), "record_latency", float(i + 1))
- self._assertSummaryHasCount(
- sess.run(summary_t), "record_latency_2", float(i + 1))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
- self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
- self._assertSummaryHasCount(
- sess.run(summary_t), "record_latency_2", 100.0)
-
- def testRepeatedTags(self):
- stats_aggregator = stats_ops.StatsAggregator()
- dataset = dataset_ops.Dataset.range(100).apply(
- stats_ops.latency_stats("record_latency")).apply(
- stats_ops.latency_stats("record_latency")).apply(
- stats_ops.set_stats_aggregator(stats_aggregator))
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
- summary_t = stats_aggregator.get_summary()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for i in range(100):
- self.assertEqual(i, sess.run(next_element))
- self._assertSummaryHasCount(
- sess.run(summary_t), "record_latency", float(2 * (i + 1)))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
- self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
-
- def testMultipleIteratorsSameAggregator(self):
- stats_aggregator = stats_ops.StatsAggregator()
- dataset = dataset_ops.Dataset.range(100).apply(
- stats_ops.latency_stats("record_latency")).apply(
- stats_ops.set_stats_aggregator(stats_aggregator))
- iterator_0 = dataset.make_initializable_iterator()
- iterator_1 = dataset.make_initializable_iterator()
- next_element = iterator_0.get_next() + iterator_1.get_next()
- summary_t = stats_aggregator.get_summary()
-
- with self.cached_session() as sess:
- sess.run([iterator_0.initializer, iterator_1.initializer])
- for i in range(100):
- self.assertEqual(i * 2, sess.run(next_element))
- self._assertSummaryHasCount(
- sess.run(summary_t), "record_latency", float(2 * (i + 1)))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
- self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
deleted file mode 100644
index 80f2625927..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
+++ /dev/null
@@ -1,71 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Base class for testing the input pipeline statistics gathering ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-
-from tensorflow.core.framework import summary_pb2
-from tensorflow.python.data.kernel_tests import test_base
-
-
-class StatsDatasetTestBase(test_base.DatasetTestBase):
- """Base class for testing statistics gathered in `StatsAggregator`."""
-
- def _assertSummaryContains(self, summary_str, tag):
- summary_proto = summary_pb2.Summary()
- summary_proto.ParseFromString(summary_str)
- for value in summary_proto.value:
- if tag == value.tag:
- return
- self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
-
- def _assertSummaryHasCount(self, summary_str, tag, expected_value):
- summary_proto = summary_pb2.Summary()
- summary_proto.ParseFromString(summary_str)
- for value in summary_proto.value:
- if tag == value.tag:
- self.assertEqual(expected_value, value.histo.num)
- return
- self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
-
- def _assertSummaryHasRange(self, summary_str, tag, min_value, max_value):
- summary_proto = summary_pb2.Summary()
- summary_proto.ParseFromString(summary_str)
- for value in summary_proto.value:
- if tag == value.tag:
- self.assertLessEqual(min_value, value.histo.min)
- self.assertGreaterEqual(max_value, value.histo.max)
- return
- self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
-
- def _assertSummaryHasSum(self, summary_str, tag, expected_value):
- summary_proto = summary_pb2.Summary()
- summary_proto.ParseFromString(summary_str)
- for value in summary_proto.value:
- if tag == value.tag:
- self.assertEqual(expected_value, value.histo.sum)
- return
- self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
-
- def _assertSummaryHasScalarValue(self, summary_str, tag, expected_value):
- summary_proto = summary_pb2.Summary()
- summary_proto.ParseFromString(summary_str)
- for value in summary_proto.value:
- if tag == value.tag:
- self.assertEqual(expected_value, value.simple_value)
- return
- self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
deleted file mode 100644
index 08de3a9143..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
+++ /dev/null
@@ -1,91 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline statistics gathering ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import threading
-
-from absl.testing import parameterized
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import threadpool
-from tensorflow.contrib.data.python.ops import unique
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.ops import script_ops
-from tensorflow.python.platform import test
-
-
-class OverrideThreadpoolDatasetTest(test_base.DatasetTestBase,
- parameterized.TestCase):
-
- @parameterized.named_parameters(
- ("1", 1, None),
- ("2", 2, None),
- ("3", 4, None),
- ("4", 8, None),
- ("5", 16, None),
- ("6", 4, -1),
- ("7", 4, 0),
- ("8", 4, 1),
- ("9", 4, 4),
- )
- def testNumThreads(self, num_threads, max_intra_op_parallelism):
-
- def get_thread_id(_):
- # Python creates a dummy thread object to represent the current
- # thread when called from an "alien" thread (such as a
- # `PrivateThreadPool` thread in this case). It does not include
- # the TensorFlow-given display name, but it has a unique
- # identifier that maps one-to-one with the underlying OS thread.
- return np.array(threading.current_thread().ident).astype(np.int64)
-
- dataset = (
- dataset_ops.Dataset.range(1000).map(
- lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64),
- num_parallel_calls=32).apply(unique.unique()))
-
- dataset = threadpool.override_threadpool(
- dataset,
- threadpool.PrivateThreadPool(
- num_threads,
- max_intra_op_parallelism=max_intra_op_parallelism,
- display_name="private_thread_pool_%d" % num_threads))
-
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- thread_ids = []
- try:
- while True:
- thread_ids.append(sess.run(next_element))
- except errors.OutOfRangeError:
- pass
- self.assertEqual(len(thread_ids), len(set(thread_ids)))
- self.assertGreater(len(thread_ids), 0)
- # NOTE(mrry): We don't control the thread pool scheduling, and
- # so cannot guarantee that all of the threads in the pool will
- # perform work.
- self.assertLessEqual(len(thread_ids), num_threads)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
deleted file mode 100644
index 8856ce5afb..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
+++ /dev/null
@@ -1,83 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.data.python.ops import unique
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.platform import test
-from tensorflow.python.util import compat
-
-
-class UniqueDatasetTest(test_base.DatasetTestBase):
-
- def _testSimpleHelper(self, dtype, test_cases):
- """Test the `unique()` transformation on a list of test cases.
-
- Args:
- dtype: The `dtype` of the elements in each test case.
- test_cases: A list of pairs of lists. The first component is the test
- input that will be passed to the transformation; the second component
- is the expected sequence of outputs from the transformation.
- """
-
- # The `current_test_case` will be updated when we loop over `test_cases`
- # below; declare it here so that the generator can capture it once.
- current_test_case = []
- dataset = dataset_ops.Dataset.from_generator(lambda: current_test_case,
- dtype).apply(unique.unique())
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for test_case, expected in test_cases:
- current_test_case = test_case
- sess.run(iterator.initializer)
- for element in expected:
- if dtype == dtypes.string:
- element = compat.as_bytes(element)
- self.assertAllEqual(element, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testSimpleInt(self):
- for dtype in [dtypes.int32, dtypes.int64]:
- self._testSimpleHelper(dtype, [
- ([], []),
- ([1], [1]),
- ([1, 1, 1, 1, 1, 1, 1], [1]),
- ([1, 2, 3, 4], [1, 2, 3, 4]),
- ([1, 2, 4, 3, 2, 1, 2, 3, 4], [1, 2, 4, 3]),
- ([[1], [1, 1], [1, 1, 1]], [[1], [1, 1], [1, 1, 1]]),
- ([[1, 1], [1, 1], [2, 2], [3, 3], [1, 1]], [[1, 1], [2, 2], [3, 3]]),
- ])
-
- def testSimpleString(self):
- self._testSimpleHelper(dtypes.string, [
- ([], []),
- (["hello"], ["hello"]),
- (["hello", "hello", "hello"], ["hello"]),
- (["hello", "world"], ["hello", "world"]),
- (["foo", "bar", "baz", "baz", "bar", "foo"], ["foo", "bar", "baz"]),
- ])
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
deleted file mode 100644
index 79134c7bc6..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
+++ /dev/null
@@ -1,527 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from absl.testing import parameterized
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import grouping
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import sparse_ops
-from tensorflow.python.platform import test
-
-
-class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
-
- def _structuredDataset(self, structure, shape, dtype):
- if structure is None:
- return dataset_ops.Dataset.from_tensors(
- array_ops.zeros(shape, dtype=dtype))
- else:
- return dataset_ops.Dataset.zip(
- tuple([
- self._structuredDataset(substructure, shape, dtype)
- for substructure in structure
- ]))
-
- def _structuredElement(self, structure, shape, dtype):
- if structure is None:
- return array_ops.zeros(shape, dtype=dtype)
- else:
- return tuple([
- self._structuredElement(substructure, shape, dtype)
- for substructure in structure
- ])
-
- def _assertEqual(self, xs, ys):
- self.assertEqual(type(xs), type(ys))
- if isinstance(xs, tuple) and isinstance(ys, tuple):
- self.assertEqual(len(xs), len(ys))
- for x, y in zip(xs, ys):
- self._assertEqual(x, y)
- elif isinstance(xs, np.ndarray) and isinstance(ys, np.ndarray):
- self.assertAllEqual(xs, ys)
- else:
- self.assertEqual(xs, ys)
-
- @parameterized.named_parameters(
- ("1", None, np.int32([]), dtypes.bool),
- ("2", None, np.int32([]), dtypes.int32),
- ("3", None, np.int32([]), dtypes.float32),
- ("4", None, np.int32([]), dtypes.string),
- ("5", None, np.int32([2]), dtypes.int32),
- ("6", None, np.int32([2, 2]), dtypes.int32),
- ("7", (None, None, None), np.int32([]), dtypes.int32),
- ("8", (None, (None, None)), np.int32([]), dtypes.int32),
- )
- def testWindowDatasetFlatMap(self, structure, shape, dtype):
- """Tests windowing by chaining it with flat map.
-
- Args:
- structure: the input structure
- shape: the input shape
- dtype: the input data type
- """
-
- def fn(*args):
- if len(args) == 1 and not isinstance(args[0], tuple):
- return args[0]
- return dataset_ops.Dataset.zip(
- tuple([fn(*arg) if isinstance(arg, tuple) else arg for arg in args]))
-
- dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply(
- grouping.window_dataset(5)).flat_map(fn)
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- expected = sess.run(self._structuredElement(structure, shape, dtype))
- for _ in range(5):
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", None, np.int32([]), dtypes.bool),
- ("2", None, np.int32([]), dtypes.int32),
- ("3", None, np.int32([]), dtypes.float32),
- ("4", None, np.int32([]), dtypes.string),
- ("5", None, np.int32([2]), dtypes.int32),
- ("6", None, np.int32([2, 2]), dtypes.int32),
- ("7", (None, None, None), np.int32([]), dtypes.int32),
- ("8", (None, (None, None)), np.int32([]), dtypes.int32),
- )
- def testWindowDatasetBatchDense(self, structure, shape, dtype):
- """Tests batching of dense tensor windows.
-
- Args:
- structure: the input structure
- shape: the input shape
- dtype: the input data type
- """
-
- def fn(*args):
- if len(args) == 1 and not isinstance(args[0], tuple):
- return batching.batch_window(args[0])
-
- return tuple([
- fn(*arg) if isinstance(arg, tuple) else batching.batch_window(arg)
- for arg in args
- ])
-
- dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply(
- grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- expected = sess.run(
- self._structuredElement(structure, np.concatenate(
- ([5], shape), axis=0), dtype))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", np.int32([])),
- ("2", np.int32([1])),
- ("3", np.int32([1, 2, 3])),
- )
- def testWindowDatasetBatchDenseDynamicShape(self, shape):
- """Tests batching of dynamically shaped dense tensor windows.
-
- Args:
- shape: the input shape
- """
-
- shape_t = array_ops.placeholder(dtypes.int32)
- dataset = dataset_ops.Dataset.from_tensors(
- array_ops.zeros(shape_t)).repeat(5).apply(
- grouping.window_dataset(5)).apply(
- grouping._map_x_dataset(batching.batch_window))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op, {shape_t: shape})
- expected = sess.run(
- self._structuredElement(None, np.concatenate(([5], shape), axis=0),
- dtypes.int32))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- def _make_dense_to_sparse_fn(self, is_scalar):
-
- def dense_to_sparse_scalar(tensor):
- indices = [[]]
- values = array_ops.expand_dims(tensor, 0)
- shape = []
- return sparse_tensor.SparseTensorValue(indices, values, shape)
-
- def dense_to_sparse_non_scalar(tensor):
- indices = array_ops.where(array_ops.ones_like(tensor, dtype=dtypes.bool))
- values = array_ops.gather_nd(tensor, indices)
- shape = array_ops.shape(tensor, out_type=dtypes.int64)
- return sparse_tensor.SparseTensorValue(indices, values, shape)
-
- if is_scalar:
- return dense_to_sparse_scalar
- return dense_to_sparse_non_scalar
-
- def _structuredSparseDataset(self, structure, shape, dtype):
- dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test
- if structure is None:
- return dataset_ops.Dataset.from_tensors(
- dense_to_sparse(array_ops.zeros(shape, dtype=dtype)))
- else:
- return dataset_ops.Dataset.zip(
- tuple([
- self._structuredSparseDataset(substructure, shape, dtype)
- for substructure in structure
- ]))
-
- def _structuredSparseElement(self, structure, shape, dtype):
- dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test
- if structure is None:
- return dense_to_sparse(array_ops.zeros(shape, dtype=dtype))
- else:
- return tuple([
- self._structuredSparseElement(substructure, shape, dtype)
- for substructure in structure
- ])
-
- @parameterized.named_parameters(
- ("1", None, np.int32([]), dtypes.bool),
- ("2", None, np.int32([]), dtypes.int32),
- ("3", None, np.int32([]), dtypes.float32),
- ("4", None, np.int32([]), dtypes.string),
- ("5", None, np.int32([2]), dtypes.int32),
- ("6", None, np.int32([2, 2]), dtypes.int32),
- ("7", (None, None, None), np.int32([]), dtypes.int32),
- ("8", (None, (None, None)), np.int32([]), dtypes.int32),
- )
- def testWindowDatasetBatchSparse(self, structure, shape, dtype):
- """Tests batching of sparse tensor windows.
-
- Args:
- structure: the input structure
- shape: the input shape
- dtype: the input data type
- """
-
- def fn(*args):
- if len(args) == 1 and not isinstance(args[0], tuple):
- return batching.batch_window(args[0])
-
- return tuple([
- fn(*arg) if isinstance(arg, tuple) else batching.batch_window(arg)
- for arg in args
- ])
-
- dataset = self._structuredSparseDataset(
- structure, shape, dtype).repeat(5).apply(
- grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- expected = sess.run(
- self._structuredSparseElement(structure,
- np.concatenate(([5], shape), axis=0),
- dtype))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", np.int32([])),
- ("2", np.int32([1])),
- ("3", np.int32([1, 2, 3])),
- )
- def testWindowDatasetBatchSparseDynamicShape(self, shape):
- """Tests batching of dynamically shaped sparse tensor windows.
-
- Args:
- shape: the input shape
- """
-
- shape_t = array_ops.placeholder(dtypes.int32)
- dataset = dataset_ops.Dataset.from_tensors(array_ops.zeros(shape_t)).map(
- self._make_dense_to_sparse_fn(len(shape) == 0)).repeat(5).apply( # pylint: disable=g-explicit-length-test
- grouping.window_dataset(5)).apply(
- grouping._map_x_dataset(batching.batch_window))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op, {shape_t: shape})
- expected = sess.run(
- self._structuredSparseElement(None,
- np.concatenate(([5], shape), axis=0),
- dtypes.int32))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- def _structuredRaggedDataset(self, structure, shapes, dtype):
-
- if structure is None:
- return dataset_ops.Dataset.from_tensor_slices(shapes).map(
- lambda shape: array_ops.zeros(shape, dtype=dtype))
- else:
- return dataset_ops.Dataset.zip(
- tuple([
- self._structuredRaggedDataset(substructure, shapes, dtype)
- for substructure in structure
- ]))
-
- @parameterized.named_parameters(
- ("1", None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]),
- ("2", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- ("3", None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]),
- ("4", None, np.int32([[1], [2], [3]]), dtypes.string, [-1]),
- ("5", None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
- ("6", None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]),
- ("7", (None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- ("8", (None,
- (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- ("9", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- ("10", None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])),
- )
- def testWindowDatasetPaddedBatchDense(self, structure, shapes, dtype,
- padded_shape):
- """Tests padded batching of dense tensor windows.
-
- Args:
- structure: the input structure
- shapes: the input shapes
- dtype: the input data type
- padded_shape: the shape to pad the output to
- """
-
- def fn(*args):
- if len(args) == 1 and not isinstance(args[0], tuple):
- return batching.padded_batch_window(args[0], padded_shape)
-
- return tuple([
- fn(*arg) if isinstance(arg, tuple) else batching.padded_batch_window(
- arg, padded_shape) for arg in args
- ])
-
- dataset = self._structuredRaggedDataset(structure, shapes, dtype).apply(
- grouping.window_dataset(len(shapes))).apply(
- grouping._map_x_dataset(fn))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape)
- expected = sess.run(
- self._structuredElement(
- structure,
- np.concatenate((np.int32([len(shapes)]), expected_shape)), dtype))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", np.int32([[1], [2], [3]]), [-1]),
- ("2", np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
- ("3", np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
- )
- def testWindowDatasetPaddedBatchDenseDynamicShape(self, shapes, padded_shape):
- """Tests padded batching of dynamically shaped dense tensor windows.
-
- Args:
- shapes: the input shapes
- padded_shape: the shape to pad the output to
- """
-
- shapes_t = array_ops.placeholder(dtypes.int32)
- dataset = dataset_ops.Dataset.from_tensor_slices(shapes_t).map(
- lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).apply(
- grouping.window_dataset(len(shapes))).apply(
- grouping._map_x_dataset(
- lambda x: batching.padded_batch_window(x, padded_shape)))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op, {shapes_t: shapes})
- expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape)
- expected = sess.run(
- self._structuredElement(
- None, np.concatenate((np.int32([len(shapes)]), expected_shape)),
- dtypes.int32))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", np.int32([[1]]), np.int32([0])),
- ("2", np.int32([[10], [20]]), np.int32([15])),
- )
- def testWindowDatasetPaddedBatchDenseInvalid(self, shapes, padded_shape):
- """Tests invalid padded batching of dense tensor windows.
-
- Args:
- shapes: the input shapes
- padded_shape: the shape to pad the output to
- """
-
- dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map(
- lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).apply(
- grouping.window_dataset(len(shapes))).apply(
- grouping._map_x_dataset(
- lambda x: batching.padded_batch_window(x, padded_shape)))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
-
- def _structuredRaggedSparseDataset(self, structure, shapes, dtype):
-
- def map_fn(shape):
- dense_to_sparse = self._make_dense_to_sparse_fn(False)
- return dense_to_sparse(array_ops.zeros(shape, dtype=dtype))
-
- if structure is None:
- return dataset_ops.Dataset.from_tensor_slices(shapes).map(map_fn)
- else:
- return dataset_ops.Dataset.zip(
- tuple([
- self._structuredRaggedSparseDataset(substructure, shapes, dtype)
- for substructure in structure
- ]))
-
- def _structuredRaggedSparseElement(self, structure, shapes, dtype,
- padded_shape):
- if structure is None:
- dense_shape = np.maximum(np.amax(shapes, axis=0), padded_shape)
- values = []
- for shape in shapes:
- dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test
- sparse = dense_to_sparse(array_ops.zeros(shape, dtype=dtype))
- padded_sparse = sparse_tensor.SparseTensor(sparse.indices,
- sparse.values, dense_shape)
- reshaped_sparse = sparse_ops.sparse_reshape(
- padded_sparse,
- array_ops.concat([np.array([1], dtype=np.int64), dense_shape], 0))
- values.append(reshaped_sparse)
- return sparse_ops.sparse_concat(0, values)
- else:
- return tuple([
- self._structuredRaggedSparseElement(substructure, shapes, dtype,
- padded_shape)
- for substructure in structure
- ])
-
- @parameterized.named_parameters(
- ("1", None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]),
- ("2", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- ("3", None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]),
- ("4", None, np.int64([[1], [2], [3]]), dtypes.string, [-1]),
- ("5", None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
- ("6", None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]),
- ("7", (None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- ("8", (None,
- (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- ("9", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- ("10", None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])),
- )
- def testWindowDatasetPaddedBatchSparse(self, structure, shapes, dtype,
- padded_shape):
- """Tests padded batching of sparse tensor windows.
-
- Args:
- structure: the input structure
- shapes: the input shapes
- dtype: the input data type
- padded_shape: the shape to pad the output to
- """
-
- def fn(*args):
- if len(args) == 1 and not isinstance(args[0], tuple):
- return batching.padded_batch_window(args[0], padded_shape)
-
- return tuple([
- fn(*arg) if isinstance(arg, tuple) else batching.padded_batch_window(
- arg, padded_shape) for arg in args
- ])
-
- dataset = self._structuredRaggedSparseDataset(
- structure, shapes, dtype).apply(grouping.window_dataset(
- len(shapes))).apply(grouping._map_x_dataset(fn))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- expected = sess.run(
- self._structuredRaggedSparseElement(structure, shapes, dtype,
- padded_shape))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", np.int64([[1], [2], [3]]), [-1]),
- ("2", np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
- ("3", np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
- )
- def testWindowDatasetPaddedBatchSparseDynamicShape(self, shapes,
- padded_shape):
- """Tests padded batching of dynamically shaped sparse tensor windows.
-
- Args:
- shapes: the input shapes
- padded_shape: the shape to pad the output to
- """
-
- shapes_t = array_ops.placeholder(dtypes.int32)
- dataset = dataset_ops.Dataset.from_tensor_slices(shapes_t).map(
- lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).map(
- self._make_dense_to_sparse_fn(False)
- ).apply(grouping.window_dataset(len(shapes))).apply(
- grouping._map_x_dataset(
- lambda x: batching.padded_batch_window(x, padded_shape)))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op, {shapes_t: shapes})
- expected = sess.run(
- self._structuredRaggedSparseElement(None, shapes, dtypes.int32,
- padded_shape))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", np.int64([[1]]), [0]),
- ("2", np.int64([[10], [20]]), [15]),
- )
- def testWindowDatasetPaddedBatchSparseInvalid(self, shapes, padded_shape):
- """Tests invalid padded batching of sparse tensor windows.
-
- Args:
- shapes: the input shapes
- padded_shape: the shape to pad the output to
- """
-
- dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map(
- lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).map(
- self._make_dense_to_sparse_fn(False)
- ).apply(grouping.window_dataset(len(shapes))).apply(
- grouping._map_x_dataset(
- lambda x: batching.padded_batch_window(x, padded_shape)))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
deleted file mode 100644
index fca546a570..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
+++ /dev/null
@@ -1,118 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-
-from tensorflow.contrib.data.python.ops import writers
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.ops import readers
-from tensorflow.python.framework import dtypes
-from tensorflow.python.lib.io import python_io
-from tensorflow.python.lib.io import tf_record
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-from tensorflow.python.util import compat
-
-
-class TFRecordWriterTest(test_base.DatasetTestBase):
-
- def setUp(self):
- super(TFRecordWriterTest, self).setUp()
- self._num_records = 7
- self.filename = array_ops.placeholder(dtypes.string, shape=[])
- self.compression_type = array_ops.placeholder_with_default("", shape=[])
-
- input_dataset = readers.TFRecordDataset([self.filename],
- self.compression_type)
- self.writer = writers.TFRecordWriter(
- self._outputFilename(), self.compression_type).write(input_dataset)
-
- def _record(self, i):
- return compat.as_bytes("Record %d" % (i))
-
- def _createFile(self, options=None):
- filename = self._inputFilename()
- writer = python_io.TFRecordWriter(filename, options)
- for i in range(self._num_records):
- writer.write(self._record(i))
- writer.close()
- return filename
-
- def _inputFilename(self):
- return os.path.join(self.get_temp_dir(), "tf_record.in.txt")
-
- def _outputFilename(self):
- return os.path.join(self.get_temp_dir(), "tf_record.out.txt")
-
- def testWrite(self):
- with self.cached_session() as sess:
- sess.run(
- self.writer, feed_dict={
- self.filename: self._createFile(),
- })
- for i, r in enumerate(tf_record.tf_record_iterator(self._outputFilename())):
- self.assertAllEqual(self._record(i), r)
-
- def testWriteZLIB(self):
- options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.ZLIB)
- with self.cached_session() as sess:
- sess.run(
- self.writer,
- feed_dict={
- self.filename: self._createFile(options),
- self.compression_type: "ZLIB",
- })
- for i, r in enumerate(
- tf_record.tf_record_iterator(self._outputFilename(), options=options)):
- self.assertAllEqual(self._record(i), r)
-
- def testWriteGZIP(self):
- options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP)
- with self.cached_session() as sess:
- sess.run(
- self.writer,
- feed_dict={
- self.filename: self._createFile(options),
- self.compression_type: "GZIP",
- })
- for i, r in enumerate(
- tf_record.tf_record_iterator(self._outputFilename(), options=options)):
- self.assertAllEqual(self._record(i), r)
-
- def testFailDataset(self):
- with self.assertRaises(TypeError):
- writers.TFRecordWriter(self._outputFilename(),
- self.compression_type).write("whoops")
-
- def testFailDType(self):
- input_dataset = dataset_ops.Dataset.from_tensors(10)
- with self.assertRaises(TypeError):
- writers.TFRecordWriter(self._outputFilename(),
- self.compression_type).write(input_dataset)
-
- def testFailShape(self):
- input_dataset = dataset_ops.Dataset.from_tensors([["hello"], ["world"]])
- with self.assertRaises(TypeError):
- writers.TFRecordWriter(self._outputFilename(),
- self.compression_type).write(input_dataset)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index 5cd1ed542b..34dc2379d0 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -16,10 +16,7 @@ py_library(
srcs = ["counter.py"],
srcs_version = "PY2AND3",
deps = [
- ":scan_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/experimental/ops:counter",
],
)
@@ -28,12 +25,7 @@ py_library(
srcs = ["get_single_element.py"],
srcs_version = "PY2AND3",
deps = [
- ":grouping",
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
- "//third_party/py/numpy",
+ "//tensorflow/python/data/experimental/ops:get_single_element",
],
)
@@ -44,10 +36,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:training",
- "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/experimental/ops:iterator_ops",
],
)
@@ -58,15 +47,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:random_seed",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:random_ops",
],
)
@@ -79,7 +60,6 @@ py_library(
deps = [
":batching",
":interleave_ops",
- ":optimization",
":parsing_ops",
":shuffle_ops",
"//tensorflow/python:constant_op",
@@ -91,6 +71,7 @@ py_library(
"//tensorflow/python:platform",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:readers",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:readers",
"//tensorflow/python/data/util:convert",
@@ -106,7 +87,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/experimental/ops:shuffle_ops",
],
)
@@ -125,6 +106,7 @@ py_library(
"//tensorflow/python:math_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_util",
+ "//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:convert",
"//tensorflow/python/data/util:nest",
@@ -138,8 +120,7 @@ py_library(
srcs = ["enumerate_ops.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dtypes",
- "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/experimental/ops:enumerate_ops",
],
)
@@ -148,10 +129,7 @@ py_library(
srcs = ["error_ops.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:experimental_dataset_ops_gen",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:error_ops",
],
)
@@ -160,16 +138,7 @@ py_library(
srcs = ["grouping.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:function",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:grouping",
],
)
@@ -178,30 +147,7 @@ py_library(
srcs = ["interleave_ops.py"],
srcs_version = "PY2AND3",
deps = [
- ":random_ops",
- "//tensorflow/contrib/stateless",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:experimental_dataset_ops_gen",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:util",
- "//tensorflow/python/data/ops:readers",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
- ],
-)
-
-py_library(
- name = "optimization",
- srcs = ["optimization.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/python:dtypes",
- "//tensorflow/python:experimental_dataset_ops_gen",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
],
)
@@ -210,25 +156,7 @@ py_library(
srcs = ["parsing_ops.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- ],
-)
-
-py_library(
- name = "map_defun",
- srcs = ["map_defun.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/experimental/ops:parsing_ops",
],
)
@@ -237,18 +165,7 @@ py_library(
srcs = ["resampling.py"],
srcs_version = "PY2AND3",
deps = [
- ":batching",
- ":interleave_ops",
- ":scan_ops",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:logging_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
+ "//tensorflow/python/data/experimental/ops:resampling",
],
)
@@ -257,12 +174,7 @@ py_library(
srcs = ["scan_ops.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:function",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:scan_ops",
],
)
@@ -282,31 +194,11 @@ py_library(
)
py_library(
- name = "stats_ops",
- srcs = ["stats_ops.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
- ],
-)
-
-py_library(
name = "threadpool",
srcs = ["threadpool.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:experimental_dataset_ops_gen",
- "//tensorflow/python:resource_variable_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
- "//tensorflow/python/eager:context",
+ "//tensorflow/python/data/experimental/ops:threadpool",
],
)
@@ -317,11 +209,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dtypes",
- "//tensorflow/python:experimental_dataset_ops_gen",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:unique",
],
)
@@ -332,20 +220,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dtypes",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_library(
- name = "indexed_dataset_ops",
- srcs = ["indexed_dataset_ops.py"],
- deps = [
- "//tensorflow/python:experimental_dataset_ops_gen",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:writers",
],
)
@@ -353,11 +228,7 @@ py_library(
name = "prefetching_ops",
srcs = ["prefetching_ops.py"],
deps = [
- "//tensorflow/python:experimental_dataset_ops_gen",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:prefetching_ops",
],
)
@@ -370,17 +241,14 @@ py_library(
":error_ops",
":get_single_element",
":grouping",
- ":indexed_dataset_ops",
":interleave_ops",
- ":map_defun",
- ":optimization",
":prefetching_ops",
+ ":random_ops",
":readers",
":resampling",
":scan_ops",
":shuffle_ops",
":sliding",
- ":stats_ops",
":threadpool",
":unique",
":writers",
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index 7a0f221284..8c60459ca8 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -17,134 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import get_single_element
-from tensorflow.contrib.data.python.ops import grouping
from tensorflow.contrib.framework import with_shape
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import convert
+from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import gen_array_ops
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import sparse_ops
from tensorflow.python.util import deprecation
-def batch_window(dataset):
- """Batches a window of tensors.
-
- Args:
- dataset: the input dataset.
-
- Returns:
- A `Tensor` representing the batch of the entire input dataset.
- """
- if isinstance(dataset.output_classes, tuple):
- raise TypeError("Input dataset expected to have a single component")
- if dataset.output_classes is ops.Tensor:
- return _batch_dense_window(dataset)
- elif dataset.output_classes is sparse_tensor.SparseTensor:
- return _batch_sparse_window(dataset)
- else:
- raise TypeError("Unsupported dataset type: %s" % dataset.output_classes)
-
-
-def _batch_dense_window(dataset):
- """Batches a window of dense tensors."""
-
- def key_fn(_):
- return np.int64(0)
-
- def shape_init_fn(_):
- return array_ops.shape(first_element)
-
- def shape_reduce_fn(state, value):
- check_ops.assert_equal(state, array_ops.shape(value))
- return state
-
- def finalize_fn(state):
- return state
-
- if dataset.output_shapes.is_fully_defined():
- shape = dataset.output_shapes
- else:
- first_element = get_single_element.get_single_element(dataset.take(1))
- shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn,
- finalize_fn)
- shape = get_single_element.get_single_element(
- dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer)))
-
- def batch_init_fn(_):
- batch_shape = array_ops.concat([[0], shape], 0)
- return gen_array_ops.empty(batch_shape, dtype=dataset.output_types)
-
- def batch_reduce_fn(state, value):
- return array_ops.concat([state, [value]], 0)
-
- batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
- return get_single_element.get_single_element(
- dataset.apply(grouping.group_by_reducer(key_fn, batch_reducer)))
-
-
-def _batch_sparse_window(dataset):
- """Batches a window of sparse tensors."""
-
- def key_fn(_):
- return np.int64(0)
-
- def shape_init_fn(_):
- return first_element.dense_shape
-
- def shape_reduce_fn(state, value):
- check_ops.assert_equal(state, value.dense_shape)
- return state
-
- def finalize_fn(state):
- return state
-
- if dataset.output_shapes.is_fully_defined():
- shape = dataset.output_shapes
- else:
- first_element = get_single_element.get_single_element(dataset.take(1))
- shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn,
- finalize_fn)
- shape = get_single_element.get_single_element(
- dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer)))
-
- def batch_init_fn(_):
- indices_shape = array_ops.concat([[0], [array_ops.size(shape) + 1]], 0)
- return sparse_tensor.SparseTensor(
- indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64),
- values=constant_op.constant([], shape=[0], dtype=dataset.output_types),
- dense_shape=array_ops.concat(
- [np.array([0], dtype=np.int64),
- math_ops.cast(shape, dtypes.int64)], 0))
-
- def batch_reduce_fn(state, value):
- return sparse_ops.sparse_concat(0, [state, value])
-
- def reshape_fn(value):
- return sparse_ops.sparse_reshape(
- value,
- array_ops.concat([np.array([1], dtype=np.int64), value.dense_shape], 0))
-
- batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
- return get_single_element.get_single_element(
- dataset.map(reshape_fn).apply(
- grouping.group_by_reducer(key_fn, batch_reducer)))
-
-
+@deprecation.deprecated(
+ None, "Use `tf.data.experimental.dense_to_sparse_batch(...)`.")
def dense_to_sparse_batch(batch_size, row_shape):
"""A transformation that batches ragged elements into `tf.SparseTensor`s.
@@ -187,201 +67,10 @@ def dense_to_sparse_batch(batch_size, row_shape):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- return _DenseToSparseBatchDataset(dataset, batch_size, row_shape)
-
- return _apply_fn
-
-
-def padded_batch_window(dataset, padded_shape, padding_value=None):
- """Batches a window of tensors with padding.
-
- Args:
- dataset: the input dataset.
- padded_shape: (Optional.) `tf.TensorShape` or `tf.int64` vector tensor-like
- object representing the shape to which the input elements should be padded
- prior to batching. Any unknown dimensions (e.g. `tf.Dimension(None)` in a
- `tf.TensorShape` or `-1` in a tensor-like object) will be padded to the
- maximum size of that dimension in each batch.
- padding_value: (Optional.) A scalar-shaped `tf.Tensor`, representing the
- padding value to use. Defaults are `0` for numeric types and the empty
- string for string types. If `dataset` contains `tf.SparseTensor`, this
- value is ignored.
-
- Returns:
- A `Tensor` representing the batch of the entire input dataset.
-
- Raises:
- ValueError: if invalid arguments are provided.
- """
- if not issubclass(dataset.output_classes,
- (ops.Tensor, sparse_tensor.SparseTensor)):
- raise TypeError("Input dataset expected to have a single tensor component")
- if issubclass(dataset.output_classes, (ops.Tensor)):
- return _padded_batch_dense_window(dataset, padded_shape, padding_value)
- elif issubclass(dataset.output_classes, (sparse_tensor.SparseTensor)):
- if padding_value is not None:
- raise ValueError("Padding value not allowed for sparse tensors")
- return _padded_batch_sparse_window(dataset, padded_shape)
- else:
- raise TypeError("Unsupported dataset type: %s" % dataset.output_classes)
-
-
-def _padded_batch_dense_window(dataset, padded_shape, padding_value=None):
- """Batches a window of dense tensors with padding."""
-
- padded_shape = math_ops.cast(
- convert.partial_shape_to_tensor(padded_shape), dtypes.int32)
-
- def key_fn(_):
- return np.int64(0)
-
- def max_init_fn(_):
- return padded_shape
-
- def max_reduce_fn(state, value):
- """Computes the maximum shape to pad to."""
- condition = math_ops.reduce_all(
- math_ops.logical_or(
- math_ops.less_equal(array_ops.shape(value), padded_shape),
- math_ops.equal(padded_shape, -1)))
- assert_op = control_flow_ops.Assert(condition, [
- "Actual shape greater than padded shape: ",
- array_ops.shape(value), padded_shape
- ])
- with ops.control_dependencies([assert_op]):
- return math_ops.maximum(state, array_ops.shape(value))
-
- def finalize_fn(state):
- return state
-
- # Compute the padded shape.
- max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn)
- padded_shape = get_single_element.get_single_element(
- dataset.apply(grouping.group_by_reducer(key_fn, max_reducer)))
-
- if padding_value is None:
- if dataset.output_types == dtypes.string:
- padding_value = ""
- elif dataset.output_types == dtypes.bool:
- padding_value = False
- elif dataset.output_types == dtypes.variant:
- raise TypeError("Unable to create padding for field of type 'variant'")
- else:
- padding_value = 0
-
- def batch_init_fn(_):
- batch_shape = array_ops.concat(
- [np.array([0], dtype=np.int32), padded_shape], 0)
- return gen_array_ops.empty(batch_shape, dtype=dataset.output_types)
-
- def batch_reduce_fn(state, value):
- return array_ops.concat([state, [value]], 0)
-
- def pad_fn(value):
- shape = array_ops.shape(value)
- left = array_ops.zeros_like(shape)
- right = padded_shape - shape
- return array_ops.pad(
- value, array_ops.stack([left, right], 1), constant_values=padding_value)
-
- batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
- return get_single_element.get_single_element(
- dataset.map(pad_fn).apply(
- grouping.group_by_reducer(key_fn, batch_reducer)))
-
-
-def _padded_batch_sparse_window(dataset, padded_shape):
- """Batches a window of sparse tensors with padding."""
-
- def key_fn(_):
- return np.int64(0)
-
- def max_init_fn(_):
- return convert.partial_shape_to_tensor(padded_shape)
-
- def max_reduce_fn(state, value):
- """Computes the maximum shape to pad to."""
- condition = math_ops.reduce_all(
- math_ops.logical_or(
- math_ops.less_equal(value.dense_shape, padded_shape),
- math_ops.equal(padded_shape, -1)))
- assert_op = control_flow_ops.Assert(condition, [
- "Actual shape greater than padded shape: ", value.dense_shape,
- padded_shape
- ])
- with ops.control_dependencies([assert_op]):
- return math_ops.maximum(state, value.dense_shape)
-
- def finalize_fn(state):
- return state
-
- # Compute the padded shape.
- max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn)
- padded_shape = get_single_element.get_single_element(
- dataset.apply(grouping.group_by_reducer(key_fn, max_reducer)))
-
- def batch_init_fn(_):
- indices_shape = array_ops.concat([[0], [array_ops.size(padded_shape) + 1]],
- 0)
- return sparse_tensor.SparseTensor(
- indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64),
- values=constant_op.constant([], shape=[0], dtype=dataset.output_types),
- dense_shape=array_ops.concat(
- [np.array([0], dtype=np.int64), padded_shape], 0))
-
- def batch_reduce_fn(state, value):
- padded_value = sparse_tensor.SparseTensor(
- indices=value.indices, values=value.values, dense_shape=padded_shape)
- reshaped_value = sparse_ops.sparse_reshape(
- padded_value,
- array_ops.concat(
- [np.array([1], dtype=np.int64), padded_value.dense_shape], 0))
- return sparse_ops.sparse_concat(0, [state, reshaped_value])
-
- reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
- return get_single_element.get_single_element(
- dataset.apply(grouping.group_by_reducer(key_fn, reducer)))
-
-
-class _UnbatchDataset(dataset_ops.UnaryDataset):
- """A dataset that splits the elements of its input into multiple elements."""
-
- def __init__(self, input_dataset):
- """See `unbatch()` for more details."""
- super(_UnbatchDataset, self).__init__(input_dataset)
- flat_shapes = nest.flatten(input_dataset.output_shapes)
- if any(s.ndims == 0 for s in flat_shapes):
- raise ValueError("Cannot unbatch an input with scalar components.")
- known_batch_dim = tensor_shape.Dimension(None)
- for s in flat_shapes:
- try:
- known_batch_dim = known_batch_dim.merge_with(s[0])
- except ValueError:
- raise ValueError("Cannot unbatch an input whose components have "
- "different batch sizes.")
- self._input_dataset = input_dataset
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.unbatch_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return nest.map_structure(lambda s: s[1:],
- self._input_dataset.output_shapes)
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
+ return batching.dense_to_sparse_batch(batch_size, row_shape)
+@deprecation.deprecated(None, "Use `tf.data.experimental.unbatch()`.")
def unbatch():
"""Splits elements of a dataset into multiple elements on the batch dimension.
@@ -403,39 +92,7 @@ def unbatch():
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- if not sparse.any_sparse(dataset.output_classes):
- return _UnbatchDataset(dataset)
-
- # NOTE(mrry): We must ensure that any SparseTensors in `dataset`
- # are normalized to the rank-1 dense representation, so that the
- # sparse-oblivious unbatching logic will slice them
- # appropriately. This leads to a somewhat inefficient re-encoding step
- # for all SparseTensor components.
- # TODO(mrry): Consider optimizing this in future
- # if it turns out to be a bottleneck.
- def normalize(arg, *rest):
- if rest:
- return sparse.serialize_many_sparse_tensors((arg,) + rest)
- else:
- return sparse.serialize_many_sparse_tensors(arg)
-
- normalized_dataset = dataset.map(normalize)
-
- # NOTE(mrry): Our `map()` has lost information about the sparseness
- # of any SparseTensor components, so re-apply the structure of the
- # original dataset.
- restructured_dataset = _RestructuredDataset(
- normalized_dataset,
- dataset.output_types,
- dataset.output_shapes,
- dataset.output_classes,
- allow_unsafe_cast=True)
- return _UnbatchDataset(restructured_dataset)
-
- return _apply_fn
+ return batching.unbatch()
@deprecation.deprecated(
@@ -514,135 +171,8 @@ def padded_batch_and_drop_remainder(batch_size,
return _apply_fn
-class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s."""
-
- def __init__(self, input_dataset, batch_size, row_shape):
- """See `Dataset.dense_to_sparse_batch()` for more details."""
- super(_DenseToSparseBatchDataset, self).__init__(input_dataset)
- if not isinstance(input_dataset.output_types, dtypes.DType):
- raise TypeError("DenseToSparseDataset requires an input whose elements "
- "have a single component, whereas the input has %r." %
- input_dataset.output_types)
- self._input_dataset = input_dataset
- self._batch_size = batch_size
- self._row_shape = row_shape
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.dense_to_sparse_batch_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._batch_size,
- row_shape=convert.partial_shape_to_tensor(self._row_shape),
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return sparse_tensor.SparseTensor
-
- @property
- def output_shapes(self):
- return tensor_shape.vector(None).concatenate(self._row_shape)
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
-
-class _RestructuredDataset(dataset_ops.UnaryDataset):
- """An internal helper for changing the structure and shape of a dataset."""
-
- def __init__(self,
- dataset,
- output_types,
- output_shapes=None,
- output_classes=None,
- allow_unsafe_cast=False):
- """Creates a new dataset with the given output types and shapes.
-
- The given `dataset` must have a structure that is convertible:
- * `dataset.output_types` must be the same as `output_types` module nesting.
- * Each shape in `dataset.output_shapes` must be compatible with each shape
- in `output_shapes` (if given).
-
- Note: This helper permits "unsafe casts" for shapes, equivalent to using
- `tf.Tensor.set_shape()` where domain-specific knowledge is available.
-
- Args:
- dataset: A `Dataset` object.
- output_types: A nested structure of `tf.DType` objects.
- output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
- If omitted, the shapes will be inherited from `dataset`.
- output_classes: (Optional.) A nested structure of class types.
- If omitted, the class types will be inherited from `dataset`.
- allow_unsafe_cast: (Optional.) If `True`, the caller may switch the
- reported output types and shapes of the restructured dataset, e.g. to
- switch a sparse tensor represented as `tf.variant` to its user-visible
- type and shape.
-
- Raises:
- ValueError: If either `output_types` or `output_shapes` is not compatible
- with the structure of `dataset`.
- """
- super(_RestructuredDataset, self).__init__(dataset)
- self._input_dataset = dataset
-
- if not allow_unsafe_cast:
- # Validate that the types are compatible.
- output_types = nest.map_structure(dtypes.as_dtype, output_types)
- flat_original_types = nest.flatten(dataset.output_types)
- flat_new_types = nest.flatten(output_types)
- if flat_original_types != flat_new_types:
- raise ValueError(
- "Dataset with output types %r cannot be restructured to have "
- "output types %r" % (dataset.output_types, output_types))
-
- self._output_types = output_types
-
- if output_shapes is None:
- # Inherit shapes from the original `dataset`.
- self._output_shapes = nest.pack_sequence_as(output_types,
- nest.flatten(
- dataset.output_shapes))
- else:
- if not allow_unsafe_cast:
- # Validate that the shapes are compatible.
- nest.assert_same_structure(output_types, output_shapes)
- flat_original_shapes = nest.flatten(dataset.output_shapes)
- flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)
-
- for original_shape, new_shape in zip(flat_original_shapes,
- flat_new_shapes):
- if not original_shape.is_compatible_with(new_shape):
- raise ValueError(
- "Dataset with output shapes %r cannot be restructured to have "
- "incompatible output shapes %r" % (dataset.output_shapes,
- output_shapes))
- self._output_shapes = nest.map_structure_up_to(
- output_types, tensor_shape.as_shape, output_shapes)
- if output_classes is None:
- # Inherit class types from the original `dataset`.
- self._output_classes = nest.pack_sequence_as(output_types,
- nest.flatten(
- dataset.output_classes))
- else:
- self._output_classes = output_classes
-
- def _as_variant_tensor(self):
- return self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
-
- @property
- def output_classes(self):
- return self._output_classes
-
- @property
- def output_types(self):
- return self._output_types
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
-
+# TODO(b/116817045): Move this to `tf.data.experimental` when the `with_shape()`
+# function is available in the core.
def assert_element_shape(expected_shapes):
"""Assert the shape of this `Dataset`.
@@ -687,7 +217,8 @@ def assert_element_shape(expected_shapes):
def _apply_fn(dataset):
output_shapes = _merge_output_shapes(dataset.output_shapes,
expected_shapes)
- return _RestructuredDataset(
+ # pylint: disable=protected-access
+ return batching._RestructuredDataset(
dataset.map(_check_shape),
dataset.output_types,
output_shapes=output_shapes,
@@ -696,49 +227,7 @@ def assert_element_shape(expected_shapes):
return _apply_fn
-class _MapAndBatchDataset(dataset_ops.MapDataset):
- """A `Dataset` that maps a function over a batch of elements."""
-
- def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls,
- drop_remainder):
- """See `Dataset.map()` for details."""
- super(_MapAndBatchDataset, self).__init__(input_dataset, map_func)
- self._batch_size_t = ops.convert_to_tensor(
- batch_size, dtype=dtypes.int64, name="batch_size")
- self._num_parallel_calls_t = ops.convert_to_tensor(
- num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
- self._drop_remainder_t = ops.convert_to_tensor(
- drop_remainder, dtype=dtypes.bool, name="drop_remainder")
-
- self._batch_size = batch_size
- self._drop_remainder = drop_remainder
-
- def _as_variant_tensor(self):
- # pylint: disable=protected-access
- input_resource = self._input_dataset._as_variant_tensor()
- return gen_dataset_ops.map_and_batch_dataset_v2(
- input_resource,
- self._map_func.captured_inputs,
- f=self._map_func,
- batch_size=self._batch_size_t,
- num_parallel_calls=self._num_parallel_calls_t,
- drop_remainder=self._drop_remainder_t,
- **dataset_ops.flat_structure(self))
- # pylint: enable=protected-access
-
- @property
- def output_shapes(self):
- dim = self._batch_size if self._drop_remainder else None
- return nest.pack_sequence_as(self._output_shapes, [
- tensor_shape.vector(dim).concatenate(s)
- for s in nest.flatten(self._output_shapes)
- ])
-
- @property
- def output_types(self):
- return self._output_types
-
-
+@deprecation.deprecated(None, "Use `tf.data.experimental.map_and_batch(...)`.")
def map_and_batch(map_func,
batch_size,
num_parallel_batches=None,
@@ -779,17 +268,5 @@ def map_and_batch(map_func,
ValueError: If both `num_parallel_batches` and `num_parallel_calls` are
specified.
"""
-
- if num_parallel_batches is None and num_parallel_calls is None:
- num_parallel_calls = batch_size
- elif num_parallel_batches is not None and num_parallel_calls is None:
- num_parallel_calls = batch_size * num_parallel_batches
- elif num_parallel_batches is not None and num_parallel_calls is not None:
- raise ValueError("The `num_parallel_batches` and `num_parallel_calls` "
- "arguments are mutually exclusive.")
-
- def _apply_fn(dataset):
- return _MapAndBatchDataset(dataset, map_func, batch_size,
- num_parallel_calls, drop_remainder)
-
- return _apply_fn
+ return batching.map_and_batch(map_func, batch_size, num_parallel_batches,
+ drop_remainder, num_parallel_calls)
diff --git a/tensorflow/contrib/data/python/ops/counter.py b/tensorflow/contrib/data/python/ops/counter.py
index 6ef65f9624..4ff5bf3e39 100644
--- a/tensorflow/contrib/data/python/ops/counter.py
+++ b/tensorflow/contrib/data/python/ops/counter.py
@@ -17,13 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import scan_ops
-
-from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.experimental.ops import counter
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None, "Use `tf.data.experimental.Counter(...)`.")
def Counter(start=0, step=1, dtype=dtypes.int64):
"""Creates a `Dataset` that counts from `start` in steps of size `step`.
@@ -46,8 +45,4 @@ def Counter(start=0, step=1, dtype=dtypes.int64):
Returns:
A `Dataset` of scalar `dtype` elements.
"""
- with ops.name_scope("counter"):
- start = ops.convert_to_tensor(start, dtype=dtype, name="start")
- step = ops.convert_to_tensor(step, dtype=dtype, name="step")
- return dataset_ops.Dataset.from_tensors(0).repeat(None).apply(
- scan_ops.scan(start, lambda state, _: (state + step, state)))
+ return counter.Counter(start, step, dtype)
diff --git a/tensorflow/contrib/data/python/ops/enumerate_ops.py b/tensorflow/contrib/data/python/ops/enumerate_ops.py
index 490281e0d2..a21da4d3ec 100644
--- a/tensorflow/contrib/data/python/ops/enumerate_ops.py
+++ b/tensorflow/contrib/data/python/ops/enumerate_ops.py
@@ -17,12 +17,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
+from tensorflow.python.data.experimental.ops import enumerate_ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.enumerate_dataset(...)`.")
def enumerate_dataset(start=0):
"""A transformation that enumerate the elements of a dataset.
@@ -49,10 +50,4 @@ def enumerate_dataset(start=0):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max
- return dataset_ops.Dataset.zip((dataset_ops.Dataset.range(start, max_value),
- dataset))
-
- return _apply_fn
+ return enumerate_ops.enumerate_dataset(start)
diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py
index f962e623ee..0559a2e09c 100644
--- a/tensorflow/contrib/data/python/ops/error_ops.py
+++ b/tensorflow/contrib/data/python/ops/error_ops.py
@@ -17,10 +17,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.ops import gen_experimental_dataset_ops
+from tensorflow.python.data.experimental.ops import error_ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None, "Use `tf.data.experimental.ignore_errors()`.")
def ignore_errors():
"""Creates a `Dataset` from another `Dataset` and silently ignores any errors.
@@ -43,34 +44,4 @@ def ignore_errors():
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- return _IgnoreErrorsDataset(dataset)
-
- return _apply_fn
-
-
-class _IgnoreErrorsDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that silently ignores errors when computing its input."""
-
- def __init__(self, input_dataset):
- """See `Dataset.ignore_errors()` for details."""
- super(_IgnoreErrorsDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
-
- def _as_variant_tensor(self):
- return gen_experimental_dataset_ops.experimental_ignore_errors_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
+ return error_ops.ignore_errors()
diff --git a/tensorflow/contrib/data/python/ops/get_single_element.py b/tensorflow/contrib/data/python/ops/get_single_element.py
index a6713b017a..58ad9eea90 100644
--- a/tensorflow/contrib/data/python/ops/get_single_element.py
+++ b/tensorflow/contrib/data/python/ops/get_single_element.py
@@ -19,13 +19,13 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.experimental.ops import get_single_element as experimental_get_single_element
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
-from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.get_single_element(...)`.")
def get_single_element(dataset):
"""Returns the single element in `dataset` as a nested structure of tensors.
@@ -61,18 +61,10 @@ def get_single_element(dataset):
InvalidArgumentError (at runtime): if `dataset` does not contain exactly
one element.
"""
- if not isinstance(dataset, dataset_ops.Dataset):
- raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
-
- nested_ret = nest.pack_sequence_as(
- dataset.output_types, gen_dataset_ops.dataset_to_single_element(
- dataset._as_variant_tensor(), # pylint: disable=protected-access
- **dataset_ops.flat_structure(dataset)))
- return sparse.deserialize_sparse_tensors(
- nested_ret, dataset.output_types, dataset.output_shapes,
- dataset.output_classes)
+ return experimental_get_single_element.get_single_element(dataset)
+@deprecation.deprecated(None, "Use `tf.data.Dataset.reduce(...)`.")
def reduce_dataset(dataset, reducer):
"""Returns the result of reducing the `dataset` using `reducer`.
@@ -90,11 +82,4 @@ def reduce_dataset(dataset, reducer):
if not isinstance(dataset, dataset_ops.Dataset):
raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
- # The sentinel dataset is used in case the reduced dataset is empty.
- sentinel_dataset = dataset_ops.Dataset.from_tensors(
- reducer.finalize_func(reducer.init_func(np.int64(0))))
- reduced_dataset = dataset.apply(
- grouping.group_by_reducer(lambda x: np.int64(0), reducer))
-
- return get_single_element(
- reduced_dataset.concatenate(sentinel_dataset).take(1))
+ return dataset.reduce(reducer.init_func(np.int64(0)), reducer.reduce_func)
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index 7cae33beb3..a99dc2f29a 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -17,20 +17,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import math_ops
+from tensorflow.python.data.experimental.ops import grouping
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.group_by_reducer(...)`.")
def group_by_reducer(key_func, reducer):
"""A transformation that groups elements and performs a reduction.
@@ -52,14 +45,11 @@ def group_by_reducer(key_func, reducer):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- return _GroupByReducerDataset(dataset, key_func, reducer)
-
- return _apply_fn
+ return grouping.group_by_reducer(key_func, reducer)
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.group_by_window(...)`.")
def group_by_window(key_func,
reduce_func,
window_size=None,
@@ -98,27 +88,12 @@ def group_by_window(key_func,
ValueError: if neither or both of {`window_size`, `window_size_func`} are
passed.
"""
- if (window_size is not None and window_size_func or
- not (window_size is not None or window_size_func)):
- raise ValueError("Must pass either window_size or window_size_func.")
-
- if window_size is not None:
-
- def constant_window_func(unused_key):
- return ops.convert_to_tensor(window_size, dtype=dtypes.int64)
-
- window_size_func = constant_window_func
-
- assert window_size_func is not None
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- return _GroupByWindowDataset(dataset, key_func, reduce_func,
- window_size_func)
-
- return _apply_fn
+ return grouping.group_by_window(key_func, reduce_func, window_size,
+ window_size_func)
+@deprecation.deprecated(
+ None, "Use `tf.data.experimental.bucket_by_sequence_length(...)`.")
def bucket_by_sequence_length(element_length_func,
bucket_boundaries,
bucket_batch_sizes,
@@ -163,342 +138,12 @@ def bucket_by_sequence_length(element_length_func,
Raises:
ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`.
"""
- with ops.name_scope("bucket_by_seq_length"):
- if len(bucket_batch_sizes) != (len(bucket_boundaries) + 1):
- raise ValueError(
- "len(bucket_batch_sizes) must equal len(bucket_boundaries) + 1")
-
- batch_sizes = constant_op.constant(bucket_batch_sizes, dtype=dtypes.int64)
-
- def element_to_bucket_id(*args):
- """Return int64 id of the length bucket for this element."""
- seq_length = element_length_func(*args)
-
- boundaries = list(bucket_boundaries)
- buckets_min = [np.iinfo(np.int32).min] + boundaries
- buckets_max = boundaries + [np.iinfo(np.int32).max]
- conditions_c = math_ops.logical_and(
- math_ops.less_equal(buckets_min, seq_length),
- math_ops.less(seq_length, buckets_max))
- bucket_id = math_ops.reduce_min(array_ops.where(conditions_c))
-
- return bucket_id
-
- def window_size_fn(bucket_id):
- # The window size is set to the batch size for this bucket
- window_size = batch_sizes[bucket_id]
- return window_size
-
- def make_padded_shapes(shapes, none_filler=None):
- padded = []
- for shape in nest.flatten(shapes):
- shape = tensor_shape.TensorShape(shape)
- shape = [
- none_filler if d.value is None else d
- for d in shape
- ]
- padded.append(shape)
- return nest.pack_sequence_as(shapes, padded)
-
- def batching_fn(bucket_id, grouped_dataset):
- """Batch elements in dataset."""
- batch_size = window_size_fn(bucket_id)
- if no_padding:
- return grouped_dataset.batch(batch_size)
- none_filler = None
- if pad_to_bucket_boundary:
- err_msg = ("When pad_to_bucket_boundary=True, elements must have "
- "length < max(bucket_boundaries).")
- check = check_ops.assert_less(
- bucket_id,
- constant_op.constant(len(bucket_batch_sizes) - 1,
- dtype=dtypes.int64),
- message=err_msg)
- with ops.control_dependencies([check]):
- boundaries = constant_op.constant(bucket_boundaries,
- dtype=dtypes.int64)
- bucket_boundary = boundaries[bucket_id]
- none_filler = bucket_boundary - 1
- shapes = make_padded_shapes(
- padded_shapes or grouped_dataset.output_shapes,
- none_filler=none_filler)
- return grouped_dataset.padded_batch(batch_size, shapes, padding_values)
-
- def _apply_fn(dataset):
- return dataset.apply(
- group_by_window(element_to_bucket_id, batching_fn,
- window_size_func=window_size_fn))
-
- return _apply_fn
-
-
-def _map_x_dataset(map_func):
- """A transformation that maps `map_func` across its input.
-
- This transformation is similar to `tf.data.Dataset.map`, but in addition to
- supporting dense and sparse tensor inputs, it also supports dataset inputs.
-
- Args:
- map_func: A function mapping a nested structure of tensors and/or datasets
- (having shapes and types defined by `self.output_shapes` and
- `self.output_types`) to another nested structure of tensors and/or
- datasets.
-
- Returns:
- Dataset: A `Dataset`.
- """
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- return _MapXDataset(dataset, map_func)
-
- return _apply_fn
-
-
-# TODO(b/115382007) Remove this once canned reducers move to core.
-def window_dataset(window_size):
- """A transformation that creates window datasets from the input dataset.
-
- The resulting datasets will contain `window_size` elements (or
- `N % window_size` for the last dataset if `window_size` does not divide the
- number of input elements `N` evenly).
-
- Args:
- window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
- consecutive elements of the input dataset to combine into a window.
-
- Returns:
- Dataset: A `Dataset`.
- """
-
- def _apply_fn(dataset):
- return dataset_ops.WindowDataset(
- dataset,
- size=window_size,
- shift=window_size,
- stride=1,
- drop_remainder=False)
-
- return _apply_fn
-
-
-class _GroupByReducerDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that groups its input and performs a reduction."""
-
- def __init__(self, input_dataset, key_func, reducer):
- """See `group_by_reducer()` for details."""
- super(_GroupByReducerDataset, self).__init__(input_dataset)
+ return grouping.bucket_by_sequence_length(
+ element_length_func, bucket_boundaries, bucket_batch_sizes, padded_shapes,
+ padding_values, pad_to_bucket_boundary, no_padding)
- self._input_dataset = input_dataset
- self._make_key_func(key_func, input_dataset)
- self._make_init_func(reducer.init_func)
- self._make_reduce_func(reducer.reduce_func, input_dataset)
- self._make_finalize_func(reducer.finalize_func)
-
- def _make_key_func(self, key_func, input_dataset):
- """Make wrapping Defun for key_func."""
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- key_func, "tf.contrib.data.group_by_reducer()", input_dataset)
- if not (
- wrapped_func.output_types == dtypes.int64 and
- wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
- raise ValueError(
- "`key_func` must return a single tf.int64 tensor. "
- "Got type=%s and shape=%s"
- % (wrapped_func.output_types, wrapped_func.output_shapes))
- self._key_func = wrapped_func.function
-
- def _make_init_func(self, init_func):
- """Make wrapping Defun for init_func."""
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- init_func, "tf.contrib.data.group_by_reducer()",
- input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(),
- input_types=dtypes.int64)
- self._init_func = wrapped_func.function
- self._state_classes = wrapped_func.output_classes
- self._state_shapes = wrapped_func.output_shapes
- self._state_types = wrapped_func.output_types
-
- def _make_reduce_func(self, reduce_func, input_dataset):
- """Make wrapping Defun for reduce_func."""
-
- # Iteratively rerun the reduce function until reaching a fixed point on
- # `self._state_shapes`.
- need_to_rerun = True
- while need_to_rerun:
-
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- reduce_func, "tf.contrib.data.group_by_reducer()",
- input_classes=(self._state_classes, input_dataset.output_classes),
- input_shapes=(self._state_shapes, input_dataset.output_shapes),
- input_types=(self._state_types, input_dataset.output_types),
- add_to_graph=False)
-
- # Extract and validate class information from the returned values.
- for new_state_class, state_class in zip(
- nest.flatten(wrapped_func.output_classes),
- nest.flatten(self._state_classes)):
- if not issubclass(new_state_class, state_class):
- raise TypeError(
- "The element classes for the new state must match the initial "
- "state. Expected %s; got %s." %
- (self._state_classes, wrapped_func.output_classes))
-
- # Extract and validate type information from the returned values.
- for new_state_type, state_type in zip(
- nest.flatten(wrapped_func.output_types),
- nest.flatten(self._state_types)):
- if new_state_type != state_type:
- raise TypeError(
- "The element types for the new state must match the initial "
- "state. Expected %s; got %s." %
- (self._state_types, wrapped_func.output_types))
-
- # Extract shape information from the returned values.
- flat_state_shapes = nest.flatten(self._state_shapes)
- flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes)
- weakened_state_shapes = [
- original.most_specific_compatible_shape(new)
- for original, new in zip(flat_state_shapes, flat_new_state_shapes)
- ]
-
- need_to_rerun = False
- for original_shape, weakened_shape in zip(flat_state_shapes,
- weakened_state_shapes):
- if original_shape.ndims is not None and (
- weakened_shape.ndims is None or
- original_shape.as_list() != weakened_shape.as_list()):
- need_to_rerun = True
- break
-
- if need_to_rerun:
- self._state_shapes = nest.pack_sequence_as(self._state_shapes,
- weakened_state_shapes)
-
- self._reduce_func = wrapped_func.function
- self._reduce_func.add_to_graph(ops.get_default_graph())
-
- def _make_finalize_func(self, finalize_func):
- """Make wrapping Defun for finalize_func."""
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- finalize_func, "tf.contrib.data.group_by_reducer()",
- input_classes=self._state_classes, input_shapes=self._state_shapes,
- input_types=self._state_types)
- self._finalize_func = wrapped_func.function
- self._output_classes = wrapped_func.output_classes
- self._output_shapes = wrapped_func.output_shapes
- self._output_types = wrapped_func.output_types
-
- @property
- def output_classes(self):
- return self._output_classes
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_types(self):
- return self._output_types
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.group_by_reducer_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._key_func.captured_inputs,
- self._init_func.captured_inputs,
- self._reduce_func.captured_inputs,
- self._finalize_func.captured_inputs,
- key_func=self._key_func,
- init_func=self._init_func,
- reduce_func=self._reduce_func,
- finalize_func=self._finalize_func,
- **dataset_ops.flat_structure(self))
-
-
-class _GroupByWindowDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that groups its input and performs a windowed reduction."""
-
- def __init__(self, input_dataset, key_func, reduce_func, window_size_func):
- """See `group_by_window()` for details."""
- super(_GroupByWindowDataset, self).__init__(input_dataset)
-
- self._input_dataset = input_dataset
-
- self._make_key_func(key_func, input_dataset)
- self._make_reduce_func(reduce_func, input_dataset)
- self._make_window_size_func(window_size_func)
-
- def _make_window_size_func(self, window_size_func):
- """Make wrapping Defun for window_size_func."""
- def window_size_func_wrapper(key):
- return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64)
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- window_size_func_wrapper, "tf.contrib.data.group_by_window()",
- input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(),
- input_types=dtypes.int64)
- if not (
- wrapped_func.output_types == dtypes.int64 and
- wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
- raise ValueError(
- "`window_size_func` must return a single tf.int64 scalar tensor.")
- self._window_size_func = wrapped_func.function
-
- def _make_key_func(self, key_func, input_dataset):
- """Make wrapping Defun for key_func."""
- def key_func_wrapper(*args):
- return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64)
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- key_func_wrapper, "tf.contrib.data.group_by_window()", input_dataset)
- if not (
- wrapped_func.output_types == dtypes.int64 and
- wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
- raise ValueError(
- "`key_func` must return a single tf.int64 scalar tensor.")
- self._key_func = wrapped_func.function
-
- def _make_reduce_func(self, reduce_func, input_dataset):
- """Make wrapping Defun for reduce_func."""
- nested_dataset = dataset_ops._NestedDatasetComponent(input_dataset) # pylint: disable=protected-access
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- reduce_func, "tf.contrib.data.reduce_by_window()",
- input_classes=(ops.Tensor, nested_dataset),
- input_shapes=(tensor_shape.scalar(), nested_dataset),
- input_types=(dtypes.int64, nested_dataset),
- experimental_nested_dataset_support=True)
- if not isinstance(
- wrapped_func.output_classes, dataset_ops._NestedDatasetComponent): # pylint: disable=protected-access
- raise TypeError("`reduce_func` must return a `Dataset` object.")
- self._output_classes = wrapped_func.output_classes.output_classes
- self._output_types = wrapped_func.output_types.output_types
- self._output_shapes = wrapped_func.output_shapes.output_shapes
- self._reduce_func = wrapped_func.function
-
- @property
- def output_classes(self):
- return self._output_classes
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_types(self):
- return self._output_types
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.group_by_window_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._key_func.captured_inputs,
- self._reduce_func.captured_inputs,
- self._window_size_func.captured_inputs,
- key_func=self._key_func,
- reduce_func=self._reduce_func,
- window_size_func=self._window_size_func,
- **dataset_ops.flat_structure(self))
-
-
-class Reducer(object):
+class Reducer(grouping.Reducer):
"""A reducer is used for reducing a set of elements.
A reducer is represented as a tuple of the three functions:
@@ -507,58 +152,6 @@ class Reducer(object):
3) finalization function: state => result
"""
+ @deprecation.deprecated(None, "Use `tf.data.experimental.Reducer(...)`.")
def __init__(self, init_func, reduce_func, finalize_func):
- self._init_func = init_func
- self._reduce_func = reduce_func
- self._finalize_func = finalize_func
-
- @property
- def init_func(self):
- return self._init_func
-
- @property
- def reduce_func(self):
- return self._reduce_func
-
- @property
- def finalize_func(self):
- return self._finalize_func
-
-
-class _MapXDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that maps a function over elements in its input."""
-
- def __init__(self, input_dataset, map_func):
- """See `map_x_dataset()` for details."""
- super(_MapXDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
-
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- map_func,
- "tf.contrib.data.map_x_dataset()",
- input_dataset,
- experimental_nested_dataset_support=True)
- self._output_classes = wrapped_func.output_classes
- self._output_shapes = wrapped_func.output_shapes
- self._output_types = wrapped_func.output_types
- self._map_func = wrapped_func.function
-
- def _as_variant_tensor(self):
- input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
- return gen_dataset_ops.map_dataset(
- input_t,
- self._map_func.captured_inputs,
- f=self._map_func,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._output_classes
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_types(self):
- return self._output_types
+ super(Reducer, self).__init__(init_func, reduce_func, finalize_func)
diff --git a/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py b/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
deleted file mode 100644
index 9c06474a2f..0000000000
--- a/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
+++ /dev/null
@@ -1,177 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Python wrappers for indexed datasets."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import abc
-
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
-
-
-class MaterializedIndexedDataset(object):
- """MaterializedIndexedDataset is highly experimental!
- """
-
- def __init__(self, materialized_resource, materializer, output_classes,
- output_types, output_shapes):
- self._materialized_resource = materialized_resource
- self._materializer = materializer
- self._output_classes = output_classes
- self._output_types = output_types
- self._output_shapes = output_shapes
-
- @property
- def initializer(self):
- if self._materializer is not None:
- return self._materializer
- raise ValueError("MaterializedDataset does not have a materializer")
-
- def get(self, index):
- """Get retrieves a value (or set of values) from the IndexedDataset.
-
- Args:
- index: A uint64 scalar or vector tensor with the indices to retrieve.
-
- Returns:
- A tensor containing the values corresponding to `index`.
- """
- # TODO(saeta): nest.pack_sequence_as(...)
- return ged_ops.experimental_indexed_dataset_get(
- self._materialized_resource,
- index,
- output_types=nest.flatten(
- sparse.as_dense_types(self._output_types, self._output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_types(self._output_shapes, self._output_classes)))
-
-
-class IndexedDataset(dataset_ops.Dataset):
- """IndexedDataset is highly experimental!
- """
-
- def __init__(self):
- pass
-
- def materialize(self, shared_name=None, container=None):
- """Materialize creates a MaterializedIndexedDataset.
-
- IndexedDatasets can be combined through operations such as TBD. Therefore,
- they are only materialized when absolutely required.
-
- Args:
- shared_name: a string for the shared name to use for the resource.
- container: a string for the container to store the resource.
-
- Returns:
- A MaterializedIndexedDataset.
- """
- if container is None:
- container = ""
- if shared_name is None:
- shared_name = ""
- materialized_resource = (
- ged_ops.experimental_materialized_index_dataset_handle(
- container=container,
- shared_name=shared_name,
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_types(self.output_shapes,
- self.output_classes))))
-
- with ops.colocate_with(materialized_resource):
- materializer = ged_ops.experimental_indexed_dataset_materialize(
- self._as_variant_tensor(), materialized_resource)
- return MaterializedIndexedDataset(materialized_resource, materializer,
- self.output_classes, self.output_types,
- self.output_shapes)
-
- @abc.abstractproperty
- def output_types(self):
- """Returns the type of each component of an element of this IndexedDataset.
-
- Returns:
- A nested structure of `tf.DType` objects corresponding to each component
- of an element of this IndexedDataset.
- """
- raise NotImplementedError("IndexedDataset.output_types")
-
- @abc.abstractproperty
- def output_classes(self):
- """Returns the class of each component of an element of this IndexedDataset.
-
- The expected values are `tf.Tensor` and `tf.SparseTensor`.
-
- Returns:
- A nested structure of Python `type` objects corresponding to each
- component of an element of this IndexedDataset.
- """
- raise NotImplementedError("IndexedDataset.output_classes")
-
- @abc.abstractproperty
- def output_shapes(self):
- """Returns the shape of each component of an element of this IndexedDataset.
-
- Returns:
- A nested structure of `tf.TensorShape` objects corresponding to each
- component of an element of this IndexedDataset.
- """
- raise NotImplementedError("IndexedDataset.output_shapes")
-
- @abc.abstractmethod
- def _as_variant_tensor(self):
- """Creates a `tf.variant` `tf.Tensor` representing this IndexedDataset.
-
- Returns:
- A scalar `tf.Tensor` of `tf.variant` type, which represents this
- IndexedDataset.
- """
- raise NotImplementedError("IndexedDataset._as_variant_tensor")
-
-
-class IdentityIndexedDataset(IndexedDataset):
- """IdentityIndexedDataset is a trivial indexed dataset used for testing.
- """
-
- def __init__(self, size):
- super(IdentityIndexedDataset, self).__init__()
- # TODO(saeta): Verify _size is a scalar!
- self._size = ops.convert_to_tensor(size, dtype=dtypes.uint64, name="size")
-
- @property
- def output_types(self):
- return dtypes.uint64
-
- @property
- def output_classes(self):
- return ops.Tensor
-
- @property
- def output_shapes(self):
- return tensor_shape.scalar()
-
- def _as_variant_tensor(self):
- return ged_ops.experimental_identity_indexed_dataset(self._size)
-
- def _inputs(self):
- return []
diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py
index 1ee9db1aa8..f50da4d429 100644
--- a/tensorflow/contrib/data/python/ops/interleave_ops.py
+++ b/tensorflow/contrib/data/python/ops/interleave_ops.py
@@ -17,20 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib import stateless
-from tensorflow.contrib.data.python.ops import random_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.ops import readers
-from tensorflow.python.data.util import nest
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_experimental_dataset_ops
-from tensorflow.python.ops import math_ops
+from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.parallel_interleave(...)`.")
def parallel_interleave(map_func,
cycle_length,
block_length=1,
@@ -80,12 +72,9 @@ def parallel_interleave(map_func,
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
- def _apply_fn(dataset):
- return readers.ParallelInterleaveDataset(
- dataset, map_func, cycle_length, block_length, sloppy,
- buffer_output_elements, prefetch_input_elements)
-
- return _apply_fn
+ return interleave_ops.parallel_interleave(
+ map_func, cycle_length, block_length, sloppy, buffer_output_elements,
+ prefetch_input_elements)
@deprecation.deprecated(
@@ -139,63 +128,12 @@ def sloppy_interleave(map_func, cycle_length, block_length=1):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
- def _apply_fn(dataset):
- return readers.ParallelInterleaveDataset(
- dataset,
- map_func,
- cycle_length,
- block_length,
- sloppy=True,
- buffer_output_elements=None,
- prefetch_input_elements=None)
-
- return _apply_fn
-
-
-class _DirectedInterleaveDataset(dataset_ops.Dataset):
- """A substitute for `Dataset.interleave()` on a fixed list of datasets."""
-
- def __init__(self, selector_input, data_inputs):
- self._selector_input = selector_input
- self._data_inputs = list(data_inputs)
-
- for data_input in data_inputs[1:]:
- if (data_input.output_types != data_inputs[0].output_types or
- data_input.output_classes != data_inputs[0].output_classes):
- raise TypeError("All datasets must have the same type and class.")
-
- def _as_variant_tensor(self):
- # pylint: disable=protected-access
- return (
- gen_experimental_dataset_ops.experimental_directed_interleave_dataset(
- self._selector_input._as_variant_tensor(), [
- data_input._as_variant_tensor()
- for data_input in self._data_inputs
- ], **dataset_ops.flat_structure(self)))
- # pylint: enable=protected-access
-
- def _inputs(self):
- return [self._selector_input] + self._data_inputs
-
- @property
- def output_classes(self):
- return self._data_inputs[0].output_classes
-
- @property
- def output_shapes(self):
- ret = self._data_inputs[0].output_shapes
- for data_input in self._data_inputs[1:]:
- ret = nest.pack_sequence_as(ret, [
- ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip(
- nest.flatten(ret), nest.flatten(data_input.output_shapes))
- ])
- return ret
-
- @property
- def output_types(self):
- return self._data_inputs[0].output_types
+ return interleave_ops.parallel_interleave(
+ map_func, cycle_length, block_length, sloppy=True)
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.sample_from_datasets(...)`.")
def sample_from_datasets(datasets, weights=None, seed=None):
"""Samples elements at random from the datasets in `datasets`.
@@ -219,64 +157,11 @@ def sample_from_datasets(datasets, weights=None, seed=None):
ValueError: If the `weights` argument is specified and does not match the
length of the `datasets` element.
"""
- num_datasets = len(datasets)
- if not isinstance(weights, dataset_ops.Dataset):
- if weights is None:
- # Select inputs with uniform probability.
- logits = [[1.0] * num_datasets]
-
- else:
- # Use the given `weights` as the probability of choosing the respective
- # input.
- weights = ops.convert_to_tensor(weights, name="weights")
- if weights.dtype not in (dtypes.float32, dtypes.float64):
- raise TypeError("`weights` must be convertible to a tensor of "
- "`tf.float32` or `tf.float64` elements.")
- if not weights.shape.is_compatible_with([num_datasets]):
- raise ValueError(
- "`weights` must be a vector of length `len(datasets)`.")
-
- # The `stateless_multinomial()` op expects log-probabilities, as opposed
- # to weights.
- logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0)
-
- # NOTE(mrry): We only specialize when `weights` is not a `Dataset`. When it
- # is a `Dataset`, it is possible that evaluating it has a side effect the
- # user depends on.
- if len(datasets) == 1:
- return datasets[0]
-
- def select_dataset_constant_logits(seed):
- return array_ops.squeeze(
- stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1])
-
- selector_input = dataset_ops.MapDataset(
- random_ops.RandomDataset(seed).batch(2),
- select_dataset_constant_logits,
- use_inter_op_parallelism=False)
-
- else:
- # Use each element of the given `weights` dataset as the probability of
- # choosing the respective input.
-
- # The `stateless_multinomial()` op expects log-probabilities, as opposed to
- # weights.
- logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits"))
-
- def select_dataset_varying_logits(logits, seed):
- return array_ops.squeeze(
- stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1])
-
- logits_and_seeds = dataset_ops.Dataset.zip(
- (logits_ds, random_ops.RandomDataset(seed).batch(2)))
- selector_input = dataset_ops.MapDataset(
- logits_and_seeds,
- select_dataset_varying_logits,
- use_inter_op_parallelism=False)
-
- return _DirectedInterleaveDataset(selector_input, datasets)
+ return interleave_ops.sample_from_datasets(datasets, weights, seed)
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.choose_from_datasets(...)`.")
def choose_from_datasets(datasets, choice_dataset):
"""Creates a dataset that deterministically chooses elements from `datasets`.
@@ -312,10 +197,4 @@ def choose_from_datasets(datasets, choice_dataset):
TypeError: If the `datasets` or `choice_dataset` arguments have the wrong
type.
"""
- if not (choice_dataset.output_types == dtypes.int64
- and choice_dataset.output_shapes.is_compatible_with(
- tensor_shape.scalar())
- and choice_dataset.output_classes == ops.Tensor):
- raise TypeError("`choice_dataset` must be a dataset of scalar "
- "`tf.int64` tensors.")
- return _DirectedInterleaveDataset(choice_dataset, datasets)
+ return interleave_ops.choose_from_datasets(datasets, choice_dataset)
diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py
index 18515e21ed..48c325c86f 100644
--- a/tensorflow/contrib/data/python/ops/iterator_ops.py
+++ b/tensorflow/contrib/data/python/ops/iterator_ops.py
@@ -16,15 +16,13 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.training import basic_session_run_hooks
-from tensorflow.python.training import checkpoint_management
-from tensorflow.python.training import saver as saver_lib
-from tensorflow.python.training import session_run_hook
+from tensorflow.python.data.experimental.ops import iterator_ops
+from tensorflow.python.util import deprecation
+
+@deprecation.deprecated(
+ None, "Use `tf.data.experimental.make_saveable_from_iterator(...)`.")
def make_saveable_from_iterator(iterator):
"""Returns a SaveableObject for saving/restore iterator state using Saver.
@@ -60,27 +58,10 @@ def make_saveable_from_iterator(iterator):
Note: Not all iterators support checkpointing yet. Attempting to save the
state of an unsupported iterator will throw an error.
"""
- return _Saveable(iterator._iterator_resource) # pylint: disable=protected-access
-
-
-class _Saveable(saver_lib.BaseSaverBuilder.SaveableObject):
- """SaveableObject for saving/restoring iterator state."""
+ return iterator_ops.make_saveable_from_iterator(iterator)
- def __init__(self, iterator_resource):
- serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource)
- specs = [
- saver_lib.BaseSaverBuilder.SaveSpec(serialized_iterator, "",
- iterator_resource.name + "-state")
- ]
- super(_Saveable, self).__init__(iterator_resource, specs,
- iterator_resource.name)
- def restore(self, restored_tensors, unused_restored_shapes):
- with ops.colocate_with(self.op):
- return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0])
-
-
-class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
+class CheckpointInputPipelineHook(iterator_ops.CheckpointInputPipelineHook):
"""Checkpoints input pipeline state every N steps or seconds.
This hook saves the state of the iterators in the `Graph` so that when
@@ -125,135 +106,7 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
collector when building the eval graph.
"""
+ @deprecation.deprecated(
+ None, "Use `tf.data.experimental.CheckpointInputPipelineHook(...)`.")
def __init__(self, estimator):
- """Initializes a `CheckpointInputPipelineHook`.
-
- Args:
- estimator: Estimator.
-
- Raises:
- ValueError: One of `save_steps` or `save_secs` should be set.
- ValueError: At most one of saver or scaffold should be set.
- """
- # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or
- # of the form "input_<task_type>_<task_id>.ckpt" for distributed pipelines.
- # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is
- # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix
- # to be different to avoid conflicts with the model checkpoint.
-
- # pylint: disable=protected-access
- checkpoint_prefix = "input"
- if estimator._config.num_worker_replicas > 1:
- # Distributed setting.
- suffix = "_{}_{}".format(estimator._config.task_type,
- estimator._config.task_id)
- checkpoint_prefix += suffix
- # pylint: enable=protected-access
-
- # We use a composition paradigm instead of inheriting from
- # `CheckpointSaverHook` because `Estimator` does an `isinstance` check
- # to check whether a `CheckpointSaverHook` is already present in the list
- # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook`
- # would thwart this behavior. This hook checkpoints *only the iterators*
- # and not the graph variables.
- self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook(
- estimator.model_dir,
- save_secs=estimator._config.save_checkpoints_secs, # pylint: disable=protected-access
- save_steps=estimator._config.save_checkpoints_steps, # pylint: disable=protected-access
- checkpoint_basename=checkpoint_prefix + ".ckpt")
-
- # Name for the protocol buffer file that will contain the list of most
- # recent checkpoints stored as a `CheckpointState` protocol buffer.
- # This file, kept in the same directory as the checkpoint files, is
- # automatically managed by the `Saver` to keep track of recent checkpoints.
- # The default name used by the `Saver` for this file is "checkpoint". Here
- # we use the name "checkpoint_<checkpoint_prefix>" so that in case the
- # `checkpoint_dir` is the same as the model checkpoint directory, there are
- # no conflicts during restore.
- self._latest_filename = "checkpoint_" + checkpoint_prefix
- self._first_run = True
-
- def begin(self):
- # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS`
- # collection if no `Saver` or `Scaffold` is provided.
- # pylint: disable=protected-access
- if (self._checkpoint_saver_hook._saver is None and
- self._checkpoint_saver_hook._scaffold is None):
- iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS)
- saveables = [_Saveable(i) for i in iterators]
- self._checkpoint_saver_hook._saver = _CustomSaver(saveables,
- self._latest_filename)
- # pylint: enable=protected-access
- self._checkpoint_saver_hook.begin()
-
- def _restore_or_save_initial_ckpt(self, session):
- # Ideally this should be run in after_create_session but is not for the
- # following reason:
- # Currently there is no way of enforcing an order of running the
- # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook`
- # is run *after* this hook. That is troublesome because
- # 1. If a checkpoint exists and this hook restores it, the initializer hook
- # will override it.
- # 2. If no checkpoint exists, this hook will try to save an initialized
- # iterator which will result in an exception.
- #
- # As a temporary fix we enter the following implicit contract between this
- # hook and the _DatasetInitializerHook.
- # 1. The _DatasetInitializerHook initializes the iterator in the call to
- # after_create_session.
- # 2. This hook saves the iterator on the first call to `before_run()`, which
- # is guaranteed to happen after `after_create_session()` of all hooks
- # have been run.
-
- # Check if there is an existing checkpoint. If so, restore from it.
- # pylint: disable=protected-access
- latest_checkpoint_path = checkpoint_management.latest_checkpoint(
- self._checkpoint_saver_hook._checkpoint_dir,
- latest_filename=self._latest_filename)
- if latest_checkpoint_path:
- self._checkpoint_saver_hook._get_saver().restore(session,
- latest_checkpoint_path)
- else:
- # The checkpoint saved here is the state at step "global_step".
- # Note: We do not save the GraphDef or MetaGraphDef here.
- global_step = session.run(self._checkpoint_saver_hook._global_step_tensor)
- self._checkpoint_saver_hook._save(session, global_step)
- self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step)
- # pylint: enable=protected-access
-
- def before_run(self, run_context):
- if self._first_run:
- self._restore_or_save_initial_ckpt(run_context.session)
- self._first_run = False
- return self._checkpoint_saver_hook.before_run(run_context)
-
- def after_run(self, run_context, run_values):
- self._checkpoint_saver_hook.after_run(run_context, run_values)
-
- def end(self, session):
- self._checkpoint_saver_hook.end(session)
-
-
-class _CustomSaver(saver_lib.Saver):
- """`Saver` with a different default `latest_filename`.
-
- This is used in the `CheckpointInputPipelineHook` to avoid conflicts with
- the model ckpt saved by the `CheckpointSaverHook`.
- """
-
- def __init__(self, var_list, latest_filename):
- super(_CustomSaver, self).__init__(var_list)
- self._latest_filename = latest_filename
-
- def save(self,
- sess,
- save_path,
- global_step=None,
- latest_filename=None,
- meta_graph_suffix="meta",
- write_meta_graph=True,
- write_state=True,
- strip_default_attrs=False):
- return super(_CustomSaver, self).save(
- sess, save_path, global_step, latest_filename or self._latest_filename,
- meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs)
+ super(CheckpointInputPipelineHook, self).__init__(estimator)
diff --git a/tensorflow/contrib/data/python/ops/map_defun.py b/tensorflow/contrib/data/python/ops/map_defun.py
deleted file mode 100644
index 3d0d0993c9..0000000000
--- a/tensorflow/contrib/data/python/ops/map_defun.py
+++ /dev/null
@@ -1,56 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Experimental API for optimizing `tf.data` pipelines."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import gen_dataset_ops
-
-
-def map_defun(fn, elems, output_dtypes, output_shapes):
- """Map a function on the list of tensors unpacked from `elems` on dimension 0.
-
- Args:
- fn: A function (`function.Defun`) that takes a list of tensors and returns
- another list of tensors. The output list has the same types as
- output_dtypes. The elements of the output list have the same dimension 0
- as `elems`, and the remaining dimensions correspond to those of
- `fn_output_shapes`.
- elems: A list of tensors.
- output_dtypes: A list of dtypes corresponding to the output types of the
- function.
- output_shapes: A list of `TensorShape`s corresponding to the output
- shapes from each invocation of the function on slices of inputs.
-
- Raises:
- ValueError: if any of the inputs are malformed.
-
- Returns:
- A list of `Tensor` objects with the same types as `output_dtypes`.
- """
- if not isinstance(elems, list):
- raise ValueError("`elems` must be a list of tensors.")
- if not isinstance(output_dtypes, list):
- raise ValueError("`output_dtypes` must be a list of tensors.")
- if not isinstance(output_shapes, list):
- raise ValueError("`output_shapes` must be a list of tensors.")
-
- elems = [ops.convert_to_tensor(e) for e in elems]
- output_shapes = [tensor_shape.TensorShape(s) for s in output_shapes]
- return gen_dataset_ops.map_defun(elems, output_dtypes, output_shapes, fn)
diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/contrib/data/python/ops/optimization.py
deleted file mode 100644
index 30348ede36..0000000000
--- a/tensorflow/contrib/data/python/ops/optimization.py
+++ /dev/null
@@ -1,171 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Experimental API for optimizing `tf.data` pipelines."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import gen_experimental_dataset_ops
-
-# A constant that can be used to enable auto-tuning.
-AUTOTUNE = -1
-
-
-# TODO(jsimsa): Support RE matching for both individual transformation (e.g. to
-# account for indexing) and transformation sequence.
-def assert_next(transformations):
- """A transformation that asserts which transformations happen next.
-
- Args:
- transformations: A `tf.string` vector `tf.Tensor` identifying the
- transformations that are expected to happen next.
-
- Returns:
- A `Dataset` transformation function, which can be passed to
- `tf.data.Dataset.apply`.
- """
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- return _AssertNextDataset(dataset, transformations)
-
- return _apply_fn
-
-
-def model():
- """A transformation that models performance.
-
- Returns:
- A `Dataset` transformation function, which can be passed to
- `tf.data.Dataset.apply`.
- """
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- return _ModelDataset(dataset)
-
- return _apply_fn
-
-
-def optimize(optimizations=None):
- """A transformation that applies optimizations.
-
- Args:
- optimizations: (Optional.) A `tf.string` vector `tf.Tensor` identifying
- optimizations to use. If not specified, the default set of optimizations
- is applied.
-
- Returns:
- A `Dataset` transformation function, which can be passed to
- `tf.data.Dataset.apply`.
- """
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- return _OptimizeDataset(dataset, optimizations)
-
- return _apply_fn
-
-
-class _AssertNextDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that asserts which transformations happen next."""
-
- def __init__(self, input_dataset, transformations):
- """See `assert_next()` for details."""
- super(_AssertNextDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- if transformations is None:
- raise ValueError("At least one transformation should be specified")
- self._transformations = ops.convert_to_tensor(
- transformations, dtype=dtypes.string, name="transformations")
-
- def _as_variant_tensor(self):
- return gen_experimental_dataset_ops.experimental_assert_next_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._transformations,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
-
-class _ModelDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that acts as an identity, and models performance."""
-
- def __init__(self, input_dataset):
- """See `optimize()` for details."""
- super(_ModelDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.model_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
-
-class _OptimizeDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that acts as an identity, and applies optimizations."""
-
- def __init__(self, input_dataset, optimizations):
- """See `optimize()` for details."""
- super(_OptimizeDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- if optimizations is None:
- optimizations = []
- self._optimizations = ops.convert_to_tensor(
- optimizations, dtype=dtypes.string, name="optimizations")
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.optimize_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._optimizations,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
diff --git a/tensorflow/contrib/data/python/ops/parsing_ops.py b/tensorflow/contrib/data/python/ops/parsing_ops.py
index cfbba701b0..3aeee9d8e4 100644
--- a/tensorflow/contrib/data/python/ops/parsing_ops.py
+++ b/tensorflow/contrib/data/python/ops/parsing_ops.py
@@ -17,92 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import parsing_ops
+from tensorflow.python.data.experimental.ops import parsing_ops
+from tensorflow.python.util import deprecation
-class _ParseExampleDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that parses `example` dataset into a `dict` dataset."""
-
- def __init__(self, input_dataset, features, num_parallel_calls):
- super(_ParseExampleDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- if not all(types == dtypes.string
- for types in nest.flatten(input_dataset.output_types)):
- raise TypeError("Input dataset should be a dataset of vectors of strings")
- self._num_parallel_calls = num_parallel_calls
- # pylint: disable=protected-access
- self._features = parsing_ops._prepend_none_dimension(features)
- # sparse_keys and dense_keys come back sorted here.
- (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults,
- dense_shapes) = parsing_ops._features_to_raw_params(
- self._features, [
- parsing_ops.VarLenFeature, parsing_ops.SparseFeature,
- parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature
- ])
- # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature.
- (_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes,
- dense_shape_as_shape) = parsing_ops._process_raw_parameters(
- None, dense_defaults, sparse_keys, sparse_types, dense_keys,
- dense_types, dense_shapes)
- # pylint: enable=protected-access
- self._sparse_keys = sparse_keys
- self._sparse_types = sparse_types
- self._dense_keys = dense_keys
- self._dense_defaults = dense_defaults_vec
- self._dense_shapes = dense_shapes
- self._dense_types = dense_types
- dense_output_shapes = [
- self._input_dataset.output_shapes.concatenate(shape)
- for shape in dense_shape_as_shape
- ]
- sparse_output_shapes = [
- self._input_dataset.output_shapes.concatenate([None])
- for _ in range(len(sparse_keys))
- ]
-
- self._output_shapes = dict(
- zip(self._dense_keys + self._sparse_keys,
- dense_output_shapes + sparse_output_shapes))
- self._output_types = dict(
- zip(self._dense_keys + self._sparse_keys,
- self._dense_types + self._sparse_types))
- self._output_classes = dict(
- zip(self._dense_keys + self._sparse_keys,
- [ops.Tensor for _ in range(len(self._dense_defaults))] +
- [sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys))
- ]))
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.parse_example_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._num_parallel_calls,
- self._dense_defaults,
- self._sparse_keys,
- self._dense_keys,
- self._sparse_types,
- self._dense_shapes,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_types(self):
- return self._output_types
-
- @property
- def output_classes(self):
- return self._output_classes
-
-
-# TODO(b/111553342): add arguments names and example names as well.
+@deprecation.deprecated(
+ None, "Use `tf.data.experimental.parse_example_dataset(...)`.")
def parse_example_dataset(features, num_parallel_calls=1):
"""A transformation that parses `Example` protos into a `dict` of tensors.
@@ -130,21 +50,4 @@ def parse_example_dataset(features, num_parallel_calls=1):
Raises:
ValueError: if features argument is None.
"""
- if features is None:
- raise ValueError("Missing: features was %s." % features)
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- out_dataset = _ParseExampleDataset(dataset, features, num_parallel_calls)
- if any([
- isinstance(feature, parsing_ops.SparseFeature)
- for _, feature in features.items()
- ]):
- # pylint: disable=protected-access
- # pylint: disable=g-long-lambda
- out_dataset = out_dataset.map(
- lambda x: parsing_ops._construct_sparse_tensors_for_sparse_features(
- features, x), num_parallel_calls=num_parallel_calls)
- return out_dataset
-
- return _apply_fn
+ return parsing_ops.parse_example_dataset(features, num_parallel_calls)
diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py
index 46f82e453a..adfb390cd9 100644
--- a/tensorflow/contrib/data/python/ops/prefetching_ops.py
+++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py
@@ -17,321 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import warnings
-
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
-from tensorflow.python.eager import context
-from tensorflow.python.framework import device as framework_device
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import functional_ops
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
-from tensorflow.python.ops import resource_variable_ops
-
-
-def function_buffering_resource(string_arg,
- target_device,
- f,
- buffer_size,
- output_types,
- container="",
- shared_name=None,
- name=None):
- """Creates a FunctionBufferingResource.
-
- A FunctionBufferingResource fills up a buffer by calling a function `f` on
- `target_device`. `f` should take in only a single string argument as input.
-
- Args:
- string_arg: The single string argument to the function.
- target_device: The device to run `f` on.
- f: The function to be executed.
- buffer_size: Size of the buffer to be populated.
- output_types: The output types generated by the function.
- container: (Optional) string. Defaults to "".
- shared_name: (Optional) string.
- name: (Optional) string to name the op.
-
- Returns:
- Handle to a FunctionBufferingResource.
- """
- if shared_name is None:
- shared_name = ""
- return ged_ops.experimental_function_buffering_resource(
- string_arg=string_arg,
- target_device=target_device,
- shared_name=shared_name,
- f=f,
- buffer_size=buffer_size,
- container=container,
- name=name,
- output_types=output_types)
-
-
-def function_buffering_resource_get_next(function_buffer_resource,
- output_types,
- name=None):
- return ged_ops.experimental_function_buffering_resource_get_next(
- function_buffer_resource=function_buffer_resource,
- output_types=output_types,
- name=name)
-
-
-def function_buffering_resource_reset(function_buffer_resource, name=None):
- return ged_ops.experimental_function_buffering_resource_reset(
- function_buffer_resource=function_buffer_resource, name=name)
-
-
-# pylint: disable=protected-access
-class _PrefetchToDeviceIterator(object):
- """A replacement for `tf.data.Iterator` that prefetches to another device.
-
- Args:
- input_dataset: The input dataset
- one_shot: If true, we make a one shot iterator that's already initialized.
- device: A fully specified device string where we want to prefetch to
- buffer_size: Size of the prefetching buffer.
- shared_name: (Optional.) If non-empty, the returned iterator will be
- shared under the given name across multiple sessions that share the
- same devices (e.g. when using a remote server).
-
- Returns:
- An Iterator type object.
- """
-
- def __init__(self,
- input_dataset,
- one_shot,
- device,
- buffer_size,
- shared_name=None):
- self._input_dataset = input_dataset
- self._get_next_call_count = 0
- self._one_shot = one_shot
- if shared_name is None:
- shared_name = ""
-
- if self._one_shot:
- self._input_iterator = input_dataset.make_one_shot_iterator()
- else:
- self._input_iterator = iterator_ops.Iterator.from_structure(
- self._input_dataset.output_types, self._input_dataset.output_shapes,
- shared_name, self._input_dataset.output_classes)
- input_iterator_handle = self._input_iterator.string_handle()
-
- @function.Defun(dtypes.string)
- def _prefetch_fn(handle):
- """Prefetches one element from `input_iterator`."""
- remote_iterator = iterator_ops.Iterator.from_string_handle(
- handle, self._input_iterator.output_types,
- self._input_iterator.output_shapes,
- self._input_iterator.output_classes)
- ret = remote_iterator.get_next()
- return nest.flatten(sparse.serialize_sparse_tensors(ret))
-
- iterator_device = ged_ops.experimental_iterator_get_device(
- self._input_iterator._iterator_resource)
-
- with ops.device(device):
- self._buffering_resource = function_buffering_resource(
- f=_prefetch_fn,
- target_device=iterator_device,
- string_arg=input_iterator_handle,
- buffer_size=buffer_size,
- shared_name=shared_name,
- output_types=nest.flatten(
- sparse.as_dense_types(self._input_dataset.output_types,
- self._input_dataset.output_classes)))
-
- if not self._one_shot:
- reset_op = function_buffering_resource_reset(self._buffering_resource)
- with ops.control_dependencies([reset_op]):
- self._initializer = self._input_iterator.make_initializer(
- self._input_dataset)
-
- def get_next(self, name=None):
- """See `tf.data.Iterator.get_next`."""
- self._get_next_call_count += 1
- if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
- warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)
-
- flat_ret = ged_ops.experimental_function_buffering_resource_get_next(
- self._buffering_resource,
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- name=name)
-
- ret = sparse.deserialize_sparse_tensors(
- nest.pack_sequence_as(self.output_types, flat_ret),
- self.output_types, self.output_shapes, self.output_classes)
-
- for tensor, shape in zip(
- nest.flatten(ret), nest.flatten(self.output_shapes)):
- if isinstance(tensor, ops.Tensor):
- tensor.set_shape(shape)
-
- return ret
-
- @property
- def initializer(self):
- if self._one_shot:
- raise NotImplementedError("Can't initialize a one_shot_iterator")
- return self._initializer
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
-
-class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator):
- """A replacement for `tf.data.Iterator` that prefetches to another device.
-
- Args:
- input_dataset: The input dataset
- one_shot: If true, we make a one shot iterator that's already initialized.
- device: A fully specified device string where we want to prefetch to
- buffer_size: Size of the prefetching buffer.
- shared_name: (Optional.) If non-empty, the returned iterator will be
- shared under the given name across multiple sessions that share the
- same devices (e.g. when using a remote server).
-
- Returns:
- An Iterator type object.
- """
-
- def __init__(self,
- input_dataset,
- device,
- buffer_size):
- with ops.device("/device:CPU:0"):
- super(_PrefetchToDeviceEagerIterator, self).__init__(input_dataset)
- input_iterator_handle = gen_dataset_ops.iterator_to_string_handle(
- self._resource)
-
- self._device = device
-
- @function.Defun(dtypes.string)
- def _prefetch_fn(handle):
- """Prefetches one element from `input_iterator`."""
- remote_iterator = iterator_ops.Iterator.from_string_handle(
- handle, self.output_types, self.output_shapes, self.output_classes)
- ret = remote_iterator.get_next()
- return nest.flatten(sparse.serialize_sparse_tensors(ret))
-
- _prefetch_fn.add_to_graph(None)
-
- with ops.device(device):
- self._buffering_resource = function_buffering_resource(
- f=_prefetch_fn,
- output_types=self._flat_output_types,
- target_device=ged_ops.experimental_iterator_get_device(
- self._resource),
- string_arg=input_iterator_handle,
- buffer_size=buffer_size,
- shared_name=iterator_ops._generate_shared_name(
- "function_buffer_resource"))
-
- def _next_internal(self):
- """Returns a nested structure of `tf.Tensor`s containing the next element.
- """
- # This runs in sync mode as iterators use an error status to communicate
- # that there is no more data to iterate over.
- # TODO(b/77291417): Fix
- with context.execution_mode(context.SYNC):
- with ops.device(self._device):
- ret = ged_ops.experimental_function_buffering_resource_get_next(
- function_buffer_resource=self._buffering_resource,
- output_types=self._flat_output_types)
- return sparse.deserialize_sparse_tensors(
- nest.pack_sequence_as(self._output_types, ret), self._output_types,
- self._output_shapes, self._output_classes)
-# pylint: enable=protected-access
-
-
-class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset):
- """A `Dataset` whose iterator prefetches elements to another device."""
-
- def __init__(self, input_dataset, device, buffer_size):
- super(_PrefetchToDeviceDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- self._device = device
- self._buffer_size = buffer_size if buffer_size is not None else 1
-
- # The static analysis cannot tell that the eager iterator's superclass has
- # a `next()` method.
- # pylint: disable=non-iterator-returned
- def __iter__(self):
- """Creates an `Iterator` for enumerating the elements of this dataset.
-
- The returned iterator implements the Python iterator protocol and therefore
- can only be used in eager mode.
-
- Returns:
- An `Iterator` over the elements of this dataset.
-
- Raises:
- RuntimeError: If eager execution is enabled.
- """
- if context.executing_eagerly():
- return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device,
- self._buffer_size)
- else:
- raise RuntimeError("dataset.__iter__() is only supported when eager "
- "execution is enabled.")
- # pylint: enable=non-iterator-returned
-
- def make_one_shot_iterator(self):
- if context.executing_eagerly():
- return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device,
- self._buffer_size)
- else:
- return _PrefetchToDeviceIterator(self._input_dataset, one_shot=True,
- device=self._device,
- buffer_size=self._buffer_size)
-
- def make_initializable_iterator(self, shared_name=None):
- return _PrefetchToDeviceIterator(
- self._input_dataset,
- one_shot=False,
- device=self._device,
- buffer_size=self._buffer_size,
- shared_name=shared_name)
-
- def _as_variant_tensor(self):
- # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset
- # transformation methods is called.
- # TODO(mrry): Investigate support for chaining further transformations after
- # the prefetch, including GPU support.
- raise NotImplementedError("`prefetch_to_device()` must be the last "
- "transformation in a dataset pipeline.")
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
+from tensorflow.python.data.experimental.ops import prefetching_ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.prefetch_to_device(...)`.")
def prefetch_to_device(device, buffer_size=None):
"""A transformation that prefetches dataset values to the given `device`.
@@ -347,12 +38,10 @@ def prefetch_to_device(device, buffer_size=None):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
- def _apply_fn(dataset):
- return _PrefetchToDeviceDataset(dataset, device, buffer_size)
-
- return _apply_fn
+ return prefetching_ops.prefetch_to_device(device, buffer_size)
+@deprecation.deprecated(None, "Use `tf.data.experimental.copy_to_device(...)`.")
def copy_to_device(target_device, source_device="/cpu:0"):
"""A transformation that copies dataset elements to the given `target_device`.
@@ -364,165 +53,4 @@ def copy_to_device(target_device, source_device="/cpu:0"):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- return _CopyToDeviceDataset(
- dataset, target_device=target_device, source_device=source_device)
-
- return _apply_fn
-
-
-# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate
-# all inputs to the Op are in host memory, thereby avoiding some unnecessary
-# Sends and Recvs.
-class _CopyToDeviceDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that copies elements to another device."""
-
- def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
- """Constructs a _CopyToDeviceDataset.
-
- Args:
- input_dataset: `Dataset` to be copied
- target_device: The name of the device to which elements would be copied.
- source_device: Device where input_dataset would be placed.
- """
- super(_CopyToDeviceDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- self._target_device = target_device
- spec = framework_device.DeviceSpec().from_string(self._target_device)
- self._is_gpu_target = (spec.device_type == "GPU")
- self._source_device_string = source_device
- self._source_device = ops.convert_to_tensor(source_device)
-
- self._flat_output_shapes = nest.flatten(
- sparse.as_dense_shapes(self._input_dataset.output_shapes,
- self._input_dataset.output_classes))
- self._flat_output_types = nest.flatten(
- sparse.as_dense_types(self._input_dataset.output_types,
- self._input_dataset.output_classes))
-
- @function.Defun()
- def _init_func():
- """Creates an iterator for the input dataset.
-
- Returns:
- A `string` tensor that encapsulates the iterator created.
- """
- # pylint: disable=protected-access
- ds_variant = self._input_dataset._as_variant_tensor()
- resource = gen_dataset_ops.anonymous_iterator(
- output_types=self._flat_output_types,
- output_shapes=self._flat_output_shapes)
- with ops.control_dependencies(
- [gen_dataset_ops.make_iterator(ds_variant, resource)]):
- return gen_dataset_ops.iterator_to_string_handle(resource)
-
- @function.Defun()
- def _remote_init_func():
- return functional_ops.remote_call(
- target=self._source_device,
- args=_init_func.captured_inputs,
- Tout=[dtypes.string],
- f=_init_func)
-
- self._init_func = _remote_init_func
- self._init_captured_args = _remote_init_func.captured_inputs
-
- @function.Defun(dtypes.string)
- def _next_func(string_handle):
- """Calls get_next for created iterator.
-
- Args:
- string_handle: An iterator string handle created by _init_func
- Returns:
- The elements generated from `input_dataset`
- """
- with ops.device(self._source_device_string):
- iterator = iterator_ops.Iterator.from_string_handle(
- string_handle, self.output_types, self.output_shapes,
- self.output_classes)
- ret = iterator.get_next()
- return nest.flatten(sparse.serialize_sparse_tensors(ret))
-
- @function.Defun(dtypes.string)
- def _remote_next_func(string_handle):
- return functional_ops.remote_call(
- target=self._source_device,
- args=[string_handle] + _next_func.captured_inputs,
- Tout=self._flat_output_types,
- f=_next_func)
-
- self._next_func = _remote_next_func
- self._next_captured_args = _remote_next_func.captured_inputs
-
- @function.Defun(dtypes.string)
- def _finalize_func(string_handle):
- """Destroys the iterator resource created.
-
- Args:
- string_handle: An iterator string handle created by _init_func
- Returns:
- Tensor constant 0
- """
- iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
- string_handle,
- output_types=self._flat_output_types,
- output_shapes=self._flat_output_shapes)
- with ops.control_dependencies([
- resource_variable_ops.destroy_resource_op(
- iterator_resource, ignore_lookup_error=True)]):
- return array_ops.constant(0, dtypes.int64)
-
- @function.Defun(dtypes.string)
- def _remote_finalize_func(string_handle):
- return functional_ops.remote_call(
- target=self._source_device,
- args=[string_handle] + _finalize_func.captured_inputs,
- Tout=[dtypes.int64],
- f=_finalize_func)
-
- self._finalize_func = _remote_finalize_func
- self._finalize_captured_args = _remote_finalize_func.captured_inputs
-
- g = ops.get_default_graph()
- _remote_init_func.add_to_graph(g)
- _remote_next_func.add_to_graph(g)
- _remote_finalize_func.add_to_graph(g)
- # pylint: enable=protected-scope
-
- # The one_shot_iterator implementation needs a 0 arg _make_dataset function
- # that thereby captures all the inputs required to create the dataset. Since
- # there are strings that are inputs to the GeneratorDataset which can't be
- # placed on a GPU, this fails for the GPU case. Therefore, disabling it for
- # GPU
- def make_one_shot_iterator(self):
- if self._is_gpu_target:
- raise ValueError("Cannot create a one shot iterator when using "
- "`tf.contrib.data.copy_to_device()` on GPU. Please use "
- "`Dataset.make_initializable_iterator()` instead.")
- else:
- return super(_CopyToDeviceDataset, self).make_one_shot_iterator()
-
- def _as_variant_tensor(self):
- with ops.device(self._target_device):
- return gen_dataset_ops.generator_dataset(
- self._init_captured_args,
- self._next_captured_args,
- self._finalize_captured_args,
- init_func=self._init_func,
- next_func=self._next_func,
- finalize_func=self._finalize_func,
- output_types=self._flat_output_types,
- output_shapes=self._flat_output_shapes)
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
+ return prefetching_ops.copy_to_device(target_device, source_device)
diff --git a/tensorflow/contrib/data/python/ops/random_ops.py b/tensorflow/contrib/data/python/ops/random_ops.py
index 344a0763c8..2c95125636 100644
--- a/tensorflow/contrib/data/python/ops/random_ops.py
+++ b/tensorflow/contrib/data/python/ops/random_ops.py
@@ -17,36 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import random_seed
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.data.experimental.ops import random_ops
+from tensorflow.python.util import deprecation
-class RandomDataset(dataset_ops.DatasetSource):
+class RandomDataset(random_ops.RandomDataset):
"""A `Dataset` of pseudorandom values."""
+ @deprecation.deprecated(
+ None, "Use `tf.data.experimental.RandomDataset(...)`.")
def __init__(self, seed=None):
- """A `Dataset` of pseudorandom values."""
- super(RandomDataset, self).__init__()
- self._seed, self._seed2 = random_seed.get_seed(seed)
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.random_dataset(
- seed=self._seed,
- seed2=self._seed2,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return ops.Tensor
-
- @property
- def output_shapes(self):
- return tensor_shape.scalar()
-
- @property
- def output_types(self):
- return dtypes.int64
+ super(RandomDataset, self).__init__(seed)
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index 360971e200..4601376dff 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -17,295 +17,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import collections
-import csv
-
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import interleave_ops
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.contrib.data.python.ops import parsing_ops
-from tensorflow.contrib.data.python.ops import shuffle_ops
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers as core_readers
-from tensorflow.python.data.util import convert
from tensorflow.python.data.util import nest
-from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.lib.io import file_io
-from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gen_experimental_dataset_ops
-from tensorflow.python.platform import gfile
from tensorflow.python.util import deprecation
-_ACCEPTABLE_CSV_TYPES = (dtypes.float32, dtypes.float64, dtypes.int32,
- dtypes.int64, dtypes.string)
-
-
-def _is_valid_int32(str_val):
- try:
- # Checks equality to prevent int32 overflow
- return dtypes.int32.as_numpy_dtype(str_val) == dtypes.int64.as_numpy_dtype(
- str_val)
- except (ValueError, OverflowError):
- return False
-
-
-def _is_valid_int64(str_val):
- try:
- dtypes.int64.as_numpy_dtype(str_val)
- return True
- except (ValueError, OverflowError):
- return False
-
-
-def _is_valid_float(str_val, float_dtype):
- try:
- return float_dtype.as_numpy_dtype(str_val) < np.inf
- except ValueError:
- return False
-
-
-def _infer_type(str_val, na_value, prev_type):
- """Given a string, infers its tensor type.
-
- Infers the type of a value by picking the least 'permissive' type possible,
- while still allowing the previous type inference for this column to be valid.
-
- Args:
- str_val: String value to infer the type of.
- na_value: Additional string to recognize as a NA/NaN CSV value.
- prev_type: Type previously inferred based on values of this column that
- we've seen up till now.
- Returns:
- Inferred dtype.
- """
- if str_val in ("", na_value):
- # If the field is null, it gives no extra information about its type
- return prev_type
-
- type_list = [
- dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string
- ] # list of types to try, ordered from least permissive to most
-
- type_functions = [
- _is_valid_int32,
- _is_valid_int64,
- lambda str_val: _is_valid_float(str_val, dtypes.float32),
- lambda str_val: _is_valid_float(str_val, dtypes.float64),
- lambda str_val: True,
- ] # Corresponding list of validation functions
-
- for i in range(len(type_list)):
- validation_fn = type_functions[i]
- if validation_fn(str_val) and (prev_type is None or
- prev_type in type_list[:i + 1]):
- return type_list[i]
-
-
-def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header):
- """Generator that yields rows of CSV file(s) in order."""
- for fn in filenames:
- with file_io.FileIO(fn, "r") as f:
- rdr = csv.reader(
- f,
- delimiter=field_delim,
- quoting=csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE)
- if header:
- next(rdr) # Skip header lines
-
- for csv_row in rdr:
- if len(csv_row) != num_cols:
- raise ValueError(
- "Problem inferring types: CSV row has different number of fields "
- "than expected.")
- yield csv_row
-
-
-def _infer_column_defaults(filenames, num_cols, field_delim, use_quote_delim,
- na_value, header, num_rows_for_inference,
- select_columns):
- """Infers column types from the first N valid CSV records of files."""
- if select_columns is None:
- select_columns = range(num_cols)
- inferred_types = [None] * len(select_columns)
-
- for i, csv_row in enumerate(
- _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header)):
- if num_rows_for_inference is not None and i >= num_rows_for_inference:
- break
-
- for j, col_index in enumerate(select_columns):
- inferred_types[j] = _infer_type(csv_row[col_index], na_value,
- inferred_types[j])
-
- # Replace None's with a default type
- inferred_types = [t or dtypes.string for t in inferred_types]
- # Default to 0 or '' for null values
- return [
- constant_op.constant([0 if t is not dtypes.string else ""], dtype=t)
- for t in inferred_types
- ]
-
-
-def _infer_column_names(filenames, field_delim, use_quote_delim):
- """Infers column names from first rows of files."""
- csv_kwargs = {
- "delimiter": field_delim,
- "quoting": csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE
- }
- with file_io.FileIO(filenames[0], "r") as f:
- try:
- column_names = next(csv.reader(f, **csv_kwargs))
- except StopIteration:
- raise ValueError(("Received StopIteration when reading the header line "
- "of %s. Empty file?") % filenames[0])
-
- for name in filenames[1:]:
- with file_io.FileIO(name, "r") as f:
- try:
- if next(csv.reader(f, **csv_kwargs)) != column_names:
- raise ValueError(
- "Files have different column names in the header row.")
- except StopIteration:
- raise ValueError(("Received StopIteration when reading the header line "
- "of %s. Empty file?") % filenames[0])
- return column_names
-
-
-def _get_sorted_col_indices(select_columns, column_names):
- """Transforms select_columns argument into sorted column indices."""
- names_to_indices = {n: i for i, n in enumerate(column_names)}
- num_cols = len(column_names)
- for i, v in enumerate(select_columns):
- if isinstance(v, int):
- if v < 0 or v >= num_cols:
- raise ValueError(
- "Column index %d specified in select_columns out of valid range." %
- v)
- continue
- if v not in names_to_indices:
- raise ValueError(
- "Value '%s' specified in select_columns not a valid column index or "
- "name." % v)
- select_columns[i] = names_to_indices[v]
-
- # Sort and ensure there are no duplicates
- result = sorted(set(select_columns))
- if len(result) != len(select_columns):
- raise ValueError("select_columns contains duplicate columns")
- return result
-
-
-def _maybe_shuffle_and_repeat(
- dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed):
- """Optionally shuffle and repeat dataset, as requested."""
- if num_epochs != 1 and shuffle:
- # Use shuffle_and_repeat for perf
- return dataset.apply(
- shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs,
- shuffle_seed))
- elif shuffle:
- return dataset.shuffle(shuffle_buffer_size, shuffle_seed)
- elif num_epochs != 1:
- return dataset.repeat(num_epochs)
- return dataset
-
-
-def make_tf_record_dataset(file_pattern,
- batch_size,
- parser_fn=None,
- num_epochs=None,
- shuffle=True,
- shuffle_buffer_size=None,
- shuffle_seed=None,
- prefetch_buffer_size=optimization.AUTOTUNE,
- num_parallel_reads=None,
- num_parallel_parser_calls=None,
- drop_final_batch=False):
- """Reads and optionally parses TFRecord files into a dataset.
-
- Provides common functionality such as batching, optional parsing, shuffling,
- and performant defaults.
-
- Args:
- file_pattern: List of files or patterns of TFRecord file paths.
- See `tf.gfile.Glob` for pattern rules.
- batch_size: An int representing the number of records to combine
- in a single batch.
- parser_fn: (Optional.) A function accepting string input to parse
- and process the record contents. This function must map records
- to components of a fixed shape, so they may be batched. By
- default, uses the record contents unmodified.
- num_epochs: (Optional.) An int specifying the number of times this
- dataset is repeated. If None (the default), cycles through the
- dataset forever.
- shuffle: (Optional.) A bool that indicates whether the input
- should be shuffled. Defaults to `True`.
- shuffle_buffer_size: (Optional.) Buffer size to use for
- shuffling. A large buffer size ensures better shuffling, but
- increases memory usage and startup time.
- shuffle_seed: (Optional.) Randomization seed to use for shuffling.
- prefetch_buffer_size: (Optional.) An int specifying the number of
- feature batches to prefetch for performance improvement.
- Defaults to auto-tune. Set to 0 to disable prefetching.
- num_parallel_reads: (Optional.) Number of threads used to read
- records from files. By default or if set to a value >1, the
- results will be interleaved.
- num_parallel_parser_calls: (Optional.) Number of parallel
- records to parse in parallel. Defaults to an automatic selection.
- drop_final_batch: (Optional.) Whether the last batch should be
- dropped in case its size is smaller than `batch_size`; the
- default behavior is not to drop the smaller batch.
-
- Returns:
- A dataset, where each element matches the output of `parser_fn`
- except it will have an additional leading `batch-size` dimension,
- or a `batch_size`-length 1-D tensor of strings if `parser_fn` is
- unspecified.
- """
- files = dataset_ops.Dataset.list_files(
- file_pattern, shuffle=shuffle, seed=shuffle_seed)
-
- if num_parallel_reads is None:
- # Note: We considered auto-tuning this value, but there is a concern
- # that this affects the mixing of records from different files, which
- # could affect training convergence/accuracy, so we are defaulting to
- # a constant for now.
- num_parallel_reads = 24
- dataset = core_readers.TFRecordDataset(
- files, num_parallel_reads=num_parallel_reads)
-
- if shuffle_buffer_size is None:
- # TODO(josh11b): Auto-tune this value when not specified
- shuffle_buffer_size = 10000
- dataset = _maybe_shuffle_and_repeat(
- dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
-
- # NOTE(mrry): We set `drop_final_batch=True` when `num_epochs is None` to
- # improve the shape inference, because it makes the batch dimension static.
- # It is safe to do this because in that case we are repeating the input
- # indefinitely, and all batches will be full-sized.
- drop_final_batch = drop_final_batch or num_epochs is None
-
- if parser_fn is None:
- dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch)
- else:
- # TODO(josh11b): if num_parallel_parser_calls is None, use some function
- # of num cores instead of map_and_batch's default behavior of one batch.
- dataset = dataset.apply(batching.map_and_batch(
- parser_fn, batch_size, num_parallel_calls=num_parallel_parser_calls,
- drop_remainder=drop_final_batch))
-
- if prefetch_buffer_size == 0:
- return dataset
- else:
- return dataset.prefetch(buffer_size=prefetch_buffer_size)
-
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.make_csv_dataset(...)`.")
def make_csv_dataset(
file_pattern,
batch_size,
@@ -387,7 +112,6 @@ def make_csv_dataset(
prefetch_buffer_size: An int specifying the number of feature
batches to prefetch for performance improvement. Recommended value is the
number of batches consumed per training step. Defaults to auto-tune.
-
num_parallel_reads: Number of threads used to read CSV records from files.
If >1, the results will be interleaved.
sloppy: If `True`, reading performance will be improved at
@@ -411,106 +135,18 @@ def make_csv_dataset(
Raises:
ValueError: If any of the arguments is malformed.
"""
- # Create dataset of all matching filenames
- filenames = _get_file_names(file_pattern, False)
- dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
- if shuffle:
- dataset = dataset.shuffle(len(filenames), shuffle_seed)
-
- # Clean arguments; figure out column names and defaults
+ return readers.make_csv_dataset(
+ file_pattern, batch_size, column_names, column_defaults, label_name,
+ select_columns, field_delim, use_quote_delim, na_value, header,
+ num_epochs, shuffle, shuffle_buffer_size, shuffle_seed,
+ prefetch_buffer_size, num_parallel_reads, sloppy, num_rows_for_inference,
+ compression_type)
- if column_names is None:
- if not header:
- raise ValueError("Cannot infer column names without a header line.")
- # If column names are not provided, infer from the header lines
- column_names = _infer_column_names(filenames, field_delim, use_quote_delim)
- if len(column_names) != len(set(column_names)):
- raise ValueError("Cannot have duplicate column names.")
- if select_columns is not None:
- select_columns = _get_sorted_col_indices(select_columns, column_names)
-
- if column_defaults is not None:
- column_defaults = [
- constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x
- for x in column_defaults
- ]
- else:
- # If column defaults are not provided, infer from records at graph
- # construction time
- column_defaults = _infer_column_defaults(
- filenames, len(column_names), field_delim, use_quote_delim, na_value,
- header, num_rows_for_inference, select_columns)
-
- if select_columns is not None and len(column_defaults) != len(select_columns):
- raise ValueError(
- "If specified, column_defaults and select_columns must have same "
- "length."
- )
- if select_columns is not None and len(column_names) > len(select_columns):
- # Pick the relevant subset of column names
- column_names = [column_names[i] for i in select_columns]
-
- if label_name is not None and label_name not in column_names:
- raise ValueError("`label_name` provided must be one of the columns.")
-
- def filename_to_dataset(filename):
- return CsvDataset(
- filename,
- record_defaults=column_defaults,
- field_delim=field_delim,
- use_quote_delim=use_quote_delim,
- na_value=na_value,
- select_cols=select_columns,
- header=header,
- compression_type=compression_type,
- )
-
- def map_fn(*columns):
- """Organizes columns into a features dictionary.
-
- Args:
- *columns: list of `Tensor`s corresponding to one csv record.
- Returns:
- An OrderedDict of feature names to values for that particular record. If
- label_name is provided, extracts the label feature to be returned as the
- second element of the tuple.
- """
- features = collections.OrderedDict(zip(column_names, columns))
- if label_name is not None:
- label = features.pop(label_name)
- return features, label
- return features
-
- # Read files sequentially (if num_parallel_reads=1) or in parallel
- dataset = dataset.apply(
- interleave_ops.parallel_interleave(
- filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy))
-
- dataset = _maybe_shuffle_and_repeat(
- dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
-
- # Apply batch before map for perf, because map has high overhead relative
- # to the size of the computation in each map.
- # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
- # improve the shape inference, because it makes the batch dimension static.
- # It is safe to do this because in that case we are repeating the input
- # indefinitely, and all batches will be full-sized.
- dataset = dataset.batch(batch_size=batch_size,
- drop_remainder=num_epochs is None)
- dataset = dataset_ops.MapDataset(
- dataset, map_fn, use_inter_op_parallelism=False)
- dataset = dataset.prefetch(prefetch_buffer_size)
-
- return dataset
-
-
-_DEFAULT_READER_BUFFER_SIZE_BYTES = 4 * 1024 * 1024 # 4 MB
-
-
-class CsvDataset(dataset_ops.DatasetSource):
+class CsvDataset(readers.CsvDataset):
"""A Dataset comprising lines from one or more CSV files."""
+ @deprecation.deprecated(None, "Use `tf.data.experimental.CsvDataset(...)`.")
def __init__(self,
filenames,
record_defaults,
@@ -521,140 +157,13 @@ class CsvDataset(dataset_ops.DatasetSource):
use_quote_delim=True,
na_value="",
select_cols=None):
- """Creates a `CsvDataset` by reading and decoding CSV files.
-
- The elements of this dataset correspond to records from the file(s).
- RFC 4180 format is expected for CSV files
- (https://tools.ietf.org/html/rfc4180)
- Note that we allow leading and trailing spaces with int or float field.
-
-
- For example, suppose we have a file 'my_file0.csv' with four CSV columns of
- different data types:
- ```
- abcdefg,4.28E10,5.55E6,12
- hijklmn,-5.3E14,,2
- ```
-
- We can construct a CsvDataset from it as follows:
- ```python
- dataset = tf.contrib.data.CsvDataset(
- "my_file*.csv",
- [tf.float32, # Required field, use dtype or empty tensor
- tf.constant([0.0], dtype=tf.float32), # Optional field, default to 0.0
- tf.int32, # Required field, use dtype or empty tensor
- ],
- select_cols=[1,2,3] # Only parse last three columns
- )
- ```
-
- The expected output of its iterations is:
- ```python
- next_element = dataset.make_one_shot_iterator().get_next()
- with tf.Session() as sess:
- while True:
- try:
- print(sess.run(next_element))
- except tf.errors.OutOfRangeError:
- break
-
- >> (4.28e10, 5.55e6, 12)
- >> (-5.3e14, 0.0, 2)
- ```
-
- Args:
- filenames: A `tf.string` tensor containing one or more filenames.
- record_defaults: A list of default values for the CSV fields. Each item in
- the list is either a valid CSV `DType` (float32, float64, int32, int64,
- string), or a `Tensor` object with one of the above types. One per
- column of CSV data, with either a scalar `Tensor` default value for the
- column if it is optional, or `DType` or empty `Tensor` if required. If
- both this and `select_columns` are specified, these must have the same
- lengths, and `column_defaults` is assumed to be sorted in order of
- increasing column index.
- compression_type: (Optional.) A `tf.string` scalar evaluating to one of
- `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no
- compression.
- buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
- to buffer while reading files. Defaults to 4MB.
- header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s)
- have header line(s) that should be skipped when parsing. Defaults to
- `False`.
- field_delim: (Optional.) A `tf.string` scalar containing the delimiter
- character that separates fields in a record. Defaults to `","`.
- use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats
- double quotation marks as regular characters inside of string fields
- (ignoring RFC 4180, Section 2, Bullet 5). Defaults to `True`.
- na_value: (Optional.) A `tf.string` scalar indicating a value that will
- be treated as NA/NaN.
- select_cols: (Optional.) A sorted list of column indices to select from
- the input data. If specified, only this subset of columns will be
- parsed. Defaults to parsing all columns.
- """
- super(CsvDataset, self).__init__()
- self._filenames = ops.convert_to_tensor(
- filenames, dtype=dtypes.string, name="filenames")
- self._compression_type = convert.optional_param_to_tensor(
- "compression_type",
- compression_type,
- argument_default="",
- argument_dtype=dtypes.string)
- record_defaults = [
- constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x
- for x in record_defaults
- ]
- self._record_defaults = ops.convert_n_to_tensor(
- record_defaults, name="record_defaults")
- self._buffer_size = convert.optional_param_to_tensor(
- "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
- self._header = ops.convert_to_tensor(
- header, dtype=dtypes.bool, name="header")
- self._field_delim = ops.convert_to_tensor(
- field_delim, dtype=dtypes.string, name="field_delim")
- self._use_quote_delim = ops.convert_to_tensor(
- use_quote_delim, dtype=dtypes.bool, name="use_quote_delim")
- self._na_value = ops.convert_to_tensor(
- na_value, dtype=dtypes.string, name="na_value")
- self._select_cols = convert.optional_param_to_tensor(
- "select_cols",
- select_cols,
- argument_default=[],
- argument_dtype=dtypes.int64,
- )
- self._output_shapes = tuple(
- tensor_shape.scalar() for _ in range(len(record_defaults)))
- self._output_types = tuple(d.dtype for d in self._record_defaults)
- self._output_classes = tuple(
- ops.Tensor for _ in range(len(record_defaults)))
-
- def _as_variant_tensor(self):
- # Constructs graph node for the dataset op.
- return gen_experimental_dataset_ops.experimental_csv_dataset(
- filenames=self._filenames,
- record_defaults=self._record_defaults,
- buffer_size=self._buffer_size,
- header=self._header,
- output_shapes=self._output_shapes,
- field_delim=self._field_delim,
- use_quote_delim=self._use_quote_delim,
- na_value=self._na_value,
- select_cols=self._select_cols,
- compression_type=self._compression_type,
- )
-
- @property
- def output_types(self):
- return self._output_types
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_classes(self):
- return self._output_classes
+ super(CsvDataset, self).__init__(
+ filenames, record_defaults, compression_type, buffer_size, header,
+ field_delim, use_quote_delim, na_value, select_cols)
+@deprecation.deprecated(
+ None, "Use `tf.data.experimental.make_batched_features_dataset(...)`.")
def make_batched_features_dataset(file_pattern,
batch_size,
features,
@@ -759,57 +268,15 @@ def make_batched_features_dataset(file_pattern,
Raises:
ValueError: If `label_key` is not one of the `features` keys.
"""
- # Create dataset of all matching filenames
- filenames = _get_file_names(file_pattern, False)
- dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
- if shuffle:
- dataset = dataset.shuffle(len(filenames), shuffle_seed)
-
- # Read `Example` records from files as tensor objects.
- if reader_args is None:
- reader_args = []
+ return readers.make_batched_features_dataset(
+ file_pattern, batch_size, features, reader, label_key, reader_args,
+ num_epochs, shuffle, shuffle_buffer_size, shuffle_seed,
+ prefetch_buffer_size, reader_num_threads, parser_num_threads,
+ sloppy_ordering, drop_final_batch)
- # Read files sequentially (if reader_num_threads=1) or in parallel
- dataset = dataset.apply(
- interleave_ops.parallel_interleave(
- lambda filename: reader(filename, *reader_args),
- cycle_length=reader_num_threads,
- sloppy=sloppy_ordering))
- # Extract values if the `Example` tensors are stored as key-value tuples.
- if dataset.output_types == (dtypes.string, dtypes.string):
- dataset = dataset_ops.MapDataset(
- dataset, lambda _, v: v, use_inter_op_parallelism=False)
-
- # Apply dataset repeat and shuffle transformations.
- dataset = _maybe_shuffle_and_repeat(
- dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
-
- # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
- # improve the shape inference, because it makes the batch dimension static.
- # It is safe to do this because in that case we are repeating the input
- # indefinitely, and all batches will be full-sized.
- dataset = dataset.batch(
- batch_size, drop_remainder=drop_final_batch or num_epochs is None)
-
- # Parse `Example` tensors to a dictionary of `Feature` tensors.
- dataset = dataset.apply(
- parsing_ops.parse_example_dataset(
- features, num_parallel_calls=parser_num_threads))
-
- if label_key:
- if label_key not in features:
- raise ValueError(
- "The `label_key` provided (%r) must be one of the `features` keys." %
- label_key)
- dataset = dataset.map(lambda x: (x, x.pop(label_key)))
-
- dataset = dataset.prefetch(prefetch_buffer_size)
- return dataset
-
-
-@deprecation.deprecated(None,
- "Use `tf.contrib.data.make_batched_features_dataset`")
+@deprecation.deprecated(
+ None, "Use `tf.data.experimental.make_batched_features_dataset(...)`")
def read_batch_features(file_pattern,
batch_size,
features,
@@ -879,7 +346,7 @@ def read_batch_features(file_pattern,
Returns:
A dict from keys in features to `Tensor` or `SparseTensor` objects.
"""
- dataset = make_batched_features_dataset(
+ dataset = readers.make_batched_features_dataset(
file_pattern,
batch_size,
features,
@@ -893,96 +360,13 @@ def read_batch_features(file_pattern,
return outputs
-def _get_file_names(file_pattern, shuffle):
- """Parse list of file names from pattern, optionally shuffled.
-
- Args:
- file_pattern: File glob pattern, or list of glob patterns.
- shuffle: Whether to shuffle the order of file names.
-
- Returns:
- List of file names matching `file_pattern`.
-
- Raises:
- ValueError: If `file_pattern` is empty, or pattern matches no files.
- """
- if isinstance(file_pattern, list):
- if not file_pattern:
- raise ValueError("File pattern is empty.")
- file_names = []
- for entry in file_pattern:
- file_names.extend(gfile.Glob(entry))
- else:
- file_names = list(gfile.Glob(file_pattern))
-
- if not file_names:
- raise ValueError("No files match %s." % file_pattern)
-
- # Sort files so it will be deterministic for unit tests.
- if not shuffle:
- file_names = sorted(file_names)
- return file_names
-
-
-class SqlDataset(dataset_ops.DatasetSource):
+class SqlDataset(readers.SqlDataset):
"""A `Dataset` consisting of the results from a SQL query."""
+ @deprecation.deprecated(None, "Use `tf.data.experimental.SqlDataset(...)`.")
def __init__(self, driver_name, data_source_name, query, output_types):
- """Creates a `SqlDataset`.
-
- `SqlDataset` allows a user to read data from the result set of a SQL query.
- For example:
-
- ```python
- dataset = tf.contrib.data.SqlDataset("sqlite", "/foo/bar.sqlite3",
- "SELECT name, age FROM people",
- (tf.string, tf.int32))
- iterator = dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
- # Prints the rows of the result set of the above query.
- while True:
- try:
- print(sess.run(next_element))
- except tf.errors.OutOfRangeError:
- break
- ```
-
- Args:
- driver_name: A 0-D `tf.string` tensor containing the database type.
- Currently, the only supported value is 'sqlite'.
- data_source_name: A 0-D `tf.string` tensor containing a connection string
- to connect to the database.
- query: A 0-D `tf.string` tensor containing the SQL query to execute.
- output_types: A tuple of `tf.DType` objects representing the types of the
- columns returned by `query`.
- """
- super(SqlDataset, self).__init__()
- self._driver_name = ops.convert_to_tensor(
- driver_name, dtype=dtypes.string, name="driver_name")
- self._data_source_name = ops.convert_to_tensor(
- data_source_name, dtype=dtypes.string, name="data_source_name")
- self._query = ops.convert_to_tensor(
- query, dtype=dtypes.string, name="query")
- self._output_types = output_types
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.sql_dataset(self._driver_name,
- self._data_source_name, self._query,
- nest.flatten(self.output_types),
- nest.flatten(self.output_shapes))
-
- @property
- def output_classes(self):
- return nest.map_structure(lambda _: ops.Tensor, self._output_types)
-
- @property
- def output_shapes(self):
- return nest.map_structure(lambda _: tensor_shape.TensorShape([]),
- self._output_types)
-
- @property
- def output_types(self):
- return self._output_types
+ super(SqlDataset, self).__init__(
+ driver_name, data_source_name, query, output_types)
class LMDBDataset(dataset_ops.DatasetSource):
diff --git a/tensorflow/contrib/data/python/ops/resampling.py b/tensorflow/contrib/data/python/ops/resampling.py
index 75642f143e..29d77528d9 100644
--- a/tensorflow/contrib/data/python/ops/resampling.py
+++ b/tensorflow/contrib/data/python/ops/resampling.py
@@ -17,22 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import interleave_ops
-from tensorflow.contrib.data.python.ops import scan_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import logging_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
+from tensorflow.python.data.experimental.ops import resampling
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.rejection_resample(...)`.")
def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
"""A transformation that resamples a dataset to achieve a target distribution.
@@ -52,243 +42,5 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist")
- class_values_ds = dataset.map(class_func)
-
- # Get initial distribution.
- if initial_dist is not None:
- initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist")
- acceptance_dist, prob_of_original = (
- _calculate_acceptance_probs_with_mixing(initial_dist_t,
- target_dist_t))
- initial_dist_ds = dataset_ops.Dataset.from_tensors(
- initial_dist_t).repeat()
- acceptance_dist_ds = dataset_ops.Dataset.from_tensors(
- acceptance_dist).repeat()
- prob_of_original_ds = dataset_ops.Dataset.from_tensors(
- prob_of_original).repeat()
- else:
- initial_dist_ds = _estimate_initial_dist_ds(
- target_dist_t, class_values_ds)
- acceptance_and_original_prob_ds = initial_dist_ds.map(
- lambda initial: _calculate_acceptance_probs_with_mixing(
- initial, target_dist_t))
- acceptance_dist_ds = acceptance_and_original_prob_ds.map(
- lambda accept_prob, _: accept_prob)
- prob_of_original_ds = acceptance_and_original_prob_ds.map(
- lambda _, prob_original: prob_original)
- filtered_ds = _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds,
- class_values_ds, seed)
- # Prefetch filtered dataset for speed.
- filtered_ds = filtered_ds.prefetch(3)
-
- prob_original_static = _get_prob_original_static(
- initial_dist_t, target_dist_t) if initial_dist is not None else None
- if prob_original_static == 1:
- return dataset_ops.Dataset.zip((class_values_ds, dataset))
- elif prob_original_static == 0:
- return filtered_ds
- else:
- return interleave_ops.sample_from_datasets(
- [dataset_ops.Dataset.zip((class_values_ds, dataset)), filtered_ds],
- weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]),
- seed=seed)
-
- return _apply_fn
-
-
-def _get_prob_original_static(initial_dist_t, target_dist_t):
- """Returns the static probability of sampling from the original.
-
- `tensor_util.constant_value(prob_of_original)` returns `None` if it encounters
- an Op that it isn't defined for. We have some custom logic to avoid this.
-
- Args:
- initial_dist_t: A tensor of the initial distribution.
- target_dist_t: A tensor of the target distribution.
-
- Returns:
- The probability of sampling from the original distribution as a constant,
- if it is a constant, or `None`.
- """
- init_static = tensor_util.constant_value(initial_dist_t)
- target_static = tensor_util.constant_value(target_dist_t)
-
- if init_static is None or target_static is None:
- return None
- else:
- return np.min(target_static / init_static)
-
-
-def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_values_ds,
- seed):
- """Filters a dataset based on per-class acceptance probabilities.
-
- Args:
- dataset: The dataset to be filtered.
- acceptance_dist_ds: A dataset of acceptance probabilities.
- initial_dist_ds: A dataset of the initial probability distribution, given or
- estimated.
- class_values_ds: A dataset of the corresponding classes.
- seed: (Optional.) Python integer seed for the resampler.
-
- Returns:
- A dataset of (class value, data) after filtering.
- """
- def maybe_warn_on_large_rejection(accept_dist, initial_dist):
- proportion_rejected = math_ops.reduce_sum((1 - accept_dist) * initial_dist)
- return control_flow_ops.cond(
- math_ops.less(proportion_rejected, .5),
- lambda: accept_dist,
- lambda: logging_ops.Print( # pylint: disable=g-long-lambda
- accept_dist, [proportion_rejected, initial_dist, accept_dist],
- message="Proportion of examples rejected by sampler is high: ",
- summarize=100,
- first_n=10))
-
- acceptance_dist_ds = (dataset_ops.Dataset.zip((acceptance_dist_ds,
- initial_dist_ds))
- .map(maybe_warn_on_large_rejection))
-
- def _gather_and_copy(class_val, acceptance_prob, data):
- return class_val, array_ops.gather(acceptance_prob, class_val), data
-
- current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip(
- (class_values_ds, acceptance_dist_ds, dataset)).map(_gather_and_copy)
- filtered_ds = (
- current_probabilities_and_class_and_data_ds
- .filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p))
- return filtered_ds.map(lambda class_value, _, data: (class_value, data))
-
-
-def _estimate_initial_dist_ds(
- target_dist_t, class_values_ds, dist_estimation_batch_size=32,
- smoothing_constant=10):
- num_classes = (target_dist_t.shape[0].value or
- array_ops.shape(target_dist_t)[0])
- initial_examples_per_class_seen = array_ops.fill(
- [num_classes], np.int64(smoothing_constant))
-
- def update_estimate_and_tile(num_examples_per_class_seen, c):
- updated_examples_per_class_seen, dist = _estimate_data_distribution(
- c, num_examples_per_class_seen)
- tiled_dist = array_ops.tile(
- array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1])
- return updated_examples_per_class_seen, tiled_dist
-
- initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size)
- .apply(scan_ops.scan(initial_examples_per_class_seen,
- update_estimate_and_tile))
- .apply(batching.unbatch()))
-
- return initial_dist_ds
-
-
-def _get_target_to_initial_ratio(initial_probs, target_probs):
- # Add tiny to initial_probs to avoid divide by zero.
- denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny)
- return target_probs / denom
-
-
-def _estimate_data_distribution(c, num_examples_per_class_seen):
- """Estimate data distribution as labels are seen.
-
- Args:
- c: The class labels. Type `int32`, shape `[batch_size]`.
- num_examples_per_class_seen: Type `int64`, shape `[num_classes]`,
- containing counts.
-
- Returns:
- num_examples_per_lass_seen: Updated counts. Type `int64`, shape
- `[num_classes]`.
- dist: The updated distribution. Type `float32`, shape `[num_classes]`.
- """
- num_classes = num_examples_per_class_seen.get_shape()[0].value
- # Update the class-count based on what labels are seen in batch.
- num_examples_per_class_seen = math_ops.add(
- num_examples_per_class_seen, math_ops.reduce_sum(
- array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0))
- init_prob_estimate = math_ops.truediv(
- num_examples_per_class_seen,
- math_ops.reduce_sum(num_examples_per_class_seen))
- dist = math_ops.cast(init_prob_estimate, dtypes.float32)
- return num_examples_per_class_seen, dist
-
-
-def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs):
- """Calculates the acceptance probabilities and mixing ratio.
-
- In this case, we assume that we can *either* sample from the original data
- distribution with probability `m`, or sample from a reshaped distribution
- that comes from rejection sampling on the original distribution. This
- rejection sampling is done on a per-class basis, with `a_i` representing the
- probability of accepting data from class `i`.
-
- This method is based on solving the following analysis for the reshaped
- distribution:
-
- Let F be the probability of a rejection (on any example).
- Let p_i be the proportion of examples in the data in class i (init_probs)
- Let a_i is the rate the rejection sampler should *accept* class i
- Let t_i is the target proportion in the minibatches for class i (target_probs)
-
- ```
- F = sum_i(p_i * (1-a_i))
- = 1 - sum_i(p_i * a_i) using sum_i(p_i) = 1
- ```
-
- An example with class `i` will be accepted if `k` rejections occur, then an
- example with class `i` is seen by the rejector, and it is accepted. This can
- be written as follows:
-
- ```
- t_i = sum_k=0^inf(F^k * p_i * a_i)
- = p_i * a_j / (1 - F) using geometric series identity, since 0 <= F < 1
- = p_i * a_i / sum_j(p_j * a_j) using F from above
- ```
-
- Note that the following constraints hold:
- ```
- 0 <= p_i <= 1, sum_i(p_i) = 1
- 0 <= a_i <= 1
- 0 <= t_i <= 1, sum_i(t_i) = 1
- ```
-
- A solution for a_i in terms of the other variables is the following:
- ```a_i = (t_i / p_i) / max_i[t_i / p_i]```
-
- If we try to minimize the amount of data rejected, we get the following:
-
- M_max = max_i [ t_i / p_i ]
- M_min = min_i [ t_i / p_i ]
-
- The desired probability of accepting data if it comes from class `i`:
-
- a_i = (t_i/p_i - m) / (M_max - m)
-
- The desired probability of pulling a data element from the original dataset,
- rather than the filtered one:
-
- m = M_min
-
- Args:
- initial_probs: A Tensor of the initial probability distribution, given or
- estimated.
- target_probs: A Tensor of the corresponding classes.
-
- Returns:
- (A 1D Tensor with the per-class acceptance probabilities, the desired
- probability of pull from the original distribution.)
- """
- ratio_l = _get_target_to_initial_ratio(initial_probs, target_probs)
- max_ratio = math_ops.reduce_max(ratio_l)
- min_ratio = math_ops.reduce_min(ratio_l)
-
- # Target prob to sample from original distribution.
- m = min_ratio
-
- # TODO(joelshor): Simplify fraction, if possible.
- a_i = (ratio_l - m) / (max_ratio - m)
- return a_i, m
+ return resampling.rejection_resample(class_func, target_dist, initial_dist,
+ seed)
diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py
index c52582cd35..0ca9fddb23 100644
--- a/tensorflow/contrib/data/python/ops/scan_ops.py
+++ b/tensorflow/contrib/data/python/ops/scan_ops.py
@@ -17,137 +17,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import collections
-
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import gen_dataset_ops
-
-
-class _ScanDataset(dataset_ops.UnaryDataset):
- """A dataset that scans a function across its input."""
-
- def __init__(self, input_dataset, initial_state, scan_func):
- """See `scan()` for details."""
- super(_ScanDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
-
- with ops.name_scope("initial_state"):
- # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
- # values to tensors.
- self._initial_state = nest.pack_sequence_as(initial_state, [
- sparse_tensor.SparseTensor.from_value(t)
- if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(
- t, name="component_%d" % i)
- for i, t in enumerate(nest.flatten(initial_state))
- ])
-
- # Compute initial values for the state classes, shapes and types based on
- # the initial state. The shapes may be refined by running `tf_scan_func` one
- # or more times below.
- self._state_classes = sparse.get_classes(self._initial_state)
- self._state_shapes = nest.pack_sequence_as(
- self._initial_state,
- [t.get_shape() for t in nest.flatten(self._initial_state)])
- self._state_types = nest.pack_sequence_as(
- self._initial_state,
- [t.dtype for t in nest.flatten(self._initial_state)])
-
- # Will be populated by calling `tf_scan_func`.
- self._output_classes = None
- self._output_shapes = None
- self._output_types = None
-
- # Iteratively rerun the scan function until reaching a fixed point on
- # `self._state_shapes`.
- need_to_rerun = True
- while need_to_rerun:
-
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- scan_func, "tf.contrib.data.scan()",
- input_classes=(self._state_classes, input_dataset.output_classes),
- input_shapes=(self._state_shapes, input_dataset.output_shapes),
- input_types=(self._state_types, input_dataset.output_types),
- add_to_graph=False)
- if not (
- isinstance(wrapped_func.output_types, collections.Sequence) and
- len(wrapped_func.output_types) == 2):
- raise TypeError("The scan function must return a pair comprising the "
- "new state and the output value.")
-
- new_state_classes, self._output_classes = wrapped_func.output_classes
-
- # Extract and validate class information from the returned values.
- for new_state_class, state_class in zip(
- nest.flatten(new_state_classes),
- nest.flatten(self._state_classes)):
- if not issubclass(new_state_class, state_class):
- raise TypeError(
- "The element classes for the new state must match the initial "
- "state. Expected %s; got %s." %
- (self._state_classes, new_state_classes))
-
- # Extract and validate type information from the returned values.
- new_state_types, self._output_types = wrapped_func.output_types
- for new_state_type, state_type in zip(
- nest.flatten(new_state_types), nest.flatten(self._state_types)):
- if new_state_type != state_type:
- raise TypeError(
- "The element types for the new state must match the initial "
- "state. Expected %s; got %s." %
- (self._state_types, new_state_types))
-
- # Extract shape information from the returned values.
- new_state_shapes, self._output_shapes = wrapped_func.output_shapes
-
- flat_state_shapes = nest.flatten(self._state_shapes)
- flat_new_state_shapes = nest.flatten(new_state_shapes)
- weakened_state_shapes = [
- original.most_specific_compatible_shape(new)
- for original, new in zip(flat_state_shapes, flat_new_state_shapes)
- ]
-
- need_to_rerun = False
- for original_shape, weakened_shape in zip(flat_state_shapes,
- weakened_state_shapes):
- if original_shape.ndims is not None and (
- weakened_shape.ndims is None or
- original_shape.as_list() != weakened_shape.as_list()):
- need_to_rerun = True
- break
-
- if need_to_rerun:
- self._state_shapes = nest.pack_sequence_as(self._state_shapes,
- weakened_state_shapes)
-
- self._scan_func = wrapped_func.function
- self._scan_func.add_to_graph(ops.get_default_graph())
-
- def _as_variant_tensor(self):
- input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
- return gen_dataset_ops.scan_dataset(
- input_t,
- nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)),
- self._scan_func.captured_inputs,
- f=self._scan_func,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._output_classes
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_types(self):
- return self._output_types
+from tensorflow.python.data.experimental.ops import scan_ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None, "Use `tf.data.experimental.scan(...)`.")
def scan(initial_state, scan_func):
"""A transformation that scans a function across an input dataset.
@@ -168,7 +42,4 @@ def scan(initial_state, scan_func):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
- def _apply_fn(dataset):
- return _ScanDataset(dataset, initial_state, scan_func)
-
- return _apply_fn
+ return scan_ops.scan(initial_state, scan_func)
diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py
index 985d1d87d0..329b34fdfe 100644
--- a/tensorflow/contrib/data/python/ops/shuffle_ops.py
+++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py
@@ -17,54 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import random_seed
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import gen_dataset_ops
-
-
-class _ShuffleAndRepeatDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that fuses `shuffle` and `repeat`."""
-
- def __init__(self, input_dataset, buffer_size, count=None, seed=None):
- super(_ShuffleAndRepeatDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- self._buffer_size = ops.convert_to_tensor(
- buffer_size, dtype=dtypes.int64, name="buffer_size")
- if count is None:
- self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
- else:
- self._count = ops.convert_to_tensor(
- count, dtype=dtypes.int64, name="count")
- self._seed, self._seed2 = random_seed.get_seed(seed)
-
- def _as_variant_tensor(self):
- # pylint: disable=protected-access
- input_resource = self._input_dataset._as_variant_tensor()
- return gen_dataset_ops.shuffle_and_repeat_dataset(
- input_resource,
- buffer_size=self._buffer_size,
- count=self._count,
- seed=self._seed,
- seed2=self._seed2,
- **dataset_ops.flat_structure(self))
- # pylint: enable=protected-access
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
+from tensorflow.python.data.experimental.ops import shuffle_ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.shuffle_and_repeat(...)`.")
def shuffle_and_repeat(buffer_size, count=None, seed=None):
"""Shuffles and repeats a Dataset returning a new permutation for each epoch.
@@ -93,8 +51,4 @@ def shuffle_and_repeat(buffer_size, count=None, seed=None):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset): # pylint: disable=missing-docstring
- return _ShuffleAndRepeatDataset(dataset, buffer_size, count, seed)
-
- return _apply_fn
+ return shuffle_ops.shuffle_and_repeat(buffer_size, count, seed)
diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py
deleted file mode 100644
index bc47c5989d..0000000000
--- a/tensorflow/contrib/data/python/ops/stats_ops.py
+++ /dev/null
@@ -1,201 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Experimental API for gathering statistics from `tf.data` pipelines."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import gen_dataset_ops
-
-
-class StatsAggregator(object):
- """A stateful resource that aggregates statistics from one or more iterators.
-
- To record statistics, use one of the custom transformation functions defined
- in this module when defining your `tf.data.Dataset`. All statistics will be
- aggregated by the `StatsAggregator` that is associated with a particular
- iterator (see below). For example, to record the latency of producing each
- element by iterating over a dataset:
-
- ```python
- dataset = ...
- dataset = dataset.apply(stats_ops.latency_stats("total_bytes"))
- ```
-
- To associate a `StatsAggregator` with a `tf.data.Dataset` object, use
- the following pattern:
-
- ```python
- stats_aggregator = stats_ops.StatsAggregator()
- dataset = ...
-
- # Apply `set_stats_aggregator` to associate `dataset` with `stats_aggregator`.
- dataset = dataset.apply(
- tf.contrib.data.set_stats_aggregator(stats_aggregator))
- iterator = dataset.make_one_shot_iterator()
- ```
-
- To get a protocol buffer summary of the currently aggregated statistics,
- use the `StatsAggregator.get_summary()` tensor. The easiest way to do this
- is to add the returned tensor to the `tf.GraphKeys.SUMMARIES` collection,
- so that the summaries will be included with any existing summaries.
-
- ```python
- stats_aggregator = stats_ops.StatsAggregator()
- # ...
- stats_summary = stats_aggregator.get_summary()
- tf.add_to_collection(tf.GraphKeys.SUMMARIES, stats_summary)
- ```
-
- Note: This interface is experimental and expected to change. In particular,
- we expect to add other implementations of `StatsAggregator` that provide
- different ways of exporting statistics, and add more types of statistics.
- """
-
- def __init__(self):
- """Creates a `StatsAggregator`."""
- self._resource = gen_dataset_ops.stats_aggregator_handle()
-
- # TODO(b/116314787): Update this/add support for V2 summary API.
- def get_summary(self):
- """Returns a string `tf.Tensor` that summarizes the aggregated statistics.
-
- The returned tensor will contain a serialized `tf.summary.Summary` protocol
- buffer, which can be used with the standard TensorBoard logging facilities.
-
- Returns:
- A scalar string `tf.Tensor` that summarizes the aggregated statistics.
- """
- return gen_dataset_ops.stats_aggregator_summary(self._resource)
-
-
-class _SetStatsAggregatorDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that acts as an identity, and sets given stats_aggregator."""
-
- def __init__(self, input_dataset, stats_aggregator):
- super(_SetStatsAggregatorDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- self._stats_aggregator = stats_aggregator
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.set_stats_aggregator_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._stats_aggregator._resource, # pylint: disable=protected-access
- **dataset_ops.flat_structure(self))
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
-
-def set_stats_aggregator(stats_aggregator):
- """Set the given `stats_aggregator` for aggregating the input dataset stats.
-
- Args:
- stats_aggregator: A `tf.contrib.data.StatsAggregator` object.
-
- Returns:
- A `Dataset` transformation function, which can be passed to
- `tf.data.Dataset.apply`.
- """
-
- def _apply_fn(dataset):
- return _SetStatsAggregatorDataset(dataset, stats_aggregator)
-
- return _apply_fn
-
-
-# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
-# or make private / remove.
-def bytes_produced_stats(tag):
- """Records the number of bytes produced by each element of the input dataset.
-
- To consume the statistics, associate a `StatsAggregator` with the output
- dataset.
-
- Args:
- tag: String. All statistics recorded by the returned transformation will
- be associated with the given `tag`.
-
- Returns:
- A `Dataset` transformation function, which can be passed to
- `tf.data.Dataset.apply`.
- """
-
- def _apply_fn(dataset):
- return _StatsDataset(dataset, gen_dataset_ops.bytes_produced_stats_dataset,
- tag)
-
- return _apply_fn
-
-
-def latency_stats(tag):
- """Records the latency of producing each element of the input dataset.
-
- To consume the statistics, associate a `StatsAggregator` with the output
- dataset.
-
- Args:
- tag: String. All statistics recorded by the returned transformation will
- be associated with the given `tag`.
-
- Returns:
- A `Dataset` transformation function, which can be passed to
- `tf.data.Dataset.apply`.
- """
-
- def _apply_fn(dataset):
- return _StatsDataset(dataset, gen_dataset_ops.latency_stats_dataset, tag)
-
- return _apply_fn
-
-
-class _StatsDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that acts as an identity, and also records statistics."""
-
- def __init__(self, input_dataset, op_function, tag):
- super(_StatsDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- self._op_function = op_function
- self._tag = ops.convert_to_tensor(tag, dtype=dtypes.string)
-
- def _as_variant_tensor(self):
- return self._op_function(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._tag,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py
index f73c3fd9cb..20cceb4647 100644
--- a/tensorflow/contrib/data/python/ops/threadpool.py
+++ b/tensorflow/contrib/data/python/ops/threadpool.py
@@ -17,88 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import threading
-
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.eager import context
-from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
-from tensorflow.python.ops import resource_variable_ops
-
-_uid_counter = 0
-_uid_lock = threading.Lock()
-
-
-def _generate_shared_name(prefix):
- with _uid_lock:
- global _uid_counter
- uid = _uid_counter
- _uid_counter += 1
- return "{}{}".format(prefix, uid)
-
-
-# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable
-# or make private / remove.
-class PrivateThreadPool(object):
- """A stateful resource that represents a private thread pool."""
-
- def __init__(self, num_threads, display_name=None,
- max_intra_op_parallelism=1):
- """Creates a `PrivateThreadPool` with the given number of threads."""
- if context.executing_eagerly():
- shared_name = _generate_shared_name("privatethreadpool")
- self._resource = ged_ops.experimental_thread_pool_handle(
- num_threads=num_threads,
- max_intra_op_parallelism=max_intra_op_parallelism,
- display_name=display_name,
- shared_name=shared_name)
- self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
- handle=self._resource, handle_device=context.context().device_name)
- else:
- self._resource = ged_ops.experimental_thread_pool_handle(
- num_threads=num_threads,
- max_intra_op_parallelism=max_intra_op_parallelism,
- display_name=display_name)
-
-
-class _ThreadPoolDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that acts as an identity, and sets a custom threadpool."""
-
- def __init__(self, input_dataset, thread_pool):
- super(_ThreadPoolDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- self._thread_pool = thread_pool
-
- def _as_variant_tensor(self):
- return ged_ops.experimental_thread_pool_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._thread_pool._resource, # pylint: disable=protected-access
- **dataset_ops.flat_structure(self))
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
-
-# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable
-# or make private / remove.
-def override_threadpool(dataset, thread_pool):
- """Returns a new dataset that uses the given thread pool for its operations.
-
- Args:
- dataset: A `tf.data.Dataset` object.
- thread_pool: A `PrivateThreadPool` object.
-
- Returns:
- A dataset containing the same values as `dataset`, but which uses
- `thread_pool` to compute any of its parallel operations (such as
- `tf.data.Dataset.map`).
- """
- return _ThreadPoolDataset(dataset, thread_pool)
+# pylint: disable=unused-import
+from tensorflow.python.data.experimental.ops.threadpool import override_threadpool
+from tensorflow.python.data.experimental.ops.threadpool import PrivateThreadPool
diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py
index ed363a7090..909d06c677 100644
--- a/tensorflow/contrib/data/python/ops/unique.py
+++ b/tensorflow/contrib/data/python/ops/unique.py
@@ -17,11 +17,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.ops import gen_experimental_dataset_ops
+from tensorflow.python.data.experimental.ops import unique as experimental_unique
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None, "Use `tf.data.experimental.unique()`.")
def unique():
"""Creates a `Dataset` from another `Dataset`, discarding duplicates.
@@ -39,39 +39,4 @@ def unique():
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- return _UniqueDataset(dataset)
-
- return _apply_fn
-
-
-class _UniqueDataset(dataset_ops.UnaryDataset):
- """A `Dataset` contains the unique elements from its input."""
-
- def __init__(self, input_dataset):
- """See `unique()` for details."""
- super(_UniqueDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- if input_dataset.output_types not in (dtypes.int32, dtypes.int64,
- dtypes.string):
- raise TypeError(
- "`tf.contrib.data.unique()` only supports inputs with a single "
- "`tf.int32`, `tf.int64`, or `tf.string` component.")
-
- def _as_variant_tensor(self):
- return gen_experimental_dataset_ops.experimental_unique_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
+ return experimental_unique.unique()
diff --git a/tensorflow/contrib/data/python/ops/writers.py b/tensorflow/contrib/data/python/ops/writers.py
index c455fdcba6..42fb69bf07 100644
--- a/tensorflow/contrib/data/python/ops/writers.py
+++ b/tensorflow/contrib/data/python/ops/writers.py
@@ -17,42 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import convert
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.data.experimental.ops import writers
+from tensorflow.python.util import deprecation
-class TFRecordWriter(object):
+class TFRecordWriter(writers.TFRecordWriter):
"""Writes data to a TFRecord file."""
+ @deprecation.deprecated(
+ None, "Use `tf.data.experimental.TFRecordWriter(...)`.")
def __init__(self, filename, compression_type=None):
- self._filename = ops.convert_to_tensor(
- filename, dtypes.string, name="filename")
- self._compression_type = convert.optional_param_to_tensor(
- "compression_type",
- compression_type,
- argument_default="",
- argument_dtype=dtypes.string)
-
- def write(self, dataset):
- """Returns a `tf.Operation` to write a dataset to a file.
-
- Args:
- dataset: a `tf.data.Dataset` whose elements are to be written to a file
-
- Returns:
- A `tf.Operation` that, when run, writes contents of `dataset` to a file.
- """
- if not isinstance(dataset, dataset_ops.Dataset):
- raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
- if (dataset.output_types != dtypes.string or
- dataset.output_shapes != tensor_shape.scalar()):
- raise TypeError(
- "`dataset` must produce scalar `DT_STRING` tensors whereas it "
- "produces shape {0} and types {1}".format(dataset.output_shapes,
- dataset.output_types))
- return gen_dataset_ops.dataset_to_tf_record(
- dataset._as_variant_tensor(), self._filename, self._compression_type) # pylint: disable=protected-access
+ super(TFRecordWriter, self).__init__(filename, compression_type)
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index cfb9d42a6f..76d5b59ce1 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -412,6 +412,24 @@ cuda_py_test(
)
cuda_py_test(
+ name = "moving_averages_test",
+ srcs = ["moving_averages_test.py"],
+ additional_deps = [
+ ":combinations",
+ "@absl_py//absl/testing:parameterized",
+ "//tensorflow/python/eager:test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ ],
+ tags = [
+ "no_pip",
+ ],
+)
+
+cuda_py_test(
name = "optimizer_v2_test",
srcs = ["optimizer_v2_test.py"],
additional_deps = [
@@ -728,6 +746,7 @@ cuda_py_test(
additional_deps = [
":keras_test_lib",
],
+ shard_count = 16,
tags = [
"multi_and_single_gpu",
"no_pip",
@@ -736,18 +755,27 @@ cuda_py_test(
],
)
-cuda_py_test(
- name = "metrics_v1_test",
+py_library(
+ name = "metrics_v1_test_lib",
+ testonly = 1,
srcs = ["metrics_v1_test.py"],
- additional_deps = [
+ deps = [
":combinations",
- "@absl_py//absl/testing:parameterized",
"//tensorflow/contrib/data/python/ops:batching",
"//tensorflow/python:math_ops",
"//tensorflow/python:metrics",
"//tensorflow/python:variables",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/eager:test",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+cuda_py_test(
+ name = "metrics_v1_test",
+ srcs = ["metrics_v1_test.py"],
+ additional_deps = [
+ ":metrics_v1_test_lib",
],
tags = [
"multi_and_single_gpu",
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
index 33ffbf6abe..6796a23d46 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
@@ -128,7 +128,8 @@ class CollectiveAllReduceStrategyTestBase(
# TODO(yuefengz): support non-Mirrored variable as destinations.
g = d.reduce(
variable_scope.VariableAggregation.SUM, g, destinations=v)
- with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
+ with ops.control_dependencies(
+ d.update(v, update, g, grouped=False)):
after_list.append(d.read_var(v))
return before_list, after_list
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index 82ca041cc2..63a163e76c 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -329,10 +329,10 @@ one_device_strategy = NamedDistribution(
required_gpus=None)
tpu_strategy = NamedDistribution(
"TPU", lambda: tpu_lib.TPUStrategy(
- TPUClusterResolver(""), steps_per_run=5),
+ TPUClusterResolver(""), steps_per_run=2),
required_tpu=True)
tpu_strategy_one_step = NamedDistribution(
- "TPU", lambda: tpu_lib.TPUStrategy(
+ "TPUOneStep", lambda: tpu_lib.TPUStrategy(
TPUClusterResolver(""), steps_per_run=1),
required_tpu=True)
# Note that we disable prefetching for testing since prefetching makes
@@ -349,26 +349,26 @@ mirrored_strategy_with_two_gpus = NamedDistribution(
required_gpus=2)
-adam_optimizer_v1_fn = NamedObject(
- "AdamV1", lambda: adam.AdamOptimizer(0.001, epsilon=1))
gradient_descent_optimizer_v1_fn = NamedObject(
"GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2))
adagrad_optimizer_v1_fn = NamedObject(
"AdagradV1", lambda: adagrad.AdagradOptimizer(0.001))
+adam_optimizer_v1_fn = NamedObject("AdamV1",
+ lambda: adam.AdamOptimizer(0.001, epsilon=1))
rmsprop_optimizer_v1_fn = NamedObject(
"RmsPropV1", lambda: rmsprop.RMSPropOptimizer(0.001))
-optimizers_v1 = [adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn,
- adagrad_optimizer_v1_fn]
-adam_optimizer_v2_fn = NamedObject(
- "AdamV2", lambda: adam_v2.AdamOptimizer(0.001, epsilon=1))
+optimizers_v1 = [gradient_descent_optimizer_v1_fn, adagrad_optimizer_v1_fn]
+
gradient_descent_optimizer_v2_fn = NamedObject(
"GradientDescentV2",
lambda: gradient_descent_v2.GradientDescentOptimizer(0.2))
adagrad_optimizer_v2_fn = NamedObject(
"AdagradV2", lambda: adagrad_v2.AdagradOptimizer(0.001))
-optimizers_v2 = [adam_optimizer_v2_fn, gradient_descent_optimizer_v2_fn,
- adagrad_optimizer_v2_fn]
+adam_optimizer_v2_fn = NamedObject(
+ "AdamV2", lambda: adam_v2.AdamOptimizer(0.001, epsilon=1))
+
+optimizers_v2 = [gradient_descent_optimizer_v2_fn, adagrad_optimizer_v2_fn]
graph_and_eager_modes = ["graph", "eager"]
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index 3aab2c521f..3511b7761f 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -189,6 +189,14 @@ def get_dataset(distribution):
return dataset
+def get_predict_dataset(distribution):
+ inputs = np.zeros((10, 3), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices(inputs)
+ dataset = dataset.repeat(100)
+ dataset = batch_wrapper(dataset, 10, distribution)
+ return dataset
+
+
strategies = [combinations.default_strategy,
combinations.one_device_strategy,
combinations.mirrored_strategy_with_gpu_and_cpu,
@@ -347,56 +355,27 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
gfile.DeleteRecursively(self._config.model_dir)
-class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
+class TestDistributionStrategyWithNumpyArrays(test.TestCase,
+ parameterized.TestCase):
- def test_validating_dataset_input_tensors_with_shape_mismatch(self):
+ @combinations.generate(strategy_combinations())
+ def test_creating_var_with_numpy_arrays(self, distribution):
with self.cached_session():
- strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0',
- '/device:CPU:0'])
- a = constant_op.constant([1, 2], shape=(1, 2))
- b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2))
- x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b})
- y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a})
- with strategy.scope():
- # Removed device and input tensor shape details from the error message
- # since the order of the device and the corresponding input tensor shape
- # is not deterministic over different runs.
- with self.assertRaisesRegexp(ValueError,
- 'Input tensor shapes do not match for '
- 'distributed tensor inputs '
- 'DistributedValues:.+'):
- distributed_training_utils.validate_distributed_dataset_inputs(
- strategy, x, y)
+ x = np.asarray(np.random.random((64, 3)), dtype=np.float32)
+ var_x = distributed_training_utils.get_var_for_numpy(distribution, x)
+ val = self.evaluate(var_x.value())
+ # Verify that the numpy value is copied to the variable.
+ self.assertAllEqual(x, val)
- def test_validating_dataset_input_tensors_with_dtype_mismatch(self):
- with self.cached_session():
- strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0',
- '/device:CPU:0'])
- a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32)
- b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64)
- x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b})
- y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a})
- with strategy.scope():
- # Removed device and input tensor dtype details from the error message
- # since the order of the device and the corresponding input tensor dtype
- # is not deterministic over different runs.
- with self.assertRaisesRegexp(ValueError,
- 'Input tensor dtypes do not match for '
- 'distributed tensor inputs '
- 'DistributedValues:.+'):
- distributed_training_utils.validate_distributed_dataset_inputs(
- strategy, x, y)
-
- def test_calling_model_with_numpy_arrays(self):
+ @combinations.generate(strategy_combinations())
+ def test_calling_model_with_numpy_arrays(self, distribution):
with self.cached_session():
model = get_model()
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
- metrics = ['mae', keras.metrics.CategoricalAccuracy()]
- strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1',
- '/device:GPU:0'])
- model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
+ metrics = ['mae']
+ model.compile(optimizer, loss, metrics=metrics, distribute=distribution)
inputs = np.zeros((64, 3), dtype=np.float32)
targets = np.zeros((64, 4), dtype=np.float32)
@@ -420,6 +399,52 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
model.predict(inputs, batch_size=8)
@combinations.generate(strategy_combinations())
+ def test_calling_model_with_nested_numpy_arrays(self, distribution):
+ with self.cached_session():
+ a = keras.layers.Input(shape=(3,), name='input_a')
+ b = keras.layers.Input(shape=(3,), name='input_b')
+
+ dense = keras.layers.Dense(4, name='dense')
+ c = dense(a)
+ d = dense(b)
+ e = keras.layers.Dropout(0.5, name='dropout')(c)
+
+ model = keras.models.Model([a, b], [d, e])
+
+ optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ model.compile(optimizer, loss, distribute=distribution)
+
+ input_a_np = np.asarray(np.random.random((64, 3)), dtype=np.float32)
+ input_b_np = np.asarray(np.random.random((64, 3)), dtype=np.float32)
+ inputs = [input_a_np, input_b_np]
+
+ output_d_np = np.asarray(np.random.random((64, 4)), dtype=np.float32)
+ output_e_np = np.asarray(np.random.random((64, 4)), dtype=np.float32)
+ targets = [output_d_np, output_e_np]
+
+ # Call fit with validation data
+ model.fit(inputs, targets, epochs=1, batch_size=8, verbose=0)
+
+ # TODO(anjalisridhar): We need tests for when the batch size and steps are
+ # smaller and results in a 0 batch_size and steps value.
+ model.evaluate(inputs, targets)
+ # with steps
+ model.evaluate(inputs, targets, steps=2)
+ # with batch_size
+ model.evaluate(inputs, targets, batch_size=8)
+
+ model.predict(inputs)
+ # with steps
+ model.predict(inputs, steps=2)
+ # with batch_size
+ model.predict(inputs, batch_size=8)
+
+
+class TestDistributionStrategyWithDatasets(test.TestCase,
+ parameterized.TestCase):
+
+ @combinations.generate(strategy_combinations())
def test_calling_model_on_same_dataset(self, distribution):
with self.cached_session():
model = get_model()
@@ -436,7 +461,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
validation_data=dataset, validation_steps=2)
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
validation_data=dataset, validation_steps=2)
- model.predict(dataset, steps=2)
+ model.predict(get_predict_dataset(distribution), steps=2)
# TODO(priyag): Enable this test for TPU. Currently tuples/dict don't work
# as clone_model's input_tensors argument only seems to accept list and not
@@ -496,10 +521,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
model.evaluate(dataset, steps=2, verbose=1)
- model.predict(dataset, steps=2)
- # Test with validation data
- model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
- validation_data=dataset, validation_steps=2)
+ model.predict(get_predict_dataset(distribution), steps=2)
@combinations.generate(strategy_and_optimizer_combinations())
def test_fit_eval_and_predict_with_optimizer(self, distribution, optimizer):
@@ -513,87 +535,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
model.evaluate(dataset, steps=2, verbose=1)
- model.predict(dataset, steps=2)
-
- def test_unsupported_features(self):
- with self.cached_session():
- model = get_model()
-
- optimizer = gradient_descent.GradientDescentOptimizer(0.001)
- loss = 'mse'
- metrics = ['mae']
- strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1',
- '/device:GPU:0'])
-
- model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
-
- dataset = get_dataset(strategy)
-
- # Test with validation split
- with self.assertRaisesRegexp(
- ValueError, '`validation_split` argument is not '
- 'supported when input `x` is a dataset or a '
- 'dataset iterator.+'):
- model.fit(dataset,
- epochs=1, steps_per_epoch=2, verbose=0,
- validation_split=0.5, validation_steps=2)
-
- # Test with sample weight.
- sample_weight = np.random.random((10,))
- with self.assertRaisesRegexp(
- NotImplementedError, '`sample_weight` is currently not supported '
- 'when using DistributionStrategy.'):
- model.fit(
- dataset,
- epochs=1,
- steps_per_epoch=2,
- verbose=0,
- sample_weight=sample_weight)
-
- # Test with not specifying the `steps` argument.
- with self.assertRaisesRegexp(
- ValueError, 'you should specify the `steps_per_epoch` argument'):
- model.fit(dataset, epochs=1, verbose=0)
- with self.assertRaisesRegexp(ValueError,
- 'you should specify the `steps` argument'):
- model.evaluate(dataset, verbose=0)
-
- with self.assertRaisesRegexp(ValueError,
- 'you should specify the `steps` argument'):
- model.predict(dataset, verbose=0)
-
- def test_calling_with_unsupported_predefined_callbacks(self):
- with self.cached_session():
- model = get_model()
-
- optimizer = gradient_descent.GradientDescentOptimizer(0.001)
- loss = 'mse'
- metrics = ['mae']
- strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1',
- '/device:GPU:0'])
- model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
-
- dataset = get_dataset(strategy)
-
- def schedule(_):
- return 0.001
- with self.assertRaisesRegexp(ValueError,
- 'LearningRateScheduler callback is not '
- 'supported with DistributionStrategy.'):
- model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
- callbacks=[keras.callbacks.LearningRateScheduler(schedule)])
-
- with self.assertRaisesRegexp(ValueError,
- 'ReduceLROnPlateau callback is not '
- 'supported with DistributionStrategy.'):
- model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
- callbacks=[keras.callbacks.ReduceLROnPlateau()])
- with self.assertRaisesRegexp(ValueError,
- 'histogram_freq in the TensorBoard callback '
- 'is not supported when using '
- 'DistributionStrategy.'):
- model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
- callbacks=[keras.callbacks.TensorBoard(histogram_freq=10)])
+ model.predict(get_predict_dataset(distribution), steps=2)
def test_dataset_input_shape_validation(self):
with self.cached_session():
@@ -679,7 +621,128 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
self.assertNotEqual(np.mean(predict_output), 0)
-class LossMaskingWithDistributionStrategyTest(test.TestCase):
+class TestDistributionStrategyErrorCases(test.TestCase, parameterized.TestCase):
+
+ def test_validating_dataset_input_tensors_with_shape_mismatch(self):
+ with self.cached_session():
+ strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0',
+ '/device:CPU:0'])
+ a = constant_op.constant([1, 2], shape=(1, 2))
+ b = constant_op.constant([[1, 2], [1, 2]], shape=(2, 2))
+ x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b})
+ y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a})
+ with strategy.scope():
+ # Removed device and input tensor shape details from the error message
+ # since the order of the device and the corresponding input tensor shape
+ # is not deterministic over different runs.
+ with self.assertRaisesRegexp(ValueError,
+ 'Input tensor shapes do not match for '
+ 'distributed tensor inputs '
+ 'DistributedValues:.+'):
+ distributed_training_utils.validate_distributed_dataset_inputs(
+ strategy, x, y)
+
+ def test_validating_dataset_input_tensors_with_dtype_mismatch(self):
+ with self.cached_session():
+ strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0',
+ '/device:CPU:0'])
+ a = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.int32)
+ b = constant_op.constant([1, 2], shape=(1, 2), dtype=dtypes.float64)
+ x = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': b})
+ y = values.DistributedValues({'/device:CPU:0': a, '/device:GPU:0': a})
+ with strategy.scope():
+ # Removed device and input tensor dtype details from the error message
+ # since the order of the device and the corresponding input tensor dtype
+ # is not deterministic over different runs.
+ with self.assertRaisesRegexp(ValueError,
+ 'Input tensor dtypes do not match for '
+ 'distributed tensor inputs '
+ 'DistributedValues:.+'):
+ distributed_training_utils.validate_distributed_dataset_inputs(
+ strategy, x, y)
+
+ def test_unsupported_features(self):
+ with self.cached_session():
+ model = get_model()
+
+ optimizer = gradient_descent.GradientDescentOptimizer(0.001)
+ loss = 'mse'
+ metrics = ['mae']
+ strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1',
+ '/device:GPU:0'])
+
+ model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
+
+ dataset = get_dataset(strategy)
+
+ # Test with validation split
+ with self.assertRaisesRegexp(
+ ValueError, '`validation_split` argument is not '
+ 'supported when input `x` is a dataset or a '
+ 'dataset iterator.+'):
+ model.fit(dataset,
+ epochs=1, steps_per_epoch=2, verbose=0,
+ validation_split=0.5, validation_steps=2)
+
+ # Test with sample weight.
+ sample_weight = np.random.random((10,))
+ with self.assertRaisesRegexp(
+ NotImplementedError, '`sample_weight` is currently not supported '
+ 'when using DistributionStrategy.'):
+ model.fit(
+ dataset,
+ epochs=1,
+ steps_per_epoch=2,
+ verbose=0,
+ sample_weight=sample_weight)
+
+ # Test with not specifying the `steps` argument.
+ with self.assertRaisesRegexp(
+ ValueError, 'you should specify the `steps_per_epoch` argument'):
+ model.fit(dataset, epochs=1, verbose=0)
+ with self.assertRaisesRegexp(ValueError,
+ 'you should specify the `steps` argument'):
+ model.evaluate(dataset, verbose=0)
+
+ with self.assertRaisesRegexp(ValueError,
+ 'you should specify the `steps` argument'):
+ model.predict(dataset, verbose=0)
+
+ def test_calling_with_unsupported_predefined_callbacks(self):
+ with self.cached_session():
+ model = get_model()
+
+ optimizer = gradient_descent.GradientDescentOptimizer(0.001)
+ loss = 'mse'
+ metrics = ['mae']
+ strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1',
+ '/device:GPU:0'])
+ model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
+
+ dataset = get_dataset(strategy)
+
+ def schedule(_):
+ return 0.001
+ with self.assertRaisesRegexp(ValueError,
+ 'LearningRateScheduler callback is not '
+ 'supported with DistributionStrategy.'):
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
+ callbacks=[keras.callbacks.LearningRateScheduler(schedule)])
+
+ with self.assertRaisesRegexp(ValueError,
+ 'ReduceLROnPlateau callback is not '
+ 'supported with DistributionStrategy.'):
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
+ callbacks=[keras.callbacks.ReduceLROnPlateau()])
+ with self.assertRaisesRegexp(ValueError,
+ 'histogram_freq in the TensorBoard callback '
+ 'is not supported when using '
+ 'DistributionStrategy.'):
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
+ callbacks=[keras.callbacks.TensorBoard(histogram_freq=10)])
+
+
+class TestDistributionStrategyWithLossMasking(test.TestCase):
# TODO(priyag): Enable all strategies for this test. Currently it does not
# work for TPU due to some invalid datatype.
@@ -706,7 +769,7 @@ class LossMaskingWithDistributionStrategyTest(test.TestCase):
self.assertEqual(hist.history['loss'][0], 0)
-class NormalizationLayerWithDistributionStrategyTest(
+class TestDistributionStrategyWithNormalizationLayer(
test.TestCase, parameterized.TestCase):
@combinations.generate(strategy_combinations())
@@ -726,16 +789,20 @@ class NormalizationLayerWithDistributionStrategyTest(
dataset = dataset.repeat(100)
dataset = batch_wrapper(dataset, 32, distribution)
+ predict_dataset = dataset_ops.Dataset.from_tensor_slices(x)
+ predict_dataset = predict_dataset.repeat(100)
+ predict_dataset = batch_wrapper(predict_dataset, 32, distribution)
+
model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10)
- out = model.predict(dataset, steps=2)
+ out = model.predict(predict_dataset, steps=2)
out -= keras.backend.eval(norm.beta)
out /= keras.backend.eval(norm.gamma)
np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1)
np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
-class CorrectnessWithDistributionStrategyTest(test.TestCase,
- parameterized.TestCase):
+class TestDistributionStrategyCorrectness(test.TestCase,
+ parameterized.TestCase):
@combinations.generate(strategy_combinations())
def test_metric_correctness(self, distribution):
@@ -811,8 +878,7 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase,
predict_batch_size = 4
if with_distribution:
predict_batch_size //= with_distribution.num_towers
- predict_dataset = dataset_ops.Dataset.from_tensor_slices((x_predict,
- x_predict))
+ predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict)
predict_dataset = batch_wrapper(predict_dataset,
predict_batch_size, distribution)
predict_result = model.predict(predict_dataset, steps=1)
diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py
index 8163494c8e..ae4189eb1c 100644
--- a/tensorflow/contrib/distribute/python/metrics_v1_test.py
+++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.contrib.distribute.python import combinations
+from tensorflow.contrib.distribute.python import tpu_strategy
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import test
from tensorflow.python.framework import ops
@@ -35,7 +36,8 @@ def _labeled_dataset_fn():
# 8: 3, 2 -> False; 9: 4, 0 -> False; 10: 0, 1 -> False; 11: 1, 2 -> False
# 12: 2, 0 -> False; 13: 3, 1 -> False; 14: 4, 2 -> False; 15: 0, 0 -> True
return dataset_ops.Dataset.range(1000).map(
- lambda x: {"labels": x % 5, "predictions": x % 3}).batch(4)
+ lambda x: {"labels": x % 5, "predictions": x % 3}).batch(
+ 4, drop_remainder=True)
def _boolean_dataset_fn():
@@ -47,7 +49,8 @@ def _boolean_dataset_fn():
# F, T -> FP; T, F -> FN; F, F -> TN
return dataset_ops.Dataset.from_tensor_slices({
"labels": [True, False, True, False],
- "predictions": [True, True, False, False]}).repeat().batch(3)
+ "predictions": [True, True, False, False]}).repeat().batch(
+ 3, drop_remainder=True)
def _threshold_dataset_fn():
@@ -59,7 +62,8 @@ def _threshold_dataset_fn():
# False, .75 -> FP; True, .25 -> FN; False, 0.0 -> TN
return dataset_ops.Dataset.from_tensor_slices({
"labels": [True, False, True, False],
- "predictions": [1.0, 0.75, 0.25, 0.]}).repeat().batch(3)
+ "predictions": [1.0, 0.75, 0.25, 0.]}).repeat().batch(
+ 3, drop_remainder=True)
def _regression_dataset_fn():
@@ -79,6 +83,12 @@ def all_combinations():
mode=["graph"])
+def tpu_combinations():
+ return combinations.combine(distribution=[combinations.tpu_strategy_one_step,
+ combinations.tpu_strategy],
+ mode=["graph"])
+
+
# TODO(josh11b): Test metrics.recall_at_top_k, metrics.average_precision_at_k,
# metrics.precision_at_k
class MetricsV1Test(test.TestCase, parameterized.TestCase):
@@ -87,42 +97,50 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
with ops.Graph().as_default(), distribution.scope():
iterator = distribution.distribute_dataset(
dataset_fn).make_one_shot_iterator()
- value, update = distribution.call_for_each_tower(
- metric_fn, iterator.get_next())
- update = distribution.group(update)
+ if isinstance(distribution, tpu_strategy.TPUStrategy):
+ def step_fn(ctx, inputs):
+ value, update = distribution.call_for_each_tower(
+ metric_fn, inputs)
+ ctx.set_non_tensor_output(name="value", output=value)
+ return distribution.group(update)
+
+ ctx = distribution.run_steps_on_dataset(
+ step_fn, iterator, iterations=distribution.steps_per_run)
+ update = ctx.run_op
+ value = ctx.non_tensor_outputs["value"]
+ # In each run, we run multiple steps, and each steps consumes as many
+ # batches as number of towers.
+ batches_per_update = (
+ distribution.num_towers * distribution.steps_per_run)
+ else:
+ value, update = distribution.call_for_each_tower(
+ metric_fn, iterator.get_next())
+ update = distribution.group(update)
+ # TODO(josh11b): Once we switch to using a global batch size for input,
+ # replace "distribution.num_towers" with "1".
+ batches_per_update = distribution.num_towers
+
+ self.evaluate(distribution.initialize())
self.evaluate(variables.local_variables_initializer())
- # TODO(josh11b): Once we switch to using a global batch size for input,
- # replace "distribution.num_towers" with "1".
- batches_per_update = distribution.num_towers
-
- # Update variables using the first `num_towers` batches.
- self.evaluate(update)
- self.assertAllClose(expected_fn(batches_per_update), self.evaluate(value),
- 0.001, msg="After first update")
-
- # Update variables using the second `num_towers` batches.
- self.evaluate(update)
- self.assertAllClose(expected_fn(2 * batches_per_update),
- self.evaluate(value),
- 0.001,
- msg="After second update")
-
- if batches_per_update == 1: # Consume 4 input batches
- self.evaluate(update)
- self.assertAllClose(expected_fn(3 * batches_per_update),
- self.evaluate(value),
- 0.001,
- msg="After third update")
+
+ batches_consumed = 0
+ for i in range(4):
self.evaluate(update)
- self.assertAllClose(expected_fn(4 * batches_per_update),
+ batches_consumed += batches_per_update
+ self.assertAllClose(expected_fn(batches_consumed),
self.evaluate(value),
0.001,
- msg="After fourth update")
+ msg="After update #" + str(i+1))
+ if batches_consumed >= 4: # Consume 4 input batches in total.
+ break
- @combinations.generate(all_combinations())
+ self.evaluate(distribution.finalize())
+
+ @combinations.generate(all_combinations() + tpu_combinations())
def testMean(self, distribution):
def _dataset_fn():
- return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch(4)
+ return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch(
+ 4, drop_remainder=True)
def _expected_fn(num_batches):
# Mean(0..3) = 1.5, Mean(0..7) = 3.5, Mean(0..11) = 5.5, etc.
@@ -130,7 +148,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(distribution, _dataset_fn, metrics.mean, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testAccuracy(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -143,6 +161,8 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _labeled_dataset_fn, _metric_fn, _expected_fn)
+ # TODO(priyag, jhseu): Enable TPU for this test once scatter_add is added
+ # for TPUMirroredVariable.
@combinations.generate(all_combinations())
def testMeanPerClassAccuracy(self, distribution):
def _metric_fn(x):
@@ -161,6 +181,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _labeled_dataset_fn, _metric_fn, _expected_fn)
+ # NOTE(priyag): This metric doesn't work on TPUs yet.
@combinations.generate(all_combinations())
def testMeanIOU(self, distribution):
def _metric_fn(x):
@@ -179,7 +200,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _labeled_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testMeanTensor(self, distribution):
def _dataset_fn():
dataset = dataset_ops.Dataset.range(1000).map(math_ops.to_float)
@@ -198,7 +219,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _dataset_fn, metrics.mean_tensor, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testAUCROC(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -212,7 +233,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testAUCPR(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -226,7 +247,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testFalseNegatives(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -239,7 +260,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testFalseNegativesAtThresholds(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -252,7 +273,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testTrueNegatives(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -265,7 +286,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testTrueNegativesAtThresholds(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -278,7 +299,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testFalsePositives(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -291,7 +312,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testFalsePositivesAtThresholds(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -304,7 +325,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testTruePositives(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -317,7 +338,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testTruePositivesAtThresholds(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -330,7 +351,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testPrecision(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -343,7 +364,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testPrecisionAtThreshold(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -356,7 +377,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testRecall(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -369,7 +390,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testRecallAtThreshold(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -382,7 +403,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testMeanSquaredError(self, distribution):
def _metric_fn(x):
labels = x["labels"]
@@ -395,7 +416,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
self._test_metric(
distribution, _regression_dataset_fn, _metric_fn, _expected_fn)
- @combinations.generate(all_combinations())
+ @combinations.generate(all_combinations() + tpu_combinations())
def testRootMeanSquaredError(self, distribution):
def _metric_fn(x):
labels = x["labels"]
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py
index ba147e7824..60e134055f 100644
--- a/tensorflow/contrib/distribute/python/minimize_loss_test.py
+++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py
@@ -179,11 +179,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
def get_expected_variables(optimizer_fn, num_parameter_devices):
variables_map = {
"GradientDescent": ["dense/kernel", "dense/bias"],
- "Adam": [
- "dense/kernel", "dense/bias", "beta1_power", "beta2_power",
- "dense/kernel/Adam", "dense/kernel/Adam_1", "dense/bias/Adam",
- "dense/bias/Adam_1"
- ],
"Adagrad": [
"dense/kernel/Adagrad", "dense/kernel",
"dense/bias/Adagrad", "dense/bias"
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 4d7516063c..a32424b316 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -318,12 +318,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
[TensorFlow's documentation](https://www.tensorflow.org/deploy/distributed).
The distribution strategy inherits these concepts as well and in addition to
that we also clarify several more concepts:
- * **In-graph replication**: the `client` creates a single `tf.Graph` that
+
+ * **In-graph replication**: the `client` creates a single `tf.Graph` that
specifies tasks for devices on all workers. The `client` then creates a
client session which will talk to the `master` service of a `worker`. Then
the `master` will partition the graph and distribute the work to all
participating workers.
- * **Worker**: A `worker` is a TensorFlow `task` that usually maps to one
+ * **Worker**: A `worker` is a TensorFlow `task` that usually maps to one
physical machine. We will have multiple `worker`s with different `task`
index. They all do similar things except for one worker checkpointing model
variables, writing summaries, etc. in addition to its ordinary work.
@@ -627,9 +628,11 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
return self._get_cross_tower_ops().batch_reduce(aggregation,
value_destination_pairs)
- def _update(self, var, fn, *args, **kwargs):
+ def _update(self, var, options, fn, *args, **kwargs):
# TODO(josh11b): In eager mode, use one thread per device.
assert isinstance(var, values.DistributedVariable)
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
updates = {}
for d, v in var._index.items(): # pylint: disable=protected-access
name = "update_%d" % self._device_index.get(d)
@@ -638,10 +641,12 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
updates[d] = fn(v,
*values.select_device_mirrored(d, args),
**values.select_device_mirrored(d, kwargs))
- return values.regroup(updates, values.Mirrored)
+ return values.update_regroup(self, updates, should_group)
- def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
assert isinstance(colocate_with, list)
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
# TODO(josh11b): In eager mode, use one thread per device.
updates = {}
for d in colocate_with:
@@ -649,7 +654,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
updates[d] = fn(*values.select_device_mirrored(d, args),
**values.select_device_mirrored(d, kwargs))
- return values.regroup(updates, values.Mirrored)
+ return values.update_regroup(self, updates, should_group)
def read_var(self, tower_local_var):
"""Read the aggregate value of a tower-local variable."""
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index f51e543624..eeac528329 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -826,7 +826,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
with dist.scope():
ret_v_sum = dist.call_for_each_tower(model_fn, run_concurrently=False)
- update_ops = dist.unwrap(dist.update(ret_v_sum, update, 5.0))
+ update_ops = dist.update(ret_v_sum, update, 5.0, grouped=False)
# Initialize variables.
self.evaluate(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/moving_averages_test.py b/tensorflow/contrib/distribute/python/moving_averages_test.py
new file mode 100644
index 0000000000..119352ad91
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/moving_averages_test.py
@@ -0,0 +1,141 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for training.moving_averages when using a DistributionStrategy."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.contrib.distribute.python import combinations
+from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.training import moving_averages
+
+
+all_combinations = combinations.combine(
+ distribution=[combinations.default_strategy,
+ combinations.one_device_strategy,
+ combinations.mirrored_strategy_with_gpu_and_cpu],
+ mode=["graph"])
+
+
+class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase):
+
+ @combinations.generate(all_combinations)
+ def testTowerModeWithoutZeroDebias(self, distribution):
+ tower_id = [0]
+
+ def tower_fn():
+ var = variables.Variable([10.0, 11.0])
+ val = constant_op.constant([1.0 + tower_id[0], 2.0 - tower_id[0]])
+ tower_id[0] += 1
+ decay = 0.25
+ assign = moving_averages.assign_moving_average(
+ var, val, decay, zero_debias=False)
+ return var, assign
+
+ with distribution.scope(), self.cached_session() as sess:
+ var, assign = distribution.call_for_each_tower(tower_fn)
+ variables.global_variables_initializer().run()
+ self.assertAllClose([10.0, 11.0], var.eval())
+ sess.run(distribution.unwrap(assign))
+ # Mean of val across calls to tower_fn().
+ average_val = [1.0 + 0.5 * (tower_id[0] - 1),
+ 2.0 - 0.5 * (tower_id[0] - 1)]
+ val_weight = 1.0 - 0.25
+ self.assertAllClose(
+ [10.0 * 0.25 + average_val[0] * val_weight,
+ 11.0 * 0.25 + average_val[1] * val_weight],
+ var.eval())
+
+ @combinations.generate(all_combinations)
+ def testTowerMode(self, distribution):
+ tower_id = [0]
+
+ def tower_fn():
+ var = variables.Variable([0.0, 0.0])
+ val = constant_op.constant([1.0 + tower_id[0], 2.0 - tower_id[0]])
+ tower_id[0] += 1
+ decay = 0.25
+ assign = moving_averages.assign_moving_average(var, val, decay)
+ return var, assign.op
+
+ with distribution.scope(), self.cached_session() as sess:
+ var, assign_op = distribution.call_for_each_tower(tower_fn)
+ variables.global_variables_initializer().run()
+ self.assertAllClose([0.0, 0.0], var.eval())
+ sess.run(distribution.unwrap(assign_op))
+ # Mean of val across calls to tower_fn().
+ average_val = [1.0 + 0.5 * (tower_id[0] - 1),
+ 2.0 - 0.5 * (tower_id[0] - 1)]
+ self.assertAllClose(average_val, var.eval())
+
+ @combinations.generate(all_combinations)
+ def testCrossTowerWithoutZeroDebias(self, distribution):
+ with distribution.scope(), self.cached_session() as sess:
+ var = variables.Variable([10.0, 11.0])
+ val = constant_op.constant([1.0, 2.0])
+ decay = 0.25
+ # NOTE(josh11b): We currently generate an error if val is a PerDevice value.
+ assign = moving_averages.assign_moving_average(
+ var, val, decay, zero_debias=False)
+
+ variables.global_variables_initializer().run()
+ self.assertAllClose([10.0, 11.0], var.eval())
+ sess.run(assign)
+ average_val = [1.0, 2.0]
+ val_weight = 1.0 - 0.25
+ self.assertAllClose(
+ [10.0 * 0.25 + average_val[0] * val_weight,
+ 11.0 * 0.25 + average_val[1] * val_weight],
+ var.eval())
+ # Also try assign.op.
+ sess.run(assign.op)
+ orig_weight = 0.25 * 0.25
+ val_weight = 1.0 - orig_weight
+ self.assertAllClose(
+ [10.0 * orig_weight + average_val[0] * val_weight,
+ 11.0 * orig_weight + average_val[1] * val_weight],
+ var.eval())
+
+ @combinations.generate(all_combinations)
+ def testCrossTower(self, distribution):
+ with distribution.scope(), self.cached_session() as sess:
+ var = variables.Variable([0.0, 0.0])
+ val = array_ops.placeholder(dtypes.float32)
+ decay = 0.25
+ # NOTE(josh11b): We currently generate an error if val is a PerDevice value.
+ assign = moving_averages.assign_moving_average(var, val, decay)
+
+ variables.global_variables_initializer().run()
+ self.assertAllClose([0.0, 0.0], var.eval())
+ sess.run(assign, feed_dict={val: [1.0, 2.0]})
+ self.assertAllClose([1.0, 2.0], var.eval())
+
+ # Also try assign.op.
+ sess.run(assign.op, feed_dict={val: [10.0, 0.0]})
+ self.assertAllClose(
+ [(1.0 * 0.25 + 10.0) / (1.0 * 0.25 + 1.0),
+ (2.0 * 0.25 + 0.0) / (1.0 * 0.25 + 1.0)],
+ var.eval())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py
index 23b220f64b..f525919048 100644
--- a/tensorflow/contrib/distribute/python/one_device_strategy.py
+++ b/tensorflow/contrib/distribute/python/one_device_strategy.py
@@ -141,14 +141,21 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
else:
assert False
- def _update(self, var, fn, *args, **kwargs):
- with ops.device(self._device), distribute_lib.UpdateContext(self._device):
- return fn(var, *args, **kwargs)
+ def _update(self, var, options, fn, *args, **kwargs):
+ # The implementations of _update() and _update_non_slot() are identical
+ # except _update() passes `var` as the first argument to `fn()`.
+ return self._update_non_slot(var, options, fn, var, *args, **kwargs)
- def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
del colocate_with
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
with ops.device(self._device), distribute_lib.UpdateContext(self._device):
- return fn(*args, **kwargs)
+ result = fn(*args, **kwargs)
+ if should_group:
+ return result
+ else:
+ return nest.map_structure(self._unwrap, result)
def read_var(self, tower_local_var):
"""Read the aggregate value of a tower-local variable."""
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
index 1125d027f6..6ddd91507b 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
@@ -343,21 +343,33 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
return nest.map_structure(_select_fn, structured)
- def _update(self, var, fn, *args, **kwargs):
+ def _update(self, var, options, fn, *args, **kwargs):
if isinstance(var, values.AggregatingVariable):
var = var.get()
if not isinstance(var, resource_variable_ops.ResourceVariable):
raise ValueError(
"You can not update `var` %r. It must be a Variable." % var)
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
with ops.colocate_with(var), distribute_lib.UpdateContext(var.device):
- return fn(var, *self._select_single_value(args),
- **self._select_single_value(kwargs))
+ result = fn(var, *self._select_single_value(args),
+ **self._select_single_value(kwargs))
+ if should_group:
+ return result
+ else:
+ return nest.map_structure(self._unwrap, result)
# TODO(yuefengz): does it need to call _select_single_value?
- def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
with ops.device(
colocate_with.device), distribute_lib.UpdateContext(colocate_with):
- return fn(*args, **kwargs)
+ result = fn(*args, **kwargs)
+ if should_group:
+ return result
+ else:
+ return nest.map_structure(self._unwrap, result)
def _unwrap(self, val):
if isinstance(val, values.DistributedValues):
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
index 12789e0bc9..353d11a583 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
@@ -395,7 +395,8 @@ class ParameterServerStrategyTestBase(
# TODO(yuefengz): support non-Mirrored variable as destinations.
g = d.reduce(
variable_scope.VariableAggregation.SUM, g, destinations=v)
- with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
+ with ops.control_dependencies(
+ d.update(v, update, g, grouped=False)):
after_list.append(d.read_var(v))
return before_list, after_list
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
index 8d949943b7..d48aa9c89b 100644
--- a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
+++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import warnings
-from tensorflow.contrib.data.python.ops import prefetching_ops
+from tensorflow.python.data.experimental.ops import prefetching_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import nest as data_nest
diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py
index 5d498fb629..fd280f5754 100644
--- a/tensorflow/contrib/distribute/python/strategy_test_lib.py
+++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py
@@ -115,7 +115,8 @@ class DistributionTestBase(test.TestCase):
with ops.control_dependencies([fetched]):
g = d.reduce(
variable_scope.VariableAggregation.SUM, g, destinations=v)
- with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
+ with ops.control_dependencies(d.update(
+ v, update, g, grouped=False)):
after_list.append(d.read_var(v))
return before_list, after_list
@@ -169,7 +170,8 @@ class DistributionTestBase(test.TestCase):
with ops.control_dependencies([fetched]):
g = d.reduce(
variable_scope.VariableAggregation.SUM, g, destinations=v)
- with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
+ with ops.control_dependencies(d.update(
+ v, update, g, grouped=False)):
after_list.append(d.read_var(v))
return before_list, after_list
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index 1b555482d3..1d9e299b38 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -132,7 +132,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
"""
# TODO(sourabhbajaj): OneDeviceStrategy should be initialized with the
# master node fetched from the cluster resolver.
- super(TPUStrategy, self).__init__('/device:CPU:0')
+ super(TPUStrategy, self).__init__("/device:CPU:0")
self._tpu_cluster_resolver = tpu_cluster_resolver
self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver)
@@ -152,6 +152,8 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
# at a time is comparable to multiple steps.
self.steps_per_run = steps_per_run
+ self._require_static_shapes = True
+
def _get_enqueue_op_per_host(self, host_id, iterator, input_shapes,
iterations):
"""Create an enqueue op for a single host identified using host_id.
@@ -297,6 +299,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
# For outputs that have already been aggregated, take the first value
# from the list as each value should be the same. Else return the full
# list of values.
+ # TODO(josh11b): If aggregation is NONE, we should return a PerDevice value.
if aggregation is not variables_lib.VariableAggregation.NONE:
# TODO(priyag): Should this return the element or a list with 1 element
last_step_tensor_outputs_dict[name] = output[0]
@@ -398,11 +401,16 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
return output * (1. / len(value))
return output
- def _update(self, var, fn, *args, **kwargs):
- # TODO(jhseu): Consider supporting grouped==False.
+ def _update(self, var, options, fn, *args, **kwargs):
assert isinstance(var, values.TPUMirroredVariable)
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
+
if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access
- return fn(var, *args, **kwargs)
+ if should_group:
+ return fn(var, *args, **kwargs)
+ else:
+ return [fn(var, *args, **kwargs)]
# Otherwise, we revert to MirroredStrategy behavior and update each variable
# directly.
@@ -414,23 +422,25 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
updates[d] = fn(v,
*values.select_device_mirrored(d, args),
**values.select_device_mirrored(d, kwargs))
+ return values.update_regroup(self, updates, should_group)
- # Make a single control dependency to keep the variables mirrored. If one
- # assignment is fetched, then run all assignments.
- sorted_keys = sorted(updates.keys())
- update_tuple = control_flow_ops.tuple([updates[d] for d in sorted_keys])
- for i, d in enumerate(sorted_keys):
- updates[d] = update_tuple[i]
- return values.regroup(updates, values.Mirrored)
+ # TODO(josh11b): Need to implement _update_non_slot()!
def read_var(self, var):
assert isinstance(var, values.TPUMirroredVariable)
return var.read_value()
- def _unwrap(self, value):
- if isinstance(value, list):
- return value
- return [value]
+ def _unwrap(self, val):
+ if isinstance(val, values.DistributedValues):
+ # Return in a deterministic order.
+ return [val.get(device=d) for d in sorted(val.devices)]
+ elif isinstance(val, list):
+ # TODO(josh11b): We need to remove this case; per device values should
+ # be represented using a PerDevice wrapper instead of a list with
+ # one entry per device.
+ return val
+ return [val]
+
@property
def num_towers(self):
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index c18faeb67d..0dd78ba185 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -366,18 +366,7 @@ class MirroredVariable(DistributedVariable, Mirrored,
# We are calling assign on the mirrored variable in cross tower context,
# use update to update the variable.
strategy = distribution_strategy_context.get_distribution_strategy()
- updates = strategy.update(self, f, *args, **kwargs)
- grouped = strategy.group(updates)
- if isinstance(updates, DistributedValues) and updates.is_tensor_like:
- # Make sure we run all updates. Without this, something like
- # session.run(mirrored_var.assign*(...)) may only update one tower.
- index = {}
- for d in updates.devices:
- with ops.device(d), ops.control_dependencies([grouped]):
- index[d] = array_ops.identity(updates.get(d))
- return Mirrored(index)
- else:
- return grouped
+ return strategy.update(self, f, *args, **kwargs)
else:
_assert_tower_context()
# We are calling an assign function on the mirrored variable in tower
@@ -582,6 +571,10 @@ class TPUMirroredVariable(checkpointable.CheckpointableBase):
ValueError("Device %s not found in %s (current device %s)" %
(device, self._index.keys(), device_util.current())), e)
+ @property
+ def device(self):
+ return self._get().device
+
# The arguments to update() are automatically unwrapped so the update()
# function would normally see regular variables, not MirroredVariables.
# However, the update function can still operate on wrapped MirroredVariables
@@ -1049,6 +1042,29 @@ def select_device_mirrored(device, structured):
return nest.map_structure(_get_mirrored, structured)
+def update_regroup(strategy, updates, should_group):
+ """Regroup for an update, with dependencies to ensure all updates execute."""
+ regrouped = regroup(updates, Mirrored)
+ if not should_group:
+ return nest.map_structure(strategy.unwrap, regrouped)
+ grouped_flat = []
+ for u in nest.flatten(regrouped):
+ if isinstance(u, DistributedValues):
+ g = strategy.group(u)
+ if u.is_tensor_like:
+ # Make sure we run all updates. Without this, something like
+ # session.run(strategy.update(...)) may only update one tower.
+ index = {}
+ for d in u.devices:
+ with ops.device(d), ops.control_dependencies([g]):
+ index[d] = array_ops.identity(u.get(d))
+ g = Mirrored(index)
+ else:
+ g = u
+ grouped_flat.append(g)
+ return nest.pack_sequence_as(regrouped, grouped_flat)
+
+
class PerDeviceDataIterator(object):
"""An iterator (like `tf.data.Iterator`) into a `PerDeviceDataset`."""
diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py
index ae3e134333..121d2fbb3f 100644
--- a/tensorflow/contrib/distribute/python/values_test.py
+++ b/tensorflow/contrib/distribute/python/values_test.py
@@ -641,7 +641,7 @@ class MirroredVariableTest(test.TestCase):
if context.num_gpus() < 1 and context.executing_eagerly():
self.skipTest("A GPU is not available for this test in eager mode.")
- with self.test_session() as sess:
+ with self.cached_session(config=self.config) as sess:
v, devices, mirrored = _make_mirrored()
# Overwrite the initial values.
@@ -744,7 +744,7 @@ class MirroredVariableTest(test.TestCase):
if context.num_gpus() < 1 or context.executing_eagerly():
self.skipTest("A GPU is not available for this test or it's eager mode.")
- with self.test_session(
+ with self.session(
graph=ops.Graph()) as sess, mirrored_strategy.MirroredStrategy(
["/device:GPU:0"]).scope():
with ops.device("/device:GPU:0"):
@@ -827,7 +827,7 @@ class TowerLocalVariableTest(test.TestCase):
if context.num_gpus() < 1 and context.executing_eagerly():
self.skipTest("A GPU is not available for this test in eager mode.")
- with self.test_session() as sess:
+ with self.cached_session(config=self.config) as sess:
v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM)
# Overwrite the initial values.
@@ -850,7 +850,7 @@ class TowerLocalVariableTest(test.TestCase):
if context.num_gpus() < 1 and context.executing_eagerly():
self.skipTest("A GPU is not available for this test in eager mode.")
- with self.test_session() as sess:
+ with self.cached_session(config=self.config) as sess:
v, tower_local = _make_tower_local(
variable_scope.VariableAggregation.MEAN)
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 3ff7da4f89..60f6b90edc 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -299,7 +299,7 @@ cuda_py_test(
cuda_py_test(
name = "mvn_diag_test",
- size = "small",
+ size = "medium",
srcs = ["python/kernel_tests/mvn_diag_test.py"],
additional_deps = [
":distributions_py",
diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py
index 135095a979..3aed121233 100644
--- a/tensorflow/contrib/eager/python/datasets.py
+++ b/tensorflow/contrib/eager/python/datasets.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import prefetching_ops
+from tensorflow.python.data.experimental.ops import prefetching_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
@@ -54,7 +54,7 @@ class Iterator(iterator_ops.EagerIterator):
"""
if isinstance(dataset, prefetching_ops._PrefetchToDeviceDataset): # pylint: disable=protected-access
raise TypeError(
- "`tf.contrib.data.prefetch_to_device()` is not compatible with "
+ "`tf.data.experimental.prefetch_to_device()` is not compatible with "
"`tf.contrib.eager.Iterator`. Use `for ... in dataset:` to iterate "
"over the dataset instead.")
diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py
index a753d77580..6a508fc6ba 100644
--- a/tensorflow/contrib/eager/python/datasets_test.py
+++ b/tensorflow/contrib/eager/python/datasets_test.py
@@ -24,11 +24,11 @@ import time
import numpy as np
from tensorflow.contrib import lookup
-from tensorflow.contrib.data.python.ops import prefetching_ops
-from tensorflow.contrib.data.python.ops import threadpool
-from tensorflow.contrib.data.python.ops import unique
from tensorflow.contrib.eager.python import datasets
from tensorflow.python.data import Dataset
+from tensorflow.python.data.experimental.ops import prefetching_ops
+from tensorflow.python.data.experimental.ops import threadpool
+from tensorflow.python.data.experimental.ops import unique
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
diff --git a/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb
index 8fae622e12..446e340118 100644
--- a/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb
+++ b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb
@@ -65,7 +65,7 @@
"\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/eager/automatic_differentiation.ipynb\"\u003e\n",
" \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
"\u003c/td\u003e\u003ctd\u003e\n",
- "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/site/en/tutorials/eager/automatic_differentiation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e"
+ "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/automatic_differentiation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e"
]
}
],
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py
index 551c76b0df..f3bb978875 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py
+++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py
@@ -51,7 +51,9 @@ def random_batch(batch_size):
class ResNet50GraphTest(tf.test.TestCase):
def testApply(self):
- batch_size = 64
+ # Use small batches for tests because the OSS version runs
+ # in constrained GPU environment with 1-2GB of memory.
+ batch_size = 8
with tf.Graph().as_default():
images = tf.placeholder(tf.float32, image_shape(None))
model = resnet50.ResNet50(data_format())
@@ -63,7 +65,7 @@ class ResNet50GraphTest(tf.test.TestCase):
sess.run(init)
np_images, _ = random_batch(batch_size)
out = sess.run(predictions, feed_dict={images: np_images})
- self.assertAllEqual([64, 1000], out.shape)
+ self.assertAllEqual([batch_size, 1000], out.shape)
def testTrainWithSummary(self):
with tf.Graph().as_default():
@@ -87,7 +89,9 @@ class ResNet50GraphTest(tf.test.TestCase):
init = tf.global_variables_initializer()
self.assertEqual(321, len(tf.global_variables()))
- batch_size = 32
+ # Use small batches for tests because the OSS version runs
+ # in constrained GPU environment with 1-2GB of memory.
+ batch_size = 2
with tf.Session() as sess:
sess.run(init)
sess.run(tf.contrib.summary.summary_writer_initializer_op())
diff --git a/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py b/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py
index 34a9984b0e..d85188de03 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py
@@ -169,11 +169,11 @@ class ImageNetInput(object):
# Read the data from disk in parallel
dataset = dataset.apply(
- tf.contrib.data.parallel_interleave(
+ tf.data.experimental.parallel_interleave(
fetch_dataset, cycle_length=self.num_parallel_calls, sloppy=True))
if self.cache:
dataset = dataset.cache().apply(
- tf.contrib.data.shuffle_and_repeat(1024 * 16))
+ tf.data.experimental.shuffle_and_repeat(1024 * 16))
else:
dataset = dataset.shuffle(1024)
@@ -188,9 +188,11 @@ class ImageNetInput(object):
# batch size. As long as this validation is done with consistent batch size,
# exactly the same images will be used.
dataset = dataset.apply(
- tf.contrib.data.map_and_batch(
- self.dataset_parser, batch_size=batch_size,
- num_parallel_batches=self.num_cores, drop_remainder=True))
+ tf.data.experimental.map_and_batch(
+ self.dataset_parser,
+ batch_size=batch_size,
+ num_parallel_batches=self.num_cores,
+ drop_remainder=True))
# Transpose for performance on TPU
if self.transpose_input:
diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
index 6a921e1997..4f4cc3af6f 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
@@ -50,6 +50,9 @@ class RevNetTest(tf.test.TestCase):
# Reconstruction could cause numerical error, use double precision for tests
config.dtype = tf.float64
config.fused = False # Fused batch norm does not support tf.float64
+ # Reduce the batch size for tests because the OSS version runs
+ # in constrained GPU environment with 1-2GB of memory.
+ config.batch_size = 2
shape = (config.batch_size,) + config.input_shape
self.model = revnet.RevNet(config=config)
self.x = tf.random_normal(shape=shape, dtype=tf.float64)
diff --git a/tensorflow/contrib/eager/python/remote_test.py b/tensorflow/contrib/eager/python/remote_test.py
index ba6fe9701d..7aa4b598b8 100644
--- a/tensorflow/contrib/eager/python/remote_test.py
+++ b/tensorflow/contrib/eager/python/remote_test.py
@@ -47,8 +47,9 @@ def run_sync_and_async(f):
@functools.wraps(f)
def decorator(self, *args, **kwargs):
- with context.execution_mode(context.ASYNC):
- f(self, *args, **kwargs)
+ # TODO(b/117110239): Re-enable.
+ # with context.execution_mode(context.ASYNC):
+ # f(self, *args, **kwargs)
with context.execution_mode(context.SYNC):
f(self, *args, **kwargs)
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
index 5faf0aacfe..6ca7aaf989 100644
--- a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
+++ b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
@@ -151,7 +151,7 @@ def make_input_layer_with_layer_annotations(original_input_layer):
# spec and looking at the keys.
spec = feature_column_lib.make_parse_example_spec(feature_columns)
for key in spec.keys():
- tensor = ops.convert_to_tensor(features[key])
+ tensor = ops.convert_to_tensor_or_indexed_slices(features[key])
ops.add_to_collection(
LayerAnnotationsCollectionNames.keys(
LayerAnnotationsCollectionNames.UNPROCESSED_FEATURES), key)
diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head.py b/tensorflow/contrib/estimator/python/estimator/multi_head.py
index ce75899214..6e793c8302 100644
--- a/tensorflow/contrib/estimator/python/estimator/multi_head.py
+++ b/tensorflow/contrib/estimator/python/estimator/multi_head.py
@@ -233,6 +233,22 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access
self, features, mode, logits, labels=None, optimizer=None,
train_op_fn=None):
"""See `_Head`."""
+ return self._create_estimator_spec(
+ features=features, mode=mode, logits=logits, labels=labels,
+ optimizer=optimizer, train_op_fn=train_op_fn, use_tpu=False)
+
+ def _create_tpu_estimator_spec(
+ self, features, mode, logits, labels=None, optimizer=None,
+ train_op_fn=None):
+ """See `_Head`."""
+ return self._create_estimator_spec(
+ features=features, mode=mode, logits=logits, labels=labels,
+ optimizer=optimizer, train_op_fn=train_op_fn, use_tpu=True)
+
+ def _create_estimator_spec(
+ self, features, mode, logits, labels=None, optimizer=None,
+ train_op_fn=None, use_tpu=False):
+ """Returns `EstimatorSpec` or `TPUEstimatorSpec`."""
if isinstance(logits, dict):
logits_dict = logits
else:
@@ -255,14 +271,15 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access
spec = self._merge_train(
all_estimator_spec=all_estimator_spec,
optimizer=optimizer,
- train_op_fn=train_op_fn)
+ train_op_fn=train_op_fn,
+ use_tpu=use_tpu)
with ops.name_scope(''):
summary.scalar(metric_keys.MetricKeys.LOSS, spec.loss)
return spec
if mode == model_fn.ModeKeys.PREDICT:
- return self._merge_predict(all_estimator_spec)
+ return self._merge_predict(all_estimator_spec, use_tpu=use_tpu)
if mode == model_fn.ModeKeys.EVAL:
- return self._merge_eval(all_estimator_spec)
+ return self._merge_eval(all_estimator_spec, use_tpu=use_tpu)
raise ValueError('mode={} unrecognized'.format(mode))
def _split_logits(self, logits):
@@ -284,28 +301,28 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access
begin_idx += head.logits_dimension
return logits_dict
- def _merge_train(self, all_estimator_spec, optimizer, train_op_fn):
- """Merges list of `EstimatorSpec` for training.
+ def _merge_train(
+ self, all_estimator_spec, optimizer, train_op_fn, use_tpu=False):
+ """Merges list of `EstimatorSpec` or `TPUEstimatorSpec` for training.
Args:
- all_estimator_spec: list of `EstimatorSpec` for the individual heads.
+ all_estimator_spec: list of `EstimatorSpec` or `TPUEstimatorSpec` for the
+ individual heads.
optimizer: `Optimizer` instance to create train op. See
`create_estimator_spec` documentation for more details.
train_op_fn: Function to create train op. Used if `optimizer` is `None`.
+ use_tpu: If `True`, returns `TPUEstimatorSpec`.
Returns:
- `EstimatorSpec` that merges all heads for TRAIN.
+ `EstimatorSpec` or `TPUEstimatorSpec` that merges all heads for TRAIN.
Raises:
ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN
mode.
"""
losses = []
- metrics = {}
for spec in all_estimator_spec:
losses.append(spec.loss)
- # Metric keys already contain head.name.
- metrics.update(spec.eval_metric_ops or {})
loss = _merge_losses(losses, self._head_weights)
if optimizer is not None:
if train_op_fn is not None:
@@ -317,20 +334,23 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access
else:
raise ValueError('train_op_fn and optimizer cannot both be None.')
- return model_fn.EstimatorSpec(
+ spec_type = (
+ model_fn._TPUEstimatorSpec if use_tpu else model_fn.EstimatorSpec) # pylint:disable=protected-access
+ return spec_type(
mode=model_fn.ModeKeys.TRAIN,
loss=loss,
- train_op=train_op,
- eval_metric_ops=metrics)
+ train_op=train_op)
- def _merge_predict(self, all_estimator_spec):
- """Merges list of `EstimatorSpec` for prediction.
+ def _merge_predict(self, all_estimator_spec, use_tpu=False):
+ """Merges list of `EstimatorSpec` or `TPUEstimatorSpec` for prediction.
Args:
- all_estimator_spec: list of `EstimatorSpec` for the individual heads.
+ all_estimator_spec: list of `EstimatorSpec` or `TPUEstimatorSpec` for the
+ individual heads.
+ use_tpu: If `True`, returns `TPUEstimatorSpec`.
Returns:
- `EstimatorSpec` that merges all heads for PREDICT.
+ `EstimatorSpec` or `TPUEstimatorSpec` that merges all heads for PREDICT.
"""
predictions = {}
export_outputs = {
@@ -357,20 +377,29 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access
export_outputs[head_lib._PREDICT_SERVING_KEY] = ( # pylint:disable=protected-access
export_output_lib.PredictOutput(merged_predict_outputs))
- return model_fn.EstimatorSpec(
+ spec_type = (
+ model_fn._TPUEstimatorSpec if use_tpu else model_fn.EstimatorSpec) # pylint:disable=protected-access
+ return spec_type(
mode=model_fn.ModeKeys.PREDICT,
predictions=predictions,
export_outputs=export_outputs)
- def _merge_eval(self, all_estimator_spec):
+ def _merge_eval(self, all_estimator_spec, use_tpu=False):
"""Merges list of `EstimatorSpec` for eval.
Args:
all_estimator_spec: list of `EstimatorSpec` for the individual heads.
+ use_tpu: If `True`, will raise `NotImplementedError`, because TPU is not
+ yet supported for eval.
Returns:
`EstimatorSpec` that merges all heads for EVAL.
+ Raises:
+ NotImplementedError: If `use_tpu` is `True`.
"""
+ if use_tpu:
+ raise NotImplementedError(
+ 'TPU evaluation is not implemented for multi_head.')
predictions = {}
metrics = {}
losses = []
diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
index 2b4d5f5261..a602f87b4a 100644
--- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py
@@ -106,7 +106,7 @@ class MultiHeadTest(test.TestCase):
multi_head = multi_head_lib.multi_head([head1, head2])
self.assertEqual('head1_head2', multi_head.name)
- def test_predict_two_heads_logits_dict(self):
+ def _test_predict_two_heads_logits_dict(self, use_tpu):
"""Tests predict with logits as dict."""
head1 = head_lib.multi_label_head(n_classes=2, name='head1')
head2 = head_lib.multi_label_head(n_classes=3, name='head2')
@@ -121,10 +121,16 @@ class MultiHeadTest(test.TestCase):
'head2': _sigmoid(logits['head2']),
}
- spec = multi_head.create_estimator_spec(
- features={'x': np.array(((42,),), dtype=np.int32)},
- mode=model_fn.ModeKeys.PREDICT,
- logits=logits)
+ if use_tpu:
+ spec = multi_head._create_tpu_estimator_spec(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.PREDICT,
+ logits=logits).as_estimator_spec()
+ else:
+ spec = multi_head.create_estimator_spec(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.PREDICT,
+ logits=logits)
self.assertItemsEqual(
(_DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/classification',
@@ -175,6 +181,12 @@ class MultiHeadTest(test.TestCase):
sess.run(
spec.export_outputs['head2/predict'].outputs['probabilities']))
+ def test_predict_two_heads_logits_dict(self):
+ self._test_predict_two_heads_logits_dict(use_tpu=False)
+
+ def test_predict_two_heads_logits_dict_tpu(self):
+ self._test_predict_two_heads_logits_dict(use_tpu=True)
+
def test_predict_two_heads_logits_tensor(self):
"""Tests predict with logits as Tensor."""
head1 = head_lib.multi_label_head(n_classes=2, name='head1')
@@ -350,6 +362,31 @@ class MultiHeadTest(test.TestCase):
rtol=tol,
atol=tol)
+ def test_eval_tpu(self):
+ head1 = head_lib.multi_label_head(n_classes=2, name='head1')
+ head2 = head_lib.multi_label_head(n_classes=3, name='head2')
+ multi_head = multi_head_lib.multi_head(
+ [head1, head2], head_weights=[1., 2.])
+
+ logits = {
+ 'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),
+ 'head2': np.array([[20., -20., 20.], [-30., 20., -20.]],
+ dtype=np.float32),
+ }
+ labels = {
+ 'head1': np.array([[1, 0], [1, 1]], dtype=np.int64),
+ 'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64),
+ }
+
+ with self.assertRaisesRegexp(
+ NotImplementedError,
+ r'TPU evaluation is not implemented for multi_head\.'):
+ multi_head._create_tpu_estimator_spec(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.EVAL,
+ logits=logits,
+ labels=labels)
+
def test_train_create_loss_one_head(self):
head1 = head_lib.multi_label_head(n_classes=2, name='head1')
multi_head = multi_head_lib.multi_head([head1])
@@ -587,7 +624,7 @@ class MultiHeadTest(test.TestCase):
six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
train_result)
- def test_train_two_heads_with_weights(self):
+ def _test_train_two_heads_with_weights(self, use_tpu):
head1 = head_lib.multi_label_head(n_classes=2, name='head1')
head2 = head_lib.multi_label_head(n_classes=3, name='head2')
multi_head = multi_head_lib.multi_head(
@@ -619,12 +656,20 @@ class MultiHeadTest(test.TestCase):
[constant_op.constant(expected_train_result),
string_ops.as_string(loss, precision=3)])
- spec = multi_head.create_estimator_spec(
- features={'x': np.array(((42,),), dtype=np.int32)},
- mode=model_fn.ModeKeys.TRAIN,
- logits=logits,
- labels=labels,
- train_op_fn=_train_op_fn)
+ if use_tpu:
+ spec = multi_head._create_tpu_estimator_spec(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=logits,
+ labels=labels,
+ train_op_fn=_train_op_fn).as_estimator_spec()
+ else:
+ spec = multi_head.create_estimator_spec(
+ features={'x': np.array(((42,),), dtype=np.int32)},
+ mode=model_fn.ModeKeys.TRAIN,
+ logits=logits,
+ labels=labels,
+ train_op_fn=_train_op_fn)
self.assertIsNotNone(spec.loss)
self.assertEqual({}, spec.eval_metric_ops)
@@ -649,6 +694,12 @@ class MultiHeadTest(test.TestCase):
metric_keys.MetricKeys.LOSS + '/head2': expected_loss_head2,
}, summary_str, tol)
+ def test_train_two_heads_with_weights(self):
+ self._test_train_two_heads_with_weights(use_tpu=False)
+
+ def test_train_two_heads_with_weights_tpu(self):
+ self._test_train_two_heads_with_weights(use_tpu=True)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/estimator/python/estimator/rnn_test.py b/tensorflow/contrib/estimator/python/estimator/rnn_test.py
index 1aebed348d..89506ee661 100644
--- a/tensorflow/contrib/estimator/python/estimator/rnn_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/rnn_test.py
@@ -25,12 +25,12 @@ import tempfile
import numpy as np
import six
-from tensorflow.contrib.data.python.ops import readers
from tensorflow.contrib.estimator.python.estimator import head as head_lib
from tensorflow.contrib.estimator.python.estimator import rnn
from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as seq_fc
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
+from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.estimator import model_fn
from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.canned import parsing_utils
diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD
index 510f292508..e344d7a23b 100644
--- a/tensorflow/contrib/factorization/BUILD
+++ b/tensorflow/contrib/factorization/BUILD
@@ -154,8 +154,6 @@ tf_py_test(
],
tags = [
"no_pip", # b/38283730
- "noasan", # b/116875897
- "nomsan",
"notsan", # Flaky: b/30756419
],
)
@@ -179,11 +177,7 @@ tf_py_test(
"//tensorflow/python:random_seed",
"//tensorflow/python:variables",
],
- tags = [
- "noasan", # b/116875897
- "nomsan",
- "notsan", # b/62863147
- ],
+ tags = ["notsan"], # b/62863147
)
py_library(
@@ -282,7 +276,6 @@ tf_py_test(
"manual",
"noasan", # times out b/63678675
"nomsan",
- "notsan", # b/116875897
],
)
diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops.py b/tensorflow/contrib/factorization/python/ops/gmm_ops.py
index e076631bc1..d365ad1117 100644
--- a/tensorflow/contrib/factorization/python/ops/gmm_ops.py
+++ b/tensorflow/contrib/factorization/python/ops/gmm_ops.py
@@ -154,10 +154,10 @@ class GmmAlgorithm(object):
def _create_variables(self):
"""Initializes GMM algorithm."""
init_value = array_ops.constant([], dtype=dtypes.float32)
- self._means = variables.Variable(init_value,
- name=self.CLUSTERS_VARIABLE,
- validate_shape=False)
- self._covs = variables.Variable(
+ self._means = variables.VariableV1(init_value,
+ name=self.CLUSTERS_VARIABLE,
+ validate_shape=False)
+ self._covs = variables.VariableV1(
init_value, name=self.CLUSTERS_COVS_VARIABLE, validate_shape=False)
# Mixture weights, representing the probability that a randomly
# selected unobservable data (in EM terms) was generated by component k.
@@ -165,9 +165,9 @@ class GmmAlgorithm(object):
array_ops.tile([1.0 / self._num_classes], [self._num_classes]),
name=self.CLUSTERS_WEIGHT,
validate_shape=False)
- self._cluster_centers_initialized = variables.Variable(False,
- dtype=dtypes.bool,
- name='initialized')
+ self._cluster_centers_initialized = variables.VariableV1(False,
+ dtype=dtypes.bool,
+ name='initialized')
def _initialize_variables(self, data, initial_means=None):
"""Initializes variables.
diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py
index 9bdbd05015..75d577f429 100644
--- a/tensorflow/contrib/factorization/python/ops/wals_test.py
+++ b/tensorflow/contrib/factorization/python/ops/wals_test.py
@@ -420,13 +420,13 @@ class WALSMatrixFactorizationUnsupportedTest(test.TestCase):
class SweepHookTest(test.TestCase):
def test_sweeps(self):
- is_row_sweep_var = variables.Variable(True)
- is_sweep_done_var = variables.Variable(False)
- init_done = variables.Variable(False)
- row_prep_done = variables.Variable(False)
- col_prep_done = variables.Variable(False)
- row_train_done = variables.Variable(False)
- col_train_done = variables.Variable(False)
+ is_row_sweep_var = variables.VariableV1(True)
+ is_sweep_done_var = variables.VariableV1(False)
+ init_done = variables.VariableV1(False)
+ row_prep_done = variables.VariableV1(False)
+ col_prep_done = variables.VariableV1(False)
+ row_train_done = variables.VariableV1(False)
+ col_train_done = variables.VariableV1(False)
init_op = state_ops.assign(init_done, True)
row_prep_op = state_ops.assign(row_prep_done, True)
@@ -486,7 +486,7 @@ class StopAtSweepHookTest(test.TestCase):
def test_stop(self):
hook = wals_lib._StopAtSweepHook(last_sweep=10)
- completed_sweeps = variables.Variable(
+ completed_sweeps = variables.VariableV1(
8, name=wals_lib.WALSMatrixFactorization.COMPLETED_SWEEPS)
train_op = state_ops.assign_add(completed_sweeps, 1)
hook.begin()
diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD
index 490da9b33b..57a5bfbf43 100644
--- a/tensorflow/contrib/fused_conv/BUILD
+++ b/tensorflow/contrib/fused_conv/BUILD
@@ -145,6 +145,7 @@ cuda_py_test(
"//tensorflow/python:client_testlib",
],
tags = [
+ "manual", # TODO(b/117128481): re-enable after fixing OSS build
"no_pip",
"requires-gpu-sm70",
],
@@ -169,6 +170,7 @@ cuda_py_test(
],
main = "python/ops/fused_conv2d_bias_activation_benchmark.py",
tags = [
+ "manual", # TODO(b/117128481): re-enable after fixing OSS build
"requires-gpu-sm70",
],
)
diff --git a/tensorflow/contrib/ignite/BUILD b/tensorflow/contrib/ignite/BUILD
new file mode 100644
index 0000000000..9393b702d1
--- /dev/null
+++ b/tensorflow/contrib/ignite/BUILD
@@ -0,0 +1,139 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "if_not_windows",
+ "if_windows",
+ "tf_custom_op_library",
+ "tf_custom_op_py_library",
+ "tf_gen_op_libs",
+ "tf_gen_op_wrapper_py",
+ "tf_kernel_library",
+ "tf_py_test",
+)
+
+py_library(
+ name = "ignite",
+ srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_ops",
+ ],
+)
+
+tf_custom_op_library(
+ name = "_dataset_ops.so",
+ srcs = ["ops/dataset_ops.cc"],
+ deps = [":dataset_kernels"],
+)
+
+tf_gen_op_libs(
+ op_lib_names = ["dataset_ops"],
+)
+
+cc_library(
+ name = "dataset_kernels",
+ srcs = [
+ "kernels/ignite_dataset_ops.cc",
+ "kernels/ignite_client.h",
+ "kernels/ignite_byte_swapper.h",
+ "kernels/ignite_plain_client.h",
+ "kernels/ignite_ssl_wrapper.h",
+ "kernels/ignite_ssl_wrapper.cc",
+ "kernels/ignite_binary_object_parser.h",
+ "kernels/ignite_binary_object_parser.cc",
+ "kernels/ignite_dataset.h",
+ "kernels/ignite_dataset.cc",
+ "kernels/ignite_dataset_iterator.h",
+ "kernels/ignite_dataset_iterator.cc",
+ ] + if_not_windows([
+ "kernels/ignite_plain_client_unix.cc",
+ ]) + if_windows([
+ "kernels/ignite_plain_client_windows.cc",
+ ]),
+ copts = if_windows([
+ "-DWIN32_LEAN_AND_MEAN",
+ ]),
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ "@boringssl//:ssl",
+ "@protobuf_archive//:protobuf_headers",
+ ],
+ alwayslink = 1,
+)
+
+py_library(
+ name = "dataset_ops",
+ srcs = [
+ "python/ops/ignite_dataset_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":ignite_op_loader",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "gen_dataset_ops",
+ out = "python/ops/gen_dataset_ops.py",
+ deps = ["//tensorflow/contrib/ignite:dataset_ops_op_lib"],
+)
+
+tf_kernel_library(
+ name = "dataset_ops_kernels",
+ deps = [
+ ":dataset_kernels",
+ "//tensorflow/core:framework",
+ ],
+ alwayslink = 1,
+)
+
+tf_custom_op_py_library(
+ name = "ignite_op_loader",
+ srcs = ["python/ops/ignite_op_loader.py"],
+ dso = ["//tensorflow/contrib/ignite:_dataset_ops.so"],
+ kernels = [
+ ":dataset_ops_kernels",
+ "//tensorflow/contrib/ignite:dataset_ops_op_lib",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":gen_dataset_ops",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:platform",
+ ],
+)
+
+# The Apache Ignite servers have to setup before the test and tear down
+# after the test manually. The docker engine has to be installed.
+#
+# To setup Apache Ignite servers:
+# $ bash ./python/tests/start_ignite.sh
+#
+# To tear down Apache Ignite servers:
+# $ bash ./python/tests/stop_ignite.sh
+tf_py_test(
+ name = "ignite_dataset_test",
+ srcs = ["python/tests/ignite_dataset_test.py"],
+ additional_deps = [
+ ":ignite",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+ tags = [
+ "manual",
+ "no_windows",
+ "notap",
+ ],
+)
diff --git a/tensorflow/contrib/ignite/README.md b/tensorflow/contrib/ignite/README.md
new file mode 100644
index 0000000000..55c89d2799
--- /dev/null
+++ b/tensorflow/contrib/ignite/README.md
@@ -0,0 +1,167 @@
+# Ignite Dataset
+
+- [Overview](#overview)
+- [Features](#features)
+ * [Distributed In-Memory Datasource](#distributed-in-memory-datasource)
+ * [Structured Objects](#structured-objects)
+ * [Distributed Training](#distributed-training)
+ * [SSL Connection](#ssl-connection)
+ * [Windows Support](#windows-support)
+- [Try it out](#try-it-out)
+- [Limitations](#limitations)
+
+## Overview
+
+[Apache Ignite](https://ignite.apache.org/) is a memory-centric distributed database, caching, and processing platform for
+transactional, analytical, and streaming workloads, delivering in-memory speeds at petabyte scale. This contrib package contains an integration between Apache Ignite and TensorFlow. The integration is based on [tf.data](https://www.tensorflow.org/api_docs/python/tf/data) from TensorFlow side and [Binary Client Protocol](https://apacheignite.readme.io/v2.6/docs/binary-client-protocol) from Apache Ignite side. It allows to use Apache Ignite as a data source for neural network training, inference and all other computations supported by TensorFlow.
+
+## Features
+
+Ignite Dataset provides features that that you can use in a wide range of cases. The most important and interesting features are described below.
+
+### Distributed In-Memory Datasource
+[Apache Ignite](https://ignite.apache.org/) is a distributed in-memory database, caching, and processing platform that provides fast data access. It allows you to avoid limitations of hard drive and store and operate with as much data as you need in distributed cluster. You can utilize
+these benefits of Apache Ignite by using Ignite Dataset. Moreover, Ignite Dataset can be used for the following use-cases:
+- If you have a **gigabyte** of data you can keep it on a single machine on a hard drive, but you will face with hard drive speed limitations. At the same time, you can store your data in Apache Ignite on the same machine and use it as a datasource for TensorFlow and thus avoid these limitations.
+- If you have a **terabyte** of data you probably still can keep it on a single machine on a hard drive, but you will face with hard drive speed limitations again. At the same time, you can store your data in Apache Ignite distributed in-memory cluster and use it as a datasource for TensorFlow and thus avoid these limitations.
+- If you have a **petabyte** of data you can't keep it on a single machine. At the same time, you can store your data in Apache Ignite distributed in-memory cluster and use it as a datasource for TensorFlow.
+
+Note that Apache Ignite is not just a step of ETL pipeline between a database or a data warehouse and TensorFlow. Apache Ignite is a high-grade database itself. By choosing Apache Ignite and TensorFlow you are getting everything you need to work with operational or historical data and, at the same time, an ability to use this data for neural network training and inference.
+
+```bash
+$ apache-ignite-fabric/bin/ignite.sh
+$ apache-ignite-fabric/bin/sqlline.sh -u "jdbc:ignite:thin://localhost:10800/"
+
+jdbc:ignite:thin://localhost/> CREATE TABLE KITTEN_CACHE (ID LONG PRIMARY KEY, NAME VARCHAR);
+jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (1, 'WARM KITTY');
+jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (2, 'SOFT KITTY');
+jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (3, 'LITTLE BALL OF FUR');
+```
+
+```python
+>>> import tensorflow as tf
+>>> from tensorflow.contrib.ignite import IgniteDataset
+>>>
+>>> dataset = IgniteDataset(cache_name="SQL_PUBLIC_KITTEN_CACHE")
+>>> iterator = dataset.make_one_shot_iterator()
+>>> next_obj = iterator.get_next()
+>>>
+>>> with tf.Session() as sess:
+>>> for _ in range(3):
+>>> print(sess.run(next_obj))
+
+{'key': 1, 'val': {'NAME': b'WARM KITTY'}}
+{'key': 2, 'val': {'NAME': b'SOFT KITTY'}}
+{'key': 3, 'val': {'NAME': b'LITTLE BALL OF FUR'}}
+```
+
+### Structured Objects
+[Apache Ignite](https://ignite.apache.org/) allows to store any type of objects. These objects can have any hierarchy. Ignite Dataset provides an ability to work with such objects.
+
+```python
+>>> import tensorflow as tf
+>>> from tensorflow.contrib.ignite import IgniteDataset
+>>>
+>>> dataset = IgniteDataset(cache_name="IMAGES")
+>>> iterator = dataset.make_one_shot_iterator()
+>>> next_obj = iterator.get_next()
+>>>
+>>> with tf.Session() as sess:
+>>> print(sess.run(next_obj))
+
+{
+ 'key': 'kitten.png',
+ 'val': {
+ 'metadata': {
+ 'file_name': b'kitten.png',
+ 'label': b'little ball of fur',
+ width: 800,
+ height: 600
+ },
+ 'pixels': [0, 0, 0, 0, ..., 0]
+ }
+}
+```
+ Neural network training and other computations require transformations that can be done as part of [tf.data](https://www.tensorflow.org/api_docs/python/tf/data) pipeline if you use Ignite Dataset.
+
+```python
+>>> import tensorflow as tf
+>>> from tensorflow.contrib.ignite import IgniteDataset
+>>>
+>>> dataset = IgniteDataset(cache_name="IMAGES").map(lambda obj: obj['val']['pixels'])
+>>> iterator = dataset.make_one_shot_iterator()
+>>> next_obj = iterator.get_next()
+>>>
+>>> with tf.Session() as sess:
+>>> print(sess.run(next_obj))
+
+[0, 0, 0, 0, ..., 0]
+```
+
+### Distributed Training
+
+TensorFlow is a machine learning framework that [natively supports](https://www.tensorflow.org/deploy/distributed) distributed neural network training, inference and other computations. The main idea behind the distributed neural network training is the ability to calculate gradients of loss functions (squares of the errors) on every partition of data (in terms of horizontal partitioning) and then sum them to get loss function gradient of the whole dataset.
+
+<a href="https://www.codecogs.com/eqnedit.php?latex=\nabla[\sum_1^n(y&space;-&space;\hat{y})^2]&space;=&space;\nabla[\sum_1^{n_1}(y&space;-&space;\hat{y})^2]&space;&plus;&space;\nabla[\sum_{n_1}^{n_2}(y&space;-&space;\hat{y})^2]&space;&plus;&space;...&space;&plus;&space;\nabla[\sum_{n_{k-1}}^n(y&space;-&space;\hat{y})^2]" target="_blank"><img src="https://latex.codecogs.com/gif.latex?\nabla[\sum_1^n(y&space;-&space;\hat{y})^2]&space;=&space;\nabla[\sum_1^{n_1}(y&space;-&space;\hat{y})^2]&space;&plus;&space;\nabla[\sum_{n_1}^{n_2}(y&space;-&space;\hat{y})^2]&space;&plus;&space;...&space;&plus;&space;\nabla[\sum_{n_{k-1}}^n(y&space;-&space;\hat{y})^2]" title="\nabla[\sum_1^n(y - \hat{y})^2] = \nabla[\sum_1^{n_1}(y - \hat{y})^2] + \nabla[\sum_{n_1}^{n_2}(y - \hat{y})^2] + ... + \nabla[\sum_{n_{k-1}}^n(y - \hat{y})^2]" /></a>
+
+Using this ability we can calculate gradients on the nodes the data is stored on, reduce them and then finally update model parameters. It allows to avoid data transfers between nodes and thus to avoid network bottlenecks.
+
+Apache Ignite uses horizontal partitioning to store data in distributed cluster. When we create Apache Ignite cache (or table in terms of SQL), we can specify the number of partitions the data will be partitioned on. For example, if an Apache Ignite cluster consists of 10 machines and we create cache with 10 partitions, then every machine will maintain approximately one data partition.
+
+Ignite Dataset allows using these two aspects of distributed neural network training (using TensorFlow) and Apache Ignite partitioning. Ignite Dataset is a computation graph operation that can be performed on a remote worker. The remote worker can override Ignite Dataset parameters (such as `host`, `port` or `part`) by setting correstondent environment variables for worker process (such as `IGNITE_DATASET_HOST`, `IGNITE_DATASET_PORT` or `IGNITE_DATASET_PART`). Using this overriding approach, we can assign a specific partition to every worker so that one worker handles one partition and, at the same time, transparently work with single dataset.
+
+```python
+>>> import tensorflow as tf
+>>> from tensorflow.contrib.ignite import IgniteDataset
+>>>
+>>> dataset = IgniteDataset("IMAGES")
+>>>
+>>> # Compute gradients locally on every worker node.
+>>> gradients = []
+>>> for i in range(5):
+>>> with tf.device("/job:WORKER/task:%d" % i):
+>>> device_iterator = dataset.make_one_shot_iterator()
+>>> device_next_obj = device_iterator.get_next()
+>>> gradient = compute_gradient(device_next_obj)
+>>> gradients.append(gradient)
+>>>
+>>> # Aggregate them on master node.
+>>> result_gradient = tf.reduce_sum(gradients)
+>>>
+>>> with tf.Session("grpc://localhost:10000") as sess:
+>>> print(sess.run(result_gradient))
+```
+
+High-level TensorFlow API for [distributed training](https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/DistributionStrategy) is supported as well.
+
+### SSL Connection
+
+Apache Ignite allows to protect data transfer channels by [SSL](https://en.wikipedia.org/wiki/Transport_Layer_Security) and authentification. Ignite Dataset supports both SSL connection with and without authntication. For more information, please refer to the [Apache Ignite SSL/TLS](https://apacheignite.readme.io/docs/ssltls) documentation.
+
+```python
+>>> import tensorflow as tf
+>>> from tensorflow.contrib.ignite import IgniteDataset
+>>>
+>>> dataset = IgniteDataset(cache_name="IMAGES", certfile="client.pem", cert_password="password", username="ignite", password="ignite")
+>>> ...
+```
+
+### Windows Support
+
+Ignite Dataset is fully compatible with Windows. You can use it as part of TensorFlow on your Windows workstation as well as on Linux/MacOS systems.
+
+## Try it out
+
+The simplest way to try Ignite Dataset is to run a [Docker](https://www.docker.com/) container with Apache Ignite and loaded [MNIST](http://yann.lecun.com/exdb/mnist/) data and after start interruct with it using Ignite Dataset. Such container is available on Docker Hub: [dmitrievanthony/ignite-with-mnist](https://hub.docker.com/r/dmitrievanthony/ignite-with-mnist/). You need to start this container on your machine:
+
+```
+docker run -it -p 10800:10800 dmitrievanthony/ignite-with-mnist
+```
+
+After that you will be able to work with it following way:
+
+![ignite-dataset-mnist](https://s3.amazonaws.com/helloworld23423423ew23/ignite-dataset-mnist.png "Ignite Dataset Mnist")
+
+## Limitations
+
+Presently, Ignite Dataset works with assumption that all objects in the cache have the same structure (homogeneous objects) and the cache contains at least one object. Another limitation concerns structured objects, Ignite Dataset does not support UUID, Maps and Object arrays that might be parts of an object structure.
diff --git a/tensorflow/contrib/ignite/__init__.py b/tensorflow/contrib/ignite/__init__.py
new file mode 100644
index 0000000000..f42947696f
--- /dev/null
+++ b/tensorflow/contrib/ignite/__init__.py
@@ -0,0 +1,42 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""IgniteDataset that allows to get data from Apache Ignite.
+
+Apache Ignite is a memory-centric distributed database, caching, and
+processing platform for transactional, analytical, and streaming workloads,
+delivering in-memory speeds at petabyte scale. This contrib package
+contains an integration between Apache Ignite and TensorFlow. The
+integration is based on tf.data from TensorFlow side and Binary Client
+Protocol from Apache Ignite side. It allows to use Apache Ignite as a
+datasource for neural network training, inference and all other
+computations supported by TensorFlow. Ignite Dataset is based on Apache
+Ignite Binary Client Protocol:
+https://apacheignite.readme.io/v2.6/docs/binary-client-protocol.
+
+@@IgniteDataset
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.ignite.python.ops.ignite_dataset_ops import IgniteDataset
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+ "IgniteDataset",
+]
+
+remove_undocumented(__name__)
diff --git a/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc b/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc
new file mode 100644
index 0000000000..2c8a7d44b0
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc
@@ -0,0 +1,334 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+BinaryObjectParser::BinaryObjectParser() : byte_swapper_(ByteSwapper(false)) {}
+
+Status BinaryObjectParser::Parse(uint8_t** ptr,
+ std::vector<Tensor>* out_tensors,
+ std::vector<int32_t>* types) const {
+ uint8_t object_type_id = ParseByte(ptr);
+
+ // Skip non-leaf nodes.
+ if (object_type_id != WRAPPED_OBJ && object_type_id != COMPLEX_OBJ)
+ types->push_back(object_type_id);
+
+ switch (object_type_id) {
+ case BYTE: {
+ out_tensors->emplace_back(cpu_allocator(), DT_UINT8, TensorShape({}));
+ out_tensors->back().scalar<uint8>()() = ParseByte(ptr);
+ break;
+ }
+ case SHORT: {
+ out_tensors->emplace_back(cpu_allocator(), DT_INT16, TensorShape({}));
+ out_tensors->back().scalar<int16>()() = ParseShort(ptr);
+ break;
+ }
+ case USHORT: {
+ out_tensors->emplace_back(cpu_allocator(), DT_UINT16, TensorShape({}));
+ out_tensors->back().scalar<uint16>()() = ParseUnsignedShort(ptr);
+ break;
+ }
+ case INT: {
+ out_tensors->emplace_back(cpu_allocator(), DT_INT32, TensorShape({}));
+ out_tensors->back().scalar<int32>()() = ParseInt(ptr);
+ break;
+ }
+ case LONG: {
+ out_tensors->emplace_back(cpu_allocator(), DT_INT64, TensorShape({}));
+ out_tensors->back().scalar<int64>()() = ParseLong(ptr);
+ break;
+ }
+ case FLOAT: {
+ out_tensors->emplace_back(cpu_allocator(), DT_FLOAT, TensorShape({}));
+ out_tensors->back().scalar<float>()() = ParseFloat(ptr);
+ break;
+ }
+ case DOUBLE: {
+ out_tensors->emplace_back(cpu_allocator(), DT_DOUBLE, TensorShape({}));
+ out_tensors->back().scalar<double>()() = ParseDouble(ptr);
+ break;
+ }
+ case BOOL: {
+ out_tensors->emplace_back(cpu_allocator(), DT_BOOL, TensorShape({}));
+ out_tensors->back().scalar<bool>()() = ParseBool(ptr);
+ break;
+ }
+ case STRING: {
+ out_tensors->emplace_back(cpu_allocator(), DT_STRING, TensorShape({}));
+ out_tensors->back().scalar<string>()() = ParseString(ptr);
+ break;
+ }
+ case DATE: {
+ out_tensors->emplace_back(cpu_allocator(), DT_INT64, TensorShape({}));
+ out_tensors->back().scalar<int64>()() = ParseLong(ptr);
+ break;
+ }
+ case BYTE_ARR: {
+ int32_t length = ParseInt(ptr);
+ uint8_t* arr = ParseByteArr(ptr, length);
+ out_tensors->emplace_back(cpu_allocator(), DT_UINT8,
+ TensorShape({length}));
+ std::copy_n(arr, length, out_tensors->back().flat<uint8>().data());
+ break;
+ }
+ case SHORT_ARR: {
+ int32_t length = ParseInt(ptr);
+ int16_t* arr = ParseShortArr(ptr, length);
+ out_tensors->emplace_back(cpu_allocator(), DT_INT16,
+ TensorShape({length}));
+ std::copy_n(arr, length, out_tensors->back().flat<int16>().data());
+ break;
+ }
+ case USHORT_ARR: {
+ int32_t length = ParseInt(ptr);
+ uint16_t* arr = ParseUnsignedShortArr(ptr, length);
+ out_tensors->emplace_back(cpu_allocator(), DT_UINT16,
+ TensorShape({length}));
+ std::copy_n(arr, length, out_tensors->back().flat<uint16>().data());
+ break;
+ }
+ case INT_ARR: {
+ int32_t length = ParseInt(ptr);
+ int32_t* arr = ParseIntArr(ptr, length);
+ out_tensors->emplace_back(cpu_allocator(), DT_INT32,
+ TensorShape({length}));
+ std::copy_n(arr, length, out_tensors->back().flat<int32>().data());
+ break;
+ }
+ case LONG_ARR: {
+ int32_t length = ParseInt(ptr);
+ int64_t* arr = ParseLongArr(ptr, length);
+ out_tensors->emplace_back(cpu_allocator(), DT_INT64,
+ TensorShape({length}));
+ std::copy_n(arr, length, out_tensors->back().flat<int64>().data());
+ break;
+ }
+ case FLOAT_ARR: {
+ int32_t length = ParseInt(ptr);
+ float* arr = ParseFloatArr(ptr, length);
+ out_tensors->emplace_back(cpu_allocator(), DT_FLOAT,
+ TensorShape({length}));
+ std::copy_n(arr, length, out_tensors->back().flat<float>().data());
+ break;
+ }
+ case DOUBLE_ARR: {
+ int32_t length = ParseInt(ptr);
+ double* arr = ParseDoubleArr(ptr, length);
+ out_tensors->emplace_back(cpu_allocator(), DT_DOUBLE,
+ TensorShape({length}));
+ std::copy_n(arr, length, out_tensors->back().flat<double>().data());
+ break;
+ }
+ case BOOL_ARR: {
+ int32_t length = ParseInt(ptr);
+ bool* arr = ParseBoolArr(ptr, length);
+ out_tensors->emplace_back(cpu_allocator(), DT_BOOL,
+ TensorShape({length}));
+ std::copy_n(arr, length, out_tensors->back().flat<bool>().data());
+ break;
+ }
+ case STRING_ARR: {
+ int32_t length = ParseInt(ptr);
+ out_tensors->emplace_back(cpu_allocator(), DT_STRING,
+ TensorShape({length}));
+ for (int32_t i = 0; i < length; i++)
+ out_tensors->back().vec<string>()(i) = ParseString(ptr);
+ break;
+ }
+ case DATE_ARR: {
+ int32_t length = ParseInt(ptr);
+ int64_t* arr = ParseLongArr(ptr, length);
+ out_tensors->emplace_back(cpu_allocator(), DT_INT64,
+ TensorShape({length}));
+ std::copy_n(arr, length, out_tensors->back().flat<int64>().data());
+ break;
+ }
+ case WRAPPED_OBJ: {
+ int32_t byte_arr_size = ParseInt(ptr);
+ TF_RETURN_IF_ERROR(Parse(ptr, out_tensors, types));
+ int32_t offset = ParseInt(ptr);
+
+ break;
+ }
+ case COMPLEX_OBJ: {
+ uint8_t version = ParseByte(ptr);
+ int16_t flags = ParseShort(ptr);
+ int32_t type_id = ParseInt(ptr);
+ int32_t hash_code = ParseInt(ptr);
+ int32_t length = ParseInt(ptr);
+ int32_t schema_id = ParseInt(ptr);
+ int32_t schema_offset = ParseInt(ptr);
+
+ // 24 is size of header just read.
+ uint8_t* end = *ptr + schema_offset - 24;
+ int32_t i = 0;
+ while (*ptr < end) {
+ i++;
+ TF_RETURN_IF_ERROR(Parse(ptr, out_tensors, types));
+ }
+
+ *ptr += (length - schema_offset);
+
+ break;
+ }
+ default: {
+ return errors::Unknown("Unknowd binary type (type id ",
+ (int)object_type_id, ")");
+ }
+ }
+
+ return Status::OK();
+}
+
+uint8_t BinaryObjectParser::ParseByte(uint8_t** ptr) const {
+ uint8_t res = **ptr;
+ *ptr += 1;
+
+ return res;
+}
+
+int16_t BinaryObjectParser::ParseShort(uint8_t** ptr) const {
+ int16_t* res = *reinterpret_cast<int16_t**>(ptr);
+ byte_swapper_.SwapIfRequiredInt16(res);
+ *ptr += 2;
+
+ return *res;
+}
+
+uint16_t BinaryObjectParser::ParseUnsignedShort(uint8_t** ptr) const {
+ uint16_t* res = *reinterpret_cast<uint16_t**>(ptr);
+ byte_swapper_.SwapIfRequiredUnsignedInt16(res);
+ *ptr += 2;
+
+ return *res;
+}
+
+int32_t BinaryObjectParser::ParseInt(uint8_t** ptr) const {
+ int32_t* res = *reinterpret_cast<int32_t**>(ptr);
+ byte_swapper_.SwapIfRequiredInt32(res);
+ *ptr += 4;
+
+ return *res;
+}
+
+int64_t BinaryObjectParser::ParseLong(uint8_t** ptr) const {
+ int64_t* res = *reinterpret_cast<int64_t**>(ptr);
+ byte_swapper_.SwapIfRequiredInt64(res);
+ *ptr += 8;
+
+ return *res;
+}
+
+float BinaryObjectParser::ParseFloat(uint8_t** ptr) const {
+ float* res = *reinterpret_cast<float**>(ptr);
+ byte_swapper_.SwapIfRequiredFloat(res);
+ *ptr += 4;
+
+ return *res;
+}
+
+double BinaryObjectParser::ParseDouble(uint8_t** ptr) const {
+ double* res = *reinterpret_cast<double**>(ptr);
+ byte_swapper_.SwapIfRequiredDouble(res);
+ *ptr += 8;
+
+ return *res;
+}
+
+bool BinaryObjectParser::ParseBool(uint8_t** ptr) const {
+ bool res = **reinterpret_cast<bool**>(ptr);
+ *ptr += 1;
+
+ return res;
+}
+
+string BinaryObjectParser::ParseString(uint8_t** ptr) const {
+ int32_t length = ParseInt(ptr);
+ string res(*reinterpret_cast<char**>(ptr), length);
+ *ptr += length;
+
+ return res;
+}
+
+uint8_t* BinaryObjectParser::ParseByteArr(uint8_t** ptr, int length) const {
+ uint8_t* res = *reinterpret_cast<uint8_t**>(ptr);
+ *ptr += length;
+
+ return res;
+}
+
+int16_t* BinaryObjectParser::ParseShortArr(uint8_t** ptr, int length) const {
+ int16_t* res = *reinterpret_cast<int16_t**>(ptr);
+ byte_swapper_.SwapIfRequiredInt16Arr(res, length);
+ *ptr += length * 2;
+
+ return res;
+}
+
+uint16_t* BinaryObjectParser::ParseUnsignedShortArr(uint8_t** ptr,
+ int length) const {
+ uint16_t* res = *reinterpret_cast<uint16_t**>(ptr);
+ byte_swapper_.SwapIfRequiredUnsignedInt16Arr(res, length);
+ *ptr += length * 2;
+
+ return res;
+}
+
+int32_t* BinaryObjectParser::ParseIntArr(uint8_t** ptr, int length) const {
+ int32_t* res = *reinterpret_cast<int32_t**>(ptr);
+ byte_swapper_.SwapIfRequiredInt32Arr(res, length);
+ *ptr += length * 4;
+
+ return res;
+}
+
+int64_t* BinaryObjectParser::ParseLongArr(uint8_t** ptr, int length) const {
+ int64_t* res = *reinterpret_cast<int64_t**>(ptr);
+ byte_swapper_.SwapIfRequiredInt64Arr(res, length);
+ *ptr += length * 8;
+
+ return res;
+}
+
+float* BinaryObjectParser::ParseFloatArr(uint8_t** ptr, int length) const {
+ float* res = *reinterpret_cast<float**>(ptr);
+ byte_swapper_.SwapIfRequiredFloatArr(res, length);
+ *ptr += length * 4;
+
+ return res;
+}
+
+double* BinaryObjectParser::ParseDoubleArr(uint8_t** ptr, int length) const {
+ double* res = *reinterpret_cast<double**>(ptr);
+ byte_swapper_.SwapIfRequiredDoubleArr(res, length);
+ *ptr += length * 8;
+
+ return res;
+}
+
+bool* BinaryObjectParser::ParseBoolArr(uint8_t** ptr, int length) const {
+ bool* res = *reinterpret_cast<bool**>(ptr);
+ *ptr += length;
+
+ return res;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h b/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h
new file mode 100644
index 0000000000..eb1f856643
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h
@@ -0,0 +1,81 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BINARY_OBJECT_PARSER_H_
+#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BINARY_OBJECT_PARSER_H_
+
+#include <vector>
+#include "tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+class BinaryObjectParser {
+ public:
+ BinaryObjectParser();
+ Status Parse(uint8_t** ptr, std::vector<Tensor>* out_tensors,
+ std::vector<int32_t>* types) const;
+
+ private:
+ uint8_t ParseByte(uint8_t** ptr) const;
+ int16_t ParseShort(uint8_t** ptr) const;
+ uint16_t ParseUnsignedShort(uint8_t** ptr) const;
+ int32_t ParseInt(uint8_t** ptr) const;
+ int64_t ParseLong(uint8_t** ptr) const;
+ float ParseFloat(uint8_t** ptr) const;
+ double ParseDouble(uint8_t** ptr) const;
+ bool ParseBool(uint8_t** ptr) const;
+ string ParseString(uint8_t** ptr) const;
+ uint8_t* ParseByteArr(uint8_t** ptr, int length) const;
+ int16_t* ParseShortArr(uint8_t** ptr, int length) const;
+ uint16_t* ParseUnsignedShortArr(uint8_t** ptr, int length) const;
+ int32_t* ParseIntArr(uint8_t** ptr, int length) const;
+ int64_t* ParseLongArr(uint8_t** ptr, int length) const;
+ float* ParseFloatArr(uint8_t** ptr, int length) const;
+ double* ParseDoubleArr(uint8_t** ptr, int length) const;
+ bool* ParseBoolArr(uint8_t** ptr, int length) const;
+
+ const ByteSwapper byte_swapper_;
+};
+
+enum ObjectType {
+ BYTE = 1,
+ SHORT = 2,
+ INT = 3,
+ LONG = 4,
+ FLOAT = 5,
+ DOUBLE = 6,
+ USHORT = 7,
+ BOOL = 8,
+ STRING = 9,
+ DATE = 11,
+ BYTE_ARR = 12,
+ SHORT_ARR = 13,
+ INT_ARR = 14,
+ LONG_ARR = 15,
+ FLOAT_ARR = 16,
+ DOUBLE_ARR = 17,
+ USHORT_ARR = 18,
+ BOOL_ARR = 19,
+ STRING_ARR = 20,
+ DATE_ARR = 22,
+ WRAPPED_OBJ = 27,
+ COMPLEX_OBJ = 103
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BINARY_OBJECT_PARSER_H_
diff --git a/tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h b/tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h
new file mode 100644
index 0000000000..46df3e39dc
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h
@@ -0,0 +1,126 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BYTE_SWAPPER_H_
+#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BYTE_SWAPPER_H_
+
+#include <stdint.h>
+#include "tensorflow/core/platform/byte_order.h"
+
+namespace tensorflow {
+
+class ByteSwapper {
+ public:
+ ByteSwapper(bool big_endian) { swap_ = big_endian == port::kLittleEndian; }
+
+ inline void SwapIfRequiredInt16(int16_t *x) const {
+ if (swap_) {
+ Swap16(x);
+ }
+ }
+
+ inline void SwapIfRequiredUnsignedInt16(uint16_t *x) const {
+ if (swap_) {
+ Swap16(reinterpret_cast<int16_t *>(x));
+ }
+ }
+
+ inline void SwapIfRequiredInt32(int32_t *x) const {
+ if (swap_) {
+ Swap32(x);
+ }
+ }
+
+ inline void SwapIfRequiredFloat(float *x) const {
+ if (swap_) {
+ Swap32(reinterpret_cast<int32_t *>(x));
+ }
+ }
+
+ inline void SwapIfRequiredInt64(int64_t *x) const {
+ if (swap_) {
+ Swap64(x);
+ }
+ }
+
+ inline void SwapIfRequiredDouble(double *x) const {
+ if (swap_) {
+ Swap64(reinterpret_cast<int64_t *>(x));
+ }
+ }
+
+ inline void SwapIfRequiredInt16Arr(int16_t *x, int32_t length) const {
+ if (swap_) {
+ for (int32_t i = 0; i < length; i++) Swap16(&x[i]);
+ }
+ }
+
+ inline void SwapIfRequiredUnsignedInt16Arr(uint16_t *x,
+ int32_t length) const {
+ if (swap_) {
+ for (int32_t i = 0; i < length; i++)
+ Swap16(reinterpret_cast<int16_t *>(&x[i]));
+ }
+ }
+
+ inline void SwapIfRequiredInt32Arr(int32_t *x, int32_t length) const {
+ if (swap_) {
+ for (int32_t i = 0; i < length; i++) Swap32(&x[i]);
+ }
+ }
+
+ inline void SwapIfRequiredFloatArr(float *x, int32_t length) const {
+ if (swap_) {
+ for (int32_t i = 0; i < length; i++)
+ Swap32(reinterpret_cast<int32_t *>(&x[i]));
+ }
+ }
+
+ inline void SwapIfRequiredInt64Arr(int64_t *x, int32_t length) const {
+ if (swap_) {
+ for (int32_t i = 0; i < length; i++) Swap64(&x[i]);
+ }
+ }
+
+ inline void SwapIfRequiredDoubleArr(double *x, int32_t length) const {
+ if (swap_) {
+ for (int32_t i = 0; i < length; i++)
+ Swap64(reinterpret_cast<int64_t *>(&x[i]));
+ }
+ }
+
+ private:
+ inline void Swap16(int16_t *x) const {
+ *x = ((*x & 0xFF) << 8) | ((*x >> 8) & 0xFF);
+ }
+
+ inline void Swap32(int32_t *x) const {
+ *x = ((*x & 0xFF) << 24) | (((*x >> 8) & 0xFF) << 16) |
+ (((*x >> 16) & 0xFF) << 8) | ((*x >> 24) & 0xFF);
+ }
+
+ inline void Swap64(int64_t *x) const {
+ *x = ((*x & 0xFF) << 56) | (((*x >> 8) & 0xFF) << 48) |
+ (((*x >> 16) & 0xFF) << 40) | (((*x >> 24) & 0xFF) << 32) |
+ (((*x >> 32) & 0xFF) << 24) | (((*x >> 40) & 0xFF) << 16) |
+ (((*x >> 48) & 0xFF) << 8) | ((*x >> 56) & 0xFF);
+ }
+
+ bool swap_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BYTE_SWAPPER_H_
diff --git a/tensorflow/contrib/ignite/kernels/ignite_client.h b/tensorflow/contrib/ignite/kernels/ignite_client.h
new file mode 100644
index 0000000000..459b50b48f
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_client.h
@@ -0,0 +1,84 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_CLIENT_H_
+#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_CLIENT_H_
+
+#include "tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+class Client {
+ public:
+ Client(bool big_endian) : byte_swapper_(ByteSwapper(big_endian)) {}
+ virtual Status Connect() = 0;
+ virtual Status Disconnect() = 0;
+ virtual bool IsConnected() = 0;
+ virtual int GetSocketDescriptor() = 0;
+ virtual Status ReadData(uint8_t *buf, const int32_t length) = 0;
+ virtual Status WriteData(const uint8_t *buf, const int32_t length) = 0;
+
+ inline Status ReadByte(uint8_t *data) { return ReadData(data, 1); }
+
+ inline Status ReadShort(int16_t *data) {
+ TF_RETURN_IF_ERROR(ReadData((uint8_t *)data, 2));
+ byte_swapper_.SwapIfRequiredInt16(data);
+
+ return Status::OK();
+ }
+
+ inline Status ReadInt(int32_t *data) {
+ TF_RETURN_IF_ERROR(ReadData((uint8_t *)data, 4));
+ byte_swapper_.SwapIfRequiredInt32(data);
+
+ return Status::OK();
+ }
+
+ inline Status ReadLong(int64_t *data) {
+ TF_RETURN_IF_ERROR(ReadData((uint8_t *)data, 8));
+ byte_swapper_.SwapIfRequiredInt64(data);
+
+ return Status::OK();
+ }
+
+ inline Status WriteByte(const uint8_t data) { return WriteData(&data, 1); }
+
+ inline Status WriteShort(const int16_t data) {
+ int16_t tmp = data;
+ byte_swapper_.SwapIfRequiredInt16(&tmp);
+ return WriteData((uint8_t *)&tmp, 2);
+ }
+
+ inline Status WriteInt(const int32_t data) {
+ int32_t tmp = data;
+ byte_swapper_.SwapIfRequiredInt32(&tmp);
+ return WriteData((uint8_t *)&tmp, 4);
+ }
+
+ inline Status WriteLong(const int64_t data) {
+ int64_t tmp = data;
+ byte_swapper_.SwapIfRequiredInt64(&tmp);
+ return WriteData((uint8_t *)&tmp, 8);
+ }
+
+ private:
+ const ByteSwapper byte_swapper_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_CLIENT_H_
diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset.cc b/tensorflow/contrib/ignite/kernels/ignite_dataset.cc
new file mode 100644
index 0000000000..c4a7d3c513
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_dataset.cc
@@ -0,0 +1,81 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+IgniteDataset::IgniteDataset(OpKernelContext* ctx, string cache_name,
+ string host, int32 port, bool local, int32 part,
+ int32 page_size, string username, string password,
+ string certfile, string keyfile,
+ string cert_password, std::vector<int32> schema,
+ std::vector<int32> permutation,
+ DataTypeVector dtypes,
+ std::vector<PartialTensorShape> shapes)
+ : DatasetBase(DatasetContext(ctx)),
+ cache_name_(std::move(cache_name)),
+ host_(std::move(host)),
+ port_(port),
+ local_(local),
+ part_(part),
+ page_size_(page_size),
+ username_(std::move(username)),
+ password_(std::move(password)),
+ certfile_(std::move(certfile)),
+ keyfile_(std::move(keyfile)),
+ cert_password_(std::move(cert_password)),
+ schema_(std::move(schema)),
+ permutation_(std::move(permutation)),
+ dtypes_(dtypes),
+ shapes_(shapes) {
+ LOG(INFO) << "Ignite Dataset created [cache_name='" << cache_name_
+ << "', host='" << host_ << "', port=" << port_
+ << ", local=" << local_ << ", part=" << part_
+ << ", page_size=" << page_size_ << ", username='" << username_
+ << "', certfile='" << certfile_ << "', keyfile='"
+ << keyfile_ + "']";
+}
+
+IgniteDataset::~IgniteDataset() { LOG(INFO) << "Ignite Dataset destroyed"; }
+
+std::unique_ptr<IteratorBase> IgniteDataset::MakeIteratorInternal(
+ const string& prefix) const {
+ return std::unique_ptr<IteratorBase>(new IgniteDatasetIterator(
+ {this, strings::StrCat(prefix, "::Ignite")}, std::move(this->host_),
+ this->port_, std::move(this->cache_name_), this->local_, this->part_,
+ this->page_size_, std::move(this->username_), std::move(this->password_),
+ std::move(this->certfile_), std::move(this->keyfile_),
+ std::move(this->cert_password_), std::move(this->schema_),
+ std::move(this->permutation_)));
+}
+
+const DataTypeVector& IgniteDataset::output_dtypes() const { return dtypes_; }
+
+const std::vector<PartialTensorShape>& IgniteDataset::output_shapes() const {
+ return shapes_;
+}
+
+string IgniteDataset::DebugString() const { return "IgniteDatasetOp::Dataset"; }
+
+Status IgniteDataset::AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const {
+ return errors::Unimplemented(
+ "IgniteDataset does not support 'AsGraphDefInternal'");
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset.h b/tensorflow/contrib/ignite/kernels/ignite_dataset.h
new file mode 100644
index 0000000000..66bfdf2e2a
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_dataset.h
@@ -0,0 +1,63 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_H_
+#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_H_
+
+#include "tensorflow/core/framework/dataset.h"
+
+namespace tensorflow {
+
+class IgniteDataset : public DatasetBase {
+ public:
+ IgniteDataset(OpKernelContext* ctx, string cache_name, string host,
+ int32 port, bool local, int32 part, int32 page_size,
+ string username, string password, string certfile,
+ string keyfile, string cert_password, std::vector<int32> schema,
+ std::vector<int32> permutation, DataTypeVector dtypes,
+ std::vector<PartialTensorShape> shapes);
+ ~IgniteDataset();
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override;
+ const DataTypeVector& output_dtypes() const override;
+ const std::vector<PartialTensorShape>& output_shapes() const override;
+ string DebugString() const override;
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override;
+
+ private:
+ const string cache_name_;
+ const string host_;
+ const int32 port_;
+ const bool local_;
+ const int32 part_;
+ const int32 page_size_;
+ const string username_;
+ const string password_;
+ const string certfile_;
+ const string keyfile_;
+ const string cert_password_;
+ const std::vector<int32> schema_;
+ const std::vector<int32> permutation_;
+ const DataTypeVector dtypes_;
+ const std::vector<PartialTensorShape> shapes_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_H_
diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc b/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc
new file mode 100644
index 0000000000..5da9127aa6
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc
@@ -0,0 +1,422 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h"
+
+#include "tensorflow/contrib/ignite/kernels/ignite_plain_client.h"
+#include "tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+IgniteDatasetIterator::IgniteDatasetIterator(
+ const Params& params, string host, int32 port, string cache_name,
+ bool local, int32 part, int32 page_size, string username, string password,
+ string certfile, string keyfile, string cert_password,
+ std::vector<int32> schema, std::vector<int32> permutation)
+ : DatasetIterator<IgniteDataset>(params),
+ cache_name_(std::move(cache_name)),
+ local_(local),
+ part_(part),
+ page_size_(page_size),
+ username_(std::move(username)),
+ password_(std::move(password)),
+ schema_(std::move(schema)),
+ permutation_(std::move(permutation)),
+ remainder_(-1),
+ cursor_id_(-1),
+ last_page_(false),
+ valid_state_(true) {
+ Client* p_client = new PlainClient(std::move(host), port, false);
+
+ if (certfile.empty())
+ client_ = std::unique_ptr<Client>(p_client);
+ else
+ client_ = std::unique_ptr<Client>(
+ new SslWrapper(std::unique_ptr<Client>(p_client), std::move(certfile),
+ std::move(keyfile), std::move(cert_password), false));
+
+ LOG(INFO) << "Ignite Dataset Iterator created";
+}
+
+IgniteDatasetIterator::~IgniteDatasetIterator() {
+ Status status = CloseConnection();
+ if (!status.ok()) LOG(ERROR) << status.ToString();
+
+ LOG(INFO) << "Ignite Dataset Iterator destroyed";
+}
+
+Status IgniteDatasetIterator::GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) {
+ mutex_lock l(mutex_);
+
+ if (valid_state_) {
+ Status status =
+ GetNextInternalWithValidState(ctx, out_tensors, end_of_sequence);
+
+ if (!status.ok()) valid_state_ = false;
+
+ return status;
+ }
+
+ return errors::Unknown("Iterator is invalid");
+}
+
+Status IgniteDatasetIterator::SaveInternal(IteratorStateWriter* writer) {
+ return errors::Unimplemented(
+ "Iterator for IgniteDataset does not support 'SaveInternal'");
+}
+
+Status IgniteDatasetIterator::RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) {
+ return errors::Unimplemented(
+ "Iterator for IgniteDataset does not support 'RestoreInternal')");
+}
+
+Status IgniteDatasetIterator::GetNextInternalWithValidState(
+ IteratorContext* ctx, std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) {
+ if (remainder_ == 0 && last_page_) {
+ cursor_id_ = -1;
+ *end_of_sequence = true;
+
+ return Status::OK();
+ } else {
+ TF_RETURN_IF_ERROR(EstablishConnection());
+
+ if (remainder_ == -1) {
+ TF_RETURN_IF_ERROR(ScanQuery());
+ } else if (remainder_ == 0) {
+ TF_RETURN_IF_ERROR(LoadNextPage());
+ }
+
+ uint8_t* initial_ptr = ptr_;
+ std::vector<Tensor> tensors;
+ std::vector<int32_t> types;
+
+ TF_RETURN_IF_ERROR(parser_.Parse(&ptr_, &tensors, &types)); // Parse key
+ TF_RETURN_IF_ERROR(parser_.Parse(&ptr_, &tensors, &types)); // Parse val
+
+ remainder_ -= (ptr_ - initial_ptr);
+
+ TF_RETURN_IF_ERROR(CheckTypes(types));
+
+ for (size_t i = 0; i < tensors.size(); i++)
+ out_tensors->push_back(tensors[permutation_[i]]);
+
+ *end_of_sequence = false;
+
+ return Status::OK();
+ }
+
+ *end_of_sequence = true;
+
+ return Status::OK();
+}
+
+Status IgniteDatasetIterator::EstablishConnection() {
+ if (!client_->IsConnected()) {
+ TF_RETURN_IF_ERROR(client_->Connect());
+
+ Status status = Handshake();
+ if (!status.ok()) {
+ Status disconnect_status = client_->Disconnect();
+ if (!disconnect_status.ok()) LOG(ERROR) << disconnect_status.ToString();
+
+ return status;
+ }
+ }
+
+ return Status::OK();
+}
+
+Status IgniteDatasetIterator::CloseConnection() {
+ if (cursor_id_ != -1 && !last_page_) {
+ TF_RETURN_IF_ERROR(EstablishConnection());
+
+ TF_RETURN_IF_ERROR(client_->WriteInt(kCloseConnectionReqLength));
+ TF_RETURN_IF_ERROR(client_->WriteShort(kCloseConnectionOpcode));
+ TF_RETURN_IF_ERROR(client_->WriteLong(0)); // Request ID
+ TF_RETURN_IF_ERROR(client_->WriteLong(cursor_id_)); // Resource ID
+
+ int32_t res_len;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&res_len));
+ if (res_len < kMinResLength)
+ return errors::Unknown("Close Resource Response is corrupted");
+
+ int64_t req_id;
+ TF_RETURN_IF_ERROR(client_->ReadLong(&req_id));
+ int32_t status;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&status));
+ if (status != 0) {
+ uint8_t err_msg_header;
+ TF_RETURN_IF_ERROR(client_->ReadByte(&err_msg_header));
+ if (err_msg_header == kStringVal) {
+ int32_t err_msg_length;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&err_msg_length));
+
+ uint8_t* err_msg_c = new uint8_t[err_msg_length];
+ auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; });
+ TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, err_msg_length));
+ string err_msg(reinterpret_cast<char*>(err_msg_c), err_msg_length);
+
+ return errors::Unknown("Close Resource Error [status=", status,
+ ", message=", err_msg, "]");
+ }
+ return errors::Unknown("Close Resource Error [status=", status, "]");
+ }
+
+ cursor_id_ = -1;
+
+ return client_->Disconnect();
+ } else {
+ LOG(INFO) << "Query Cursor " << cursor_id_ << " is already closed";
+ }
+
+ return client_->IsConnected() ? client_->Disconnect() : Status::OK();
+}
+
+Status IgniteDatasetIterator::Handshake() {
+ int32_t msg_len = kHandshakeReqDefaultLength;
+
+ if (username_.empty())
+ msg_len += 1;
+ else
+ msg_len += 5 + username_.length(); // 1 byte header, 4 bytes length.
+
+ if (password_.empty())
+ msg_len += 1;
+ else
+ msg_len += 5 + password_.length(); // 1 byte header, 4 bytes length.
+
+ TF_RETURN_IF_ERROR(client_->WriteInt(msg_len));
+ TF_RETURN_IF_ERROR(client_->WriteByte(1));
+ TF_RETURN_IF_ERROR(client_->WriteShort(kProtocolMajorVersion));
+ TF_RETURN_IF_ERROR(client_->WriteShort(kProtocolMinorVersion));
+ TF_RETURN_IF_ERROR(client_->WriteShort(kProtocolPatchVersion));
+ TF_RETURN_IF_ERROR(client_->WriteByte(2));
+ if (username_.empty()) {
+ TF_RETURN_IF_ERROR(client_->WriteByte(kNullVal));
+ } else {
+ TF_RETURN_IF_ERROR(client_->WriteByte(kStringVal));
+ TF_RETURN_IF_ERROR(client_->WriteInt(username_.length()));
+ TF_RETURN_IF_ERROR(
+ client_->WriteData(reinterpret_cast<const uint8_t*>(username_.c_str()),
+ username_.length()));
+ }
+
+ if (password_.empty()) {
+ TF_RETURN_IF_ERROR(client_->WriteByte(kNullVal));
+ } else {
+ TF_RETURN_IF_ERROR(client_->WriteByte(kStringVal));
+ TF_RETURN_IF_ERROR(client_->WriteInt(password_.length()));
+ TF_RETURN_IF_ERROR(
+ client_->WriteData(reinterpret_cast<const uint8_t*>(password_.c_str()),
+ password_.length()));
+ }
+
+ int32_t handshake_res_len;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&handshake_res_len));
+ uint8_t handshake_res;
+ TF_RETURN_IF_ERROR(client_->ReadByte(&handshake_res));
+
+ if (handshake_res != 1) {
+ int16_t serv_ver_major;
+ TF_RETURN_IF_ERROR(client_->ReadShort(&serv_ver_major));
+ int16_t serv_ver_minor;
+ TF_RETURN_IF_ERROR(client_->ReadShort(&serv_ver_minor));
+ int16_t serv_ver_patch;
+ TF_RETURN_IF_ERROR(client_->ReadShort(&serv_ver_patch));
+ uint8_t header;
+ TF_RETURN_IF_ERROR(client_->ReadByte(&header));
+
+ if (header == kStringVal) {
+ int32_t length;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&length));
+
+ uint8_t* err_msg_c = new uint8_t[length];
+ auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; });
+ TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, length));
+ string err_msg(reinterpret_cast<char*>(err_msg_c), length);
+
+ return errors::Unknown("Handshake Error [result=", handshake_res,
+ ", version=", serv_ver_major, ".", serv_ver_minor,
+ ".", serv_ver_patch, ", message='", err_msg, "']");
+ } else if (header == kNullVal) {
+ return errors::Unknown("Handshake Error [result=", handshake_res,
+ ", version=", serv_ver_major, ".", serv_ver_minor,
+ ".", serv_ver_patch, "]");
+ } else {
+ return errors::Unknown("Handshake Error [result=", handshake_res,
+ ", version=", serv_ver_major, ".", serv_ver_minor,
+ ".", serv_ver_patch, "]");
+ }
+ }
+
+ return Status::OK();
+}
+
+Status IgniteDatasetIterator::ScanQuery() {
+ TF_RETURN_IF_ERROR(client_->WriteInt(kScanQueryReqLength));
+ TF_RETURN_IF_ERROR(client_->WriteShort(kScanQueryOpcode));
+ TF_RETURN_IF_ERROR(client_->WriteLong(0)); // Request ID
+ TF_RETURN_IF_ERROR(
+ client_->WriteInt(JavaHashCode(cache_name_))); // Cache name
+ TF_RETURN_IF_ERROR(client_->WriteByte(0)); // Flags
+ TF_RETURN_IF_ERROR(client_->WriteByte(kNullVal)); // Filter object
+ TF_RETURN_IF_ERROR(client_->WriteInt(page_size_)); // Cursor page size
+ TF_RETURN_IF_ERROR(client_->WriteInt(part_)); // part_ition to query
+ TF_RETURN_IF_ERROR(client_->WriteByte(local_)); // local_ flag
+
+ uint64 wait_start = Env::Default()->NowMicros();
+ int32_t res_len;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&res_len));
+ int64_t wait_stop = Env::Default()->NowMicros();
+
+ LOG(INFO) << "Scan Query waited " << (wait_stop - wait_start) / 1000 << " ms";
+
+ if (res_len < kMinResLength)
+ return errors::Unknown("Scan Query Response is corrupted");
+
+ int64_t req_id;
+ TF_RETURN_IF_ERROR(client_->ReadLong(&req_id));
+
+ int32_t status;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&status));
+
+ if (status != 0) {
+ uint8_t err_msg_header;
+ TF_RETURN_IF_ERROR(client_->ReadByte(&err_msg_header));
+
+ if (err_msg_header == kStringVal) {
+ int32_t err_msg_length;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&err_msg_length));
+
+ uint8_t* err_msg_c = new uint8_t[err_msg_length];
+ auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; });
+ TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, err_msg_length));
+ string err_msg(reinterpret_cast<char*>(err_msg_c), err_msg_length);
+
+ return errors::Unknown("Scan Query Error [status=", status,
+ ", message=", err_msg, "]");
+ }
+ return errors::Unknown("Scan Query Error [status=", status, "]");
+ }
+
+ TF_RETURN_IF_ERROR(client_->ReadLong(&cursor_id_));
+
+ int32_t row_cnt;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&row_cnt));
+
+ int32_t page_size = res_len - kScanQueryResHeaderLength;
+
+ return ReceivePage(page_size);
+}
+
+Status IgniteDatasetIterator::LoadNextPage() {
+ TF_RETURN_IF_ERROR(client_->WriteInt(kLoadNextPageReqLength));
+ TF_RETURN_IF_ERROR(client_->WriteShort(kLoadNextPageOpcode));
+ TF_RETURN_IF_ERROR(client_->WriteLong(0)); // Request ID
+ TF_RETURN_IF_ERROR(client_->WriteLong(cursor_id_)); // Cursor ID
+
+ uint64 wait_start = Env::Default()->NowMicros();
+ int32_t res_len;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&res_len));
+ uint64 wait_stop = Env::Default()->NowMicros();
+
+ LOG(INFO) << "Load Next Page waited " << (wait_stop - wait_start) / 1000
+ << " ms";
+
+ if (res_len < kMinResLength)
+ return errors::Unknown("Load Next Page Response is corrupted");
+
+ int64_t req_id;
+ TF_RETURN_IF_ERROR(client_->ReadLong(&req_id));
+
+ int32_t status;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&status));
+
+ if (status != 0) {
+ uint8_t err_msg_header;
+ TF_RETURN_IF_ERROR(client_->ReadByte(&err_msg_header));
+
+ if (err_msg_header == kStringVal) {
+ int32_t err_msg_length;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&err_msg_length));
+
+ uint8_t* err_msg_c = new uint8_t[err_msg_length];
+ auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; });
+ TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, err_msg_length));
+ string err_msg(reinterpret_cast<char*>(err_msg_c), err_msg_length);
+
+ return errors::Unknown("Load Next Page Error [status=", status,
+ ", message=", err_msg, "]");
+ }
+ return errors::Unknown("Load Next Page Error [status=", status, "]");
+ }
+
+ int32_t row_cnt;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&row_cnt));
+
+ int32_t page_size = res_len - kLoadNextPageResHeaderLength;
+
+ return ReceivePage(page_size);
+}
+
+Status IgniteDatasetIterator::ReceivePage(int32_t page_size) {
+ remainder_ = page_size;
+ page_ = std::unique_ptr<uint8_t>(new uint8_t[remainder_]);
+ ptr_ = page_.get();
+
+ uint64 start = Env::Default()->NowMicros();
+ TF_RETURN_IF_ERROR(client_->ReadData(ptr_, remainder_));
+ uint64 stop = Env::Default()->NowMicros();
+
+ double size_in_mb = 1.0 * remainder_ / 1024 / 1024;
+ double time_in_s = 1.0 * (stop - start) / 1000 / 1000;
+ LOG(INFO) << "Page size " << size_in_mb << " Mb, time " << time_in_s * 1000
+ << " ms download speed " << size_in_mb / time_in_s << " Mb/sec";
+
+ uint8_t last_page_b;
+ TF_RETURN_IF_ERROR(client_->ReadByte(&last_page_b));
+
+ last_page_ = !last_page_b;
+
+ return Status::OK();
+}
+
+Status IgniteDatasetIterator::CheckTypes(const std::vector<int32_t>& types) {
+ if (schema_.size() != types.size())
+ return errors::Unknown("Object has unexpected schema");
+
+ for (size_t i = 0; i < schema_.size(); i++) {
+ if (schema_[i] != types[permutation_[i]])
+ return errors::Unknown("Object has unexpected schema");
+ }
+
+ return Status::OK();
+}
+
+int32_t IgniteDatasetIterator::JavaHashCode(string str) const {
+ int32_t h = 0;
+ for (char& c : str) {
+ h = 31 * h + c;
+ }
+ return h;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h b/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h
new file mode 100644
index 0000000000..c499e2c9cc
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h
@@ -0,0 +1,99 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_ITERATOR_H_
+#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_ITERATOR_H_
+
+#include "tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h"
+#include "tensorflow/contrib/ignite/kernels/ignite_client.h"
+#include "tensorflow/contrib/ignite/kernels/ignite_dataset.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+class IgniteDatasetIterator : public DatasetIterator<IgniteDataset> {
+ public:
+ IgniteDatasetIterator(const Params& params, string host, int32 port,
+ string cache_name, bool local, int32 part,
+ int32 page_size, string username, string password,
+ string certfile, string keyfile, string cert_password,
+ std::vector<int32> schema,
+ std::vector<int32> permutation);
+ ~IgniteDatasetIterator();
+ Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override;
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override;
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override;
+
+ private:
+ Status GetNextInternalWithValidState(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence);
+
+ Status EstablishConnection();
+ Status CloseConnection();
+ Status Handshake();
+ Status ScanQuery();
+ Status LoadNextPage();
+ Status ReceivePage(int32_t page_size);
+ Status CheckTypes(const std::vector<int32_t>& types);
+ int32_t JavaHashCode(string str) const;
+
+ std::unique_ptr<Client> client_;
+ BinaryObjectParser parser_;
+
+ const string cache_name_;
+ const bool local_;
+ const int32 part_;
+ const int32 page_size_;
+ const string username_;
+ const string password_;
+ const std::vector<int32> schema_;
+ const std::vector<int32> permutation_;
+
+ int32_t remainder_;
+ int64_t cursor_id_;
+ bool last_page_;
+
+ bool valid_state_;
+
+ mutex mutex_;
+
+ std::unique_ptr<uint8_t> page_;
+ uint8_t* ptr_;
+};
+
+constexpr uint8_t kNullVal = 101;
+constexpr uint8_t kStringVal = 9;
+constexpr uint8_t kProtocolMajorVersion = 1;
+constexpr uint8_t kProtocolMinorVersion = 1;
+constexpr uint8_t kProtocolPatchVersion = 0;
+constexpr int16_t kScanQueryOpcode = 2000;
+constexpr int16_t kLoadNextPageOpcode = 2001;
+constexpr int16_t kCloseConnectionOpcode = 0;
+constexpr int32_t kScanQueryReqLength = 25;
+constexpr int32_t kScanQueryResHeaderLength = 25;
+constexpr int32_t kLoadNextPageReqLength = 18;
+constexpr int32_t kLoadNextPageResHeaderLength = 17;
+constexpr int32_t kCloseConnectionReqLength = 18;
+constexpr int32_t kHandshakeReqDefaultLength = 8;
+constexpr int32_t kMinResLength = 12;
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_ITERATOR_H_
diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc b/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc
new file mode 100644
index 0000000000..f75b1c5ff5
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc
@@ -0,0 +1,198 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <stdlib.h>
+
+#include "tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h"
+#include "tensorflow/contrib/ignite/kernels/ignite_dataset.h"
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+
+namespace tensorflow {
+namespace {
+
+Status SchemaToTypes(const std::vector<int32>& schema, DataTypeVector* dtypes) {
+ for (auto e : schema) {
+ if (e == BYTE || e == BYTE_ARR) {
+ dtypes->push_back(DT_UINT8);
+ } else if (e == SHORT || e == SHORT_ARR) {
+ dtypes->push_back(DT_INT16);
+ } else if (e == INT || e == INT_ARR) {
+ dtypes->push_back(DT_INT32);
+ } else if (e == LONG || e == LONG_ARR) {
+ dtypes->push_back(DT_INT64);
+ } else if (e == FLOAT || e == FLOAT_ARR) {
+ dtypes->push_back(DT_FLOAT);
+ } else if (e == DOUBLE || e == DOUBLE_ARR) {
+ dtypes->push_back(DT_DOUBLE);
+ } else if (e == USHORT || e == USHORT_ARR) {
+ dtypes->push_back(DT_UINT8);
+ } else if (e == BOOL || e == BOOL_ARR) {
+ dtypes->push_back(DT_BOOL);
+ } else if (e == STRING || e == STRING_ARR) {
+ dtypes->push_back(DT_STRING);
+ } else {
+ return errors::Unknown("Unexpected type in schema [type_id=", e, "]");
+ }
+ }
+
+ return Status::OK();
+}
+
+Status SchemaToShapes(const std::vector<int32>& schema,
+ std::vector<PartialTensorShape>* shapes) {
+ for (auto e : schema) {
+ if (e >= 1 && e < 10) {
+ shapes->push_back(PartialTensorShape({}));
+ } else if (e >= 12 && e < 21) {
+ shapes->push_back(PartialTensorShape({-1}));
+ } else {
+ return errors::Unknown("Unexpected type in schema [type_id=", e, "]");
+ }
+ }
+
+ return Status::OK();
+}
+
+class IgniteDatasetOp : public DatasetOpKernel {
+ public:
+ using DatasetOpKernel::DatasetOpKernel;
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ string cache_name = "";
+ string host = "";
+ int32 port = -1;
+ bool local = false;
+ int32 part = -1;
+ int32 page_size = -1;
+ string username = "";
+ string password = "";
+ string certfile = "";
+ string keyfile = "";
+ string cert_password = "";
+
+ const char* env_cache_name = std::getenv("IGNITE_DATASET_CACHE_NAME");
+ const char* env_host = std::getenv("IGNITE_DATASET_HOST");
+ const char* env_port = std::getenv("IGNITE_DATASET_PORT");
+ const char* env_local = std::getenv("IGNITE_DATASET_LOCAL");
+ const char* env_part = std::getenv("IGNITE_DATASET_PART");
+ const char* env_page_size = std::getenv("IGNITE_DATASET_PAGE_SIZE");
+ const char* env_username = std::getenv("IGNITE_DATASET_USERNAME");
+ const char* env_password = std::getenv("IGNITE_DATASET_PASSWORD");
+ const char* env_certfile = std::getenv("IGNITE_DATASET_CERTFILE");
+ const char* env_keyfile = std::getenv("IGNITE_DATASET_KEYFILE");
+ const char* env_cert_password = std::getenv("IGNITE_DATASET_CERT_PASSWORD");
+
+ if (env_cache_name) {
+ cache_name = string(env_cache_name);
+ } else {
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<string>(ctx, "cache_name", &cache_name));
+ }
+
+ if (env_host) {
+ host = string(env_host);
+ } else {
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "host", &host));
+ }
+
+ if (env_port) {
+ OP_REQUIRES(ctx, strings::safe_strto32(env_port, &port),
+ errors::InvalidArgument("IGNITE_DATASET_PORT environment "
+ "variable is not a valid integer: ",
+ env_port));
+ } else {
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int32>(ctx, "port", &port));
+ }
+
+ if (env_local) {
+ local = true;
+ } else {
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "local", &local));
+ }
+
+ if (env_part) {
+ OP_REQUIRES(ctx, strings::safe_strto32(env_part, &part),
+ errors::InvalidArgument("IGNITE_DATASET_PART environment "
+ "variable is not a valid integer: ",
+ env_part));
+ } else {
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int32>(ctx, "part", &part));
+ }
+
+ if (env_page_size) {
+ OP_REQUIRES(ctx, strings::safe_strto32(env_page_size, &page_size),
+ errors::InvalidArgument("IGNITE_DATASET_PAGE_SIZE "
+ "environment variable is not a valid "
+ "integer: ",
+ env_page_size));
+ } else {
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<int32>(ctx, "page_size", &page_size));
+ }
+
+ if (env_username) username = string(env_username);
+
+ if (env_password) password = string(env_password);
+
+ if (env_certfile) certfile = string(env_certfile);
+
+ if (env_keyfile) keyfile = string(env_keyfile);
+
+ if (env_cert_password) cert_password = string(env_cert_password);
+
+ const Tensor* schema_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("schema", &schema_tensor));
+ OP_REQUIRES(ctx, schema_tensor->dims() == 1,
+ errors::InvalidArgument("`schema` must be a vector."));
+
+ std::vector<int32> schema;
+ schema.reserve(schema_tensor->NumElements());
+ for (int i = 0; i < schema_tensor->NumElements(); i++) {
+ schema.push_back(schema_tensor->flat<int32>()(i));
+ }
+
+ const Tensor* permutation_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("permutation", &permutation_tensor));
+ OP_REQUIRES(ctx, permutation_tensor->dims() == 1,
+ errors::InvalidArgument("`permutation` must be a vector."));
+
+ std::vector<int32> permutation;
+ permutation.resize(permutation_tensor->NumElements());
+ for (int i = 0; i < permutation_tensor->NumElements(); i++) {
+ // Inversed permutation.
+ permutation[permutation_tensor->flat<int32>()(i)] = i;
+ }
+
+ DataTypeVector dtypes;
+ std::vector<PartialTensorShape> shapes;
+
+ OP_REQUIRES_OK(ctx, SchemaToTypes(schema, &dtypes));
+ OP_REQUIRES_OK(ctx, SchemaToShapes(schema, &shapes));
+
+ *output = new IgniteDataset(
+ ctx, std::move(cache_name), std::move(host), port, local, part,
+ page_size, std::move(username), std::move(password),
+ std::move(certfile), std::move(keyfile), std::move(cert_password),
+ std::move(schema), std::move(permutation), std::move(dtypes),
+ std::move(shapes));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("IgniteDataset").Device(DEVICE_CPU),
+ IgniteDatasetOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/ignite/kernels/ignite_plain_client.h b/tensorflow/contrib/ignite/kernels/ignite_plain_client.h
new file mode 100644
index 0000000000..75424c19ee
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_plain_client.h
@@ -0,0 +1,43 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_PLAIN_CLIENT_H_
+#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_PLAIN_CLIENT_H_
+
+#include "tensorflow/contrib/ignite/kernels/ignite_client.h"
+
+namespace tensorflow {
+
+class PlainClient : public Client {
+ public:
+ PlainClient(string host, int port, bool big_endian);
+ ~PlainClient();
+
+ Status Connect() override;
+ Status Disconnect() override;
+ bool IsConnected() override;
+ int GetSocketDescriptor() override;
+ Status ReadData(uint8_t* buf, const int32_t length) override;
+ Status WriteData(const uint8_t* buf, const int32_t length) override;
+
+ private:
+ const string host_;
+ const int port_;
+ int sock_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_PLAIN_CLIENT_H_
diff --git a/tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc b/tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc
new file mode 100644
index 0000000000..cf672942c6
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc
@@ -0,0 +1,123 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/ignite/kernels/ignite_plain_client.h"
+
+#include <arpa/inet.h>
+#include <netdb.h>
+#include <sys/socket.h>
+#include <unistd.h>
+
+#include <iostream>
+#include <map>
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+PlainClient::PlainClient(string host, int port, bool big_endian)
+ : Client(big_endian), host_(std::move(host)), port_(port), sock_(-1) {}
+
+PlainClient::~PlainClient() {
+ if (IsConnected()) {
+ Status status = Disconnect();
+ if (!status.ok()) LOG(WARNING) << status.ToString();
+ }
+}
+
+Status PlainClient::Connect() {
+ if (sock_ == -1) {
+ sock_ = socket(AF_INET, SOCK_STREAM, 0);
+ if (sock_ == -1) return errors::Internal("Failed to create socket");
+ }
+
+ sockaddr_in server;
+
+ server.sin_addr.s_addr = inet_addr(host_.c_str());
+ if (server.sin_addr.s_addr == -1) {
+ hostent* he;
+ in_addr** addr_list;
+
+ if ((he = gethostbyname(host_.c_str())) == NULL)
+ return errors::Internal("Failed to resolve hostname \"", host_, "\"");
+
+ addr_list = (in_addr**)he->h_addr_list;
+ if (addr_list[0] != NULL) server.sin_addr = *addr_list[0];
+ }
+
+ server.sin_family = AF_INET;
+ server.sin_port = htons(port_);
+
+ if (connect(sock_, (sockaddr*)&server, sizeof(server)) < 0)
+ return errors::Internal("Failed to connect to \"", host_, ":", port_, "\"");
+
+ LOG(INFO) << "Connection to \"" << host_ << ":" << port_ << "\" established";
+
+ return Status::OK();
+}
+
+Status PlainClient::Disconnect() {
+ int close_res = close(sock_);
+ sock_ = -1;
+
+ LOG(INFO) << "Connection to \"" << host_ << ":" << port_ << "\" is closed";
+
+ return close_res == 0
+ ? Status::OK()
+ : errors::Internal("Failed to correctly close connection");
+}
+
+bool PlainClient::IsConnected() { return sock_ != -1; }
+
+int PlainClient::GetSocketDescriptor() { return sock_; }
+
+Status PlainClient::ReadData(uint8_t* buf, const int32_t length) {
+ int received = 0;
+
+ while (received < length) {
+ int res = recv(sock_, buf, length - received, 0);
+
+ if (res < 0)
+ return errors::Internal("Error occurred while reading from socket: ", res,
+ ", ", string(strerror(errno)));
+
+ if (res == 0) return errors::Internal("Server closed connection");
+
+ received += res;
+ buf += res;
+ }
+
+ return Status::OK();
+}
+
+Status PlainClient::WriteData(const uint8_t* buf, const int32_t length) {
+ int sent = 0;
+
+ while (sent < length) {
+ int res = send(sock_, buf, length - sent, 0);
+
+ if (res < 0)
+ return errors::Internal("Error occurred while writing into socket: ", res,
+ ", ", string(strerror(errno)));
+
+ sent += res;
+ buf += res;
+ }
+
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc b/tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc
new file mode 100644
index 0000000000..dad5aace5f
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc
@@ -0,0 +1,142 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/ignite/kernels/ignite_plain_client.h"
+
+#define WIN32_LEAN_AND_MEAN
+#include <windows.h>
+#include <winsock2.h>
+#include <ws2tcpip.h>
+
+#pragma comment(lib, "Ws2_32.lib")
+#pragma comment(lib, "Mswsock.lib")
+#pragma comment(lib, "AdvApi32.lib")
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+PlainClient::PlainClient(string host, int port, bool big_endian)
+ : Client(big_endian),
+ host_(std::move(host)),
+ port_(port),
+ sock_(INVALID_SOCKET) {}
+
+PlainClient::~PlainClient() {
+ if (IsConnected()) {
+ Status status = Disconnect();
+ if (!status.ok()) LOG(WARNING) << status.ToString();
+ }
+}
+
+Status PlainClient::Connect() {
+ WSADATA wsaData;
+ addrinfo *result = NULL, *ptr = NULL, hints;
+
+ int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
+ if (res != 0) return errors::Internal("WSAStartup failed with error: ", res);
+
+ ZeroMemory(&hints, sizeof(hints));
+ hints.ai_family = AF_UNSPEC;
+ hints.ai_socktype = SOCK_STREAM;
+ hints.ai_protocol = IPPROTO_TCP;
+
+ res = getaddrinfo(host_.c_str(), std::to_string(port_).c_str(), &hints,
+ &result);
+ if (res != 0) return errors::Internal("Getaddrinfo failed with error: ", res);
+
+ auto clean = gtl::MakeCleanup([result] { freeaddrinfo(result); });
+
+ for (ptr = result; ptr != NULL; ptr = ptr->ai_next) {
+ sock_ = socket(ptr->ai_family, ptr->ai_socktype, ptr->ai_protocol);
+ if (sock_ == INVALID_SOCKET) {
+ WSACleanup();
+ return errors::Internal("Socket failed with error: ", WSAGetLastError());
+ }
+
+ res = connect(sock_, ptr->ai_addr, (int)ptr->ai_addrlen);
+ if (res == SOCKET_ERROR) {
+ closesocket(sock_);
+ sock_ = INVALID_SOCKET;
+ continue;
+ }
+
+ break;
+ }
+
+ if (sock_ == INVALID_SOCKET) {
+ WSACleanup();
+ return errors::Internal("Unable to connect to server");
+ }
+
+ LOG(INFO) << "Connection to \"" << host_ << ":" << port_ << "\" established";
+
+ return Status::OK();
+}
+
+Status PlainClient::Disconnect() {
+ int res = shutdown(sock_, SD_SEND);
+ closesocket(sock_);
+ WSACleanup();
+
+ if (res == SOCKET_ERROR)
+ return errors::Internal("Shutdown failed with error: ", WSAGetLastError());
+ else
+ return Status::OK();
+}
+
+bool PlainClient::IsConnected() { return sock_ != INVALID_SOCKET; }
+
+int PlainClient::GetSocketDescriptor() { return sock_; }
+
+Status PlainClient::ReadData(uint8_t *buf, const int32_t length) {
+ int received = 0;
+
+ while (received < length) {
+ int res = recv(sock_, (char *)buf, length - received, 0);
+
+ if (res < 0)
+ return errors::Internal("Error occurred while reading from socket: ",
+ res);
+
+ if (res == 0) return errors::Internal("Server closed connection");
+
+ received += res;
+ buf += res;
+ }
+
+ return Status::OK();
+}
+
+Status PlainClient::WriteData(const uint8_t *buf, const int32_t length) {
+ int sent = 0;
+
+ while (sent < length) {
+ int res = send(sock_, (char *)buf, length - sent, 0);
+
+ if (res < 0)
+ return errors::Internal("Error occurred while writing into socket: ",
+ res);
+
+ sent += res;
+ buf += res;
+ }
+
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc b/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc
new file mode 100644
index 0000000000..ceb479b084
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc
@@ -0,0 +1,151 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h"
+
+#include <openssl/err.h>
+#include <openssl/ssl.h>
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+static int PasswordCb(char *buf, int size, int rwflag, void *password) {
+ strncpy(buf, (char *)(password), size);
+ buf[size - 1] = '\0';
+ return (strlen(buf));
+}
+
+SslWrapper::SslWrapper(std::shared_ptr<Client> client, string certfile,
+ string keyfile, string cert_password, bool big_endian)
+ : Client(big_endian),
+ client_(client),
+ certfile_(std::move(certfile)),
+ keyfile_(std::move(keyfile)),
+ cert_password_(std::move(cert_password)),
+ ctx_(nullptr),
+ ssl_(nullptr) {}
+
+SslWrapper::~SslWrapper() {
+ if (IsConnected()) {
+ Status status = Disconnect();
+ if (!status.ok()) LOG(WARNING) << status.ToString();
+ }
+
+ if (ctx_ != nullptr) {
+ SSL_CTX_free(ctx_);
+ ctx_ = nullptr;
+ }
+
+ if (ssl_ != nullptr) {
+ SSL_free(ssl_);
+ ssl_ = nullptr;
+ }
+}
+
+Status SslWrapper::InitSslContext() {
+ OpenSSL_add_all_algorithms();
+ SSL_load_error_strings();
+
+ ctx_ = SSL_CTX_new(SSLv23_method());
+ if (ctx_ == NULL) return errors::Internal("Couldn't create SSL context");
+
+ SSL_CTX_set_default_passwd_cb(ctx_, PasswordCb);
+ SSL_CTX_set_default_passwd_cb_userdata(ctx_, (void *)cert_password_.c_str());
+
+ if (SSL_CTX_use_certificate_chain_file(ctx_, certfile_.c_str()) != 1)
+ return errors::Internal("Couldn't load cetificate chain (file '", certfile_,
+ "')");
+
+ string private_key_file = keyfile_.empty() ? certfile_ : keyfile_;
+ if (SSL_CTX_use_PrivateKey_file(ctx_, private_key_file.c_str(),
+ SSL_FILETYPE_PEM) != 1)
+ return errors::Internal("Couldn't load private key (file '",
+ private_key_file, "')");
+
+ return Status::OK();
+}
+
+Status SslWrapper::Connect() {
+ if (ctx_ == NULL) {
+ TF_RETURN_IF_ERROR(InitSslContext());
+ }
+
+ ssl_ = SSL_new(ctx_);
+ if (ssl_ == NULL)
+ return errors::Internal("Failed to establish SSL connection");
+
+ TF_RETURN_IF_ERROR(client_->Connect());
+
+ SSL_set_fd(ssl_, client_->GetSocketDescriptor());
+ if (SSL_connect(ssl_) != 1)
+ return errors::Internal("Failed to establish SSL connection");
+
+ LOG(INFO) << "SSL connection established";
+
+ return Status::OK();
+}
+
+Status SslWrapper::Disconnect() {
+ SSL_free(ssl_);
+ ssl_ = nullptr;
+
+ LOG(INFO) << "SSL connection closed";
+
+ return client_->Disconnect();
+}
+
+bool SslWrapper::IsConnected() { return client_->IsConnected(); }
+
+int SslWrapper::GetSocketDescriptor() { return client_->GetSocketDescriptor(); }
+
+Status SslWrapper::ReadData(uint8_t *buf, const int32_t length) {
+ int received = 0;
+
+ while (received < length) {
+ int res = SSL_read(ssl_, buf, length - received);
+
+ if (res < 0)
+ return errors::Internal("Error occurred while reading from SSL socket: ",
+ res);
+
+ if (res == 0) return errors::Internal("Server closed SSL connection");
+
+ received += res;
+ buf += res;
+ }
+
+ return Status::OK();
+}
+
+Status SslWrapper::WriteData(const uint8_t *buf, const int32_t length) {
+ int sent = 0;
+
+ while (sent < length) {
+ int res = SSL_write(ssl_, buf, length - sent);
+
+ if (res < 0)
+ return errors::Internal("Error occurred while writing into socket: ",
+ res);
+
+ sent += res;
+ buf += res;
+ }
+
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h b/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h
new file mode 100644
index 0000000000..0406644bba
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h
@@ -0,0 +1,51 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_SSL_WRAPPER_H_
+#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_SSL_WRAPPER_H_
+
+#include "tensorflow/contrib/ignite/kernels/ignite_client.h"
+
+#include <openssl/ssl.h>
+
+namespace tensorflow {
+
+class SslWrapper : public Client {
+ public:
+ SslWrapper(std::shared_ptr<Client> client, string certfile, string keyfile,
+ string cert_password, bool big_endian);
+ ~SslWrapper();
+
+ Status Connect() override;
+ Status Disconnect() override;
+ bool IsConnected() override;
+ int GetSocketDescriptor() override;
+ Status ReadData(uint8_t* buf, const int32_t length) override;
+ Status WriteData(const uint8_t* buf, const int32_t length) override;
+
+ private:
+ Status InitSslContext();
+
+ std::shared_ptr<Client> client_;
+ string certfile_;
+ string keyfile_;
+ string cert_password_;
+ SSL_CTX* ctx_;
+ SSL* ssl_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_SSL_WRAPPER_H_
diff --git a/tensorflow/contrib/ignite/ops/dataset_ops.cc b/tensorflow/contrib/ignite/ops/dataset_ops.cc
new file mode 100644
index 0000000000..3d6fbe00e6
--- /dev/null
+++ b/tensorflow/contrib/ignite/ops/dataset_ops.cc
@@ -0,0 +1,56 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+REGISTER_OP("IgniteDataset")
+ .Input("cache_name: string")
+ .Input("host: string")
+ .Input("port: int32")
+ .Input("local: bool")
+ .Input("part: int32")
+ .Input("page_size: int32")
+ .Input("schema: int32")
+ .Input("permutation: int32")
+ .Output("handle: variant")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+IgniteDataset that allows to get data from Apache Ignite.
+
+Apache Ignite is a memory-centric distributed database, caching, and processing
+platform for transactional, analytical, and streaming workloads, delivering
+in-memory speeds at petabyte scale. This contrib package contains an
+integration between Apache Ignite and TensorFlow. The integration is based on
+tf.data from TensorFlow side and Binary Client Protocol from Apache Ignite side.
+It allows to use Apache Ignite as a datasource for neural network training,
+inference and all other computations supported by TensorFlow. Ignite Dataset
+is based on Apache Ignite Binary Client Protocol.
+
+cache_name: Ignite Cache Name.
+host: Ignite Thin Client Host.
+port: Ignite Thin Client Port.
+local: Local flag that defines that data should be fetched from local host only.
+part: Partition data should be fetched from.
+page_size: Page size for Ignite Thin Client.
+schema: Internal structure that defines schema of cache objects.
+permutation: Internal structure that defines permutation of cache objects.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py
new file mode 100644
index 0000000000..288d485320
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py
@@ -0,0 +1,772 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ignite Dataset."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import socket
+import ssl
+import struct
+
+from tensorflow.contrib.ignite.python.ops import gen_dataset_ops
+from tensorflow.contrib.ignite.python.ops import ignite_op_loader # pylint: disable=unused-import
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+
+
+class Readable(object):
+ """Readable abstract class that exposes methods to do reading-related
+
+ operations.
+ """
+
+ @abc.abstractmethod
+ def __init__(self):
+ pass
+
+ def read_byte(self):
+ """Reads and returnes byte."""
+ return self._read("b", 1)
+
+ def read_short(self):
+ """Reads and returns short (2 bytes, little-endian)."""
+ return self._read("h", 2)
+
+ def read_int(self):
+ """Reads and returns int (4 bytes, little-endian)."""
+ return self._read("i", 4)
+
+ def read_long(self):
+ """Reads and returns long (8 bytes, little-endian)."""
+ return self._read("q", 8)
+
+ def skip(self, length):
+ """Skips the specified number of bytes."""
+ self.read_data(length)
+
+ @abc.abstractmethod
+ def read_data(self, length):
+ """Reads the specified number of bytes and returns them as a buffer."""
+ return None
+
+ def _read(self, data_type, length):
+ """Reads, unpacks and returns specified type (little-endian)."""
+ data_buffer = self.read_data(length)
+ return struct.unpack("<" + data_type, data_buffer)[0]
+
+
+class DataBuffer(Readable):
+ """DataBuffer class that exposes methods to read data from a byte buffer."""
+
+ def __init__(self, data_buffer):
+ """Constructs a new instance based on the specified byte buffer.
+
+ Args:
+ data_buffer: Buffer to be read.
+ """
+ Readable.__init__(self)
+ self.buffer = data_buffer
+ self.ptr = 0
+
+ def read_data(self, length):
+ """Reads the specified number of bytes and returns them as a buffer."""
+ data_buffer = self.buffer[self.ptr:][:length]
+ self.ptr += length
+ return data_buffer
+
+
+class TcpClient(Readable):
+ """TcpClient class that exposes methods to read data from a socket."""
+
+ def __init__(self, host, port, certfile=None, keyfile=None, password=None):
+ """Constructs a new instance based on the specified host and port.
+
+ Args:
+ host: Host to be connected.
+ port: Port to be connected.
+ certfile: File in PEM format containing the certificate as well as any
+ number of CA certificates needed to establish the certificate's
+ authenticity.
+ keyfile: File containing the private key (otherwise the private key will
+ be taken from certfile as well).
+ password: Password to be used if the private key is encrypted and a
+ password is necessary.
+
+ Raises:
+ ValueError: If the wrong combination of arguments is provided.
+ """
+ Readable.__init__(self)
+ self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+
+ if certfile is not None:
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ context.load_cert_chain(certfile, keyfile, password)
+ self.sock = context.wrap_socket(self.sock)
+ else:
+ if keyfile is not None:
+ raise ValueError("SSL is disabled, keyfile must not be specified "
+ "(to enable SSL specify certfile)")
+ if password is not None:
+ raise ValueError("SSL is disabled, password must not be specified "
+ "(to enable SSL specify certfile)")
+
+ self.host = host
+ self.port = port
+
+ def __enter__(self):
+ """Connects to host and port specified in the constructor."""
+ self.sock.connect((self.host, self.port))
+ return self
+
+ def __exit__(self, t, v, traceback):
+ """Disconnects the socket."""
+ self.sock.close()
+
+ def write_byte(self, v):
+ """Writes the specified byte."""
+ self._write(v, "b")
+
+ def write_short(self, v):
+ """Writes the specified short (2 bytes, little-endian)."""
+ self._write(v, "h")
+
+ def write_int(self, v):
+ """Writes the specified short (4 bytes, little-endian)."""
+ self._write(v, "i")
+
+ def write_long(self, v):
+ """Writes the specified int (8 bytes, little-endian)."""
+ self._write(v, "q")
+
+ def write_string(self, v):
+ """Writes the specified string."""
+ self.sock.sendall(v.encode("UTF-8"))
+
+ def read_data(self, length):
+ """Reads the specified number of bytes and returns them as a buffer."""
+ data_buffer = None
+ rem = length
+ while rem > 0:
+ buf = self.sock.recv(rem)
+ rem = rem - len(buf)
+ if data_buffer is None:
+ data_buffer = buf
+ else:
+ data_buffer += buf
+ return data_buffer
+
+ def _write(self, value, data_type):
+ """Packs and writes data using the specified type (little-endian)."""
+ data_buffer = struct.pack("<" + data_type, value)
+ self.sock.sendall(data_buffer)
+
+
+class BinaryType(object):
+ """BinaryType class that encapsulated type id, type name and fields."""
+
+ def __init__(self, type_id, type_name, fields):
+ """Constructs a new instance of BinaryType."""
+ self.type_id = type_id
+ self.type_name = type_name
+ self.fields = fields
+
+
+class BinaryField(object):
+ """BinaryField class that encapsulated field name, type id and field id."""
+
+ def __init__(self, field_name, type_id, field_id):
+ """Constructs a new instance of BinaryField."""
+ self.field_name = field_name
+ self.type_id = type_id
+ self.field_id = field_id
+
+
+# Binary types defined in Apache Ignite Thin client and supported by
+# TensorFlow on Apache Ignite, see
+# https://apacheignite.readme.io/v2.6/docs/binary-client-protocol.
+# True means that type is a vector, False means type is scalar.
+types = {
+ 1: (dtypes.uint8, False),
+ 2: (dtypes.int16, False),
+ 3: (dtypes.int32, False),
+ 4: (dtypes.int64, False),
+ 5: (dtypes.float32, False),
+ 6: (dtypes.float64, False),
+ 7: (dtypes.uint16, False),
+ 8: (dtypes.bool, False),
+ 9: (dtypes.string, False),
+ 12: (dtypes.uint8, True),
+ 13: (dtypes.int16, True),
+ 14: (dtypes.int32, True),
+ 15: (dtypes.int64, True),
+ 16: (dtypes.float32, True),
+ 17: (dtypes.float64, True),
+ 18: (dtypes.uint16, True),
+ 19: (dtypes.bool, True),
+ 20: (dtypes.string, True)
+}
+
+
+class TypeTreeNode(object):
+ """TypeTreeNode class exposes methods to format object tree structure
+
+ data.
+ """
+
+ def __init__(self, name, type_id, fields=None, permutation=None):
+ """Constructs a new instance of TypeTreeNode.
+
+ Args:
+ name: Name of the object tree node.
+ type_id: Type id of the object tree node.
+ fields: List of fields (children of the object tree node).
+ permutation: Permutation that should be applied to order object children.
+ """
+ self.name = name
+ self.type_id = type_id
+ self.fields = fields
+ self.permutation = permutation
+
+ def to_output_classes(self):
+ """Formats the tree object as required by `Dataset.output_classes`."""
+ if self.fields is None:
+ return ops.Tensor
+ output_classes = {}
+ for field in self.fields:
+ output_classes[field.name] = field.to_output_classes()
+ return output_classes
+
+ def to_output_shapes(self):
+ """Formats the tree object as required by `Dataset.output_shapes`."""
+ if self.fields is None:
+ if self.type_id in types:
+ object_type = types[self.type_id]
+ is_array = object_type[1]
+ if is_array:
+ return tensor_shape.TensorShape([None])
+ return tensor_shape.TensorShape([])
+ raise ValueError("Unsupported type [type_id=%d]" % self.type_id)
+ output_shapes = {}
+ for field in self.fields:
+ output_shapes[field.name] = field.to_output_shapes()
+ return output_shapes
+
+ def to_output_types(self):
+ """Formats the tree object as required by `Dataset.output_types`."""
+ if self.fields is None:
+ if self.type_id in types:
+ object_type = types[self.type_id]
+ return object_type[0]
+ raise ValueError("Unsupported type [type_id=%d]" % self.type_id)
+ else:
+ output_types = {}
+ for field in self.fields:
+ output_types[field.name] = field.to_output_types()
+ return output_types
+
+ def to_flat(self):
+ """Returns a list of node types."""
+ return self.to_flat_rec([])
+
+ def to_permutation(self):
+ """Returns a permutation that should be applied to order object leaves."""
+ correct_order_dict = {}
+ self.traversal_rec(correct_order_dict, 0)
+ object_order = []
+ self.traversal_permutation_rec(object_order)
+ return [correct_order_dict[o] for o in object_order]
+
+ def to_flat_rec(self, flat):
+ """Formats a list of leaf node types in pre-order."""
+ if self.fields is None:
+ flat.append(self.type_id)
+ else:
+ for field in self.fields:
+ field.to_flat_rec(flat)
+ return flat
+
+ def traversal_permutation_rec(self, permutation):
+ """Collects nodes in accordance with permutation."""
+ if self.fields is None:
+ permutation.append(self)
+ else:
+ for idx in self.permutation:
+ field = self.fields[idx]
+ field.traversal_permutation_rec(permutation)
+
+ def traversal_rec(self, d, i):
+ """Collects nodes in pre-order traversal."""
+ if self.fields is None:
+ d[self] = i
+ i += 1
+ else:
+ for field in self.fields:
+ i = field.traversal_rec(d, i)
+ return i
+
+
+class IgniteClient(TcpClient):
+ """IgniteClient enables working with Apache Ignite using a thin client.
+
+ This client works with assumption that all object in the cache
+ have the same structure (homogeneous objects) and the cache contains at
+ least one object.
+ """
+
+ def __init__(self,
+ host,
+ port,
+ username=None,
+ password=None,
+ certfile=None,
+ keyfile=None,
+ cert_password=None):
+ """Constructs a new instance of IgniteClient.
+
+ Args:
+ host: Apache Ignite Thin client host to be connected.
+ port: Apache Ignite Thin client port to be connected.
+ username: Apache Ignite Thin Client authentication username.
+ password: Apache Ignite Thin Client authentication password.
+ certfile: File in PEM format containing the certificate as well as any
+ number of CA certificates needed to establish the certificate's
+ authenticity.
+ keyfile: File containing the private key (otherwise the private key will
+ be taken from certfile as well).
+ cert_password: Password to be used if the private key is encrypted and a
+ password is necessary.
+ """
+ TcpClient.__init__(self, host, port, certfile, keyfile, cert_password)
+ self.username = username
+ self.password = password
+
+ def handshake(self):
+ """Makes a handshake after connect and before any other calls."""
+ msg_len = 8
+
+ if self.username is None:
+ msg_len += 1
+ else:
+ msg_len += 5 + len(self.username)
+
+ if self.password is None:
+ msg_len += 1
+ else:
+ msg_len += 5 + len(self.password)
+
+ self.write_int(msg_len) # Message length
+ self.write_byte(1) # Handshake operation
+ self.write_short(1) # Version (1.1.0)
+ self.write_short(1)
+ self.write_short(0)
+ self.write_byte(2) # Thin client
+
+ if self.username is None: # Username
+ self.write_byte(101)
+ else:
+ self.write_byte(9)
+ self.write_int(len(self.username))
+ self.write_string(self.username)
+
+ if self.password is None: # Password
+ self.write_byte(101)
+ else:
+ self.write_byte(9)
+ self.write_int(len(self.password))
+ self.write_string(self.password)
+
+ self.read_int() # Result length
+ res = self.read_byte()
+
+ if res != 1:
+ serv_ver_major = self.read_short()
+ serv_ver_minor = self.read_short()
+ serv_ver_patch = self.read_short()
+ err_msg = self._parse_string()
+ if err_msg is None:
+ raise RuntimeError(
+ "Handshake Error [result=%d, version=%d.%d.%d]" %
+ (res, serv_ver_major, serv_ver_minor, serv_ver_patch))
+ else:
+ raise RuntimeError(
+ "Handshake Error [result=%d, version=%d.%d.%d, message='%s']" %
+ (res, serv_ver_major, serv_ver_minor, serv_ver_patch, err_msg))
+
+ def get_cache_type(self, cache_name):
+ """Collects type information about objects stored in the specified cache."""
+ cache_name_hash = self._java_hash_code(cache_name)
+ self.write_int(25) # Message length
+ self.write_short(2000) # Operation code
+ self.write_long(0) # Request ID
+ self.write_int(cache_name_hash) # Cache name
+ self.write_byte(0) # Flags
+ self.write_byte(101) # Filter (NULL)
+ self.write_int(1) # Cursor page size
+ self.write_int(-1) # Partition to query
+ self.write_byte(0) # Local flag
+
+ result_length = self.read_int()
+ self.read_long() # Request id
+ status = self.read_int()
+
+ if status != 0:
+ err_msg = self._parse_string()
+ if err_msg is None:
+ raise RuntimeError("Scan Query Error [status=%s]" % status)
+ else:
+ raise RuntimeError(
+ "Scan Query Error [status=%s, message='%s']" % (status, err_msg))
+
+ self.read_long() # Cursor id
+ row_count = self.read_int()
+
+ if row_count == 0:
+ raise RuntimeError("Scan Query returned empty result, so it's "
+ "impossible to derive the cache type")
+
+ payload = DataBuffer(self.read_data(result_length - 25))
+
+ self.read_byte() # Next page
+
+ res = TypeTreeNode("root", 0, [
+ self._collect_types("key", payload),
+ self._collect_types("val", payload)
+ ], [0, 1])
+
+ return res
+
+ def _java_hash_code(self, s):
+ """Computes hash code of the specified string using Java code."""
+ h = 0
+ for c in s:
+ h = (31 * h + ord(c)) & 0xFFFFFFFF
+ return ((h + 0x80000000) & 0xFFFFFFFF) - 0x80000000
+
+ def _collect_types(self, field_name, data):
+ """Extracts type information from the specified object."""
+ type_id = data.read_byte()
+
+ # Byte scalar.
+ if type_id == 1:
+ data.skip(1)
+ return TypeTreeNode(field_name, type_id)
+
+ # Short scalar.
+ if type_id == 2:
+ data.skip(2)
+ return TypeTreeNode(field_name, type_id)
+
+ # Integer scalar.
+ if type_id == 3:
+ data.skip(4)
+ return TypeTreeNode(field_name, type_id)
+
+ # Long scalar.
+ if type_id == 4:
+ data.skip(8)
+ return TypeTreeNode(field_name, type_id)
+
+ # Float scalar.
+ if type_id == 5:
+ data.skip(4)
+ return TypeTreeNode(field_name, type_id)
+
+ # Double scalar.
+ if type_id == 6:
+ data.skip(8)
+ return TypeTreeNode(field_name, type_id)
+
+ # Char scalar.
+ if type_id == 7:
+ data.skip(2)
+ return TypeTreeNode(field_name, type_id)
+
+ # Bool scalar.
+ if type_id == 8:
+ data.skip(1)
+ return TypeTreeNode(field_name, type_id)
+
+ # String scalar.
+ if type_id == 9:
+ length = data.read_int()
+ data.skip(length)
+ return TypeTreeNode(field_name, type_id)
+
+ # UUID scalar.
+ if type_id == 10:
+ data.skip(16)
+ return TypeTreeNode(field_name, type_id)
+
+ # Date scalar.
+ if type_id == 11:
+ data.skip(8)
+ return TypeTreeNode(field_name, type_id)
+
+ # Byte array.
+ if type_id == 12:
+ length = data.read_int()
+ data.skip(length)
+ return TypeTreeNode(field_name, type_id)
+
+ # Short array.
+ if type_id == 13:
+ length = data.read_int()
+ data.skip(length * 2)
+ return TypeTreeNode(field_name, type_id)
+
+ # Integer array.
+ if type_id == 14:
+ length = data.read_int()
+ data.skip(length * 4)
+ return TypeTreeNode(field_name, type_id)
+
+ # Long array.
+ if type_id == 15:
+ length = data.read_int()
+ data.skip(length * 8)
+ return TypeTreeNode(field_name, type_id)
+
+ # Float array.
+ if type_id == 16:
+ length = data.read_int()
+ data.skip(length * 4)
+ return TypeTreeNode(field_name, type_id)
+
+ # Double array.
+ if type_id == 17:
+ length = data.read_int()
+ data.skip(length * 8)
+ return TypeTreeNode(field_name, type_id)
+
+ # Char array.
+ if type_id == 18:
+ length = data.read_int()
+ data.skip(length * 2)
+ return TypeTreeNode(field_name, type_id)
+
+ # Bool array.
+ if type_id == 19:
+ length = data.read_int()
+ data.skip(length)
+ return TypeTreeNode(field_name, type_id)
+
+ # String array.
+ if type_id == 20:
+ length = data.read_int()
+ for _ in range(length):
+ header = data.read_byte()
+ if header == 9:
+ str_length = data.read_int()
+ data.skip(str_length)
+ elif header == 101:
+ pass
+ else:
+ raise RuntimeError(
+ "Unknown binary type when expected string [type_id=%d]" % header)
+ return TypeTreeNode(field_name, type_id)
+
+ # UUID array.
+ if type_id == 21:
+ length = data.read_int()
+ data.skip(length * 16) # TODO(dmitrievanthony): support NULL values.
+ return TypeTreeNode(field_name, type_id)
+
+ # Date array.
+ if type_id == 22:
+ length = data.read_int()
+ data.skip(length * 8)
+ return TypeTreeNode(field_name, type_id)
+
+ # Wrapped Binary Object.
+ if type_id == 27:
+ length = data.read_int()
+ inner_data = data.read_data(length)
+ data.read_int() # Offset
+ return self._collect_types(field_name, DataBuffer(inner_data))
+
+ # Complex Object.
+ if type_id == 103:
+ data.read_byte() # Object version
+ data.read_short() # Object flags
+ obj_type_id = data.read_int()
+ data.read_int() # Object hash code
+ obj_length = data.read_int()
+ data.read_int() # Object schema id
+ obj_schema_offset = data.read_int()
+
+ obj_type = self._get_type(obj_type_id)
+ children = []
+
+ for obj_field in obj_type.fields:
+ child = self._collect_types(obj_field.field_name, data)
+ children.append(child)
+
+ children_sorted = sorted(children, key=lambda child: child.name)
+ permutation = [children_sorted.index(child) for child in children]
+ children = children_sorted
+
+ data.skip(obj_length - obj_schema_offset)
+
+ return TypeTreeNode(field_name, type_id, children, permutation)
+
+ raise RuntimeError("Unknown binary type [type_id=%d]" % type_id)
+
+ def _get_type(self, type_id):
+ """Queries Apache Ignite information about type by type id."""
+ self.write_int(14) # Message length
+ self.write_short(3002) # Operation code
+ self.write_long(0) # Request ID
+ self.write_int(type_id) # Type ID
+
+ self.read_int() # Result length
+ self.read_long() # Request id
+ status = self.read_int()
+
+ if status != 0:
+ err_msg = self._parse_string()
+ if err_msg is None:
+ raise RuntimeError("Get Binary Type Error [status=%d, message='%s']" %
+ (status, err_msg))
+ else:
+ raise RuntimeError("Get Binary Type Error [status=%d]" % status)
+
+ binary_type_exists = self.read_byte()
+
+ if binary_type_exists == 0:
+ raise RuntimeError("Binary type not found [type_id=%d] " % type_id)
+
+ binary_type_id = self.read_int()
+ binary_type_name = self._parse_string()
+ self._parse_string() # Affinity field name
+
+ fields = []
+ for _ in range(self.read_int()):
+ field_name = self._parse_string()
+ field_type_id = self.read_int()
+ field_id = self.read_int()
+
+ field = BinaryField(field_name, field_type_id, field_id)
+ fields.append(field)
+
+ is_enum = self.read_byte()
+ if is_enum == 1:
+ raise RuntimeError("Enum fields are not supported yet")
+
+ schema_cnt = self.read_int()
+ for _ in range(schema_cnt):
+ self.read_int() # Schema id
+ field_cnt = self.read_int()
+ self.skip(field_cnt * 4)
+
+ return BinaryType(binary_type_id, binary_type_name, fields)
+
+ def _parse_string(self):
+ """Parses string."""
+ header = self.read_byte()
+ if header == 9:
+ length = self.read_int()
+ return self.read_data(length).decode("utf-8")
+ if header == 101:
+ return None
+ raise RuntimeError(
+ "Unknown binary type when expected string [type_id=%d]" % header)
+
+
+class IgniteDataset(dataset_ops.DatasetSource):
+ """Apache Ignite is a memory-centric distributed database, caching, and
+
+ processing platform for transactional, analytical, and streaming workloads,
+ delivering in-memory speeds at petabyte scale. This contrib package
+ contains an integration between Apache Ignite and TensorFlow. The
+ integration is based on tf.data from TensorFlow side and Binary Client
+ Protocol from Apache Ignite side. It allows to use Apache Ignite as a
+ datasource for neural network training, inference and all other
+ computations supported by TensorFlow. Ignite Dataset is based on Apache
+ Ignite Binary Client Protocol.
+ """
+
+ def __init__(self,
+ cache_name,
+ host="localhost",
+ port=10800,
+ local=False,
+ part=-1,
+ page_size=100,
+ username=None,
+ password=None,
+ certfile=None,
+ keyfile=None,
+ cert_password=None):
+ """Create a IgniteDataset.
+
+ Args:
+ cache_name: Cache name to be used as datasource.
+ host: Apache Ignite Thin Client host to be connected.
+ port: Apache Ignite Thin Client port to be connected.
+ local: Local flag that defines to query only local data.
+ part: Number of partitions to be queried.
+ page_size: Apache Ignite Thin Client page size.
+ username: Apache Ignite Thin Client authentication username.
+ password: Apache Ignite Thin Client authentication password.
+ certfile: File in PEM format containing the certificate as well as any
+ number of CA certificates needed to establish the certificate's
+ authenticity.
+ keyfile: File containing the private key (otherwise the private key will
+ be taken from certfile as well).
+ cert_password: Password to be used if the private key is encrypted and a
+ password is necessary.
+ """
+ super(IgniteDataset, self).__init__()
+
+ with IgniteClient(host, port, username, password, certfile, keyfile,
+ cert_password) as client:
+ client.handshake()
+ self.cache_type = client.get_cache_type(cache_name)
+
+ self.cache_name = ops.convert_to_tensor(
+ cache_name, dtype=dtypes.string, name="cache_name")
+ self.host = ops.convert_to_tensor(host, dtype=dtypes.string, name="host")
+ self.port = ops.convert_to_tensor(port, dtype=dtypes.int32, name="port")
+ self.local = ops.convert_to_tensor(local, dtype=dtypes.bool, name="local")
+ self.part = ops.convert_to_tensor(part, dtype=dtypes.int32, name="part")
+ self.page_size = ops.convert_to_tensor(
+ page_size, dtype=dtypes.int32, name="page_size")
+ self.schema = ops.convert_to_tensor(
+ self.cache_type.to_flat(), dtype=dtypes.int32, name="schema")
+ self.permutation = ops.convert_to_tensor(
+ self.cache_type.to_permutation(),
+ dtype=dtypes.int32,
+ name="permutation")
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.ignite_dataset(self.cache_name, self.host, self.port,
+ self.local, self.part, self.page_size,
+ self.schema, self.permutation)
+
+ @property
+ def output_classes(self):
+ return self.cache_type.to_output_classes()
+
+ @property
+ def output_shapes(self):
+ return self.cache_type.to_output_shapes()
+
+ @property
+ def output_types(self):
+ return self.cache_type.to_output_types()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py b/tensorflow/contrib/ignite/python/ops/ignite_op_loader.py
index d5c03495e3..c9af7386cf 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py
+++ b/tensorflow/contrib/ignite/python/ops/ignite_op_loader.py
@@ -12,28 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the OptimizeDataset serialization."""
+"""Python helper for loading Ignite ops and kernels."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.platform import test
+from tensorflow.contrib.util import loader
+from tensorflow.python.platform import resource_loader
-
-class OptimizeDatasetSerializationTest(
- dataset_serialization_test_base.DatasetSerializationTestBase):
-
- def testCore(self):
-
- def build_dataset(num_elements, batch_size):
- return dataset_ops.Dataset.range(num_elements).map(lambda x: x * x).batch(
- batch_size).apply(optimization.optimize(["map_and_batch_fusion"]))
-
- self.run_core_tests(lambda: build_dataset(200, 10), None, 20)
-
-
-if __name__ == "__main__":
- test.main()
+_dataset_ops = loader.load_op_library(
+ resource_loader.get_path_to_datafile("../../_dataset_ops.so"))
diff --git a/tensorflow/contrib/ignite/python/tests/bin/start-plain.sh b/tensorflow/contrib/ignite/python/tests/bin/start-plain.sh
new file mode 100755
index 0000000000..f4607ce8ad
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/bin/start-plain.sh
@@ -0,0 +1,24 @@
+#!/usr/bin/env bash
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+nohup apache-ignite-fabric/bin/ignite.sh /data/config/ignite-config-plain.xml &
+sleep 5 # Wait Apache Ignite to be started
+
+./apache-ignite-fabric/bin/sqlline.sh \
+-u "jdbc:ignite:thin://127.0.0.1/" \
+--run=/data/sql/init.sql
+
+tail -f nohup.out
diff --git a/tensorflow/contrib/ignite/python/tests/config/ignite-config-plain.xml b/tensorflow/contrib/ignite/python/tests/config/ignite-config-plain.xml
new file mode 100644
index 0000000000..d900174a8a
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/config/ignite-config-plain.xml
@@ -0,0 +1,39 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+
+<beans xmlns="http://www.springframework.org/schema/beans"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xmlns:util="http://www.springframework.org/schema/util"
+ xsi:schemaLocation="http://www.springframework.org/schema/beans
+ http://www.springframework.org/schema/beans/spring-beans.xsd
+ http://www.springframework.org/schema/util
+ http://www.springframework.org/schema/util/spring-util.xsd">
+
+ <bean class="org.apache.ignite.configuration.IgniteConfiguration">
+ <property name="discoverySpi">
+ <bean class="org.apache.ignite.spi.discovery.tcp.TcpDiscoverySpi">
+ <property name="ipFinder">
+ <bean class="org.apache.ignite.spi.discovery.tcp.ipfinder.vm.TcpDiscoveryVmIpFinder">
+ <property name="addresses">
+ <list>
+ <value>127.0.0.1</value>
+ </list>
+ </property>
+ </bean>
+ </property>
+ </bean>
+ </property>
+ </bean>
+
+</beans>
diff --git a/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py b/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py
new file mode 100644
index 0000000000..ef29b5f14a
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py
@@ -0,0 +1,82 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+# ==============================================================================
+"""Tests for IgniteDataset."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.contrib.ignite import IgniteDataset
+from tensorflow.python.client import session
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+
+class IgniteDatasetTest(test.TestCase):
+ """The Apache Ignite servers have to setup before the test and tear down
+
+ after the test manually. The docker engine has to be installed.
+
+ To setup Apache Ignite servers:
+ $ bash start_ignite.sh
+
+ To tear down Apache Ignite servers:
+ $ bash stop_ignite.sh
+ """
+
+ def test_ignite_dataset_with_plain_client(self):
+ """Test Ignite Dataset with plain client.
+
+ """
+ self._clear_env()
+ ds = IgniteDataset(cache_name="SQL_PUBLIC_TEST_CACHE", port=42300)
+ self._check_dataset(ds)
+
+ def _clear_env(self):
+ """Clears environment variables used by Ignite Dataset.
+
+ """
+ if "IGNITE_DATASET_USERNAME" in os.environ:
+ del os.environ["IGNITE_DATASET_USERNAME"]
+ if "IGNITE_DATASET_PASSWORD" in os.environ:
+ del os.environ["IGNITE_DATASET_PASSWORD"]
+ if "IGNITE_DATASET_CERTFILE" in os.environ:
+ del os.environ["IGNITE_DATASET_CERTFILE"]
+ if "IGNITE_DATASET_CERT_PASSWORD" in os.environ:
+ del os.environ["IGNITE_DATASET_CERT_PASSWORD"]
+
+ def _check_dataset(self, dataset):
+ """Checks that dataset provides correct data."""
+ self.assertEqual(dtypes.int64, dataset.output_types["key"])
+ self.assertEqual(dtypes.string, dataset.output_types["val"]["NAME"])
+ self.assertEqual(dtypes.int64, dataset.output_types["val"]["VAL"])
+
+ it = dataset.make_one_shot_iterator()
+ ne = it.get_next()
+
+ with session.Session() as sess:
+ rows = [sess.run(ne), sess.run(ne), sess.run(ne)]
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(ne)
+
+ self.assertEqual({"key": 1, "val": {"NAME": b"TEST1", "VAL": 42}}, rows[0])
+ self.assertEqual({"key": 2, "val": {"NAME": b"TEST2", "VAL": 43}}, rows[1])
+ self.assertEqual({"key": 3, "val": {"NAME": b"TEST3", "VAL": 44}}, rows[2])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/ignite/python/tests/sql/init.sql b/tensorflow/contrib/ignite/python/tests/sql/init.sql
new file mode 100644
index 0000000000..5a192aef17
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/sql/init.sql
@@ -0,0 +1,20 @@
+-- Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+--
+-- Licensed under the Apache License, Version 2.0 (the "License");
+-- you may not use this file except in compliance with the License.
+-- You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+-- ==============================================================================
+
+CREATE TABLE TEST_CACHE (ID LONG PRIMARY KEY, NAME VARCHAR, VAL LONG);
+
+INSERT INTO TEST_CACHE VALUES (1, 'TEST1', 42);
+INSERT INTO TEST_CACHE VALUES (2, 'TEST2', 43);
+INSERT INTO TEST_CACHE VALUES (3, 'TEST3', 44);
diff --git a/tensorflow/contrib/ignite/python/tests/start_ignite.sh b/tensorflow/contrib/ignite/python/tests/start_ignite.sh
new file mode 100755
index 0000000000..a67bd44f2f
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/start_ignite.sh
@@ -0,0 +1,22 @@
+#!/usr/bin/env bash
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+IGNITE_VERSION=2.6.0
+SCRIPT_PATH="$( cd "$(dirname "$0")" ; pwd -P )"
+
+# Start Apache Ignite with plain client listener.
+docker run -itd --name ignite-plain -p 42300:10800 \
+-v ${SCRIPT_PATH}:/data apacheignite/ignite:${IGNITE_VERSION} /data/bin/start-plain.sh
diff --git a/tensorflow/contrib/ignite/python/tests/stop_ignite.sh b/tensorflow/contrib/ignite/python/tests/stop_ignite.sh
new file mode 100755
index 0000000000..8f03dbd1ed
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/stop_ignite.sh
@@ -0,0 +1,19 @@
+#!/usr/bin/env bash
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+docker rm -f ignite-plain
+docker rm -f ignite-ssl
+docker rm -f ignite-ssl-auth
diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc
index 370a8caf6a..788bf04b28 100644
--- a/tensorflow/contrib/image/kernels/image_ops.cc
+++ b/tensorflow/contrib/image/kernels/image_ops.cc
@@ -156,6 +156,7 @@ namespace functor {
TF_CALL_uint8(DECLARE_FUNCTOR);
TF_CALL_int32(DECLARE_FUNCTOR);
TF_CALL_int64(DECLARE_FUNCTOR);
+TF_CALL_half(DECLARE_FUNCTOR);
TF_CALL_float(DECLARE_FUNCTOR);
TF_CALL_double(DECLARE_FUNCTOR);
@@ -175,6 +176,7 @@ TF_CALL_double(DECLARE_FUNCTOR);
TF_CALL_uint8(REGISTER);
TF_CALL_int32(REGISTER);
TF_CALL_int64(REGISTER);
+TF_CALL_half(REGISTER);
TF_CALL_float(REGISTER);
TF_CALL_double(REGISTER);
diff --git a/tensorflow/contrib/image/kernels/image_ops.h b/tensorflow/contrib/image/kernels/image_ops.h
index 6b63eed130..7fac774d07 100644
--- a/tensorflow/contrib/image/kernels/image_ops.h
+++ b/tensorflow/contrib/image/kernels/image_ops.h
@@ -71,14 +71,7 @@ class ProjectiveGenerator {
(transform[3] * output_x + transform[4] * output_y + transform[5]) /
projection;
- // TODO(ringwalt): Add a fill value input.
-#if (defined __CUDA_ARCH__) && (CUDART_VERSION < 8000)
- // On CUDA versions previous to 8.0, only __shared__ variables
- // could be declared as static in the device code.
const T fill_value = T(0);
-#else
- static const T fill_value = T(0);
-#endif
switch (interpolation_) {
case INTERPOLATION_NEAREST:
// Switch the order of x and y again for indexing into the image.
diff --git a/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc b/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc
index 8743a5ff72..36b9a236a6 100644
--- a/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc
+++ b/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc
@@ -32,6 +32,7 @@ typedef Eigen::GpuDevice GPUDevice;
template class FillProjectiveTransform<GPUDevice, uint8>;
template class FillProjectiveTransform<GPUDevice, int32>;
template class FillProjectiveTransform<GPUDevice, int64>;
+template class FillProjectiveTransform<GPUDevice, Eigen::half>;
template class FillProjectiveTransform<GPUDevice, float>;
template class FillProjectiveTransform<GPUDevice, double>;
diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
index 376c0751ee..4997c31a7f 100644
--- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
+++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
@@ -272,6 +272,15 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
with self.cached_session():
self.assertAllEqual([[[[1], [0]], [[0], [1]]]], result.eval())
+ def test_transform_data_types(self):
+ for dtype in _DTYPES:
+ image = constant_op.constant([[1, 2], [3, 4]], dtype=dtype)
+ value = image_ops.transform(image, [1] * 8)
+ with self.test_session(use_gpu=True):
+ self.assertAllEqual(
+ value.eval(),
+ np.array([[4, 4], [4, 4]]).astype(dtype.as_numpy_dtype()))
+
class BipartiteMatchTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD
index f3ebe3b245..787a85644c 100644
--- a/tensorflow/contrib/lite/BUILD
+++ b/tensorflow/contrib/lite/BUILD
@@ -4,6 +4,7 @@ package(default_visibility = [
licenses(["notice"]) # Apache 2.0
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops")
exports_files(glob([
@@ -165,10 +166,6 @@ cc_library(
"stderr_reporter.h",
],
copts = tflite_copts(),
- defines = select({
- ":with_tflite_flex": ["TFLITE_FLEX"],
- "//conditions:default": [],
- }),
linkopts = [
] + select({
"//tensorflow:android": [
@@ -276,6 +273,7 @@ cc_test(
"testdata/0_subgraphs.bin",
"testdata/2_subgraphs.bin",
"testdata/empty_model.bin",
+ "testdata/multi_add_flex.bin",
"testdata/test_model.bin",
"testdata/test_model_broken.bin",
],
@@ -283,6 +281,26 @@ cc_test(
":framework",
"//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/core/api",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/testing:util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+# Test model framework with the flex library linked into the target.
+tf_cc_test(
+ name = "model_flex_test",
+ size = "small",
+ srcs = ["model_flex_test.cc"],
+ data = [
+ "testdata/multi_add_flex.bin",
+ ],
+ tags = ["no_windows"], # TODO(b/116667551): No weak symbols with MSVC.
+ deps = [
+ ":framework",
+ "//tensorflow/contrib/lite/core/api",
+ "//tensorflow/contrib/lite/delegates/flex:delegate",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
"//tensorflow/contrib/lite/testing:util",
"@com_google_googletest//:gtest",
],
diff --git a/tensorflow/contrib/lite/c/builtin_op_data.h b/tensorflow/contrib/lite/c/builtin_op_data.h
index be9d551ee4..44daf7adaa 100644
--- a/tensorflow/contrib/lite/c/builtin_op_data.h
+++ b/tensorflow/contrib/lite/c/builtin_op_data.h
@@ -99,6 +99,12 @@ typedef struct {
TfLiteFusedActivation activation;
} TfLiteSequenceRNNParams;
+typedef struct {
+ bool time_major;
+ TfLiteFusedActivation activation;
+ bool merge_outputs;
+} TfLiteBidirectionalSequenceRNNParams;
+
typedef enum {
kTfLiteFullyConnectedWeightsFormatDefault = 0,
kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1,
@@ -181,6 +187,16 @@ typedef struct {
} TfLiteLSTMParams;
typedef struct {
+ // Parameters for the LSTM kernel.
+ TfLiteFusedActivation activation;
+ float cell_clip;
+ float proj_clip;
+
+ // If true, store the outputs of both directions in the first output.
+ bool merge_outputs;
+} TfLiteBidirectionalSequenceLSTMParams;
+
+typedef struct {
bool align_corners;
} TfLiteResizeBilinearParams;
diff --git a/tensorflow/contrib/lite/c/builtin_op_data_test.cc b/tensorflow/contrib/lite/c/builtin_op_data_test.cc
index 4d0ba75e68..ba458b4252 100644
--- a/tensorflow/contrib/lite/c/builtin_op_data_test.cc
+++ b/tensorflow/contrib/lite/c/builtin_op_data_test.cc
@@ -73,6 +73,8 @@ TEST(IntArray, CanCompileStructs) {
TfLiteFakeQuantParams fake_quant_params;
TfLitePackParams pack_params;
TfLiteOneHotParams one_hot_params;
+ TfLiteBidirectionalSequenceRNNParams bidi_sequence_rnn_params;
+ TfLiteBidirectionalSequenceLSTMParams bidi_sequence_lstm_params;
}
} // namespace tflite
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
index e6900e0950..eac7db9a88 100644
--- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
@@ -224,10 +224,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
- case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
- TfLiteSequenceRNNParams* params =
- allocator->AllocatePOD<TfLiteSequenceRNNParams>();
+ auto params = allocator->AllocatePOD<TfLiteSequenceRNNParams>();
if (auto* sequence_rnn_params =
op->builtin_options_as_SequenceRNNOptions()) {
params->activation =
@@ -237,6 +235,19 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
+ case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: {
+ auto params =
+ allocator->AllocatePOD<TfLiteBidirectionalSequenceRNNParams>();
+ if (auto* bidi_sequence_rnn_params =
+ op->builtin_options_as_BidirectionalSequenceRNNOptions()) {
+ params->activation = parse_activation(
+ bidi_sequence_rnn_params->fused_activation_function());
+ params->time_major = bidi_sequence_rnn_params->time_major();
+ params->merge_outputs = bidi_sequence_rnn_params->merge_outputs();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
case BuiltinOperator_RNN: {
TfLiteRNNParams* params = allocator->AllocatePOD<TfLiteRNNParams>();
if (auto* rnn_params = op->builtin_options_as_RNNOptions()) {
@@ -360,10 +371,9 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
- case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM:
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
case BuiltinOperator_LSTM: {
- TfLiteLSTMParams* params = allocator->AllocatePOD<TfLiteLSTMParams>();
+ auto params = allocator->AllocatePOD<TfLiteLSTMParams>();
if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
params->activation =
parse_activation(lstm_params->fused_activation_function());
@@ -381,6 +391,20 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
+ case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: {
+ auto params =
+ allocator->AllocatePOD<TfLiteBidirectionalSequenceLSTMParams>();
+ if (auto* bidi_lstm_params =
+ op->builtin_options_as_BidirectionalSequenceLSTMOptions()) {
+ params->activation =
+ parse_activation(bidi_lstm_params->fused_activation_function());
+ params->cell_clip = bidi_lstm_params->cell_clip();
+ params->proj_clip = bidi_lstm_params->proj_clip();
+ params->merge_outputs = bidi_lstm_params->merge_outputs();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
case BuiltinOperator_RESIZE_BILINEAR: {
auto* params = allocator->AllocatePOD<TfLiteResizeBilinearParams>();
if (auto* schema_params =
diff --git a/tensorflow/contrib/lite/delegates/flex/BUILD b/tensorflow/contrib/lite/delegates/flex/BUILD
index bf5d91899c..9b89ed4f84 100644
--- a/tensorflow/contrib/lite/delegates/flex/BUILD
+++ b/tensorflow/contrib/lite/delegates/flex/BUILD
@@ -2,7 +2,7 @@
# This is a TF Lite delegate that is powered by TensorFlow's Eager.
#
package(default_visibility = [
- "//visibility:public",
+ "//visibility:private",
])
licenses(["notice"]) # Apache 2.0
@@ -20,7 +20,7 @@ cc_library(
"//tensorflow/contrib/lite:kernel_api",
] + select({
"//tensorflow:android": [
- "//tensorflow/core:android_tensorflow_lib_lite_no_runtime",
+ "//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:framework",
@@ -50,6 +50,7 @@ cc_library(
hdrs = [
"delegate.h",
],
+ visibility = ["//visibility:public"],
deps = [
":buffer_map",
":delegate_data",
@@ -60,12 +61,13 @@ cc_library(
"//tensorflow/contrib/lite:util",
] + select({
"//tensorflow:android": [
- "//tensorflow/core:android_tensorflow_lib_lite_no_runtime",
+ "//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:lib",
],
}),
+ alwayslink = 1,
)
tf_cc_test(
@@ -178,7 +180,7 @@ cc_library(
"//tensorflow/contrib/lite:kernel_api",
] + select({
"//tensorflow:android": [
- "//tensorflow/core:android_tensorflow_lib_lite_no_runtime",
+ "//tensorflow/core:android_tensorflow_lib",
],
"//conditions:default": [
"//tensorflow/core:lib",
diff --git a/tensorflow/contrib/lite/delegates/flex/delegate.cc b/tensorflow/contrib/lite/delegates/flex/delegate.cc
index ba065a8ff5..c72b0cf513 100644
--- a/tensorflow/contrib/lite/delegates/flex/delegate.cc
+++ b/tensorflow/contrib/lite/delegates/flex/delegate.cc
@@ -83,6 +83,15 @@ TfLiteStatus CopyFromBufferHandle(TfLiteContext* context,
} // namespace delegate
} // namespace flex
+// Corresponding weak declaration found in lite/model.cc.
+std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>
+AcquireFlexDelegate() {
+ return std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>(
+ tflite::FlexDelegate::Create().release(), [](TfLiteDelegate* delegate) {
+ delete reinterpret_cast<tflite::FlexDelegate*>(delegate);
+ });
+}
+
std::unique_ptr<FlexDelegate> FlexDelegate::Create() {
std::unique_ptr<flex::DelegateData> delegate_data;
if (!flex::DelegateData::Create(&delegate_data).ok()) {
diff --git a/tensorflow/contrib/lite/experimental/micro/BUILD b/tensorflow/contrib/lite/experimental/micro/BUILD
new file mode 100644
index 0000000000..df1036bc8b
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/BUILD
@@ -0,0 +1,76 @@
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load(
+ "//tensorflow/contrib/lite/experimental/micro/testing:micro_test.bzl",
+ "tflite_micro_cc_test",
+)
+
+cc_library(
+ name = "micro_framework",
+ srcs = [
+ "micro_error_reporter.cc",
+ "micro_interpreter.cc",
+ "micro_mutable_op_resolver.cc",
+ "simple_tensor_allocator.cc",
+ ],
+ hdrs = [
+ "compatibility.h",
+ "micro_error_reporter.h",
+ "micro_interpreter.h",
+ "micro_mutable_op_resolver.h",
+ "simple_tensor_allocator.h",
+ ],
+ deps = [
+ "//tensorflow/contrib/lite:schema_fbs_version",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/core/api",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ ],
+)
+
+tflite_micro_cc_test(
+ name = "micro_error_reporter_test",
+ srcs = [
+ "micro_error_reporter_test.cc",
+ ],
+ deps = [
+ ":micro_framework",
+ ],
+)
+
+tflite_micro_cc_test(
+ name = "micro_mutable_op_resolver_test",
+ srcs = [
+ "micro_mutable_op_resolver_test.cc",
+ ],
+ deps = [
+ ":micro_framework",
+ "//tensorflow/contrib/lite/experimental/micro/testing:micro_test",
+ ],
+)
+
+tflite_micro_cc_test(
+ name = "micro_interpreter_test",
+ srcs = [
+ "micro_interpreter_test.cc",
+ ],
+ deps = [
+ ":micro_framework",
+ "//tensorflow/contrib/lite/experimental/micro/testing:micro_test",
+ ],
+)
+
+tflite_micro_cc_test(
+ name = "simple_tensor_allocator_test",
+ srcs = [
+ "simple_tensor_allocator_test.cc",
+ ],
+ deps = [
+ ":micro_framework",
+ "//tensorflow/contrib/lite/experimental/micro/testing:micro_test",
+ ],
+)
diff --git a/tensorflow/contrib/lite/experimental/micro/README.md b/tensorflow/contrib/lite/experimental/micro/README.md
new file mode 100644
index 0000000000..414cafde4d
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/README.md
@@ -0,0 +1,114 @@
+# TensorFlow Lite for Microcontrollers
+
+This an experimental port of TensorFlow Lite aimed at micro controllers and other devices with only kilobytes of memory. It doesn't require any operating system support, any standard C or C++ libraries, or dynamic memory allocation, so it's designed to be portable even to 'bare metal' systems. The core runtime fits in 16KB on a Cortex M3, and with enough operators to run a speech keyword detection model, takes up a total of 22KB.
+
+The design goals are for the framework to be:
+
+- **Readable**: We want embedded software engineers to be able to understand what's required to run ML inference without having to study research papers. We've tried to keep the code base small, modular, and have reference implementations of all operations to help with this.
+
+- **Easy to modify**: We know that there are a lot of different platforms and requirements in the embedded world, and we don't expect to cover all of them in one framework. Instead, we're hoping that it can be a good starting point for developers to build on top of to meet their own needs. For example, we tried to make it easy to replace the implementations of key computational operators that are often crucial for performance, without having to touch the data flow and other runtime code. We want it to make more sense to use our workflow to handle things like model import and less-important operations, and customize the parts that matter, rather than having to reimplement everything in your own engine.
+
+- **Well-tested**: If you're modifying code, you need to know if your changes are correct. Having an easy way to test lets you develop much faster. To help there, we've written tests for all the components, and we've made sure that the tests can be run on almost any platform, with no dependencies apart from the ability to log text to a debug console somewhere. We also provide an easy way to run all the tests on-device as part of an automated test framework, and we use qemu/Renode emulation so that tests can be run even without physical devices present.
+
+- **Easy to integrate**: We want to be as open a system as possible, and use the best code available for each platform. To do that, we're going to rely on projects like [CMSIS-NN](https://www.keil.com/pack/doc/CMSIS/NN/html/index.html), [uTensor](https://github.com/uTensor/uTensor), and other vendor libraries to handle as much performance-critical code as possible. We know that there are an increasing number of options to accelerate neural networks on microcontrollers, so we're aiming to be a good host for deploying those hardware technologies too.
+
+- **Compatible**: We're using the same file schema, interpreter API, and kernel interface as regular TensorFlow Lite, so we leverage the large existing set of tools, documentation, and examples for the project. The biggest barrier to deploying ML models is getting them from a training environment into a form that's easy to run inference on, so we see reusing this rich ecosystem as being crucial to being easily usable. We also hope to integrate this experimental work back into the main codebase in the future.
+
+To meet those goals, we've made some tradeoffs:
+
+- **Simple C++**: To help with readability, our code is written in a modern version of C++, but we generally treat it as a "better C", rather relying on more complex features such as template meta-programming. As mentioned earlier, we avoid any use of dynamic memory allocation (new/delete) or the standard C/C++ libraries, so we believe this should still be fairly portable. It does mean that some older devices with C-only toolchains won't be supported, but we're hoping that the reference operator implementations (which are simple C-like functions) can still be useful in those cases. The interfaces are also designed to be C-only, so it should be possible to integrate the resulting library with pure C projects.
+
+- **Interpreted**: Code generation is a popular pattern for embedded code, because it gives standalone code that's easy to modify and step through, but we've chosen to go with an interpreted approach. In our internal microcontroller work we've found that using an extremely stripped-down interpreter with almost no dependencies gives us a lot of the same advantages, but is easier to maintain. For example, when new updates come out for the underlying library, you can just merge your local modifications in a single step, rather than having to regenerate new code and then patch in any changes you subsequently made. The coarse granularity of the interpreted primitives means that each operation call typically takes hundreds of thousands of instruction cycles at least, so we don't see noticeable performance gains from avoiding what's essentially a single switch statement at the interpreter level to call each operation. We're still working on improving the packaging though, for example we're considering having the ability to snapshot all the source files and headers used for a particular model, being able to compile the code and data together as a library, and then access it through a minimal set of C interface calls which hide the underlying complexity.
+
+- **Flatbuffers**: We represent our models using [the standard flatbuffer schema used by the rest of TensorFlow Lite](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/schema/schema.fbs), with the difference that we always keep it in read-only program memory (typically flash) rather than relying on having a file system to read it from. This is a good fit because flatbuffer's serialized format is designed to be mapped into memory without requiring any extra memory allocations or modifications to access it. All of the functions to read model values work directly on the serialized bytes, and large sections of data like weights are directly accessible as sequential C-style arrays of their data type, with no strides or unpacking needed. We do get a lot of value from using flatbuffers, but there is a cost in complexity. The flat buffer library code is all inline [inside the main headers](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/schema/schema_generated.h), but it isn't straightforward to inspect their implementations, and the model data structures aren't easy to comprehend from the debugger. The header for the schema itself also has to be periodically updated when new information is added to the file format, though we try to handle that transparently for most developers by checking in a pre-generated version.
+
+- **Code Duplication**: Some of the code in this prototype largely duplicates the logic in other parts of the TensorFlow Lite code base, for example the operator wrappers. We've tried to keep share as much as we can between the two interpreters, but there are some assumptions built into the original runtime that make this difficult. We'll be working on modularizing the main interpreter so that we can move to an entirely shared system.
+
+This initial preview release is designed to get early feedback, and is not intended to be a final product. It only includes enough operations to run a simple keyword recognition model, and the implementations are not optimized. We're hoping this will be a good way to get feedback and collaborate to improve the framework.
+
+## Getting Started
+
+Building requires a Linux or OS X machine.
+
+ - Open a terminal
+ - Download the TensorFlow source with `git clone https://github.com/tensorflow`
+ - Enter the source root directory by running `cd tensorflow`
+ - Download the dependencies by running `tensorflow/contrib/lite/experimental/micro/tools/make/download_dependencies.sh`. This may take a few minutes
+ - Build and test the library with `make -f tensorflow/contrib/lite/experimental/micro/tools/make/Makefile test`
+
+You should see a series of compilation steps, followed by "~~~ALL TESTS PASSED~~~" for the various tests of the code that it will run. If there's an error, you should get an informative message from make about what went wrong.
+
+These tests are all built as simple binaries with few dependencies, so you can run them manually. For example, here's how to run the depthwise convolution test, and its output:
+
+```
+tensorflow/contrib/lite/experimental/micro/tools/make/gen/linux_x86_64/bin/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test
+
+Testing SimpleTest
+Testing SimpleTestQuantized
+Testing SimpleTestRelu
+Testing SimpleTestReluQuantized
+4/4 tests passed
+~ALL TESTS PASSED~~~
+```
+
+Looking at the [depthwise_conv_test.cc](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test.cc) code, you'll see a sequence that looks like this:
+
+```
+...
+TF_LITE_MICRO_TESTS_BEGIN
+
+TF_LITE_MICRO_TEST(SimpleTest) {
+...
+}
+...
+TF_LITE_MICRO_TESTS_END
+```
+
+These macros work a lot like [the Google test framework](https://github.com/google/googletest), but they don't require any dependencies and just write results to stderr, rather than aborting the program. If all the tests pass, then "~~~ALL TESTS PASSED~~~" is output, and the test harness that runs the binary during the make process knows that everything ran correctly. If there's an error, the lack of the expected string lets the harness know that the test failed.
+
+So, why are we running tests in this complicated way? So far, we've been building binaries that run locally on the Mac OS or Linux machine you're building on, but this approach becomes important when we're targeting simple micro controller devices.
+
+## Building for the "Blue Pill" STM32F103
+
+The goal of this library is to enable machine learning on resource-constrained micro controllers and DSPs, and as part of that we've targeted the ["Blue Pill" STM32F103-compatible development board](https://github.com/google/googletest) as a cheap and popular platform. It only has 20KB of RAM and 64KB of flash, so it's a good device to ensure we can run efficiently on small chips.
+
+It's fairly easy to [buy and wire up a physical board](https://github.com/google/stm32_bare_lib#wiring-up-your-blue-pill), but even if you don't have an actual device, the [Renode project](https://renode.io/) makes it easy to run a faithful emulation on your desktop machine. You'll need [Docker](https://www.docker.com/) installed, but once you have that set up, try running the following command:
+
+`make -f tensorflow/contrib/lite/experimental/micro/tools/make/Makefile TARGET=bluepill test`
+
+You should see a similar set of outputs as you did in the previous section, with the addition of some extra Docker logging messages. These are because we're using Docker to run the Renode micro controller emulation tool, and the tests themselves are being run on a simulated STM32F103 device. The communication channels between an embedded device and the host are quite limited, so the test harness looks at the output of the debug log to see if tests have passed, just as it did in the previous section. This makes it a very flexible way to run cross-platform tests, even when a platform has no operating system facilities, as long as it can output debugging text logs.
+
+To understand what's happening here, try running the same depthwise convolution test, but through the emulated device test harness, with the following command:
+
+```
+tensorflow/contrib/lite/experimental/micro/testing/test_bluepill_binary.sh \
+tensorflow/contrib/lite/experimental/micro/tools/make/gen/bluepill_cortex-m3/bin/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test
+
+```
+
+You should see output that looks something like this:
+
+```
+Sending build context to Docker daemon 21.5kB
+Step 1/2 : FROM antmicro/renode:latest
+ ---> 1b670a243e8f
+Step 2/2 : LABEL maintainer="Pete Warden <petewarden@google.com>"
+ ---> Using cache
+ ---> 3afcd410846d
+Successfully built 3afcd410846d
+Successfully tagged renode_bluepill:latest
+LOGS:
+...
+03:27:32.4340 [INFO] machine-0: Machine started.
+03:27:32.4790 [DEBUG] cpu.uartSemihosting: [+0.22s host +0s virt 0s virt from start] Testing SimpleTest
+03:27:32.4812 [DEBUG] cpu.uartSemihosting: [+2.21ms host +0s virt 0s virt from start] Testing SimpleTestQuantized
+03:27:32.4833 [DEBUG] cpu.uartSemihosting: [+2.14ms host +0s virt 0s virt from start] Testing SimpleTestRelu
+03:27:32.4834 [DEBUG] cpu.uartSemihosting: [+0.18ms host +0s virt 0s virt from start] Testing SimpleTestReluQuantized
+03:27:32.4838 [DEBUG] cpu.uartSemihosting: [+0.4ms host +0s virt 0s virt from start] 4/4 tests passed
+03:27:32.4839 [DEBUG] cpu.uartSemihosting: [+41µs host +0s virt 0s virt from start] ~~~ALL TESTS PASSED~~~
+03:27:32.4839 [DEBUG] cpu.uartSemihosting: [+5µs host +0s virt 0s virt from start]
+...
+tensorflow/contrib/lite/experimental/micro/tools/make/gen/bluepill_cortex-m3/bin/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test: PASS
+```
+
+There's a lot of output here, but you should be able to see that the same tests that were covered when we ran locally on the development machine show up in the debug logs here, along with the magic string "~~~ALL TESTS PASSED~~~". This is the exact same code as before, just compiled and run on the STM32F103 rather than your desktop. We hope that the simplicity of this testing approach will help make adding support for new platforms as easy as possible.
diff --git a/tensorflow/contrib/lite/experimental/micro/compatibility.h b/tensorflow/contrib/lite/experimental/micro/compatibility.h
new file mode 100644
index 0000000000..4f0fd9f312
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/compatibility.h
@@ -0,0 +1,32 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_COMPATIBILITY_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_COMPATIBILITY_H_
+
+// C++ will automatically create class-specific delete operators for virtual
+// objects, which by default call the global delete function. For embedded
+// applications we want to avoid this, and won't be calling new/delete on these
+// objects, so we need to override the default implementation with one that does
+// nothing to avoid linking in ::delete().
+// This macro needs to be included in all subclasses of a virtual base class in
+// the private section.
+#ifdef TF_LITE_STATIC_MEMORY
+#define TF_LITE_REMOVE_VIRTUAL_DELETE \
+ void operator delete(void* p) {}
+#else
+#define TF_LITE_REMOVE_VIRTUAL_DELETE
+#endif
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_COMPATIBILITY_H_
diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/BUILD b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/BUILD
new file mode 100644
index 0000000000..dad58b6c1c
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/BUILD
@@ -0,0 +1,31 @@
+# Description:
+# TensorFlow Lite microcontroller example.
+
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+load(
+ "//tensorflow/contrib/lite/experimental/micro/testing:micro_test.bzl",
+ "tflite_micro_cc_test",
+)
+
+tflite_micro_cc_test(
+ name = "micro_speech_test",
+ srcs = [
+ "micro_speech_test.cc",
+ "tiny_conv_model_data.cc",
+ "tiny_conv_model_data.h",
+ ],
+ tags = [
+ "nomsan",
+ ],
+ deps = [
+ "//tensorflow/contrib/lite:schema_fbs_version",
+ "//tensorflow/contrib/lite/experimental/micro:micro_framework",
+ "//tensorflow/contrib/lite/experimental/micro/kernels:all_ops_resolver",
+ "//tensorflow/contrib/lite/experimental/micro/kernels:micro_ops",
+ "//tensorflow/contrib/lite/experimental/micro/testing:micro_test",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ ],
+)
diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc
new file mode 100644
index 0000000000..86cd056a72
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc
@@ -0,0 +1,55 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h"
+#include "tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h"
+#include "tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h"
+#include "tensorflow/contrib/lite/experimental/micro/micro_interpreter.h"
+#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/contrib/lite/version.h"
+
+TF_LITE_MICRO_TESTS_BEGIN
+
+TF_LITE_MICRO_TEST(TestInvoke) {
+ tflite::MicroErrorReporter micro_error_reporter;
+ tflite::ErrorReporter* error_reporter = &micro_error_reporter;
+
+ const tflite::Model* model = ::tflite::GetModel(g_tiny_conv_model_data);
+ if (model->version() != TFLITE_SCHEMA_VERSION) {
+ error_reporter->Report(
+ "Model provided is schema version %d not equal "
+ "to supported version %d.\n",
+ model->version(), TFLITE_SCHEMA_VERSION);
+ }
+ tflite::ops::micro::AllOpsResolver resolver;
+
+ const int tensor_arena_size = 10 * 1024;
+ uint8_t tensor_arena[tensor_arena_size];
+ tflite::SimpleTensorAllocator tensor_allocator(tensor_arena,
+ tensor_arena_size);
+
+ tflite::MicroInterpreter interpreter(model, resolver, &tensor_allocator,
+ error_reporter);
+ TfLiteStatus invoke_status = interpreter.Invoke();
+ if (invoke_status != kTfLiteOk) {
+ error_reporter->Report("Invoke failed\n");
+ }
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, invoke_status);
+
+ error_reporter->Report("Ran successfully\n");
+}
+
+TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc
new file mode 100644
index 0000000000..f1f9e0e219
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc
@@ -0,0 +1,1672 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Automatically created from a TensorFlow Lite flatbuffer using the command:
+// xxd -i tiny_conv.tflite > tiny_conv_model_data.cc
+
+#include "tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h"
+
+const unsigned char g_tiny_conv_model_data[] = {
+ 0x18, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c, 0x33, 0x00, 0x00, 0x0e, 0x00,
+ 0x18, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x14, 0x00,
+ 0x0e, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x08, 0x4d, 0x00, 0x00,
+ 0x0c, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00,
+ 0x01, 0x00, 0x00, 0x00, 0xf4, 0x47, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00,
+ 0x54, 0x4f, 0x43, 0x4f, 0x20, 0x43, 0x6f, 0x6e, 0x76, 0x65, 0x72, 0x74,
+ 0x65, 0x64, 0x2e, 0x00, 0x09, 0x00, 0x00, 0x00, 0xd4, 0x47, 0x00, 0x00,
+ 0x04, 0x03, 0x00, 0x00, 0xfc, 0x02, 0x00, 0x00, 0xf4, 0x02, 0x00, 0x00,
+ 0x64, 0x00, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00,
+ 0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xb8, 0xb3, 0xff, 0xff,
+ 0x16, 0xb4, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0xd7, 0x02, 0x00, 0x00, 0x2f, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xe8, 0xb3, 0xff, 0xff,
+ 0x46, 0xb4, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
+ 0xab, 0x00, 0x00, 0x00, 0x1e, 0xff, 0xff, 0xff, 0xed, 0xff, 0xff, 0xff,
+ 0x4a, 0x00, 0x00, 0x00, 0x62, 0xb4, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00,
+ 0x80, 0x02, 0x00, 0x00, 0xce, 0xad, 0xaf, 0x3c, 0xc8, 0xe9, 0xb0, 0x83,
+ 0xa1, 0xbf, 0xb2, 0xb1, 0xab, 0xd0, 0xa7, 0x53, 0xa5, 0xe9, 0xb5, 0xac,
+ 0xa2, 0xd3, 0xc4, 0x9e, 0x8b, 0xb2, 0x64, 0xb3, 0x9d, 0xa2, 0xae, 0xa6,
+ 0xd5, 0xbe, 0x43, 0x9f, 0x9c, 0x54, 0xb5, 0xa8, 0x49, 0x78, 0x86, 0xa2,
+ 0xa3, 0x55, 0x35, 0x96, 0x3d, 0x7f, 0xe2, 0xb5, 0xb0, 0x47, 0x28, 0xa9,
+ 0x9d, 0xbb, 0xd6, 0xff, 0xb7, 0x79, 0x63, 0xb5, 0xaf, 0xa7, 0xab, 0x7e,
+ 0xbc, 0xc7, 0xa0, 0xc3, 0xb1, 0xb6, 0xb2, 0xa1, 0xc2, 0xbb, 0x79, 0x57,
+ 0xbe, 0xc1, 0xb7, 0xb0, 0x6b, 0xb7, 0xa5, 0x75, 0x97, 0xb8, 0xe7, 0xac,
+ 0xad, 0x7e, 0xb1, 0x9b, 0xc3, 0xba, 0x6b, 0xa2, 0x7f, 0x58, 0xb9, 0x7a,
+ 0x4c, 0x91, 0x74, 0x9e, 0xa7, 0x3d, 0xc2, 0x94, 0x75, 0xa1, 0xa4, 0xac,
+ 0xab, 0x45, 0x2e, 0xb4, 0xb6, 0xbf, 0xc1, 0xdb, 0xaf, 0x6c, 0x67, 0xb1,
+ 0xa9, 0xa6, 0xa8, 0xca, 0xc2, 0xc4, 0xb9, 0xbf, 0xb4, 0xb9, 0xaa, 0x9d,
+ 0x9f, 0xb9, 0xb2, 0x71, 0xb2, 0xca, 0xbe, 0xaf, 0x5f, 0xbc, 0xa0, 0x5b,
+ 0xa8, 0xb4, 0xa4, 0xa8, 0xd8, 0x69, 0xb7, 0x8a, 0xbc, 0xb8, 0xaf, 0x9c,
+ 0x7c, 0x5d, 0xb3, 0x6b, 0x49, 0x95, 0x64, 0xa0, 0xa2, 0x49, 0xcb, 0x87,
+ 0xa5, 0xb5, 0xa1, 0xb2, 0xa3, 0x40, 0x6d, 0x9f, 0xc5, 0xb6, 0xbb, 0xd4,
+ 0x9c, 0x6d, 0x69, 0xa9, 0xa8, 0x91, 0xad, 0xb8, 0xd2, 0xc6, 0xaf, 0xb8,
+ 0xac, 0xa9, 0xa2, 0xa7, 0x60, 0xa6, 0xa1, 0xc9, 0xb8, 0xd6, 0xcf, 0xb1,
+ 0x56, 0xb4, 0xac, 0x40, 0xae, 0xbd, 0xbf, 0xa2, 0x54, 0x72, 0x9b, 0x8c,
+ 0xc2, 0xb5, 0xc2, 0x9b, 0x64, 0x6d, 0xb4, 0x62, 0x4e, 0x9b, 0x6c, 0xa6,
+ 0x8f, 0x4c, 0xca, 0x95, 0xb6, 0xbf, 0x92, 0xae, 0x9c, 0x49, 0xae, 0xb2,
+ 0xc0, 0xb6, 0xbc, 0xd1, 0xa4, 0x7b, 0x64, 0xa0, 0xa6, 0x81, 0xac, 0xa6,
+ 0xbd, 0xc8, 0xbc, 0xae, 0xaa, 0x9e, 0x61, 0xb1, 0x57, 0xac, 0xbf, 0xbf,
+ 0xbb, 0xe0, 0xa6, 0xae, 0x47, 0xc9, 0xbc, 0x57, 0xb0, 0xb5, 0xc7, 0x98,
+ 0xf4, 0x93, 0xb6, 0x70, 0xc3, 0xb3, 0xca, 0xab, 0x77, 0x9a, 0xac, 0x45,
+ 0x5c, 0x9e, 0x9a, 0xa9, 0x9b, 0x35, 0xc0, 0x6f, 0xc6, 0xc7, 0x91, 0xb4,
+ 0xa8, 0x3c, 0xce, 0xb8, 0xad, 0xb9, 0xb5, 0xdd, 0x9c, 0x6d, 0xbf, 0x91,
+ 0xb2, 0x7d, 0xa0, 0xaf, 0x9f, 0xbd, 0xb9, 0xcf, 0x9b, 0x5d, 0x3f, 0xac,
+ 0x64, 0xae, 0xaf, 0xb8, 0xbc, 0xb8, 0x86, 0xb5, 0x36, 0xcf, 0xb4, 0xa9,
+ 0xad, 0xcd, 0xdb, 0xa4, 0x68, 0xa6, 0xa4, 0x67, 0xc8, 0xb7, 0xe5, 0xa4,
+ 0x76, 0xb8, 0xa8, 0x28, 0x6b, 0xa5, 0xba, 0xad, 0x9f, 0x3a, 0xa5, 0x42,
+ 0xc5, 0xb0, 0x88, 0xad, 0xa5, 0x4d, 0xea, 0x8a, 0xb8, 0xb5, 0xb3, 0xd9,
+ 0xa0, 0x77, 0xbb, 0x92, 0x9e, 0x80, 0xbd, 0xbd, 0x6d, 0xcc, 0xab, 0x99,
+ 0x88, 0x58, 0x4d, 0xb0, 0x6c, 0xbc, 0x96, 0xbd, 0xae, 0xab, 0x5b, 0xac,
+ 0x2f, 0xc3, 0x9a, 0xbe, 0xac, 0xb3, 0x84, 0x9b, 0xe3, 0xaf, 0x95, 0x6b,
+ 0xc2, 0xb5, 0xca, 0xb7, 0x4e, 0xbc, 0x9d, 0x24, 0x75, 0xa9, 0xd2, 0xae,
+ 0xa0, 0x2b, 0x90, 0x34, 0xd1, 0xb5, 0x96, 0xae, 0xaa, 0x4d, 0xc1, 0xa3,
+ 0xb1, 0xb4, 0xaa, 0xd2, 0x9c, 0x7d, 0xc0, 0x91, 0x91, 0x7a, 0xb8, 0x83,
+ 0x44, 0xcb, 0xaf, 0x9b, 0x6b, 0x5b, 0x75, 0xb2, 0x62, 0xb6, 0xaa, 0xcb,
+ 0x99, 0xa8, 0x63, 0xae, 0x24, 0xc7, 0x8a, 0xbe, 0xa9, 0xb6, 0xa0, 0xa1,
+ 0x41, 0xac, 0x84, 0xb5, 0xb9, 0xb3, 0x9b, 0xad, 0x77, 0xbf, 0xa8, 0x7e,
+ 0x82, 0xb9, 0xbe, 0xaa, 0xa3, 0x47, 0x6d, 0xb5, 0xc3, 0xb1, 0xbf, 0xa7,
+ 0xb1, 0x57, 0x75, 0xb5, 0xb0, 0xb6, 0xb9, 0xce, 0xa4, 0x86, 0xb0, 0xa4,
+ 0x98, 0x80, 0xc5, 0x3e, 0x90, 0xca, 0x9b, 0xa2, 0x5a, 0x50, 0xc5, 0xa5,
+ 0xad, 0xc1, 0x9c, 0x91, 0x83, 0x8f, 0x21, 0xab, 0xac, 0xba, 0x70, 0xb4,
+ 0xae, 0x85, 0x7e, 0xa7, 0xbd, 0xba, 0x7c, 0xb2, 0xb5, 0xb2, 0x7e, 0xb3,
+ 0xc3, 0xcd, 0x82, 0xac, 0x9b, 0xb3, 0xa6, 0xb0, 0xbc, 0x6f, 0x52, 0xb9,
+ 0xbf, 0xb1, 0xa6, 0xa4, 0xc1, 0x7a, 0x90, 0xc0, 0xae, 0xab, 0x94, 0xd8,
+ 0xab, 0xa4, 0x98, 0xbb, 0x8b, 0x86, 0x94, 0x01, 0xad, 0xe7, 0xb1, 0x9b,
+ 0x57, 0x48, 0xc1, 0x88, 0xbf, 0xcc, 0xb4, 0x4b, 0x62, 0x8b, 0x48, 0xa7,
+ 0xbe, 0xe1, 0x80, 0xa6, 0xb3, 0x64, 0xaa, 0xa4, 0xcf, 0xba, 0x6d, 0xa6,
+ 0xb8, 0xa0, 0x8f, 0xb3, 0xce, 0xc3, 0x87, 0xb2, 0xa0, 0xc0, 0x78, 0xb0,
+ 0xb9, 0xaa, 0x40, 0xb8, 0xd8, 0xa3, 0x9a, 0xaa, 0xcc, 0xa2, 0x9f, 0xb9,
+ 0xbe, 0xc2, 0x89, 0xd6, 0xc6, 0x9c, 0xa3, 0xc7, 0x94, 0xb6, 0xff, 0xff,
+ 0x98, 0xb6, 0xff, 0xff, 0xf6, 0xb6, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00,
+ 0xc0, 0x44, 0x00, 0x00, 0x4a, 0x4d, 0x59, 0x60, 0x5a, 0x45, 0x3d, 0x50,
+ 0x4a, 0x43, 0x3d, 0x59, 0x3e, 0x49, 0x4a, 0x59, 0x45, 0x44, 0x41, 0x5d,
+ 0x50, 0x2f, 0x4e, 0x34, 0x46, 0x48, 0x41, 0x4a, 0x4c, 0x3b, 0x4b, 0x3e,
+ 0x49, 0x49, 0x43, 0x4b, 0x3e, 0x49, 0x47, 0x41, 0x3e, 0x4a, 0x46, 0x43,
+ 0x41, 0x43, 0x47, 0x49, 0x4a, 0x4c, 0x46, 0x58, 0x3f, 0x4c, 0x4b, 0x4c,
+ 0x4d, 0x4b, 0x45, 0x52, 0x45, 0x42, 0x52, 0x52, 0x48, 0x40, 0x46, 0x5f,
+ 0x4c, 0x41, 0x47, 0x48, 0x48, 0x4c, 0x43, 0x61, 0x50, 0x4b, 0x49, 0x49,
+ 0x46, 0x3f, 0x40, 0x67, 0x40, 0x4d, 0x45, 0x40, 0x40, 0x45, 0x47, 0x56,
+ 0x44, 0x3a, 0x4a, 0x4c, 0x52, 0x48, 0x46, 0x50, 0x4b, 0x44, 0x51, 0x45,
+ 0x40, 0x45, 0x45, 0x48, 0x4e, 0x4e, 0x43, 0x48, 0x44, 0x4b, 0x45, 0x4a,
+ 0x53, 0x45, 0x4a, 0x4b, 0x3f, 0x43, 0x45, 0x53, 0x4d, 0x43, 0x46, 0x3f,
+ 0x47, 0x4e, 0x51, 0x50, 0x48, 0x4f, 0x4f, 0x4a, 0x4a, 0x4e, 0x45, 0x4e,
+ 0x46, 0x41, 0x4a, 0x46, 0x45, 0x47, 0x45, 0x4b, 0x50, 0x4c, 0x46, 0x45,
+ 0x41, 0x47, 0x41, 0x47, 0x46, 0x4f, 0x3f, 0x4f, 0x4a, 0x51, 0x4f, 0x53,
+ 0x54, 0x48, 0x51, 0x43, 0x4b, 0x48, 0x4d, 0x46, 0x48, 0x4f, 0x49, 0x44,
+ 0x43, 0x53, 0x50, 0x59, 0x56, 0x3d, 0x45, 0x44, 0x48, 0x38, 0x3b, 0x5f,
+ 0x39, 0x43, 0x43, 0x52, 0x46, 0x3e, 0x43, 0x58, 0x43, 0x1e, 0x50, 0x3c,
+ 0x46, 0x4b, 0x46, 0x50, 0x3c, 0x37, 0x4c, 0x47, 0x47, 0x4b, 0x47, 0x54,
+ 0x43, 0x3e, 0x47, 0x4f, 0x4b, 0x41, 0x53, 0x50, 0x42, 0x46, 0x4f, 0x4b,
+ 0x4e, 0x3f, 0x49, 0x52, 0x4a, 0x4a, 0x49, 0x53, 0x52, 0x47, 0x52, 0x5a,
+ 0x40, 0x42, 0x4d, 0x4b, 0x50, 0x43, 0x49, 0x59, 0x47, 0x4c, 0x4d, 0x50,
+ 0x4e, 0x3c, 0x44, 0x61, 0x51, 0x49, 0x49, 0x46, 0x49, 0x47, 0x4b, 0x5a,
+ 0x45, 0x4b, 0x43, 0x40, 0x44, 0x52, 0x4d, 0x54, 0x49, 0x47, 0x44, 0x48,
+ 0x46, 0x48, 0x3e, 0x40, 0x45, 0x4f, 0x4d, 0x4b, 0x4c, 0x40, 0x3d, 0x40,
+ 0x3e, 0x48, 0x50, 0x4e, 0x4c, 0x42, 0x48, 0x4b, 0x3d, 0x48, 0x4b, 0x44,
+ 0x52, 0x4b, 0x49, 0x4f, 0x49, 0x3f, 0x47, 0x43, 0x4d, 0x3f, 0x53, 0x4e,
+ 0x4a, 0x4f, 0x4e, 0x4e, 0x53, 0x42, 0x46, 0x4c, 0x44, 0x4c, 0x46, 0x51,
+ 0x45, 0x48, 0x4a, 0x50, 0x47, 0x41, 0x45, 0x54, 0x4a, 0x44, 0x50, 0x49,
+ 0x48, 0x50, 0x51, 0x4b, 0x50, 0x4c, 0x4a, 0x49, 0x43, 0x47, 0x50, 0x4a,
+ 0x4d, 0x4c, 0x4e, 0x49, 0x42, 0x50, 0x52, 0x48, 0x45, 0x5a, 0x4e, 0x55,
+ 0x51, 0x3d, 0x3d, 0x4d, 0x42, 0x32, 0x36, 0x64, 0x39, 0x4c, 0x41, 0x48,
+ 0x44, 0x35, 0x43, 0x56, 0x47, 0x1e, 0x4b, 0x3e, 0x47, 0x3f, 0x43, 0x52,
+ 0x51, 0x34, 0x41, 0x4d, 0x3e, 0x41, 0x41, 0x48, 0x3c, 0x4b, 0x45, 0x3b,
+ 0x40, 0x43, 0x4c, 0x46, 0x46, 0x47, 0x3e, 0x4f, 0x4b, 0x48, 0x42, 0x47,
+ 0x4e, 0x3e, 0x49, 0x47, 0x43, 0x43, 0x4e, 0x52, 0x51, 0x45, 0x3f, 0x54,
+ 0x46, 0x44, 0x48, 0x5d, 0x3e, 0x4a, 0x47, 0x52, 0x53, 0x3a, 0x4f, 0x5d,
+ 0x41, 0x4c, 0x48, 0x51, 0x43, 0x4b, 0x4b, 0x67, 0x48, 0x4b, 0x45, 0x4d,
+ 0x4b, 0x43, 0x4a, 0x54, 0x4c, 0x46, 0x43, 0x4a, 0x4d, 0x43, 0x4c, 0x47,
+ 0x4a, 0x48, 0x4d, 0x42, 0x4d, 0x48, 0x3f, 0x43, 0x4c, 0x44, 0x4e, 0x4c,
+ 0x40, 0x45, 0x4b, 0x48, 0x47, 0x47, 0x3e, 0x4c, 0x52, 0x41, 0x44, 0x4e,
+ 0x4d, 0x44, 0x49, 0x4d, 0x3d, 0x45, 0x48, 0x4f, 0x4c, 0x4a, 0x55, 0x51,
+ 0x4d, 0x4c, 0x45, 0x4e, 0x46, 0x45, 0x44, 0x49, 0x4e, 0x44, 0x40, 0x48,
+ 0x49, 0x44, 0x53, 0x51, 0x42, 0x41, 0x51, 0x49, 0x51, 0x45, 0x51, 0x3f,
+ 0x4b, 0x3f, 0x52, 0x3c, 0x50, 0x4d, 0x4f, 0x4b, 0x44, 0x4f, 0x40, 0x52,
+ 0x49, 0x4a, 0x50, 0x3f, 0x3d, 0x54, 0x4c, 0x53, 0x52, 0x45, 0x41, 0x43,
+ 0x47, 0x2d, 0x40, 0x63, 0x3a, 0x51, 0x43, 0x4e, 0x40, 0x2b, 0x36, 0x5b,
+ 0x4b, 0x12, 0x4d, 0x35, 0x4b, 0x3f, 0x44, 0x4a, 0x46, 0x31, 0x54, 0x48,
+ 0x43, 0x42, 0x3d, 0x51, 0x41, 0x45, 0x49, 0x4b, 0x47, 0x49, 0x3d, 0x3e,
+ 0x46, 0x3d, 0x4d, 0x48, 0x3d, 0x45, 0x48, 0x4b, 0x49, 0x52, 0x44, 0x4c,
+ 0x45, 0x44, 0x45, 0x49, 0x50, 0x48, 0x45, 0x46, 0x45, 0x44, 0x52, 0x55,
+ 0x46, 0x45, 0x4b, 0x3d, 0x42, 0x4a, 0x3e, 0x57, 0x48, 0x4b, 0x3c, 0x42,
+ 0x4a, 0x46, 0x47, 0x6c, 0x54, 0x4b, 0x41, 0x49, 0x49, 0x50, 0x43, 0x56,
+ 0x44, 0x43, 0x4d, 0x3e, 0x44, 0x41, 0x47, 0x40, 0x4a, 0x4b, 0x4d, 0x4d,
+ 0x3e, 0x46, 0x45, 0x47, 0x3e, 0x42, 0x4a, 0x45, 0x49, 0x3d, 0x3f, 0x43,
+ 0x40, 0x44, 0x47, 0x4a, 0x45, 0x4d, 0x4b, 0x4c, 0x43, 0x40, 0x3d, 0x3e,
+ 0x4c, 0x4c, 0x42, 0x4d, 0x48, 0x4d, 0x49, 0x42, 0x51, 0x51, 0x4c, 0x4b,
+ 0x53, 0x4f, 0x48, 0x4d, 0x40, 0x46, 0x45, 0x4b, 0x47, 0x47, 0x4b, 0x46,
+ 0x54, 0x42, 0x42, 0x46, 0x46, 0x4a, 0x4c, 0x55, 0x3f, 0x3c, 0x52, 0x4b,
+ 0x4b, 0x4d, 0x4e, 0x48, 0x53, 0x4c, 0x4b, 0x42, 0x52, 0x54, 0x50, 0x4b,
+ 0x40, 0x5f, 0x58, 0x53, 0x50, 0x42, 0x35, 0x48, 0x39, 0x24, 0x3c, 0x5e,
+ 0x41, 0x50, 0x3c, 0x51, 0x42, 0x26, 0x42, 0x56, 0x41, 0x0c, 0x3e, 0x3d,
+ 0x48, 0x3e, 0x50, 0x4b, 0x3a, 0x2c, 0x43, 0x3d, 0x48, 0x3e, 0x43, 0x48,
+ 0x4c, 0x3f, 0x4a, 0x3e, 0x51, 0x4a, 0x4f, 0x40, 0x47, 0x43, 0x50, 0x4c,
+ 0x43, 0x4d, 0x3f, 0x45, 0x4d, 0x3e, 0x4c, 0x44, 0x51, 0x47, 0x4b, 0x51,
+ 0x45, 0x49, 0x44, 0x3f, 0x46, 0x46, 0x46, 0x57, 0x49, 0x4c, 0x49, 0x4e,
+ 0x47, 0x4c, 0x47, 0x5e, 0x43, 0x46, 0x45, 0x4b, 0x52, 0x49, 0x45, 0x5f,
+ 0x47, 0x41, 0x46, 0x43, 0x4f, 0x3b, 0x43, 0x51, 0x46, 0x53, 0x4a, 0x4e,
+ 0x4b, 0x43, 0x4e, 0x40, 0x48, 0x49, 0x46, 0x3f, 0x48, 0x50, 0x4b, 0x41,
+ 0x4a, 0x47, 0x4b, 0x3d, 0x46, 0x49, 0x4b, 0x43, 0x43, 0x42, 0x3e, 0x47,
+ 0x47, 0x4a, 0x45, 0x46, 0x51, 0x48, 0x51, 0x4e, 0x3f, 0x50, 0x44, 0x4b,
+ 0x4d, 0x4e, 0x44, 0x4d, 0x3d, 0x49, 0x4a, 0x4e, 0x42, 0x51, 0x43, 0x42,
+ 0x46, 0x3e, 0x48, 0x4b, 0x4f, 0x50, 0x3d, 0x48, 0x4c, 0x4f, 0x46, 0x44,
+ 0x44, 0x48, 0x42, 0x4b, 0x48, 0x41, 0x43, 0x46, 0x4d, 0x49, 0x4f, 0x43,
+ 0x41, 0x44, 0x3f, 0x3d, 0x45, 0x4f, 0x45, 0x41, 0x40, 0x58, 0x4f, 0x54,
+ 0x5b, 0x4b, 0x3a, 0x47, 0x3d, 0x28, 0x3d, 0x57, 0x3e, 0x51, 0x3f, 0x47,
+ 0x3f, 0x2e, 0x3e, 0x54, 0x4e, 0x0b, 0x41, 0x3d, 0x3b, 0x3d, 0x43, 0x47,
+ 0x47, 0x28, 0x4d, 0x43, 0x43, 0x3b, 0x4e, 0x4a, 0x4d, 0x42, 0x51, 0x46,
+ 0x4f, 0x3d, 0x4c, 0x3a, 0x49, 0x49, 0x4a, 0x43, 0x42, 0x4b, 0x47, 0x42,
+ 0x42, 0x49, 0x3f, 0x4d, 0x46, 0x4a, 0x49, 0x4e, 0x42, 0x3c, 0x4a, 0x41,
+ 0x4c, 0x40, 0x4d, 0x5a, 0x49, 0x46, 0x51, 0x46, 0x4b, 0x4c, 0x46, 0x62,
+ 0x45, 0x42, 0x51, 0x4e, 0x4d, 0x3e, 0x4d, 0x5b, 0x4d, 0x43, 0x45, 0x50,
+ 0x4b, 0x40, 0x50, 0x53, 0x4f, 0x4f, 0x51, 0x53, 0x46, 0x41, 0x4e, 0x3a,
+ 0x4b, 0x47, 0x3f, 0x3e, 0x4d, 0x48, 0x53, 0x3f, 0x45, 0x42, 0x4c, 0x45,
+ 0x55, 0x4c, 0x4b, 0x39, 0x4a, 0x45, 0x48, 0x4d, 0x47, 0x40, 0x48, 0x4f,
+ 0x4d, 0x49, 0x3e, 0x41, 0x46, 0x4e, 0x40, 0x49, 0x4b, 0x47, 0x4c, 0x45,
+ 0x44, 0x51, 0x4f, 0x4b, 0x48, 0x49, 0x44, 0x41, 0x43, 0x46, 0x51, 0x45,
+ 0x40, 0x48, 0x4b, 0x42, 0x44, 0x4f, 0x53, 0x4d, 0x44, 0x46, 0x4e, 0x4c,
+ 0x48, 0x50, 0x41, 0x45, 0x42, 0x48, 0x4d, 0x4d, 0x47, 0x45, 0x41, 0x45,
+ 0x48, 0x58, 0x4e, 0x46, 0x43, 0x53, 0x57, 0x52, 0x5e, 0x42, 0x45, 0x4e,
+ 0x39, 0x24, 0x32, 0x56, 0x47, 0x56, 0x49, 0x52, 0x46, 0x26, 0x3a, 0x51,
+ 0x4b, 0x05, 0x3e, 0x43, 0x3f, 0x38, 0x4d, 0x4b, 0x4f, 0x27, 0x51, 0x46,
+ 0x47, 0x41, 0x4a, 0x47, 0x4a, 0x3e, 0x44, 0x51, 0x3f, 0x3a, 0x43, 0x46,
+ 0x4d, 0x49, 0x46, 0x52, 0x43, 0x48, 0x49, 0x3e, 0x47, 0x46, 0x4a, 0x4d,
+ 0x47, 0x46, 0x52, 0x50, 0x44, 0x48, 0x4c, 0x47, 0x45, 0x41, 0x49, 0x5b,
+ 0x4d, 0x4b, 0x47, 0x4c, 0x4a, 0x47, 0x45, 0x5b, 0x49, 0x46, 0x52, 0x47,
+ 0x47, 0x3d, 0x55, 0x59, 0x40, 0x4b, 0x3e, 0x50, 0x42, 0x43, 0x40, 0x4f,
+ 0x48, 0x3f, 0x47, 0x53, 0x4d, 0x44, 0x4e, 0x37, 0x4c, 0x43, 0x51, 0x4d,
+ 0x46, 0x4e, 0x40, 0x41, 0x52, 0x44, 0x43, 0x4a, 0x50, 0x48, 0x47, 0x42,
+ 0x48, 0x45, 0x50, 0x4d, 0x42, 0x52, 0x44, 0x43, 0x45, 0x43, 0x4c, 0x4d,
+ 0x44, 0x51, 0x47, 0x48, 0x51, 0x4f, 0x48, 0x45, 0x49, 0x4a, 0x3e, 0x43,
+ 0x4d, 0x4e, 0x4e, 0x46, 0x54, 0x4d, 0x49, 0x4d, 0x47, 0x46, 0x4b, 0x41,
+ 0x4a, 0x49, 0x44, 0x45, 0x4d, 0x3e, 0x53, 0x50, 0x47, 0x4d, 0x4e, 0x43,
+ 0x4f, 0x45, 0x4e, 0x4a, 0x47, 0x49, 0x4c, 0x4c, 0x4d, 0x54, 0x42, 0x4c,
+ 0x43, 0x5d, 0x59, 0x50, 0x5e, 0x4b, 0x44, 0x43, 0x3c, 0x25, 0x31, 0x5b,
+ 0x46, 0x5a, 0x50, 0x4d, 0x41, 0x2a, 0x41, 0x4f, 0x44, 0x00, 0x41, 0x3d,
+ 0x43, 0x4b, 0x47, 0x45, 0x4e, 0x2e, 0x44, 0x46, 0x53, 0x3d, 0x43, 0x41,
+ 0x44, 0x46, 0x49, 0x42, 0x45, 0x4f, 0x4d, 0x3a, 0x43, 0x3c, 0x47, 0x53,
+ 0x43, 0x4e, 0x3f, 0x41, 0x4d, 0x50, 0x4b, 0x4c, 0x51, 0x47, 0x53, 0x4f,
+ 0x45, 0x4a, 0x44, 0x45, 0x41, 0x46, 0x47, 0x50, 0x51, 0x3f, 0x3e, 0x41,
+ 0x48, 0x45, 0x46, 0x5d, 0x45, 0x4a, 0x4c, 0x46, 0x4a, 0x49, 0x50, 0x51,
+ 0x51, 0x4c, 0x4f, 0x47, 0x47, 0x42, 0x45, 0x47, 0x4e, 0x48, 0x46, 0x40,
+ 0x45, 0x46, 0x4d, 0x3b, 0x4d, 0x52, 0x4c, 0x51, 0x49, 0x51, 0x47, 0x3d,
+ 0x4d, 0x42, 0x4f, 0x4e, 0x43, 0x43, 0x45, 0x3a, 0x42, 0x50, 0x4c, 0x4a,
+ 0x41, 0x53, 0x4c, 0x45, 0x51, 0x3f, 0x54, 0x43, 0x4b, 0x54, 0x56, 0x4d,
+ 0x4f, 0x4a, 0x50, 0x4b, 0x44, 0x45, 0x4f, 0x4f, 0x47, 0x3e, 0x50, 0x4f,
+ 0x4b, 0x48, 0x4d, 0x49, 0x55, 0x4d, 0x45, 0x4d, 0x4a, 0x53, 0x43, 0x46,
+ 0x4c, 0x45, 0x41, 0x46, 0x49, 0x49, 0x4f, 0x4b, 0x49, 0x50, 0x52, 0x49,
+ 0x41, 0x54, 0x44, 0x4c, 0x44, 0x63, 0x4a, 0x49, 0x40, 0x59, 0x52, 0x52,
+ 0x59, 0x3f, 0x3e, 0x3e, 0x40, 0x25, 0x3c, 0x5c, 0x4f, 0x57, 0x44, 0x50,
+ 0x41, 0x2a, 0x48, 0x4f, 0x43, 0x08, 0x47, 0x43, 0x49, 0x48, 0x4d, 0x49,
+ 0x46, 0x2b, 0x48, 0x44, 0x4e, 0x47, 0x47, 0x43, 0x44, 0x3e, 0x4a, 0x52,
+ 0x3f, 0x4a, 0x53, 0x42, 0x49, 0x47, 0x4c, 0x50, 0x43, 0x46, 0x46, 0x3c,
+ 0x4c, 0x47, 0x4e, 0x4d, 0x42, 0x41, 0x53, 0x52, 0x4f, 0x40, 0x54, 0x50,
+ 0x46, 0x43, 0x50, 0x56, 0x51, 0x48, 0x48, 0x48, 0x49, 0x39, 0x47, 0x5e,
+ 0x4e, 0x4b, 0x4f, 0x4e, 0x43, 0x45, 0x42, 0x58, 0x4a, 0x3b, 0x48, 0x4d,
+ 0x43, 0x3e, 0x4b, 0x43, 0x3c, 0x45, 0x46, 0x4b, 0x42, 0x42, 0x4e, 0x3d,
+ 0x4b, 0x4e, 0x51, 0x52, 0x48, 0x3e, 0x4b, 0x3f, 0x4c, 0x4a, 0x4b, 0x4c,
+ 0x46, 0x48, 0x3e, 0x48, 0x47, 0x4d, 0x4a, 0x46, 0x49, 0x4d, 0x4a, 0x48,
+ 0x50, 0x4b, 0x40, 0x48, 0x4b, 0x52, 0x46, 0x50, 0x4f, 0x3e, 0x42, 0x44,
+ 0x44, 0x42, 0x43, 0x49, 0x4f, 0x4f, 0x46, 0x42, 0x4a, 0x54, 0x42, 0x48,
+ 0x50, 0x4f, 0x4f, 0x4c, 0x4c, 0x47, 0x52, 0x49, 0x4c, 0x45, 0x4a, 0x4d,
+ 0x4a, 0x41, 0x47, 0x4a, 0x4d, 0x4a, 0x4c, 0x46, 0x51, 0x44, 0x4b, 0x49,
+ 0x53, 0x5e, 0x45, 0x4a, 0x3b, 0x57, 0x5a, 0x4c, 0x59, 0x43, 0x3e, 0x4a,
+ 0x3e, 0x20, 0x36, 0x5d, 0x47, 0x5b, 0x3f, 0x55, 0x3e, 0x24, 0x41, 0x52,
+ 0x3f, 0x01, 0x49, 0x41, 0x40, 0x45, 0x42, 0x46, 0x49, 0x2a, 0x47, 0x40,
+ 0x44, 0x3f, 0x42, 0x47, 0x4e, 0x42, 0x4b, 0x3d, 0x45, 0x4c, 0x47, 0x3d,
+ 0x4c, 0x44, 0x48, 0x43, 0x43, 0x41, 0x4a, 0x3d, 0x48, 0x4b, 0x46, 0x4e,
+ 0x4c, 0x45, 0x48, 0x4d, 0x54, 0x4d, 0x3e, 0x46, 0x3e, 0x47, 0x44, 0x4e,
+ 0x48, 0x49, 0x53, 0x4b, 0x41, 0x45, 0x4c, 0x57, 0x52, 0x4e, 0x40, 0x48,
+ 0x4d, 0x43, 0x44, 0x5a, 0x4a, 0x4c, 0x48, 0x4d, 0x3f, 0x52, 0x41, 0x50,
+ 0x4a, 0x47, 0x3e, 0x43, 0x4c, 0x42, 0x48, 0x3e, 0x4f, 0x4b, 0x41, 0x43,
+ 0x49, 0x40, 0x43, 0x36, 0x3f, 0x4b, 0x49, 0x49, 0x51, 0x43, 0x48, 0x40,
+ 0x4c, 0x51, 0x4d, 0x4a, 0x49, 0x3f, 0x4b, 0x3d, 0x4f, 0x4b, 0x43, 0x4d,
+ 0x46, 0x40, 0x46, 0x4d, 0x49, 0x48, 0x4d, 0x4c, 0x52, 0x4c, 0x49, 0x4f,
+ 0x53, 0x40, 0x49, 0x53, 0x47, 0x43, 0x4c, 0x45, 0x42, 0x48, 0x42, 0x4e,
+ 0x49, 0x43, 0x42, 0x40, 0x4f, 0x46, 0x50, 0x47, 0x51, 0x4a, 0x52, 0x45,
+ 0x4c, 0x51, 0x48, 0x47, 0x40, 0x41, 0x52, 0x4f, 0x41, 0x5a, 0x53, 0x47,
+ 0x42, 0x5f, 0x55, 0x4f, 0x53, 0x3e, 0x41, 0x49, 0x3d, 0x20, 0x3f, 0x54,
+ 0x42, 0x5b, 0x49, 0x4d, 0x3d, 0x22, 0x3e, 0x48, 0x41, 0x01, 0x4c, 0x3d,
+ 0x43, 0x4a, 0x46, 0x43, 0x4f, 0x2b, 0x49, 0x46, 0x47, 0x4a, 0x51, 0x3d,
+ 0x4b, 0x44, 0x49, 0x41, 0x47, 0x47, 0x45, 0x3a, 0x44, 0x42, 0x40, 0x52,
+ 0x46, 0x51, 0x4a, 0x41, 0x4a, 0x52, 0x44, 0x52, 0x4a, 0x40, 0x46, 0x45,
+ 0x52, 0x4c, 0x4e, 0x42, 0x42, 0x48, 0x40, 0x4f, 0x4b, 0x4f, 0x51, 0x4c,
+ 0x4e, 0x48, 0x4a, 0x5a, 0x46, 0x3d, 0x41, 0x50, 0x52, 0x4c, 0x44, 0x53,
+ 0x4b, 0x4d, 0x4f, 0x49, 0x47, 0x4c, 0x48, 0x45, 0x48, 0x4a, 0x44, 0x4e,
+ 0x4c, 0x40, 0x4d, 0x35, 0x40, 0x49, 0x4a, 0x51, 0x49, 0x4a, 0x46, 0x36,
+ 0x46, 0x47, 0x4a, 0x4c, 0x40, 0x4e, 0x42, 0x38, 0x48, 0x45, 0x42, 0x49,
+ 0x54, 0x4c, 0x3f, 0x49, 0x4c, 0x39, 0x47, 0x45, 0x4e, 0x4a, 0x42, 0x44,
+ 0x4b, 0x53, 0x43, 0x40, 0x46, 0x51, 0x3d, 0x50, 0x4b, 0x43, 0x4a, 0x4c,
+ 0x55, 0x54, 0x4a, 0x43, 0x48, 0x40, 0x44, 0x3f, 0x47, 0x45, 0x3e, 0x41,
+ 0x49, 0x44, 0x4d, 0x49, 0x44, 0x41, 0x4a, 0x50, 0x44, 0x49, 0x4d, 0x47,
+ 0x4a, 0x49, 0x46, 0x49, 0x40, 0x5b, 0x4d, 0x51, 0x47, 0x57, 0x49, 0x4f,
+ 0x56, 0x46, 0x3a, 0x4a, 0x3e, 0x22, 0x36, 0x5c, 0x44, 0x56, 0x46, 0x48,
+ 0x3a, 0x2d, 0x4a, 0x48, 0x44, 0x17, 0x41, 0x42, 0x40, 0x3d, 0x4e, 0x45,
+ 0x40, 0x26, 0x43, 0x52, 0x41, 0x40, 0x44, 0x4a, 0x48, 0x42, 0x4f, 0x47,
+ 0x46, 0x4c, 0x4a, 0x3b, 0x42, 0x3e, 0x3e, 0x49, 0x4e, 0x44, 0x4e, 0x49,
+ 0x47, 0x41, 0x47, 0x44, 0x4c, 0x45, 0x4d, 0x49, 0x49, 0x48, 0x55, 0x3d,
+ 0x4a, 0x45, 0x50, 0x4f, 0x46, 0x4c, 0x46, 0x45, 0x3c, 0x51, 0x4b, 0x5a,
+ 0x46, 0x47, 0x54, 0x41, 0x44, 0x40, 0x4f, 0x53, 0x49, 0x46, 0x46, 0x48,
+ 0x44, 0x40, 0x50, 0x49, 0x49, 0x43, 0x50, 0x41, 0x52, 0x4b, 0x46, 0x3e,
+ 0x44, 0x44, 0x46, 0x4e, 0x47, 0x48, 0x3e, 0x38, 0x4c, 0x4c, 0x48, 0x43,
+ 0x48, 0x3e, 0x50, 0x42, 0x51, 0x50, 0x4a, 0x48, 0x4a, 0x42, 0x44, 0x3d,
+ 0x4a, 0x46, 0x46, 0x3d, 0x4e, 0x47, 0x3d, 0x48, 0x4c, 0x46, 0x50, 0x4d,
+ 0x49, 0x45, 0x4a, 0x4c, 0x4c, 0x47, 0x4a, 0x42, 0x4a, 0x45, 0x50, 0x52,
+ 0x4b, 0x4d, 0x4c, 0x43, 0x42, 0x53, 0x41, 0x45, 0x49, 0x41, 0x4b, 0x4c,
+ 0x52, 0x54, 0x4b, 0x41, 0x48, 0x4c, 0x47, 0x4c, 0x41, 0x49, 0x4a, 0x47,
+ 0x50, 0x59, 0x4e, 0x45, 0x3c, 0x5d, 0x53, 0x4c, 0x5a, 0x3e, 0x3a, 0x51,
+ 0x3a, 0x22, 0x35, 0x59, 0x40, 0x5a, 0x43, 0x46, 0x41, 0x32, 0x44, 0x4b,
+ 0x47, 0x04, 0x4c, 0x3a, 0x4a, 0x49, 0x48, 0x3d, 0x45, 0x2b, 0x50, 0x41,
+ 0x3e, 0x44, 0x4f, 0x43, 0x4a, 0x3f, 0x48, 0x4b, 0x53, 0x49, 0x4b, 0x38,
+ 0x44, 0x40, 0x48, 0x4c, 0x41, 0x3f, 0x47, 0x3e, 0x47, 0x49, 0x45, 0x42,
+ 0x43, 0x3e, 0x46, 0x44, 0x53, 0x4d, 0x48, 0x44, 0x45, 0x42, 0x43, 0x53,
+ 0x55, 0x49, 0x4d, 0x4b, 0x45, 0x44, 0x47, 0x5f, 0x48, 0x44, 0x4a, 0x48,
+ 0x45, 0x4d, 0x4f, 0x5e, 0x4e, 0x46, 0x49, 0x49, 0x4d, 0x49, 0x44, 0x48,
+ 0x4d, 0x41, 0x50, 0x48, 0x3d, 0x3f, 0x4d, 0x38, 0x46, 0x4a, 0x50, 0x4a,
+ 0x45, 0x3e, 0x43, 0x36, 0x42, 0x48, 0x53, 0x54, 0x49, 0x43, 0x4b, 0x3a,
+ 0x45, 0x48, 0x50, 0x45, 0x4a, 0x4c, 0x4a, 0x4d, 0x43, 0x4c, 0x55, 0x4e,
+ 0x4c, 0x42, 0x45, 0x52, 0x52, 0x45, 0x46, 0x40, 0x54, 0x4c, 0x3d, 0x4e,
+ 0x49, 0x4e, 0x44, 0x47, 0x45, 0x48, 0x4b, 0x50, 0x49, 0x4b, 0x44, 0x4b,
+ 0x4f, 0x49, 0x47, 0x47, 0x53, 0x3f, 0x4b, 0x42, 0x45, 0x3e, 0x4d, 0x4d,
+ 0x48, 0x51, 0x45, 0x40, 0x43, 0x43, 0x4e, 0x44, 0x51, 0x55, 0x4a, 0x3e,
+ 0x45, 0x55, 0x58, 0x50, 0x50, 0x38, 0x44, 0x4f, 0x3b, 0x23, 0x3c, 0x55,
+ 0x3c, 0x54, 0x49, 0x42, 0x44, 0x2f, 0x3e, 0x47, 0x42, 0x01, 0x42, 0x37,
+ 0x3f, 0x42, 0x45, 0x45, 0x47, 0x2a, 0x52, 0x4b, 0x45, 0x3c, 0x47, 0x44,
+ 0x44, 0x40, 0x50, 0x53, 0x48, 0x42, 0x4d, 0x36, 0x50, 0x3d, 0x49, 0x44,
+ 0x4f, 0x4c, 0x4a, 0x42, 0x4d, 0x3e, 0x3d, 0x3f, 0x4e, 0x44, 0x4d, 0x4e,
+ 0x54, 0x3d, 0x42, 0x46, 0x49, 0x47, 0x4b, 0x53, 0x45, 0x46, 0x47, 0x4a,
+ 0x45, 0x3d, 0x4a, 0x5f, 0x51, 0x3e, 0x45, 0x45, 0x44, 0x3a, 0x4d, 0x57,
+ 0x45, 0x47, 0x4d, 0x45, 0x4e, 0x4b, 0x51, 0x48, 0x4b, 0x4a, 0x3c, 0x4e,
+ 0x51, 0x41, 0x4d, 0x36, 0x47, 0x4a, 0x46, 0x51, 0x4e, 0x4c, 0x52, 0x41,
+ 0x55, 0x47, 0x41, 0x47, 0x4d, 0x47, 0x4b, 0x3d, 0x4a, 0x4a, 0x46, 0x49,
+ 0x4d, 0x48, 0x46, 0x46, 0x4d, 0x52, 0x52, 0x48, 0x49, 0x3f, 0x4b, 0x4e,
+ 0x4c, 0x49, 0x45, 0x47, 0x41, 0x4b, 0x44, 0x48, 0x52, 0x4b, 0x53, 0x44,
+ 0x46, 0x4e, 0x44, 0x49, 0x52, 0x50, 0x46, 0x4b, 0x44, 0x43, 0x50, 0x49,
+ 0x4a, 0x53, 0x45, 0x49, 0x52, 0x3f, 0x4a, 0x4e, 0x49, 0x4c, 0x4d, 0x4d,
+ 0x40, 0x40, 0x3f, 0x4a, 0x47, 0x56, 0x51, 0x43, 0x40, 0x5a, 0x58, 0x52,
+ 0x4f, 0x3d, 0x3d, 0x45, 0x38, 0x29, 0x33, 0x59, 0x45, 0x54, 0x3c, 0x42,
+ 0x3f, 0x27, 0x3e, 0x49, 0x48, 0x06, 0x4a, 0x3f, 0x41, 0x49, 0x4c, 0x48,
+ 0x46, 0x2b, 0x4a, 0x4f, 0x44, 0x46, 0x4c, 0x46, 0x4a, 0x3b, 0x4d, 0x4a,
+ 0x40, 0x41, 0x45, 0x38, 0x51, 0x39, 0x46, 0x46, 0x41, 0x51, 0x4e, 0x41,
+ 0x49, 0x44, 0x48, 0x4a, 0x4b, 0x46, 0x47, 0x46, 0x4a, 0x4c, 0x47, 0x48,
+ 0x3d, 0x42, 0x50, 0x4f, 0x50, 0x4a, 0x4a, 0x48, 0x4a, 0x45, 0x45, 0x61,
+ 0x4a, 0x4c, 0x49, 0x3d, 0x4b, 0x4a, 0x4a, 0x5a, 0x48, 0x49, 0x50, 0x4f,
+ 0x42, 0x48, 0x3e, 0x44, 0x43, 0x3b, 0x4f, 0x54, 0x4b, 0x4a, 0x47, 0x31,
+ 0x4a, 0x49, 0x47, 0x4e, 0x48, 0x48, 0x46, 0x42, 0x4a, 0x45, 0x4c, 0x49,
+ 0x4b, 0x4e, 0x53, 0x43, 0x4c, 0x49, 0x4f, 0x4b, 0x46, 0x4c, 0x4b, 0x4e,
+ 0x51, 0x4b, 0x49, 0x52, 0x44, 0x55, 0x45, 0x49, 0x4b, 0x4a, 0x50, 0x4c,
+ 0x4d, 0x4a, 0x4b, 0x48, 0x41, 0x46, 0x47, 0x43, 0x4b, 0x3f, 0x54, 0x4a,
+ 0x46, 0x49, 0x51, 0x48, 0x4e, 0x4a, 0x41, 0x52, 0x52, 0x4e, 0x53, 0x47,
+ 0x42, 0x48, 0x43, 0x44, 0x54, 0x51, 0x40, 0x49, 0x4c, 0x48, 0x49, 0x44,
+ 0x4c, 0x56, 0x52, 0x49, 0x3d, 0x59, 0x4f, 0x56, 0x56, 0x42, 0x46, 0x45,
+ 0x3e, 0x28, 0x3f, 0x5b, 0x3f, 0x5a, 0x4c, 0x42, 0x44, 0x22, 0x3f, 0x46,
+ 0x47, 0x0d, 0x3e, 0x41, 0x45, 0x49, 0x4a, 0x3b, 0x45, 0x2d, 0x4d, 0x4a,
+ 0x44, 0x43, 0x49, 0x46, 0x4b, 0x47, 0x49, 0x45, 0x4e, 0x40, 0x4c, 0x3c,
+ 0x42, 0x3e, 0x4b, 0x50, 0x48, 0x49, 0x4c, 0x42, 0x3c, 0x43, 0x50, 0x43,
+ 0x49, 0x4e, 0x4e, 0x43, 0x46, 0x4c, 0x48, 0x4a, 0x43, 0x4c, 0x49, 0x4e,
+ 0x47, 0x44, 0x50, 0x4c, 0x4a, 0x48, 0x47, 0x5f, 0x3f, 0x3e, 0x48, 0x4f,
+ 0x4f, 0x49, 0x4a, 0x5f, 0x4e, 0x40, 0x4e, 0x48, 0x47, 0x44, 0x40, 0x4d,
+ 0x3f, 0x4a, 0x53, 0x45, 0x3e, 0x50, 0x3f, 0x39, 0x50, 0x45, 0x45, 0x4b,
+ 0x43, 0x41, 0x46, 0x41, 0x49, 0x47, 0x4b, 0x41, 0x3c, 0x4b, 0x46, 0x3f,
+ 0x41, 0x4a, 0x4e, 0x4c, 0x49, 0x4c, 0x3f, 0x44, 0x53, 0x4c, 0x45, 0x49,
+ 0x48, 0x4d, 0x48, 0x4a, 0x48, 0x4f, 0x45, 0x4d, 0x48, 0x4c, 0x41, 0x49,
+ 0x42, 0x48, 0x53, 0x46, 0x4a, 0x46, 0x4b, 0x4f, 0x4c, 0x52, 0x4c, 0x51,
+ 0x41, 0x4d, 0x49, 0x41, 0x49, 0x4f, 0x49, 0x42, 0x4a, 0x48, 0x51, 0x4a,
+ 0x44, 0x4d, 0x55, 0x48, 0x47, 0x4d, 0x4d, 0x45, 0x42, 0x60, 0x4a, 0x51,
+ 0x42, 0x54, 0x56, 0x56, 0x50, 0x4a, 0x3f, 0x4a, 0x40, 0x25, 0x3a, 0x59,
+ 0x46, 0x58, 0x52, 0x46, 0x41, 0x28, 0x3d, 0x3e, 0x45, 0x13, 0x47, 0x41,
+ 0x3d, 0x44, 0x48, 0x45, 0x49, 0x26, 0x46, 0x4c, 0x3b, 0x4a, 0x42, 0x47,
+ 0x46, 0x41, 0x44, 0x52, 0x50, 0x4a, 0x4f, 0x40, 0x4b, 0x39, 0x42, 0x45,
+ 0x4a, 0x4d, 0x4f, 0x3f, 0x42, 0x4f, 0x49, 0x45, 0x42, 0x4a, 0x46, 0x47,
+ 0x48, 0x40, 0x4a, 0x46, 0x41, 0x3b, 0x48, 0x55, 0x4b, 0x4e, 0x4e, 0x48,
+ 0x4b, 0x44, 0x46, 0x53, 0x48, 0x45, 0x4b, 0x53, 0x49, 0x43, 0x4a, 0x5c,
+ 0x46, 0x45, 0x45, 0x49, 0x49, 0x49, 0x4c, 0x43, 0x4e, 0x4a, 0x41, 0x4a,
+ 0x42, 0x43, 0x4a, 0x38, 0x44, 0x4a, 0x4b, 0x3f, 0x45, 0x49, 0x45, 0x38,
+ 0x43, 0x40, 0x45, 0x4c, 0x47, 0x42, 0x3f, 0x42, 0x3e, 0x4a, 0x43, 0x50,
+ 0x4a, 0x4e, 0x4f, 0x47, 0x4d, 0x49, 0x49, 0x47, 0x4a, 0x4d, 0x46, 0x4c,
+ 0x4f, 0x3d, 0x52, 0x4a, 0x41, 0x44, 0x4b, 0x50, 0x4c, 0x52, 0x49, 0x50,
+ 0x4b, 0x45, 0x49, 0x4d, 0x48, 0x55, 0x50, 0x47, 0x4e, 0x50, 0x4f, 0x48,
+ 0x46, 0x4d, 0x4d, 0x41, 0x48, 0x51, 0x4b, 0x4c, 0x47, 0x51, 0x42, 0x42,
+ 0x4d, 0x47, 0x43, 0x4c, 0x4c, 0x5a, 0x4e, 0x47, 0x3b, 0x59, 0x51, 0x57,
+ 0x4c, 0x40, 0x46, 0x4c, 0x37, 0x2a, 0x35, 0x58, 0x44, 0x5b, 0x4c, 0x44,
+ 0x3e, 0x2e, 0x3f, 0x43, 0x46, 0x23, 0x49, 0x3e, 0x41, 0x3f, 0x4b, 0x3e,
+ 0x4e, 0x2f, 0x4d, 0x4a, 0x4e, 0x40, 0x4e, 0x41, 0x40, 0x3f, 0x4a, 0x42,
+ 0x4d, 0x4c, 0x44, 0x47, 0x4e, 0x44, 0x40, 0x43, 0x4d, 0x49, 0x4f, 0x3d,
+ 0x49, 0x3f, 0x51, 0x48, 0x42, 0x4a, 0x49, 0x47, 0x49, 0x46, 0x4a, 0x45,
+ 0x45, 0x49, 0x53, 0x4d, 0x4c, 0x4e, 0x44, 0x50, 0x4b, 0x43, 0x4e, 0x5f,
+ 0x3c, 0x40, 0x44, 0x46, 0x48, 0x4b, 0x42, 0x62, 0x4e, 0x50, 0x4c, 0x49,
+ 0x4a, 0x4f, 0x44, 0x53, 0x42, 0x43, 0x49, 0x48, 0x4b, 0x3c, 0x4a, 0x37,
+ 0x4c, 0x41, 0x49, 0x46, 0x46, 0x47, 0x43, 0x40, 0x4d, 0x4d, 0x4a, 0x48,
+ 0x50, 0x4b, 0x50, 0x41, 0x44, 0x3e, 0x51, 0x47, 0x44, 0x4a, 0x44, 0x45,
+ 0x48, 0x4d, 0x52, 0x4e, 0x44, 0x48, 0x4d, 0x43, 0x42, 0x45, 0x48, 0x52,
+ 0x44, 0x42, 0x50, 0x42, 0x4d, 0x45, 0x48, 0x4d, 0x4f, 0x4e, 0x45, 0x49,
+ 0x51, 0x48, 0x4f, 0x53, 0x4d, 0x4c, 0x48, 0x50, 0x4e, 0x4d, 0x50, 0x48,
+ 0x49, 0x42, 0x4c, 0x42, 0x4b, 0x4b, 0x49, 0x48, 0x48, 0x49, 0x4a, 0x54,
+ 0x44, 0x57, 0x4d, 0x4b, 0x3f, 0x56, 0x53, 0x5c, 0x50, 0x4e, 0x46, 0x49,
+ 0x40, 0x24, 0x44, 0x58, 0x49, 0x54, 0x48, 0x49, 0x41, 0x22, 0x44, 0x3f,
+ 0x48, 0x1c, 0x4d, 0x39, 0x3e, 0x4c, 0x3d, 0x4a, 0x48, 0x2d, 0x48, 0x3e,
+ 0x3f, 0x3a, 0x46, 0x4e, 0x44, 0x43, 0x49, 0x51, 0x4d, 0x3c, 0x44, 0x41,
+ 0x4e, 0x44, 0x42, 0x4c, 0x45, 0x48, 0x45, 0x46, 0x42, 0x46, 0x47, 0x42,
+ 0x4f, 0x45, 0x47, 0x44, 0x48, 0x47, 0x4a, 0x42, 0x4d, 0x48, 0x3e, 0x53,
+ 0x47, 0x4b, 0x44, 0x4b, 0x45, 0x4a, 0x50, 0x55, 0x4c, 0x45, 0x48, 0x43,
+ 0x53, 0x3d, 0x4e, 0x5f, 0x42, 0x44, 0x4a, 0x4f, 0x3f, 0x48, 0x4e, 0x4b,
+ 0x43, 0x48, 0x43, 0x41, 0x4a, 0x4b, 0x51, 0x39, 0x52, 0x46, 0x44, 0x49,
+ 0x48, 0x45, 0x4c, 0x40, 0x45, 0x49, 0x51, 0x48, 0x45, 0x42, 0x45, 0x48,
+ 0x40, 0x43, 0x3d, 0x47, 0x53, 0x54, 0x4d, 0x4a, 0x4a, 0x47, 0x48, 0x43,
+ 0x4c, 0x46, 0x43, 0x4f, 0x49, 0x4c, 0x3f, 0x3d, 0x4b, 0x41, 0x40, 0x48,
+ 0x4e, 0x4c, 0x4b, 0x40, 0x4c, 0x43, 0x49, 0x4d, 0x47, 0x4f, 0x47, 0x42,
+ 0x47, 0x4a, 0x4d, 0x4f, 0x46, 0x4d, 0x51, 0x49, 0x48, 0x4d, 0x4e, 0x46,
+ 0x47, 0x41, 0x44, 0x4d, 0x4b, 0x55, 0x4b, 0x4c, 0x41, 0x5e, 0x50, 0x45,
+ 0x40, 0x55, 0x4b, 0x60, 0x55, 0x47, 0x3d, 0x4a, 0x42, 0x22, 0x46, 0x5a,
+ 0x47, 0x53, 0x49, 0x44, 0x44, 0x27, 0x41, 0x4f, 0x3e, 0x22, 0x4a, 0x44,
+ 0x49, 0x3e, 0x4e, 0x4d, 0x3f, 0x3a, 0x4c, 0x44, 0x4a, 0x44, 0x46, 0x51,
+ 0x4f, 0x42, 0x4c, 0x4e, 0x39, 0x4b, 0x42, 0x39, 0x4b, 0x3e, 0x4f, 0x47,
+ 0x4a, 0x4f, 0x3f, 0x4d, 0x43, 0x4c, 0x4a, 0x4b, 0x4b, 0x3d, 0x51, 0x46,
+ 0x49, 0x4c, 0x47, 0x44, 0x43, 0x3d, 0x3c, 0x54, 0x4a, 0x47, 0x4d, 0x50,
+ 0x4a, 0x46, 0x51, 0x62, 0x46, 0x4d, 0x4b, 0x46, 0x49, 0x3c, 0x50, 0x57,
+ 0x47, 0x40, 0x3e, 0x4c, 0x4b, 0x3f, 0x55, 0x46, 0x3d, 0x45, 0x42, 0x4e,
+ 0x50, 0x49, 0x46, 0x3a, 0x4c, 0x47, 0x4a, 0x49, 0x42, 0x42, 0x4a, 0x44,
+ 0x42, 0x40, 0x49, 0x54, 0x46, 0x4b, 0x47, 0x45, 0x51, 0x47, 0x41, 0x42,
+ 0x49, 0x50, 0x4e, 0x48, 0x4b, 0x4b, 0x47, 0x4a, 0x47, 0x49, 0x4b, 0x45,
+ 0x4b, 0x54, 0x48, 0x54, 0x4b, 0x49, 0x51, 0x4a, 0x4a, 0x40, 0x46, 0x42,
+ 0x44, 0x44, 0x4d, 0x4b, 0x47, 0x43, 0x45, 0x41, 0x3e, 0x49, 0x43, 0x51,
+ 0x3e, 0x4b, 0x52, 0x46, 0x48, 0x3f, 0x4e, 0x51, 0x51, 0x49, 0x3f, 0x48,
+ 0x4c, 0x4c, 0x52, 0x47, 0x43, 0x57, 0x44, 0x42, 0x40, 0x52, 0x50, 0x5d,
+ 0x4f, 0x40, 0x42, 0x45, 0x46, 0x26, 0x3c, 0x51, 0x4b, 0x4e, 0x4b, 0x49,
+ 0x46, 0x35, 0x49, 0x53, 0x49, 0x2b, 0x4d, 0x3e, 0x50, 0x44, 0x4f, 0x54,
+ 0x46, 0x34, 0x49, 0x4d, 0x42, 0x45, 0x44, 0x4b, 0x52, 0x44, 0x52, 0x41,
+ 0x4d, 0x4c, 0x52, 0x41, 0x49, 0x3a, 0x4e, 0x49, 0x40, 0x4b, 0x45, 0x4d,
+ 0x4b, 0x4a, 0x47, 0x49, 0x45, 0x49, 0x4d, 0x50, 0x3e, 0x47, 0x44, 0x51,
+ 0x4c, 0x41, 0x45, 0x50, 0x47, 0x41, 0x4a, 0x52, 0x4b, 0x3d, 0x4b, 0x5b,
+ 0x4c, 0x4c, 0x4d, 0x3f, 0x47, 0x44, 0x49, 0x5d, 0x4a, 0x53, 0x44, 0x45,
+ 0x45, 0x46, 0x3d, 0x4f, 0x50, 0x3b, 0x44, 0x4e, 0x40, 0x41, 0x4c, 0x3a,
+ 0x4a, 0x45, 0x49, 0x48, 0x45, 0x4a, 0x45, 0x36, 0x45, 0x4d, 0x4c, 0x49,
+ 0x3f, 0x47, 0x4d, 0x40, 0x53, 0x48, 0x49, 0x4c, 0x47, 0x4f, 0x42, 0x44,
+ 0x45, 0x40, 0x4a, 0x4c, 0x49, 0x4f, 0x4b, 0x4d, 0x42, 0x45, 0x3e, 0x4a,
+ 0x48, 0x4a, 0x49, 0x50, 0x4c, 0x53, 0x50, 0x45, 0x4b, 0x4c, 0x46, 0x4f,
+ 0x44, 0x43, 0x54, 0x50, 0x3f, 0x48, 0x42, 0x4b, 0x43, 0x3f, 0x4d, 0x4c,
+ 0x43, 0x49, 0x4a, 0x47, 0x54, 0x4b, 0x4f, 0x4d, 0x44, 0x47, 0x49, 0x4e,
+ 0x4e, 0x55, 0x40, 0x46, 0x44, 0x56, 0x4e, 0x65, 0x4f, 0x3f, 0x43, 0x48,
+ 0x39, 0x27, 0x43, 0x55, 0x4b, 0x4c, 0x44, 0x46, 0x42, 0x34, 0x44, 0x52,
+ 0x43, 0x22, 0x4e, 0x41, 0x49, 0x48, 0x49, 0x51, 0x3b, 0x37, 0x4b, 0x40,
+ 0x4f, 0x45, 0x53, 0x4c, 0x47, 0x46, 0x47, 0x4c, 0x3e, 0x44, 0x45, 0x49,
+ 0x48, 0x50, 0x45, 0x40, 0x46, 0x4c, 0x47, 0x4d, 0x44, 0x48, 0x49, 0x50,
+ 0x4f, 0x4a, 0x46, 0x55, 0x4e, 0x42, 0x4c, 0x4c, 0x50, 0x48, 0x3d, 0x55,
+ 0x46, 0x3e, 0x4a, 0x4b, 0x4f, 0x46, 0x46, 0x60, 0x50, 0x3f, 0x55, 0x40,
+ 0x42, 0x44, 0x48, 0x63, 0x50, 0x3d, 0x45, 0x4f, 0x4e, 0x41, 0x47, 0x48,
+ 0x4a, 0x3c, 0x3d, 0x46, 0x3f, 0x42, 0x43, 0x37, 0x4f, 0x4f, 0x50, 0x47,
+ 0x47, 0x4b, 0x52, 0x40, 0x3f, 0x44, 0x4a, 0x40, 0x4d, 0x44, 0x4e, 0x37,
+ 0x43, 0x48, 0x47, 0x3f, 0x51, 0x4d, 0x45, 0x42, 0x41, 0x46, 0x3d, 0x53,
+ 0x4f, 0x4b, 0x54, 0x45, 0x51, 0x40, 0x4a, 0x4a, 0x48, 0x4f, 0x43, 0x4a,
+ 0x4f, 0x4c, 0x4c, 0x4f, 0x48, 0x4c, 0x44, 0x4e, 0x43, 0x46, 0x4f, 0x4a,
+ 0x43, 0x41, 0x49, 0x49, 0x47, 0x53, 0x45, 0x49, 0x4e, 0x46, 0x4c, 0x4e,
+ 0x3c, 0x49, 0x44, 0x45, 0x4c, 0x42, 0x49, 0x41, 0x48, 0x58, 0x54, 0x4d,
+ 0x35, 0x52, 0x4e, 0x5b, 0x4f, 0x40, 0x3e, 0x46, 0x46, 0x36, 0x3d, 0x60,
+ 0x4d, 0x49, 0x4a, 0x43, 0x44, 0x36, 0x49, 0x67, 0x4a, 0x2d, 0x4b, 0x40,
+ 0x3f, 0x49, 0x43, 0x5f, 0x45, 0x3c, 0x49, 0x4c, 0x4a, 0x43, 0x48, 0x55,
+ 0x49, 0x46, 0x49, 0x46, 0x44, 0x4e, 0x42, 0x4e, 0x40, 0x45, 0x42, 0x52,
+ 0x4a, 0x40, 0x4a, 0x44, 0x40, 0x45, 0x54, 0x3d, 0x4c, 0x3e, 0x4c, 0x55,
+ 0x4d, 0x45, 0x4d, 0x51, 0x4a, 0x4b, 0x44, 0x5b, 0x48, 0x3d, 0x3e, 0x46,
+ 0x4f, 0x4d, 0x3f, 0x62, 0x4d, 0x45, 0x3f, 0x47, 0x47, 0x47, 0x44, 0x5b,
+ 0x4b, 0x4f, 0x51, 0x4c, 0x4a, 0x47, 0x48, 0x5b, 0x47, 0x40, 0x4a, 0x47,
+ 0x42, 0x44, 0x46, 0x46, 0x45, 0x48, 0x4a, 0x3f, 0x40, 0x4f, 0x48, 0x3a,
+ 0x49, 0x52, 0x4a, 0x53, 0x43, 0x4c, 0x4b, 0x4a, 0x4a, 0x4a, 0x4e, 0x42,
+ 0x4b, 0x46, 0x3d, 0x50, 0x51, 0x4b, 0x4b, 0x4f, 0x50, 0x4c, 0x4f, 0x4c,
+ 0x4d, 0x41, 0x41, 0x3c, 0x40, 0x43, 0x54, 0x51, 0x48, 0x3d, 0x48, 0x51,
+ 0x42, 0x42, 0x4c, 0x4e, 0x4d, 0x4b, 0x49, 0x43, 0x48, 0x47, 0x4b, 0x49,
+ 0x49, 0x4e, 0x4d, 0x46, 0x4c, 0x52, 0x49, 0x49, 0x51, 0x4e, 0x45, 0x47,
+ 0x44, 0x47, 0x42, 0x4a, 0x46, 0x59, 0x48, 0x48, 0x4b, 0x4f, 0x4c, 0x5e,
+ 0x5c, 0x45, 0x3f, 0x48, 0x3d, 0x3f, 0x37, 0x5a, 0x4b, 0x4b, 0x45, 0x49,
+ 0x3e, 0x42, 0x41, 0x6b, 0x49, 0x2d, 0x45, 0x43, 0x47, 0x45, 0x49, 0x61,
+ 0x3d, 0x3b, 0x49, 0x43, 0x49, 0x4b, 0x4b, 0x55, 0x4b, 0x47, 0x46, 0x46,
+ 0x48, 0x4d, 0x49, 0x4f, 0x4a, 0x4c, 0x42, 0x51, 0x41, 0x44, 0x45, 0x4f,
+ 0x4e, 0x44, 0x3f, 0x55, 0x3e, 0x4a, 0x45, 0x50, 0x46, 0x42, 0x41, 0x49,
+ 0x49, 0x47, 0x49, 0x61, 0x47, 0x40, 0x41, 0x4e, 0x4d, 0x4b, 0x4a, 0x5e,
+ 0x52, 0x49, 0x4b, 0x52, 0x51, 0x55, 0x42, 0x61, 0x53, 0x4c, 0x48, 0x4a,
+ 0x4e, 0x48, 0x48, 0x57, 0x4c, 0x40, 0x40, 0x48, 0x45, 0x43, 0x3e, 0x46,
+ 0x43, 0x4a, 0x45, 0x45, 0x44, 0x4f, 0x44, 0x40, 0x49, 0x48, 0x4e, 0x49,
+ 0x4a, 0x4e, 0x49, 0x51, 0x46, 0x4f, 0x47, 0x44, 0x42, 0x4d, 0x43, 0x4e,
+ 0x4f, 0x4d, 0x44, 0x51, 0x47, 0x49, 0x40, 0x57, 0x4b, 0x49, 0x47, 0x4c,
+ 0x4d, 0x4d, 0x3e, 0x47, 0x45, 0x41, 0x50, 0x4b, 0x4b, 0x45, 0x42, 0x4e,
+ 0x48, 0x47, 0x4e, 0x4b, 0x56, 0x4c, 0x4f, 0x52, 0x51, 0x49, 0x4d, 0x4a,
+ 0x4b, 0x52, 0x4d, 0x55, 0x4b, 0x4e, 0x4e, 0x4b, 0x51, 0x57, 0x47, 0x42,
+ 0x49, 0x48, 0x56, 0x44, 0x52, 0x56, 0x53, 0x5a, 0x63, 0x53, 0x4c, 0x4c,
+ 0x43, 0x56, 0x3c, 0x57, 0x47, 0x47, 0x4d, 0x52, 0x43, 0x48, 0x45, 0x5f,
+ 0x45, 0x29, 0x47, 0x45, 0x48, 0x40, 0x41, 0x4b, 0x3f, 0x39, 0x49, 0x4e,
+ 0x47, 0x55, 0x42, 0x56, 0x4d, 0x43, 0x48, 0x44, 0x45, 0x53, 0x43, 0x46,
+ 0x49, 0x43, 0x49, 0x4a, 0x40, 0x4e, 0x4a, 0x4a, 0x47, 0x43, 0x45, 0x4d,
+ 0x4a, 0x47, 0x3f, 0x53, 0x45, 0x43, 0x4b, 0x4c, 0x42, 0x47, 0x47, 0x5f,
+ 0x48, 0x48, 0x46, 0x44, 0x50, 0x47, 0x41, 0x64, 0x4e, 0x46, 0x49, 0x4a,
+ 0x4d, 0x55, 0x42, 0x55, 0x46, 0x3d, 0x49, 0x43, 0x52, 0x52, 0x47, 0x52,
+ 0x4e, 0x46, 0x47, 0x41, 0x49, 0x4d, 0x50, 0x47, 0x42, 0x49, 0x41, 0x42,
+ 0x4b, 0x48, 0x49, 0x42, 0x4d, 0x48, 0x51, 0x54, 0x43, 0x56, 0x4c, 0x52,
+ 0x53, 0x4d, 0x54, 0x4a, 0x51, 0x50, 0x48, 0x4c, 0x4e, 0x48, 0x4c, 0x4c,
+ 0x52, 0x49, 0x4a, 0x4e, 0x4e, 0x41, 0x4f, 0x53, 0x49, 0x52, 0x42, 0x4b,
+ 0x50, 0x46, 0x50, 0x4a, 0x53, 0x56, 0x46, 0x4f, 0x4b, 0x49, 0x3d, 0x41,
+ 0x4c, 0x52, 0x42, 0x50, 0x4d, 0x45, 0x4e, 0x51, 0x4b, 0x4c, 0x46, 0x42,
+ 0x41, 0x4b, 0x40, 0x4a, 0x42, 0x57, 0x4f, 0x43, 0x40, 0x50, 0x4c, 0x51,
+ 0x4f, 0x48, 0x3a, 0x4e, 0x51, 0x40, 0x49, 0x66, 0x4b, 0x42, 0x48, 0x3c,
+ 0x5b, 0x47, 0x53, 0x40, 0x4a, 0x48, 0x35, 0x44, 0x5f, 0x50, 0x4a, 0x3c,
+ 0x41, 0x45, 0x48, 0x3b, 0x42, 0x59, 0x43, 0x4b, 0x48, 0x49, 0x4a, 0x40,
+ 0x4f, 0x5c, 0x50, 0x54, 0x53, 0x55, 0x4c, 0x4a, 0x43, 0x46, 0x49, 0x47,
+ 0x49, 0x48, 0x4b, 0x43, 0x42, 0x44, 0x42, 0x46, 0x44, 0x3f, 0x4b, 0x42,
+ 0x4d, 0x49, 0x41, 0x46, 0x47, 0x51, 0x51, 0x44, 0x4c, 0x54, 0x4e, 0x4b,
+ 0x42, 0x52, 0x4e, 0x4c, 0x4b, 0x4a, 0x50, 0x4e, 0x44, 0x4b, 0x4e, 0x4e,
+ 0x4f, 0x42, 0x4b, 0x48, 0x46, 0x43, 0x48, 0x54, 0x4b, 0x4e, 0x48, 0x4f,
+ 0x4a, 0x4d, 0x43, 0x4e, 0x47, 0x50, 0x4a, 0x44, 0x47, 0x52, 0x46, 0x53,
+ 0x4a, 0x40, 0x46, 0x54, 0x50, 0x4a, 0x47, 0x51, 0x49, 0x45, 0x4b, 0x4e,
+ 0x4b, 0x46, 0x4c, 0x4c, 0x52, 0x47, 0x45, 0x45, 0x4a, 0x47, 0x4c, 0x52,
+ 0x44, 0x51, 0x47, 0x42, 0x47, 0x43, 0x43, 0x49, 0x52, 0x5a, 0x55, 0x3e,
+ 0x45, 0x4b, 0x4c, 0x46, 0x4f, 0x4b, 0x45, 0x49, 0x4a, 0x4e, 0x4a, 0x50,
+ 0x3e, 0x4e, 0x42, 0x4e, 0x44, 0x55, 0x3d, 0x4a, 0x4d, 0x49, 0x4d, 0x42,
+ 0x49, 0x4e, 0x50, 0x44, 0x4b, 0x3c, 0x41, 0x49, 0x51, 0x49, 0x3c, 0x4e,
+ 0x4c, 0x39, 0x4c, 0x72, 0x44, 0x4b, 0x49, 0x42, 0x5f, 0x48, 0x4a, 0x48,
+ 0x41, 0x4c, 0x43, 0x40, 0x62, 0x5e, 0x47, 0x3c, 0x4a, 0x4c, 0x55, 0x49,
+ 0x4b, 0x52, 0x4e, 0x4b, 0x4d, 0x48, 0x4c, 0x3c, 0x3f, 0x4f, 0x4e, 0x48,
+ 0x45, 0x55, 0x4a, 0x46, 0x48, 0x3d, 0x45, 0x44, 0x4b, 0x4a, 0x46, 0x3a,
+ 0x4e, 0x44, 0x4d, 0x49, 0x49, 0x49, 0x40, 0x3e, 0x40, 0x47, 0x48, 0x43,
+ 0x3f, 0x51, 0x46, 0x4c, 0x45, 0x4c, 0x49, 0x44, 0x3e, 0x57, 0x49, 0x4e,
+ 0x48, 0x3f, 0x48, 0x47, 0x53, 0x4d, 0x50, 0x51, 0x49, 0x42, 0x45, 0x44,
+ 0x49, 0x49, 0x46, 0x4b, 0x45, 0x49, 0x4f, 0x49, 0x46, 0x48, 0x4c, 0x55,
+ 0x46, 0x51, 0x48, 0x4a, 0x48, 0x54, 0x4b, 0x5a, 0x4c, 0x47, 0x40, 0x47,
+ 0x40, 0x55, 0x50, 0x52, 0x4a, 0x4b, 0x4f, 0x49, 0x4b, 0x50, 0x4b, 0x5b,
+ 0x51, 0x53, 0x4f, 0x4e, 0x49, 0x48, 0x44, 0x52, 0x46, 0x4e, 0x47, 0x48,
+ 0x44, 0x43, 0x49, 0x55, 0x48, 0x58, 0x4f, 0x46, 0x45, 0x53, 0x45, 0x4a,
+ 0x4c, 0x4c, 0x49, 0x46, 0x47, 0x4d, 0x41, 0x4d, 0x4f, 0x59, 0x4a, 0x49,
+ 0x46, 0x4e, 0x44, 0x49, 0x4d, 0x48, 0x54, 0x47, 0x48, 0x4e, 0x48, 0x43,
+ 0x46, 0x41, 0x46, 0x44, 0x52, 0x46, 0x42, 0x4c, 0x4c, 0x31, 0x4d, 0x6f,
+ 0x51, 0x4f, 0x4d, 0x43, 0x5c, 0x48, 0x49, 0x49, 0x46, 0x4c, 0x43, 0x3b,
+ 0x5d, 0x63, 0x58, 0x46, 0x49, 0x45, 0x4e, 0x48, 0x49, 0x5d, 0x45, 0x50,
+ 0x56, 0x4d, 0x57, 0x37, 0x40, 0x55, 0x43, 0x4b, 0x4e, 0x46, 0x4c, 0x3b,
+ 0x3d, 0x4b, 0x49, 0x4b, 0x52, 0x47, 0x4d, 0x34, 0x4c, 0x4c, 0x47, 0x4e,
+ 0x4d, 0x4c, 0x3d, 0x3f, 0x4a, 0x49, 0x44, 0x45, 0x4a, 0x54, 0x43, 0x44,
+ 0x50, 0x4b, 0x4d, 0x4c, 0x4e, 0x48, 0x46, 0x51, 0x43, 0x48, 0x48, 0x48,
+ 0x42, 0x44, 0x4e, 0x48, 0x47, 0x45, 0x48, 0x51, 0x53, 0x4a, 0x4f, 0x58,
+ 0x42, 0x4d, 0x48, 0x4f, 0x4c, 0x45, 0x4a, 0x57, 0x4b, 0x43, 0x4d, 0x4b,
+ 0x4a, 0x4e, 0x4c, 0x5f, 0x3f, 0x4f, 0x4a, 0x42, 0x4b, 0x48, 0x4d, 0x62,
+ 0x4f, 0x4b, 0x50, 0x4c, 0x45, 0x49, 0x44, 0x53, 0x4a, 0x4f, 0x45, 0x56,
+ 0x4b, 0x44, 0x41, 0x53, 0x49, 0x48, 0x4d, 0x49, 0x47, 0x4b, 0x46, 0x4c,
+ 0x49, 0x4b, 0x4c, 0x54, 0x4f, 0x4b, 0x47, 0x49, 0x44, 0x4a, 0x4e, 0x53,
+ 0x4f, 0x49, 0x54, 0x4e, 0x4a, 0x48, 0x42, 0x54, 0x51, 0x46, 0x4b, 0x52,
+ 0x45, 0x48, 0x51, 0x4a, 0x40, 0x4a, 0x50, 0x45, 0x4a, 0x46, 0x49, 0x46,
+ 0x54, 0x46, 0x42, 0x48, 0x50, 0x36, 0x4a, 0x6b, 0x46, 0x59, 0x51, 0x47,
+ 0x5f, 0x4d, 0x43, 0x4d, 0x44, 0x4d, 0x42, 0x3b, 0x65, 0x6a, 0x56, 0x48,
+ 0x4d, 0x4c, 0x52, 0x4a, 0x4d, 0x61, 0x52, 0x4b, 0x47, 0x4f, 0x48, 0x49,
+ 0x3f, 0x5b, 0x45, 0x51, 0x48, 0x48, 0x4b, 0x3c, 0x3b, 0x4c, 0x54, 0x52,
+ 0x4f, 0x51, 0x53, 0x31, 0x47, 0x4c, 0x45, 0x4a, 0x42, 0x4b, 0x47, 0x40,
+ 0x41, 0x49, 0x4c, 0x46, 0x4b, 0x53, 0x46, 0x49, 0x44, 0x4b, 0x4e, 0x4b,
+ 0x48, 0x51, 0x49, 0x4d, 0x4b, 0x3f, 0x42, 0x44, 0x45, 0x43, 0x46, 0x56,
+ 0x42, 0x4b, 0x49, 0x4e, 0x4e, 0x53, 0x42, 0x5c, 0x4b, 0x46, 0x49, 0x46,
+ 0x4e, 0x41, 0x42, 0x67, 0x41, 0x49, 0x4d, 0x48, 0x49, 0x4e, 0x3f, 0x61,
+ 0x48, 0x4a, 0x40, 0x42, 0x4c, 0x51, 0x50, 0x63, 0x49, 0x44, 0x49, 0x47,
+ 0x45, 0x4d, 0x49, 0x61, 0x3f, 0x48, 0x40, 0x41, 0x49, 0x49, 0x45, 0x57,
+ 0x45, 0x46, 0x4d, 0x46, 0x4c, 0x4a, 0x4d, 0x4b, 0x43, 0x54, 0x4b, 0x49,
+ 0x4c, 0x49, 0x41, 0x49, 0x4b, 0x47, 0x45, 0x4b, 0x44, 0x43, 0x46, 0x3f,
+ 0x47, 0x47, 0x43, 0x4c, 0x49, 0x4c, 0x3d, 0x4d, 0x4b, 0x54, 0x4a, 0x4f,
+ 0x44, 0x4c, 0x4b, 0x47, 0x4c, 0x45, 0x3d, 0x52, 0x58, 0x4b, 0x45, 0x4e,
+ 0x48, 0x39, 0x53, 0x70, 0x4a, 0x5d, 0x4c, 0x4e, 0x5a, 0x4f, 0x46, 0x4b,
+ 0x3e, 0x4f, 0x44, 0x3d, 0x66, 0x6b, 0x50, 0x4d, 0x4d, 0x57, 0x52, 0x4a,
+ 0x4c, 0x5b, 0x4e, 0x53, 0x4d, 0x54, 0x50, 0x42, 0x3c, 0x5d, 0x4a, 0x4c,
+ 0x56, 0x52, 0x50, 0x40, 0x48, 0x4c, 0x4d, 0x49, 0x49, 0x4f, 0x51, 0x38,
+ 0x42, 0x49, 0x4d, 0x4f, 0x45, 0x40, 0x4d, 0x41, 0x4b, 0x4a, 0x47, 0x51,
+ 0x4b, 0x53, 0x4c, 0x4a, 0x51, 0x4c, 0x42, 0x56, 0x48, 0x4a, 0x47, 0x58,
+ 0x49, 0x46, 0x52, 0x4a, 0x45, 0x47, 0x51, 0x54, 0x4f, 0x50, 0x50, 0x53,
+ 0x49, 0x4a, 0x4d, 0x56, 0x56, 0x4b, 0x4d, 0x45, 0x40, 0x4d, 0x48, 0x60,
+ 0x4e, 0x56, 0x48, 0x4b, 0x47, 0x45, 0x47, 0x62, 0x4e, 0x4f, 0x41, 0x49,
+ 0x48, 0x57, 0x44, 0x64, 0x4f, 0x4f, 0x49, 0x44, 0x49, 0x4c, 0x3f, 0x53,
+ 0x40, 0x41, 0x4e, 0x4b, 0x4d, 0x54, 0x42, 0x53, 0x4e, 0x41, 0x49, 0x44,
+ 0x41, 0x45, 0x4d, 0x4f, 0x47, 0x51, 0x45, 0x4a, 0x42, 0x45, 0x4e, 0x40,
+ 0x4b, 0x52, 0x48, 0x47, 0x4e, 0x4f, 0x47, 0x41, 0x48, 0x53, 0x47, 0x47,
+ 0x46, 0x42, 0x48, 0x4b, 0x42, 0x4c, 0x49, 0x4c, 0x45, 0x4c, 0x54, 0x45,
+ 0x4c, 0x43, 0x4e, 0x49, 0x56, 0x47, 0x45, 0x4f, 0x4d, 0x3a, 0x58, 0x74,
+ 0x49, 0x5b, 0x4c, 0x4f, 0x64, 0x4e, 0x45, 0x43, 0x44, 0x5b, 0x43, 0x41,
+ 0x63, 0x70, 0x55, 0x45, 0x4a, 0x4a, 0x4d, 0x51, 0x4b, 0x5a, 0x51, 0x57,
+ 0x54, 0x5b, 0x55, 0x44, 0x38, 0x57, 0x4e, 0x50, 0x4e, 0x56, 0x57, 0x3a,
+ 0x3a, 0x4b, 0x57, 0x4c, 0x51, 0x53, 0x4d, 0x3b, 0x44, 0x43, 0x47, 0x4c,
+ 0x48, 0x59, 0x51, 0x41, 0x43, 0x44, 0x51, 0x51, 0x4a, 0x54, 0x51, 0x4b,
+ 0x4e, 0x45, 0x51, 0x4a, 0x49, 0x4a, 0x4f, 0x52, 0x4c, 0x3e, 0x4e, 0x55,
+ 0x42, 0x46, 0x46, 0x4a, 0x42, 0x52, 0x49, 0x47, 0x4a, 0x56, 0x4f, 0x50,
+ 0x46, 0x4f, 0x43, 0x51, 0x53, 0x46, 0x40, 0x60, 0x44, 0x4d, 0x46, 0x54,
+ 0x3d, 0x49, 0x43, 0x64, 0x45, 0x4d, 0x50, 0x49, 0x4f, 0x4d, 0x53, 0x60,
+ 0x4a, 0x52, 0x49, 0x47, 0x48, 0x5a, 0x48, 0x58, 0x4e, 0x4f, 0x43, 0x4f,
+ 0x50, 0x51, 0x41, 0x52, 0x4c, 0x4d, 0x45, 0x42, 0x41, 0x4c, 0x44, 0x54,
+ 0x4e, 0x4d, 0x4a, 0x47, 0x40, 0x4a, 0x3e, 0x47, 0x4c, 0x58, 0x46, 0x46,
+ 0x55, 0x4c, 0x4d, 0x45, 0x49, 0x51, 0x53, 0x46, 0x46, 0x43, 0x43, 0x48,
+ 0x52, 0x3d, 0x4b, 0x4e, 0x49, 0x47, 0x3f, 0x3d, 0x4f, 0x45, 0x44, 0x3f,
+ 0x5a, 0x43, 0x4b, 0x4d, 0x51, 0x35, 0x54, 0x76, 0x4f, 0x5e, 0x4c, 0x50,
+ 0x5a, 0x51, 0x46, 0x49, 0x44, 0x61, 0x4f, 0x41, 0x67, 0x72, 0x56, 0x4f,
+ 0x42, 0x48, 0x4b, 0x52, 0x46, 0x60, 0x50, 0x4e, 0x4a, 0x5b, 0x5f, 0x46,
+ 0x31, 0x5b, 0x4a, 0x48, 0x4b, 0x58, 0x51, 0x41, 0x37, 0x4e, 0x4f, 0x55,
+ 0x51, 0x5c, 0x4f, 0x42, 0x4b, 0x4e, 0x4f, 0x54, 0x4f, 0x52, 0x43, 0x43,
+ 0x48, 0x53, 0x53, 0x41, 0x4b, 0x49, 0x4e, 0x50, 0x46, 0x4c, 0x4f, 0x49,
+ 0x42, 0x49, 0x4c, 0x4c, 0x4c, 0x41, 0x4e, 0x48, 0x47, 0x4c, 0x49, 0x53,
+ 0x44, 0x46, 0x51, 0x53, 0x45, 0x52, 0x4e, 0x53, 0x50, 0x58, 0x42, 0x45,
+ 0x44, 0x42, 0x48, 0x58, 0x4e, 0x4d, 0x54, 0x56, 0x4c, 0x46, 0x4a, 0x58,
+ 0x48, 0x4f, 0x47, 0x51, 0x47, 0x4f, 0x4f, 0x5b, 0x41, 0x4e, 0x45, 0x45,
+ 0x4a, 0x50, 0x3e, 0x57, 0x48, 0x4e, 0x41, 0x4c, 0x45, 0x51, 0x46, 0x4c,
+ 0x46, 0x4f, 0x42, 0x45, 0x4b, 0x4c, 0x49, 0x4c, 0x44, 0x4f, 0x4e, 0x4d,
+ 0x48, 0x56, 0x43, 0x48, 0x42, 0x54, 0x48, 0x43, 0x3e, 0x51, 0x43, 0x47,
+ 0x47, 0x47, 0x49, 0x4d, 0x46, 0x4e, 0x52, 0x42, 0x48, 0x4e, 0x4c, 0x4a,
+ 0x4d, 0x3e, 0x43, 0x40, 0x48, 0x41, 0x47, 0x4f, 0x5e, 0x49, 0x40, 0x4c,
+ 0x50, 0x42, 0x56, 0x75, 0x51, 0x5e, 0x51, 0x4e, 0x62, 0x58, 0x49, 0x47,
+ 0x51, 0x59, 0x46, 0x46, 0x6c, 0x72, 0x55, 0x44, 0x4c, 0x4a, 0x4d, 0x59,
+ 0x53, 0x64, 0x4d, 0x51, 0x55, 0x5e, 0x59, 0x50, 0x30, 0x58, 0x50, 0x4c,
+ 0x4c, 0x60, 0x59, 0x42, 0x32, 0x53, 0x50, 0x55, 0x4d, 0x53, 0x59, 0x43,
+ 0x3e, 0x49, 0x4f, 0x52, 0x4d, 0x51, 0x47, 0x45, 0x4d, 0x4e, 0x53, 0x4e,
+ 0x54, 0x4f, 0x4d, 0x4d, 0x4e, 0x40, 0x47, 0x53, 0x53, 0x49, 0x56, 0x4d,
+ 0x4d, 0x3a, 0x4c, 0x4e, 0x45, 0x4a, 0x47, 0x45, 0x53, 0x4a, 0x4e, 0x52,
+ 0x4d, 0x4e, 0x48, 0x56, 0x4e, 0x4a, 0x4d, 0x52, 0x49, 0x4e, 0x4e, 0x58,
+ 0x47, 0x50, 0x4c, 0x54, 0x49, 0x42, 0x46, 0x54, 0x50, 0x54, 0x54, 0x46,
+ 0x40, 0x49, 0x4b, 0x57, 0x4b, 0x59, 0x44, 0x46, 0x52, 0x55, 0x51, 0x55,
+ 0x4f, 0x50, 0x4d, 0x4d, 0x48, 0x50, 0x4e, 0x49, 0x4e, 0x42, 0x45, 0x3f,
+ 0x4d, 0x4f, 0x51, 0x47, 0x4a, 0x4c, 0x4b, 0x4b, 0x46, 0x4d, 0x44, 0x52,
+ 0x4d, 0x44, 0x40, 0x4d, 0x54, 0x46, 0x54, 0x44, 0x4b, 0x46, 0x47, 0x45,
+ 0x50, 0x45, 0x45, 0x4b, 0x4c, 0x48, 0x3f, 0x55, 0x4a, 0x45, 0x49, 0x4e,
+ 0x40, 0x49, 0x4a, 0x41, 0x56, 0x4b, 0x49, 0x4e, 0x4a, 0x41, 0x50, 0x70,
+ 0x56, 0x59, 0x4b, 0x55, 0x58, 0x59, 0x49, 0x47, 0x4a, 0x5a, 0x4c, 0x46,
+ 0x62, 0x7b, 0x58, 0x51, 0x44, 0x47, 0x44, 0x57, 0x4f, 0x65, 0x4e, 0x50,
+ 0x4d, 0x67, 0x5c, 0x4a, 0x2b, 0x61, 0x48, 0x4b, 0x4b, 0x5d, 0x5c, 0x48,
+ 0x39, 0x50, 0x45, 0x4d, 0x53, 0x60, 0x53, 0x46, 0x42, 0x46, 0x50, 0x45,
+ 0x4f, 0x4e, 0x46, 0x4a, 0x4d, 0x51, 0x54, 0x47, 0x59, 0x4b, 0x58, 0x4a,
+ 0x50, 0x3d, 0x59, 0x48, 0x45, 0x4e, 0x4e, 0x47, 0x4f, 0x47, 0x4d, 0x4b,
+ 0x52, 0x42, 0x4c, 0x48, 0x4a, 0x4f, 0x47, 0x43, 0x4e, 0x4c, 0x4d, 0x51,
+ 0x49, 0x4f, 0x4c, 0x47, 0x47, 0x48, 0x47, 0x59, 0x4f, 0x4f, 0x53, 0x49,
+ 0x4e, 0x4b, 0x4f, 0x5a, 0x50, 0x42, 0x47, 0x50, 0x4a, 0x54, 0x47, 0x5a,
+ 0x43, 0x49, 0x47, 0x4e, 0x49, 0x4d, 0x43, 0x54, 0x4c, 0x53, 0x4e, 0x4e,
+ 0x42, 0x43, 0x48, 0x46, 0x4f, 0x43, 0x43, 0x45, 0x51, 0x47, 0x4b, 0x4f,
+ 0x56, 0x48, 0x48, 0x49, 0x46, 0x45, 0x4d, 0x52, 0x47, 0x4b, 0x46, 0x50,
+ 0x3e, 0x4e, 0x4c, 0x43, 0x45, 0x4d, 0x53, 0x43, 0x46, 0x45, 0x44, 0x52,
+ 0x45, 0x49, 0x49, 0x51, 0x3d, 0x4a, 0x4d, 0x46, 0x42, 0x41, 0x4e, 0x48,
+ 0x5a, 0x49, 0x49, 0x49, 0x4f, 0x3d, 0x56, 0x68, 0x56, 0x67, 0x4b, 0x57,
+ 0x5f, 0x5c, 0x40, 0x4a, 0x4a, 0x54, 0x4c, 0x47, 0x64, 0x7a, 0x54, 0x48,
+ 0x46, 0x45, 0x46, 0x57, 0x4e, 0x61, 0x4f, 0x50, 0x4d, 0x64, 0x5b, 0x43,
+ 0x2d, 0x60, 0x55, 0x51, 0x4c, 0x54, 0x4f, 0x4e, 0x2f, 0x50, 0x4f, 0x52,
+ 0x50, 0x61, 0x54, 0x4b, 0x3d, 0x4c, 0x47, 0x51, 0x4a, 0x54, 0x4b, 0x42,
+ 0x3b, 0x55, 0x47, 0x50, 0x4f, 0x49, 0x4a, 0x46, 0x43, 0x44, 0x45, 0x47,
+ 0x46, 0x4b, 0x4f, 0x46, 0x43, 0x47, 0x4a, 0x4e, 0x51, 0x43, 0x55, 0x47,
+ 0x4d, 0x46, 0x4c, 0x4c, 0x49, 0x4d, 0x43, 0x51, 0x47, 0x51, 0x52, 0x4a,
+ 0x46, 0x4f, 0x49, 0x52, 0x50, 0x4a, 0x43, 0x53, 0x46, 0x4e, 0x50, 0x54,
+ 0x45, 0x3a, 0x4a, 0x4a, 0x4c, 0x50, 0x4b, 0x54, 0x43, 0x4f, 0x4e, 0x45,
+ 0x49, 0x4f, 0x46, 0x53, 0x4d, 0x51, 0x52, 0x53, 0x3d, 0x4a, 0x47, 0x4e,
+ 0x43, 0x4a, 0x53, 0x48, 0x4a, 0x4c, 0x4a, 0x4a, 0x42, 0x53, 0x3e, 0x43,
+ 0x4f, 0x4c, 0x47, 0x48, 0x54, 0x4d, 0x48, 0x48, 0x4e, 0x4c, 0x43, 0x51,
+ 0x42, 0x49, 0x44, 0x3e, 0x49, 0x51, 0x4a, 0x4d, 0x4f, 0x49, 0x45, 0x44,
+ 0x4e, 0x41, 0x48, 0x4b, 0x4c, 0x49, 0x46, 0x47, 0x5d, 0x4c, 0x4d, 0x50,
+ 0x45, 0x40, 0x4e, 0x6a, 0x4f, 0x62, 0x53, 0x50, 0x5c, 0x5e, 0x4a, 0x4c,
+ 0x50, 0x56, 0x52, 0x42, 0x60, 0x7e, 0x5b, 0x4b, 0x43, 0x41, 0x4c, 0x56,
+ 0x46, 0x5f, 0x4d, 0x49, 0x43, 0x65, 0x5c, 0x4d, 0x2c, 0x61, 0x48, 0x4c,
+ 0x44, 0x55, 0x5c, 0x49, 0x37, 0x54, 0x4e, 0x57, 0x52, 0x5c, 0x50, 0x49,
+ 0x3e, 0x4d, 0x4f, 0x4f, 0x51, 0x4c, 0x48, 0x43, 0x4a, 0x5a, 0x4d, 0x4b,
+ 0x4e, 0x58, 0x54, 0x49, 0x51, 0x42, 0x49, 0x4f, 0x46, 0x45, 0x52, 0x3d,
+ 0x4b, 0x4b, 0x43, 0x54, 0x47, 0x47, 0x4c, 0x42, 0x4b, 0x49, 0x45, 0x46,
+ 0x46, 0x4a, 0x51, 0x47, 0x47, 0x4f, 0x48, 0x4a, 0x3f, 0x4c, 0x4b, 0x57,
+ 0x4a, 0x3f, 0x52, 0x4a, 0x56, 0x52, 0x4b, 0x54, 0x4c, 0x3e, 0x3f, 0x4f,
+ 0x4b, 0x50, 0x4c, 0x53, 0x4a, 0x49, 0x46, 0x4e, 0x50, 0x48, 0x4f, 0x4b,
+ 0x4a, 0x4e, 0x3e, 0x49, 0x45, 0x42, 0x42, 0x41, 0x47, 0x4b, 0x4f, 0x42,
+ 0x49, 0x4c, 0x55, 0x4c, 0x4e, 0x42, 0x47, 0x42, 0x4b, 0x48, 0x46, 0x41,
+ 0x46, 0x4e, 0x4d, 0x3f, 0x4f, 0x46, 0x4f, 0x4b, 0x4b, 0x4d, 0x50, 0x3e,
+ 0x42, 0x43, 0x44, 0x4a, 0x49, 0x40, 0x4e, 0x43, 0x3e, 0x52, 0x3e, 0x44,
+ 0x49, 0x43, 0x4d, 0x44, 0x62, 0x51, 0x42, 0x53, 0x51, 0x40, 0x4c, 0x64,
+ 0x4f, 0x63, 0x4e, 0x5c, 0x5b, 0x5c, 0x48, 0x4d, 0x4a, 0x57, 0x4f, 0x42,
+ 0x65, 0xfe, 0x5c, 0x4e, 0x47, 0x43, 0x4a, 0x58, 0x4e, 0x5e, 0x48, 0x4c,
+ 0x51, 0x5e, 0x60, 0x56, 0x2f, 0x62, 0x54, 0x58, 0x51, 0x52, 0x55, 0x51,
+ 0x36, 0x4b, 0x46, 0x51, 0x53, 0x5f, 0x46, 0x4c, 0x37, 0x4d, 0x4a, 0x45,
+ 0x4b, 0x3f, 0x41, 0x42, 0x3f, 0x53, 0x4a, 0x48, 0x49, 0x4a, 0x4a, 0x45,
+ 0x52, 0x3f, 0x52, 0x52, 0x45, 0x4d, 0x4f, 0x45, 0x46, 0x4a, 0x51, 0x48,
+ 0x56, 0x47, 0x50, 0x3e, 0x46, 0x49, 0x4c, 0x51, 0x49, 0x54, 0x45, 0x4f,
+ 0x4b, 0x4b, 0x49, 0x46, 0x4b, 0x4d, 0x49, 0x5c, 0x4d, 0x43, 0x47, 0x49,
+ 0x48, 0x52, 0x46, 0x50, 0x51, 0x37, 0x50, 0x52, 0x4c, 0x4d, 0x4f, 0x51,
+ 0x4f, 0x42, 0x50, 0x47, 0x48, 0x4e, 0x4d, 0x4c, 0x48, 0x48, 0x4a, 0x51,
+ 0x49, 0x42, 0x50, 0x4f, 0x43, 0x4e, 0x47, 0x4b, 0x47, 0x4a, 0x44, 0x44,
+ 0x4c, 0x51, 0x49, 0x44, 0x45, 0x45, 0x45, 0x48, 0x3f, 0x4a, 0x43, 0x49,
+ 0x46, 0x49, 0x4c, 0x4d, 0x45, 0x50, 0x44, 0x45, 0x44, 0x55, 0x4a, 0x45,
+ 0x48, 0x47, 0x4c, 0x43, 0x3f, 0x48, 0x42, 0x43, 0x43, 0x43, 0x48, 0x46,
+ 0x5c, 0x51, 0x47, 0x51, 0x48, 0x40, 0x54, 0x66, 0x4e, 0x67, 0x4d, 0x5a,
+ 0x60, 0x57, 0x47, 0x4d, 0x4d, 0x58, 0x53, 0x46, 0x66, 0x7e, 0x56, 0x48,
+ 0x44, 0x4f, 0x49, 0x5c, 0x4a, 0x63, 0x50, 0x4c, 0x49, 0x56, 0x61, 0x50,
+ 0x2c, 0x68, 0x4d, 0x51, 0x46, 0x4e, 0x5b, 0x51, 0x2e, 0x53, 0x54, 0x50,
+ 0x46, 0x58, 0x44, 0x4f, 0x37, 0x48, 0x55, 0x50, 0x49, 0x49, 0x4e, 0x46,
+ 0x43, 0x56, 0x52, 0x4e, 0x50, 0x4b, 0x50, 0x4c, 0x49, 0x40, 0x4d, 0x4f,
+ 0x50, 0x41, 0x44, 0x39, 0x4b, 0x4d, 0x4b, 0x41, 0x51, 0x4d, 0x4c, 0x41,
+ 0x3f, 0x52, 0x4e, 0x4b, 0x49, 0x53, 0x45, 0x43, 0x4d, 0x4f, 0x44, 0x4d,
+ 0x4b, 0x53, 0x50, 0x4e, 0x45, 0x3f, 0x4e, 0x51, 0x50, 0x55, 0x4f, 0x51,
+ 0x4d, 0x3d, 0x58, 0x3f, 0x46, 0x50, 0x50, 0x50, 0x56, 0x42, 0x49, 0x49,
+ 0x50, 0x4f, 0x42, 0x4b, 0x4c, 0x45, 0x52, 0x41, 0x46, 0x43, 0x4c, 0x4a,
+ 0x4c, 0x51, 0x4d, 0x4d, 0x4a, 0x49, 0x54, 0x49, 0x58, 0x53, 0x49, 0x45,
+ 0x47, 0x4c, 0x4c, 0x44, 0x4e, 0x51, 0x4c, 0x4c, 0x47, 0x48, 0x4c, 0x4e,
+ 0x49, 0x54, 0x4c, 0x51, 0x49, 0x48, 0x47, 0x45, 0x42, 0x49, 0x42, 0x51,
+ 0x4e, 0x3f, 0x49, 0x41, 0x50, 0x3e, 0x4d, 0x50, 0x5c, 0x51, 0x4d, 0x56,
+ 0x47, 0x48, 0x58, 0x65, 0x51, 0x6b, 0x56, 0x5b, 0x56, 0x55, 0x46, 0x49,
+ 0x4b, 0x58, 0x59, 0x4a, 0x68, 0x79, 0x53, 0x46, 0x45, 0x4b, 0x53, 0x5d,
+ 0x4b, 0x6f, 0x4e, 0x4f, 0x4c, 0x53, 0x5b, 0x52, 0x30, 0x63, 0x46, 0x57,
+ 0x46, 0x50, 0x4b, 0x48, 0x2e, 0x4c, 0x46, 0x48, 0x44, 0x51, 0x46, 0x4a,
+ 0x35, 0x55, 0x43, 0x4c, 0x43, 0x4d, 0x4e, 0x3e, 0x47, 0x56, 0x50, 0x4d,
+ 0x44, 0x59, 0x4c, 0x51, 0x46, 0x42, 0x4e, 0x43, 0x4c, 0x44, 0x42, 0x3a,
+ 0x40, 0x48, 0x46, 0x44, 0x45, 0x4a, 0x46, 0x3a, 0x53, 0x4c, 0x4d, 0x4c,
+ 0x4a, 0x4f, 0x53, 0x40, 0x4b, 0x48, 0x54, 0x4b, 0x44, 0x59, 0x41, 0x50,
+ 0x4e, 0x50, 0x55, 0x4d, 0x55, 0x41, 0x4a, 0x4f, 0x47, 0x43, 0x4e, 0x50,
+ 0x52, 0x4c, 0x50, 0x4d, 0x47, 0x42, 0x4f, 0x4b, 0x47, 0x43, 0x41, 0x4a,
+ 0x55, 0x3e, 0x50, 0x4b, 0x41, 0x49, 0x47, 0x49, 0x53, 0x4d, 0x48, 0x4b,
+ 0x43, 0x43, 0x51, 0x44, 0x4d, 0x4c, 0x44, 0x50, 0x4d, 0x42, 0x49, 0x4e,
+ 0x50, 0x50, 0x4c, 0x49, 0x49, 0x51, 0x46, 0x43, 0x4a, 0x4e, 0x53, 0x47,
+ 0x43, 0x46, 0x40, 0x49, 0x47, 0x44, 0x44, 0x4d, 0x4b, 0x4b, 0x51, 0x4b,
+ 0x45, 0x49, 0x47, 0x43, 0x56, 0x49, 0x4c, 0x54, 0x50, 0x3c, 0x4c, 0x5e,
+ 0x51, 0x67, 0x4f, 0x57, 0x57, 0x53, 0x3e, 0x4e, 0x4e, 0x5e, 0x4b, 0x48,
+ 0x5a, 0x78, 0x55, 0x4a, 0x3f, 0x4b, 0x4c, 0x5b, 0x53, 0x64, 0x4d, 0x53,
+ 0x49, 0x57, 0x57, 0x58, 0x37, 0x62, 0x4f, 0x56, 0x44, 0x4e, 0x58, 0x4a,
+ 0x30, 0x4f, 0x40, 0x4e, 0x47, 0x58, 0x52, 0x50, 0x35, 0x4d, 0x49, 0x52,
+ 0x4e, 0x42, 0x46, 0x47, 0x44, 0x57, 0x54, 0x43, 0x4e, 0x56, 0x43, 0x49,
+ 0x44, 0x40, 0x44, 0x41, 0x50, 0x49, 0x4b, 0x44, 0x4d, 0x52, 0x49, 0x43,
+ 0x52, 0x54, 0x49, 0x3f, 0x49, 0x42, 0x49, 0x4a, 0x43, 0x3e, 0x50, 0x40,
+ 0x46, 0x4b, 0x50, 0x4b, 0x53, 0x4b, 0x47, 0x52, 0x51, 0x4b, 0x47, 0x3f,
+ 0x46, 0x4b, 0x4c, 0x57, 0x49, 0x47, 0x54, 0x49, 0x50, 0x50, 0x4d, 0x4a,
+ 0x42, 0x4e, 0x51, 0x4c, 0x47, 0x47, 0x42, 0x43, 0x54, 0x43, 0x46, 0x47,
+ 0x4d, 0x43, 0x54, 0x47, 0x43, 0x58, 0x48, 0x45, 0x4b, 0x46, 0x48, 0x3d,
+ 0x47, 0x3f, 0x44, 0x4f, 0x4e, 0x46, 0x41, 0x40, 0x4d, 0x4d, 0x4d, 0x52,
+ 0x54, 0x47, 0x4f, 0x51, 0x4f, 0x45, 0x45, 0x48, 0x4b, 0x4d, 0x44, 0x52,
+ 0x51, 0x4b, 0x48, 0x4f, 0x49, 0x49, 0x46, 0x50, 0x54, 0x42, 0x44, 0x51,
+ 0x58, 0x4e, 0x43, 0x58, 0x55, 0x40, 0x53, 0x5a, 0x51, 0x61, 0x51, 0x60,
+ 0x53, 0x57, 0x45, 0x4f, 0x45, 0x5e, 0x51, 0x42, 0x61, 0x7a, 0x55, 0x47,
+ 0x41, 0x4b, 0x4a, 0x5b, 0x4c, 0x65, 0x4f, 0x55, 0x46, 0x54, 0x65, 0x59,
+ 0x36, 0x61, 0x54, 0x55, 0x48, 0x57, 0x52, 0x4e, 0x24, 0x4b, 0x49, 0x4d,
+ 0x43, 0x57, 0x44, 0x51, 0x3b, 0x4f, 0x45, 0x40, 0x47, 0x4a, 0x43, 0x47,
+ 0x46, 0x58, 0x50, 0x54, 0x4d, 0x50, 0x44, 0x42, 0x4a, 0x46, 0x4b, 0x4d,
+ 0x4f, 0x4f, 0x4d, 0x40, 0x48, 0x4a, 0x53, 0x48, 0x49, 0x48, 0x4d, 0x39,
+ 0x47, 0x4e, 0x44, 0x4c, 0x4b, 0x49, 0x44, 0x42, 0x4a, 0x45, 0x46, 0x46,
+ 0x53, 0x4d, 0x49, 0x4f, 0x4e, 0x48, 0x50, 0x4a, 0x4c, 0x46, 0x56, 0x4b,
+ 0x4b, 0x57, 0x4c, 0x49, 0x4a, 0x4a, 0x43, 0x4e, 0x56, 0x45, 0x50, 0x4c,
+ 0x47, 0x55, 0x48, 0x46, 0x4e, 0x46, 0x45, 0x3f, 0x4a, 0x4c, 0x4c, 0x47,
+ 0x4a, 0x51, 0x4e, 0x50, 0x40, 0x52, 0x45, 0x45, 0x4b, 0x46, 0x4f, 0x44,
+ 0x51, 0x4a, 0x4e, 0x4d, 0x4c, 0x46, 0x42, 0x47, 0x4a, 0x4e, 0x46, 0x42,
+ 0x4b, 0x4f, 0x4b, 0x4e, 0x4e, 0x46, 0x42, 0x50, 0x53, 0x51, 0x4f, 0x54,
+ 0x45, 0x4f, 0x45, 0x42, 0x4c, 0x45, 0x40, 0x48, 0x59, 0x49, 0x49, 0x53,
+ 0x4c, 0x43, 0x4b, 0x57, 0x54, 0x64, 0x4e, 0x5f, 0x5c, 0x59, 0x4b, 0x56,
+ 0x49, 0x5d, 0x4f, 0x4b, 0x62, 0x73, 0x54, 0x45, 0x49, 0x50, 0x48, 0x5a,
+ 0x50, 0x6d, 0x4a, 0x4e, 0x48, 0x55, 0x5d, 0x57, 0x38, 0x68, 0x52, 0x5a,
+ 0x46, 0x56, 0x4c, 0x5a, 0x2e, 0x55, 0x49, 0x4f, 0x4a, 0x57, 0x4f, 0x54,
+ 0x41, 0x53, 0x46, 0x43, 0x45, 0x47, 0x53, 0x4a, 0x42, 0x4f, 0x4d, 0x48,
+ 0x4c, 0x49, 0x47, 0x48, 0x45, 0x49, 0x48, 0x53, 0x48, 0x52, 0x4a, 0x44,
+ 0x4c, 0x49, 0x52, 0x4b, 0x47, 0x51, 0x42, 0x47, 0x49, 0x51, 0x3f, 0x45,
+ 0x47, 0x4e, 0x53, 0x33, 0x55, 0x51, 0x55, 0x48, 0x4b, 0x51, 0x56, 0x47,
+ 0x43, 0x55, 0x47, 0x42, 0x47, 0x4f, 0x47, 0x51, 0x46, 0x55, 0x4a, 0x4b,
+ 0x50, 0x52, 0x4f, 0x43, 0x4b, 0x53, 0x4d, 0x3f, 0x4e, 0x56, 0x50, 0x49,
+ 0x4d, 0x47, 0x51, 0x49, 0x4a, 0x52, 0x44, 0x43, 0x4d, 0x4e, 0x41, 0x51,
+ 0x4c, 0x4d, 0x47, 0x48, 0x4f, 0x40, 0x50, 0x46, 0x43, 0x4d, 0x4e, 0x50,
+ 0x43, 0x47, 0x4e, 0x46, 0x4f, 0x4b, 0x51, 0x4b, 0x4a, 0x57, 0x42, 0x51,
+ 0x4c, 0x54, 0x52, 0x42, 0x4c, 0x42, 0x47, 0x54, 0x4a, 0x4a, 0x47, 0x4a,
+ 0x3f, 0x46, 0x4e, 0x4c, 0x53, 0x50, 0x47, 0x53, 0x49, 0x44, 0x52, 0x5a,
+ 0x4b, 0x65, 0x50, 0x5b, 0x57, 0x59, 0x4a, 0x48, 0x48, 0x5f, 0x55, 0x48,
+ 0x5c, 0x78, 0x55, 0x48, 0x4a, 0x4b, 0x49, 0x4c, 0x46, 0x6b, 0x54, 0x57,
+ 0x55, 0x4b, 0x59, 0x52, 0x38, 0x5b, 0x57, 0x56, 0x4b, 0x4f, 0x48, 0x4e,
+ 0x34, 0x5a, 0x4e, 0x4f, 0x43, 0x4e, 0x4b, 0x4e, 0x36, 0x4d, 0x52, 0x48,
+ 0x4d, 0x4c, 0x4c, 0x49, 0x51, 0x54, 0x45, 0x54, 0x4a, 0x4e, 0x52, 0x41,
+ 0x4c, 0x45, 0x4a, 0x53, 0x55, 0x4b, 0x50, 0x47, 0x4e, 0x4d, 0x43, 0x51,
+ 0x4e, 0x4a, 0x51, 0x46, 0x4e, 0x4d, 0x48, 0x3f, 0x43, 0x52, 0x56, 0x38,
+ 0x52, 0x46, 0x43, 0x49, 0x40, 0x49, 0x53, 0x41, 0x47, 0x41, 0x41, 0x42,
+ 0x4f, 0x4b, 0x46, 0x4b, 0x4a, 0x57, 0x4a, 0x45, 0x4b, 0x46, 0x47, 0x3c,
+ 0x43, 0x46, 0x4f, 0x50, 0x4c, 0x53, 0x4f, 0x41, 0x4a, 0x4a, 0x40, 0x4a,
+ 0x3e, 0x4e, 0x4d, 0x41, 0x4a, 0x42, 0x49, 0x4c, 0x51, 0x46, 0x4f, 0x43,
+ 0x4b, 0x41, 0x50, 0x48, 0x4a, 0x40, 0x52, 0x45, 0x40, 0x40, 0x46, 0x48,
+ 0x48, 0x52, 0x52, 0x41, 0x43, 0x49, 0x49, 0x4c, 0x44, 0x48, 0x50, 0x4a,
+ 0x47, 0x48, 0x4c, 0x42, 0x49, 0x48, 0x52, 0x56, 0x4b, 0x41, 0x4e, 0x47,
+ 0x52, 0x56, 0x4e, 0x56, 0x4b, 0x38, 0x50, 0x55, 0x5a, 0x63, 0x51, 0x5a,
+ 0x54, 0x52, 0x44, 0x45, 0x47, 0x5e, 0x4c, 0x4a, 0x5e, 0x71, 0x56, 0x44,
+ 0x4c, 0x4b, 0x4c, 0x4e, 0x49, 0x69, 0x50, 0x53, 0x4d, 0x5c, 0x59, 0x50,
+ 0x36, 0x5d, 0x46, 0x5b, 0x51, 0x55, 0x55, 0x51, 0x36, 0x5a, 0x53, 0x56,
+ 0x54, 0x4a, 0x55, 0x53, 0x3c, 0x52, 0x4a, 0x45, 0x4c, 0x56, 0x49, 0x46,
+ 0x4f, 0x5b, 0x43, 0x4b, 0x49, 0x4c, 0x4b, 0x41, 0x44, 0x4b, 0x47, 0x4b,
+ 0x4b, 0x54, 0x4a, 0x4c, 0x49, 0x44, 0x46, 0x46, 0x48, 0x49, 0x47, 0x4a,
+ 0x40, 0x4e, 0x47, 0x53, 0x4a, 0x47, 0x4a, 0x3b, 0x48, 0x4b, 0x50, 0x51,
+ 0x50, 0x44, 0x4d, 0x49, 0x42, 0x4b, 0x43, 0x48, 0x4a, 0x43, 0x4d, 0x4d,
+ 0x49, 0x4d, 0x43, 0x4f, 0x50, 0x49, 0x47, 0x48, 0x48, 0x4f, 0x49, 0x41,
+ 0x4c, 0x46, 0x47, 0x3e, 0x51, 0x4d, 0x4e, 0x42, 0x3d, 0x53, 0x4d, 0x3b,
+ 0x53, 0x52, 0x4c, 0x4c, 0x43, 0x46, 0x43, 0x3d, 0x53, 0x48, 0x43, 0x4e,
+ 0x45, 0x52, 0x4d, 0x4a, 0x44, 0x49, 0x47, 0x4c, 0x4e, 0x4c, 0x4a, 0x4e,
+ 0x41, 0x48, 0x4b, 0x44, 0x4d, 0x4a, 0x4d, 0x44, 0x4a, 0x45, 0x4f, 0x52,
+ 0x45, 0x3f, 0x4b, 0x48, 0x43, 0x41, 0x3d, 0x53, 0x53, 0x50, 0x4a, 0x56,
+ 0x4d, 0x3e, 0x55, 0x4e, 0x56, 0x5e, 0x52, 0x52, 0x54, 0x50, 0x42, 0x4a,
+ 0x4d, 0x5f, 0x4f, 0x49, 0x5d, 0x6f, 0x55, 0x4a, 0x47, 0x49, 0x4e, 0x4a,
+ 0x43, 0x6e, 0x4e, 0x4f, 0x52, 0x59, 0x62, 0x4b, 0x3e, 0x5c, 0x4c, 0x4e,
+ 0x45, 0x52, 0x43, 0x4d, 0x3c, 0x58, 0x52, 0x49, 0x48, 0x55, 0x53, 0x4e,
+ 0x3d, 0x4e, 0x4c, 0x4b, 0x4b, 0x50, 0x4a, 0x47, 0x45, 0x62, 0x50, 0x49,
+ 0x48, 0x4b, 0x55, 0x45, 0x46, 0x51, 0x41, 0x55, 0x54, 0x55, 0x50, 0x47,
+ 0x46, 0x4d, 0x46, 0x4b, 0x41, 0x49, 0x4c, 0x40, 0x45, 0x4f, 0x52, 0x54,
+ 0x45, 0x4d, 0x53, 0x3a, 0x4c, 0x55, 0x4e, 0x48, 0x44, 0x45, 0x56, 0x3c,
+ 0x48, 0x46, 0x4b, 0x51, 0x53, 0x43, 0x41, 0x49, 0x4c, 0x52, 0x48, 0x42,
+ 0x48, 0x3f, 0x4c, 0x38, 0x46, 0x50, 0x4a, 0x44, 0x50, 0x54, 0x4e, 0x38,
+ 0x48, 0x42, 0x43, 0x4a, 0x4c, 0x44, 0x47, 0x42, 0x42, 0x46, 0x4a, 0x50,
+ 0x47, 0x4b, 0x43, 0x40, 0x44, 0x46, 0x46, 0x4d, 0x50, 0x4a, 0x4e, 0x51,
+ 0x44, 0x40, 0x50, 0x43, 0x52, 0x4d, 0x42, 0x4c, 0x50, 0x41, 0x4a, 0x4e,
+ 0x45, 0x49, 0x4d, 0x40, 0x46, 0x51, 0x43, 0x4b, 0x48, 0x47, 0x42, 0x55,
+ 0x4a, 0x41, 0x4f, 0x49, 0x4f, 0x4e, 0x47, 0x4c, 0x4a, 0x48, 0x50, 0x4e,
+ 0x50, 0x57, 0x4e, 0x56, 0x56, 0x4e, 0x44, 0x48, 0x4a, 0x5b, 0x55, 0x49,
+ 0x59, 0x67, 0x54, 0x46, 0x4f, 0x41, 0x4d, 0x4e, 0x4a, 0x63, 0x4d, 0x44,
+ 0x53, 0x5b, 0x59, 0x4f, 0x43, 0x55, 0x56, 0x4e, 0x55, 0x4c, 0x4b, 0x54,
+ 0x3c, 0x56, 0x4d, 0x50, 0x4f, 0x4a, 0x5a, 0x47, 0x48, 0x56, 0x4f, 0x4f,
+ 0x50, 0x51, 0x48, 0x4e, 0x4d, 0x50, 0x4e, 0x45, 0x4b, 0x48, 0x4e, 0x44,
+ 0x46, 0x4d, 0x43, 0x46, 0x41, 0x59, 0x53, 0x4b, 0x4a, 0x3e, 0x51, 0x47,
+ 0x43, 0x48, 0x52, 0x3f, 0x43, 0x50, 0x4b, 0x4f, 0x41, 0x48, 0x43, 0x2e,
+ 0x4d, 0x4e, 0x4c, 0x45, 0x45, 0x46, 0x4b, 0x43, 0x46, 0x49, 0x46, 0x4d,
+ 0x47, 0x4e, 0x4d, 0x3c, 0x47, 0x4a, 0x52, 0x4e, 0x41, 0x50, 0x43, 0x3a,
+ 0x50, 0x47, 0x4a, 0x45, 0x52, 0x4a, 0x4c, 0x3f, 0x42, 0x3d, 0x49, 0x48,
+ 0x48, 0x4c, 0x42, 0x3a, 0x40, 0x47, 0x46, 0x4e, 0x44, 0x52, 0x46, 0x44,
+ 0x4a, 0x44, 0x43, 0x49, 0x42, 0x45, 0x3f, 0x50, 0x4c, 0x44, 0x48, 0x43,
+ 0x47, 0x4a, 0x48, 0x48, 0x3e, 0x45, 0x43, 0x48, 0x4a, 0x48, 0x53, 0x4b,
+ 0x50, 0x49, 0x43, 0x4d, 0x53, 0x4f, 0x4b, 0x4b, 0x40, 0x42, 0x50, 0x4d,
+ 0x53, 0x4e, 0x44, 0x4d, 0x45, 0x3d, 0x51, 0x51, 0x4f, 0x59, 0x4b, 0x51,
+ 0x4a, 0x4e, 0x42, 0x40, 0x49, 0x5b, 0x4b, 0x43, 0x53, 0x60, 0x47, 0x49,
+ 0x4a, 0x44, 0x44, 0x48, 0x4b, 0x60, 0x51, 0x3f, 0x4b, 0x5b, 0x4f, 0x4a,
+ 0x4a, 0x50, 0x49, 0x46, 0x55, 0x50, 0x4b, 0x4c, 0x40, 0x4e, 0x51, 0x4f,
+ 0x4b, 0x51, 0x54, 0x50, 0x48, 0x4e, 0x4a, 0x4f, 0x4d, 0x4e, 0x54, 0x4d,
+ 0x41, 0x50, 0x4e, 0x47, 0x47, 0x47, 0x54, 0x3b, 0x51, 0x54, 0x50, 0x49,
+ 0x48, 0x4c, 0x4e, 0x47, 0x3f, 0x3c, 0x4c, 0x43, 0x45, 0x42, 0x45, 0x37,
+ 0x41, 0x52, 0x49, 0x47, 0x4e, 0x4a, 0x4b, 0x37, 0x48, 0x4d, 0x4e, 0x4a,
+ 0x42, 0x56, 0x3d, 0x35, 0x48, 0x42, 0x4b, 0x4a, 0x44, 0x52, 0x40, 0x48,
+ 0x4f, 0x49, 0x4f, 0x4c, 0x4d, 0x43, 0x49, 0x38, 0x4b, 0x42, 0x48, 0x42,
+ 0x45, 0x45, 0x54, 0x3a, 0x47, 0x47, 0x52, 0x45, 0x4a, 0x48, 0x47, 0x39,
+ 0x4d, 0x45, 0x54, 0x4b, 0x4e, 0x4f, 0x4e, 0x38, 0x4a, 0x4b, 0x48, 0x45,
+ 0x4e, 0x43, 0x4e, 0x4e, 0x46, 0x4e, 0x4e, 0x50, 0x46, 0x4c, 0x42, 0x45,
+ 0x4b, 0x46, 0x47, 0x4d, 0x49, 0x3f, 0x4f, 0x50, 0x46, 0x4a, 0x47, 0x4e,
+ 0x4a, 0x3e, 0x50, 0x46, 0x47, 0x40, 0x4f, 0x47, 0x51, 0x4b, 0x43, 0x46,
+ 0x4a, 0x42, 0x55, 0x4d, 0x46, 0x63, 0x49, 0x4e, 0x4f, 0x4f, 0x42, 0x45,
+ 0x50, 0x57, 0x49, 0x3e, 0x57, 0x63, 0x45, 0x4a, 0x49, 0x50, 0x41, 0x4a,
+ 0x48, 0x64, 0x4f, 0x42, 0x47, 0x58, 0x4b, 0x45, 0x43, 0x57, 0x49, 0x58,
+ 0x51, 0x51, 0x47, 0x43, 0x51, 0x4b, 0x4a, 0x45, 0x50, 0x54, 0x4d, 0x4d,
+ 0x3e, 0x4a, 0x50, 0x40, 0x51, 0x4f, 0x52, 0x48, 0x53, 0x49, 0x44, 0x4b,
+ 0x51, 0x4b, 0x50, 0x42, 0x4d, 0x49, 0x4a, 0x46, 0x44, 0x50, 0x47, 0x3f,
+ 0x48, 0x47, 0x41, 0x4a, 0x42, 0x52, 0x4a, 0x33, 0x50, 0x50, 0x54, 0x3f,
+ 0x44, 0x4e, 0x51, 0x3c, 0x4e, 0x51, 0x48, 0x4b, 0x47, 0x49, 0x3f, 0x3d,
+ 0x4e, 0x46, 0x4a, 0x41, 0x40, 0x50, 0x49, 0x40, 0x4a, 0x4b, 0x45, 0x50,
+ 0x4e, 0x4d, 0x4b, 0x39, 0x4e, 0x4b, 0x48, 0x3c, 0x47, 0x44, 0x4c, 0x42,
+ 0x45, 0x50, 0x3e, 0x54, 0x4d, 0x49, 0x48, 0x3c, 0x45, 0x42, 0x55, 0x4a,
+ 0x41, 0x4f, 0x40, 0x3f, 0x47, 0x46, 0x46, 0x44, 0x4f, 0x47, 0x46, 0x44,
+ 0x41, 0x40, 0x44, 0x48, 0x3e, 0x3c, 0x46, 0x3e, 0x4a, 0x45, 0x4c, 0x52,
+ 0x47, 0x42, 0x47, 0x3f, 0x47, 0x4e, 0x4b, 0x53, 0x4a, 0x3d, 0x4d, 0x47,
+ 0x4f, 0x3d, 0x4e, 0x43, 0x4f, 0x46, 0x43, 0x43, 0x46, 0x41, 0x4f, 0x42,
+ 0x46, 0x57, 0x4d, 0x51, 0x49, 0x51, 0x4c, 0x44, 0x51, 0x4f, 0x46, 0x44,
+ 0x54, 0x5d, 0x4f, 0x40, 0x59, 0x46, 0x53, 0x46, 0x48, 0x54, 0x43, 0x45,
+ 0x4d, 0x51, 0x4f, 0x44, 0x44, 0x53, 0x49, 0x4e, 0x48, 0x46, 0x44, 0x4a,
+ 0x4a, 0x42, 0x4c, 0x46, 0x54, 0x4f, 0x52, 0x47, 0x46, 0x44, 0x4c, 0x4d,
+ 0x4c, 0x47, 0x4d, 0x40, 0x55, 0x58, 0x46, 0x46, 0x3f, 0x3e, 0x47, 0x36,
+ 0x3f, 0x4d, 0x4b, 0x4d, 0x4f, 0x4f, 0x48, 0x34, 0x4d, 0x46, 0x46, 0x50,
+ 0x50, 0x4b, 0x47, 0x45, 0x4e, 0x49, 0x50, 0x4f, 0x4a, 0x48, 0x4f, 0x39,
+ 0x53, 0x4c, 0x4b, 0x56, 0x45, 0x4f, 0x55, 0x3a, 0x40, 0x53, 0x43, 0x4b,
+ 0x47, 0x3d, 0x4c, 0x34, 0x4b, 0x4e, 0x4a, 0x4b, 0x4d, 0x49, 0x4e, 0x40,
+ 0x4d, 0x48, 0x40, 0x4a, 0x4a, 0x4b, 0x4a, 0x42, 0x4c, 0x52, 0x43, 0x42,
+ 0x44, 0x3f, 0x4e, 0x42, 0x44, 0x45, 0x40, 0x3d, 0x4b, 0x45, 0x4a, 0x43,
+ 0x4b, 0x4b, 0x4e, 0x46, 0x55, 0x43, 0x44, 0x3f, 0x44, 0x43, 0x4b, 0x4b,
+ 0x45, 0x51, 0x48, 0x49, 0x3d, 0x44, 0x4a, 0x4a, 0x50, 0x50, 0x47, 0x44,
+ 0x4f, 0x3e, 0x3f, 0x43, 0x4c, 0x46, 0x4a, 0x4e, 0x4c, 0x52, 0x48, 0x4e,
+ 0x48, 0x46, 0x45, 0x48, 0x41, 0x4f, 0x51, 0x48, 0x40, 0x4d, 0x4a, 0x4b,
+ 0x4c, 0x51, 0x49, 0x50, 0x4e, 0x4b, 0x4a, 0x42, 0x49, 0x54, 0x4e, 0x43,
+ 0x52, 0x47, 0x4a, 0x41, 0x42, 0x51, 0x48, 0x4a, 0x46, 0x45, 0x4a, 0x43,
+ 0x4e, 0x4f, 0x41, 0x49, 0x4b, 0x42, 0x40, 0x4a, 0x50, 0x41, 0x42, 0x3f,
+ 0x49, 0x4a, 0x40, 0x3e, 0x3f, 0x42, 0x4d, 0x51, 0x4e, 0x4e, 0x47, 0x41,
+ 0x4e, 0x4e, 0x49, 0x4b, 0x41, 0x45, 0x51, 0x40, 0x45, 0x4c, 0x3f, 0x42,
+ 0x4c, 0x45, 0x4d, 0x39, 0x46, 0x52, 0x4a, 0x4e, 0x4c, 0x49, 0x4e, 0x43,
+ 0x43, 0x4c, 0x48, 0x46, 0x48, 0x49, 0x50, 0x3a, 0x3f, 0x49, 0x42, 0x4f,
+ 0x42, 0x4d, 0x4e, 0x3f, 0x51, 0x4b, 0x4e, 0x4b, 0x51, 0x44, 0x43, 0x4a,
+ 0x4a, 0x4c, 0x50, 0x48, 0x45, 0x47, 0x4d, 0x41, 0x47, 0x45, 0x51, 0x41,
+ 0x42, 0x48, 0x4c, 0x39, 0x51, 0x45, 0x46, 0x53, 0x4b, 0x50, 0x46, 0x45,
+ 0x4b, 0x4d, 0x42, 0x4b, 0x3f, 0x45, 0x4b, 0x4e, 0x50, 0x50, 0x47, 0x4a,
+ 0x45, 0x40, 0x4b, 0x43, 0x3f, 0x4a, 0x41, 0x42, 0x51, 0x41, 0x4d, 0x42,
+ 0x53, 0x48, 0x48, 0x49, 0x4b, 0x40, 0x42, 0x3d, 0x4f, 0x53, 0x49, 0x46,
+ 0x46, 0x43, 0x42, 0x44, 0x46, 0x48, 0x3f, 0x46, 0x31, 0x43, 0x4d, 0x4b,
+ 0x48, 0x4d, 0x4c, 0x43, 0x45, 0x53, 0x50, 0x40, 0x4a, 0x48, 0x45, 0x3b,
+ 0x4f, 0x4d, 0x53, 0x4c, 0x44, 0x54, 0x50, 0x66, 0x3f, 0x45, 0x4c, 0x4c,
+ 0x4a, 0x49, 0x49, 0x4a, 0x40, 0x52, 0x3e, 0x4c, 0x49, 0x40, 0x44, 0x49,
+ 0x48, 0x3f, 0x45, 0x5b, 0x49, 0x4b, 0x4c, 0x44, 0x50, 0x4e, 0x4a, 0x4a,
+ 0x49, 0x4e, 0x4f, 0x47, 0x46, 0x4b, 0x44, 0x3b, 0x4e, 0x4b, 0x48, 0x46,
+ 0x45, 0x45, 0x3d, 0x35, 0x4c, 0x49, 0x54, 0x42, 0x51, 0x46, 0x49, 0x2d,
+ 0x43, 0x4a, 0x53, 0x49, 0x49, 0x42, 0x4f, 0x40, 0x4e, 0x50, 0x54, 0x51,
+ 0x4b, 0x45, 0x48, 0x35, 0x4d, 0x41, 0x51, 0x40, 0x41, 0x49, 0x4a, 0x3b,
+ 0x45, 0x50, 0x48, 0x51, 0x51, 0x4d, 0x4c, 0x36, 0x47, 0x4a, 0x44, 0x45,
+ 0x4d, 0x47, 0x43, 0x3a, 0x48, 0x40, 0x42, 0x4f, 0x4f, 0x4f, 0x4f, 0x43,
+ 0x4a, 0x41, 0x4b, 0x53, 0x43, 0x46, 0x4f, 0x39, 0x46, 0x4a, 0x4d, 0x53,
+ 0x41, 0x44, 0x4e, 0x44, 0x3f, 0x47, 0x4c, 0x4d, 0x4d, 0x43, 0x45, 0x3d,
+ 0x43, 0x4b, 0x3e, 0x48, 0x42, 0x4c, 0x47, 0x42, 0x42, 0x50, 0x49, 0x4b,
+ 0x43, 0x4e, 0x44, 0x44, 0x4c, 0x3d, 0x4c, 0x47, 0x4e, 0x42, 0x4b, 0x44,
+ 0x4b, 0x44, 0x3f, 0x49, 0x33, 0x46, 0x4a, 0x4a, 0x42, 0x57, 0x5e, 0x4a,
+ 0x46, 0x4f, 0x55, 0x3c, 0x4a, 0x4b, 0x4c, 0x43, 0x51, 0x59, 0x64, 0x51,
+ 0x45, 0x60, 0x4b, 0x65, 0x46, 0x4a, 0x4e, 0x49, 0x41, 0x4b, 0x50, 0x5c,
+ 0x48, 0x4b, 0x3e, 0x52, 0x4f, 0x2f, 0x4e, 0x4a, 0x45, 0x53, 0x48, 0x59,
+ 0x4c, 0x4e, 0x4a, 0x4d, 0x49, 0x40, 0x52, 0x44, 0x49, 0x46, 0x4e, 0x46,
+ 0x42, 0x4b, 0x4a, 0x4b, 0x4b, 0x4b, 0x4f, 0x52, 0x46, 0x50, 0x4d, 0x3d,
+ 0x46, 0x4b, 0x4b, 0x40, 0x4d, 0x3f, 0x43, 0x33, 0x4e, 0x53, 0x4b, 0x4a,
+ 0x45, 0x48, 0x4c, 0x2e, 0x48, 0x4f, 0x49, 0x42, 0x54, 0x4f, 0x4b, 0x2b,
+ 0x55, 0x4e, 0x43, 0x4d, 0x4d, 0x47, 0x42, 0x3e, 0x48, 0x48, 0x4d, 0x54,
+ 0x52, 0x4f, 0x43, 0x37, 0x4b, 0x42, 0x4b, 0x4e, 0x49, 0x49, 0x4b, 0x2e,
+ 0x45, 0x4e, 0x48, 0x4e, 0x44, 0x49, 0x48, 0x30, 0x4c, 0x4b, 0x3f, 0x42,
+ 0x4f, 0x4f, 0x4e, 0x38, 0x4f, 0x42, 0x54, 0x49, 0x41, 0x42, 0x45, 0x3a,
+ 0x47, 0x43, 0x43, 0x4b, 0x49, 0x40, 0x4d, 0x38, 0x52, 0x4c, 0x3d, 0x4d,
+ 0x43, 0x54, 0x4e, 0x41, 0x4a, 0x47, 0x44, 0x51, 0x47, 0x48, 0x41, 0x47,
+ 0x4d, 0x41, 0x46, 0x4c, 0x4d, 0x46, 0x51, 0x4a, 0x49, 0x46, 0x4a, 0x42,
+ 0x3a, 0x43, 0x4a, 0x4b, 0x43, 0x4c, 0x68, 0x44, 0x4b, 0x52, 0x50, 0x37,
+ 0x4d, 0x4c, 0x57, 0x4c, 0x68, 0x62, 0x64, 0x4a, 0x3e, 0x64, 0x4b, 0x66,
+ 0x48, 0x4d, 0x54, 0x57, 0x4b, 0x52, 0x49, 0x5c, 0x4d, 0x55, 0x51, 0x57,
+ 0x4c, 0x3a, 0x48, 0x43, 0x3b, 0x43, 0x52, 0x5d, 0x45, 0x4e, 0x51, 0x4d,
+ 0x4a, 0x55, 0x4e, 0x4c, 0x44, 0x51, 0x4c, 0x4f, 0x41, 0x4f, 0x4a, 0x43,
+ 0x53, 0x48, 0x47, 0x49, 0x46, 0x52, 0x48, 0x3e, 0x4b, 0x4e, 0x4a, 0x50,
+ 0x4f, 0x47, 0x3e, 0x2e, 0x4b, 0x51, 0x4a, 0x44, 0x4c, 0x49, 0x4f, 0x26,
+ 0x48, 0x4f, 0x44, 0x51, 0x48, 0x3f, 0x4c, 0x30, 0x4e, 0x48, 0x4d, 0x48,
+ 0x48, 0x44, 0x4b, 0x2f, 0x50, 0x41, 0x4d, 0x50, 0x52, 0x42, 0x45, 0x33,
+ 0x4c, 0x48, 0x48, 0x3d, 0x46, 0x41, 0x43, 0x38, 0x45, 0x4f, 0x48, 0x4b,
+ 0x41, 0x49, 0x4c, 0x2f, 0x53, 0x4c, 0x48, 0x4a, 0x47, 0x40, 0x4a, 0x31,
+ 0x52, 0x40, 0x49, 0x4c, 0x3f, 0x48, 0x48, 0x39, 0x48, 0x3f, 0x45, 0x43,
+ 0x40, 0x48, 0x3c, 0x40, 0x4c, 0x48, 0x48, 0x4d, 0x3e, 0x42, 0x4a, 0x3d,
+ 0x4c, 0x45, 0x44, 0x46, 0x44, 0x45, 0x4a, 0x47, 0x52, 0x48, 0x4a, 0x4d,
+ 0x3f, 0x49, 0x4c, 0x4c, 0x48, 0x44, 0x4c, 0x44, 0x3d, 0x41, 0x47, 0x45,
+ 0x43, 0x4a, 0x5a, 0x3f, 0x48, 0x5d, 0x50, 0x35, 0x47, 0x4f, 0x5b, 0x46,
+ 0x6e, 0x50, 0x6d, 0x44, 0x49, 0x6a, 0x53, 0x6b, 0x4b, 0x4b, 0x4f, 0x62,
+ 0x45, 0x57, 0x48, 0x5b, 0x40, 0x4b, 0x4f, 0x63, 0x48, 0x3a, 0x4b, 0x42,
+ 0x43, 0x53, 0x41, 0x5f, 0x54, 0x3e, 0x4d, 0x43, 0x3d, 0x4c, 0x46, 0x46,
+ 0x49, 0x56, 0x4b, 0x45, 0x47, 0x45, 0x4e, 0x4f, 0x4c, 0x4d, 0x4f, 0x47,
+ 0x49, 0x4b, 0x51, 0x33, 0x4b, 0x45, 0x4d, 0x41, 0x51, 0x4a, 0x43, 0x2a,
+ 0x50, 0x4b, 0x4a, 0x4b, 0x4c, 0x52, 0x4c, 0x3b, 0x45, 0x4c, 0x51, 0x44,
+ 0x4c, 0x48, 0x43, 0x35, 0x51, 0x50, 0x48, 0x49, 0x3f, 0x48, 0x3d, 0x3b,
+ 0x52, 0x3f, 0x42, 0x4b, 0x49, 0x49, 0x47, 0x38, 0x4a, 0x4a, 0x41, 0x52,
+ 0x41, 0x3e, 0x4b, 0x2f, 0x46, 0x4d, 0x49, 0x44, 0x46, 0x3b, 0x47, 0x36,
+ 0x46, 0x3f, 0x49, 0x48, 0x47, 0x42, 0x42, 0x35, 0x44, 0x4b, 0x4d, 0x56,
+ 0x50, 0x49, 0x43, 0x42, 0x4b, 0x3e, 0x53, 0x44, 0x4a, 0x43, 0x47, 0x38,
+ 0x4a, 0x45, 0x4d, 0x3f, 0x46, 0x4a, 0x47, 0x3a, 0x4c, 0x3e, 0x47, 0x45,
+ 0x46, 0x4b, 0x45, 0x49, 0x4a, 0x4b, 0x54, 0x49, 0x4a, 0x53, 0x4a, 0x4c,
+ 0x45, 0x48, 0x53, 0x42, 0x4b, 0x47, 0x4e, 0x50, 0x3d, 0x51, 0x60, 0x3e,
+ 0x53, 0x5d, 0x51, 0x30, 0x45, 0x50, 0x59, 0x4e, 0x62, 0x52, 0x68, 0x51,
+ 0x45, 0x6c, 0x4c, 0x64, 0x4d, 0x47, 0x55, 0x61, 0x44, 0x57, 0x44, 0x58,
+ 0x44, 0x4a, 0x53, 0x58, 0x47, 0x31, 0x3f, 0x4c, 0x43, 0x45, 0x48, 0x5e,
+ 0x41, 0x43, 0x3f, 0x43, 0x51, 0x46, 0x48, 0x4b, 0x4d, 0x5b, 0x45, 0x4b,
+ 0x48, 0x46, 0x3f, 0x45, 0x47, 0x45, 0x40, 0x4a, 0x51, 0x51, 0x3d, 0x3f,
+ 0x43, 0x45, 0x4d, 0x4a, 0x47, 0x50, 0x49, 0x32, 0x4c, 0x5a, 0x55, 0x4f,
+ 0x4c, 0x51, 0x43, 0x37, 0x40, 0x59, 0x49, 0x49, 0x4e, 0x4f, 0x47, 0x34,
+ 0x40, 0x4c, 0x4a, 0x41, 0x4a, 0x47, 0x4a, 0x42, 0x4e, 0x4a, 0x48, 0x4e,
+ 0x4e, 0x4e, 0x45, 0x39, 0x4e, 0x45, 0x45, 0x4e, 0x4c, 0x48, 0x4a, 0x35,
+ 0x45, 0x4c, 0x49, 0x4f, 0x51, 0x43, 0x3c, 0x3a, 0x4a, 0x4a, 0x46, 0x48,
+ 0x49, 0x42, 0x4e, 0x2f, 0x42, 0x4e, 0x45, 0x50, 0x51, 0x40, 0x45, 0x32,
+ 0x4a, 0x4d, 0x44, 0x4e, 0x48, 0x48, 0x47, 0x2f, 0x48, 0x4b, 0x49, 0x44,
+ 0x48, 0x4d, 0x46, 0x3b, 0x46, 0x4a, 0x41, 0x4e, 0x4e, 0x47, 0x54, 0x4b,
+ 0x45, 0x49, 0x45, 0x44, 0x45, 0x48, 0x4a, 0x46, 0x55, 0x49, 0x47, 0x49,
+ 0x4b, 0x42, 0x48, 0x4f, 0x3f, 0x52, 0x60, 0x39, 0x4b, 0x5e, 0x55, 0x2e,
+ 0x48, 0x50, 0x59, 0x4f, 0x68, 0x5f, 0x64, 0x4f, 0x3b, 0x71, 0x50, 0x63,
+ 0x4f, 0x50, 0x50, 0x6c, 0x4b, 0x55, 0x47, 0x5b, 0x4c, 0x40, 0x48, 0x59,
+ 0x4f, 0x2e, 0x4b, 0x4c, 0x4e, 0x4e, 0x46, 0x61, 0x50, 0x41, 0x4c, 0x4a,
+ 0x44, 0x3e, 0x3f, 0x47, 0x4b, 0x4f, 0x47, 0x4b, 0x47, 0x3d, 0x41, 0x49,
+ 0x49, 0x3f, 0x4d, 0x44, 0x4a, 0x4d, 0x45, 0x41, 0x4d, 0x43, 0x49, 0x3c,
+ 0x49, 0x57, 0x49, 0x3b, 0x49, 0x59, 0x3f, 0x4f, 0x4e, 0x49, 0x4e, 0x46,
+ 0x52, 0x4e, 0x4c, 0x54, 0x4a, 0x48, 0x48, 0x3a, 0x44, 0x4a, 0x4f, 0x4a,
+ 0x44, 0x4b, 0x43, 0x4d, 0x51, 0x42, 0x53, 0x4d, 0x52, 0x41, 0x4d, 0x43,
+ 0x4e, 0x54, 0x4b, 0x42, 0x4b, 0x3f, 0x53, 0x45, 0x3f, 0x4a, 0x45, 0x50,
+ 0x3f, 0x4c, 0x4f, 0x43, 0x46, 0x42, 0x4b, 0x4d, 0x4c, 0x3b, 0x48, 0x40,
+ 0x4e, 0x4e, 0x49, 0x46, 0x4d, 0x4d, 0x52, 0x40, 0x4e, 0x4f, 0x46, 0x4a,
+ 0x40, 0x4b, 0x4c, 0x40, 0x4f, 0x4a, 0x44, 0x41, 0x46, 0x3c, 0x40, 0x3d,
+ 0x44, 0x48, 0x4a, 0x50, 0x46, 0x53, 0x46, 0x40, 0x44, 0x3e, 0x47, 0x43,
+ 0x48, 0x3d, 0x4e, 0x3e, 0x48, 0x49, 0x4b, 0x49, 0x4c, 0x3e, 0x4c, 0x4a,
+ 0x46, 0x4e, 0x62, 0x3c, 0x59, 0x60, 0x51, 0x29, 0x47, 0x52, 0x59, 0x4c,
+ 0x67, 0x68, 0x68, 0x4e, 0x3b, 0x72, 0x4d, 0x68, 0x44, 0x4f, 0x53, 0x63,
+ 0x47, 0x5a, 0x45, 0x4f, 0x4b, 0x37, 0x43, 0x5b, 0x4b, 0x3d, 0x44, 0x41,
+ 0x4a, 0x4b, 0x3c, 0x64, 0x48, 0x38, 0x42, 0x3f, 0x48, 0x46, 0x4b, 0x46,
+ 0x46, 0x4f, 0x46, 0x46, 0x44, 0x3c, 0x4b, 0x4f, 0x4d, 0x4a, 0x4b, 0x46,
+ 0x4d, 0x4f, 0x4f, 0x3f, 0x3a, 0x4b, 0x55, 0x3c, 0x51, 0x56, 0x4d, 0x42,
+ 0x52, 0x5a, 0x3e, 0x4b, 0x54, 0x57, 0x4e, 0x4d, 0x4e, 0x5b, 0x4e, 0x49,
+ 0x4e, 0x3c, 0x40, 0x41, 0x40, 0x4d, 0x48, 0x42, 0x49, 0x4e, 0x4f, 0x47,
+ 0x47, 0x48, 0x50, 0x49, 0x51, 0x46, 0x44, 0x45, 0x49, 0x46, 0x43, 0x48,
+ 0x48, 0x49, 0x4d, 0x4c, 0x45, 0x4f, 0x4c, 0x45, 0x44, 0x40, 0x49, 0x45,
+ 0x49, 0x51, 0x4b, 0x4b, 0x50, 0x4b, 0x48, 0x3d, 0x4e, 0x52, 0x4a, 0x47,
+ 0x49, 0x41, 0x55, 0x3d, 0x48, 0x4d, 0x49, 0x48, 0x4e, 0x4c, 0x48, 0x3d,
+ 0x3f, 0x4c, 0x4e, 0x53, 0x3e, 0x48, 0x4a, 0x3f, 0x54, 0x4d, 0x54, 0x4b,
+ 0x47, 0x4e, 0x44, 0x48, 0x49, 0x4b, 0x4c, 0x49, 0x4d, 0x42, 0x52, 0x4b,
+ 0x40, 0x3e, 0x54, 0x49, 0x55, 0x45, 0x47, 0x4d, 0x45, 0x5c, 0x60, 0x40,
+ 0x57, 0x60, 0x5b, 0x27, 0x4a, 0x5a, 0x64, 0x53, 0x6a, 0x5a, 0x5f, 0x52,
+ 0x3a, 0x72, 0x4b, 0x5f, 0x45, 0x56, 0x5f, 0x5f, 0x54, 0x5f, 0x39, 0x52,
+ 0x51, 0x3e, 0x3b, 0x5a, 0x44, 0x32, 0x46, 0x50, 0x3a, 0x4f, 0x44, 0x5d,
+ 0x4c, 0x41, 0x39, 0x3f, 0x45, 0x46, 0x3b, 0x43, 0x46, 0x51, 0x3c, 0x4c,
+ 0x4b, 0x43, 0x4b, 0x51, 0x43, 0x48, 0x4d, 0x43, 0x38, 0x46, 0x46, 0x43,
+ 0x44, 0x4a, 0x46, 0x49, 0x48, 0x50, 0x4e, 0x4a, 0x4e, 0x58, 0x4a, 0x49,
+ 0x48, 0x4f, 0x4a, 0x49, 0x41, 0x57, 0x51, 0x50, 0x4b, 0x48, 0x47, 0x4b,
+ 0x53, 0x3d, 0x4b, 0x4c, 0x4b, 0x4b, 0x55, 0x56, 0x45, 0x49, 0x46, 0x4c,
+ 0x45, 0x51, 0x47, 0x50, 0x40, 0x4b, 0x4f, 0x4b, 0x4d, 0x4a, 0x4f, 0x50,
+ 0x49, 0x53, 0x50, 0x46, 0x40, 0x48, 0x4a, 0x4a, 0x49, 0x4a, 0x42, 0x45,
+ 0x4b, 0x45, 0x42, 0x45, 0x4e, 0x4e, 0x44, 0x41, 0x4b, 0x4a, 0x49, 0x3f,
+ 0x41, 0x51, 0x48, 0x4c, 0x40, 0x41, 0x51, 0x42, 0x49, 0x49, 0x48, 0x42,
+ 0x48, 0x4c, 0x4b, 0x3c, 0x49, 0x45, 0x42, 0x49, 0x4c, 0x46, 0x45, 0x43,
+ 0x43, 0x48, 0x48, 0x41, 0x43, 0x42, 0x4c, 0x4b, 0x40, 0x45, 0x44, 0x46,
+ 0x4c, 0x4b, 0x4e, 0x4d, 0x3f, 0x59, 0x55, 0x41, 0x56, 0x5a, 0x51, 0x30,
+ 0x49, 0x5a, 0x63, 0x4d, 0x61, 0x5b, 0x64, 0x55, 0x34, 0x7a, 0x4c, 0x62,
+ 0x3e, 0x5d, 0x56, 0x60, 0x48, 0x61, 0x3f, 0x54, 0x46, 0x40, 0x42, 0x56,
+ 0x52, 0x35, 0x4c, 0x59, 0x45, 0x4c, 0x42, 0x60, 0x49, 0x3f, 0x4c, 0x3c,
+ 0x52, 0x36, 0x46, 0x3d, 0x58, 0x4b, 0x41, 0x48, 0x3e, 0x45, 0x4e, 0x54,
+ 0x4c, 0x56, 0x47, 0x44, 0x39, 0x4a, 0x4a, 0x4a, 0x46, 0x48, 0x4a, 0x48,
+ 0x51, 0x4f, 0x4b, 0x49, 0x45, 0x4b, 0x44, 0x4c, 0x3e, 0x4c, 0x42, 0x59,
+ 0x47, 0x55, 0x47, 0x47, 0x41, 0x44, 0x44, 0x4a, 0x44, 0x4b, 0x44, 0x46,
+ 0x49, 0x5a, 0x48, 0x5d, 0x4f, 0x4a, 0x47, 0x50, 0x48, 0x4e, 0x44, 0x57,
+ 0x49, 0x46, 0x42, 0x4d, 0x3d, 0x4a, 0x4a, 0x58, 0x41, 0x4d, 0x3c, 0x47,
+ 0x42, 0x4e, 0x4d, 0x49, 0x44, 0x4b, 0x4c, 0x4b, 0x53, 0x42, 0x4a, 0x46,
+ 0x4e, 0x56, 0x4b, 0x47, 0x50, 0x43, 0x4f, 0x48, 0x49, 0x50, 0x48, 0x50,
+ 0x42, 0x4c, 0x4e, 0x3c, 0x41, 0x4f, 0x4a, 0x41, 0x44, 0x47, 0x4c, 0x42,
+ 0x51, 0x4f, 0x53, 0x46, 0x4c, 0x4b, 0x48, 0x51, 0x47, 0x4b, 0x4c, 0x4d,
+ 0x4d, 0x49, 0x3d, 0x44, 0x4b, 0x42, 0x43, 0x49, 0x51, 0x47, 0x4c, 0x4b,
+ 0x4a, 0x50, 0x5b, 0x43, 0x5b, 0x68, 0x54, 0x31, 0x4c, 0x5d, 0x5c, 0x54,
+ 0x63, 0x5a, 0x61, 0x54, 0x3d, 0x7a, 0x51, 0x5b, 0x40, 0x59, 0x5a, 0x62,
+ 0x4c, 0x5e, 0x42, 0x58, 0x49, 0x3c, 0x38, 0x50, 0x54, 0x37, 0x42, 0x51,
+ 0x4d, 0x4f, 0x42, 0x68, 0x4a, 0x40, 0x4e, 0x40, 0x3f, 0x3e, 0x3f, 0x40,
+ 0x54, 0x52, 0x3e, 0x43, 0x46, 0x4a, 0x48, 0x51, 0x4e, 0x4d, 0x42, 0x47,
+ 0x3f, 0x51, 0x47, 0x44, 0x3f, 0x4c, 0x46, 0x47, 0x4f, 0x55, 0x4b, 0x4e,
+ 0x4c, 0x51, 0x40, 0x51, 0x47, 0x4a, 0x44, 0x5c, 0x48, 0x54, 0x4b, 0x46,
+ 0x49, 0x4b, 0x53, 0x59, 0x43, 0x3e, 0x45, 0x4e, 0x4f, 0x58, 0x4b, 0x64,
+ 0x41, 0x4b, 0x45, 0x4a, 0x4c, 0x51, 0x47, 0x57, 0x45, 0x46, 0x43, 0x4f,
+ 0x4d, 0x4d, 0x49, 0x58, 0x4b, 0x52, 0x43, 0x4b, 0x45, 0x4c, 0x50, 0x4c,
+ 0x4e, 0x4b, 0x40, 0x4c, 0x44, 0x4e, 0x4c, 0x47, 0x41, 0x55, 0x45, 0x4a,
+ 0x4c, 0x48, 0x46, 0x41, 0x47, 0x52, 0x44, 0x4f, 0x48, 0x49, 0x4b, 0x47,
+ 0x50, 0x4f, 0x42, 0x4a, 0x44, 0x4b, 0x52, 0x43, 0x45, 0x4e, 0x46, 0x49,
+ 0x45, 0x52, 0x51, 0x45, 0x44, 0x41, 0x4c, 0x46, 0x4c, 0x4b, 0x44, 0x4d,
+ 0x4f, 0x48, 0x44, 0x4d, 0x56, 0x48, 0x50, 0x4f, 0x3b, 0x4e, 0x55, 0x43,
+ 0x52, 0x62, 0x57, 0x2c, 0x4d, 0x5e, 0x5e, 0x50, 0x64, 0x5b, 0x6a, 0x55,
+ 0x39, 0x7d, 0x4b, 0x5e, 0x43, 0x54, 0x5d, 0x5c, 0x4d, 0x5c, 0x42, 0x51,
+ 0x4c, 0x3d, 0x46, 0x51, 0x4c, 0x2a, 0x3e, 0x54, 0x47, 0x48, 0x46, 0x64,
+ 0x42, 0x3d, 0x47, 0x3f, 0x42, 0x45, 0x49, 0x3b, 0x59, 0x50, 0x4c, 0x46,
+ 0x4d, 0x44, 0x47, 0x4d, 0x4a, 0x50, 0x41, 0x48, 0x43, 0x50, 0x3e, 0x44,
+ 0x4b, 0x53, 0x48, 0x49, 0x51, 0x51, 0x4d, 0x57, 0x49, 0x4f, 0x53, 0x50,
+ 0x46, 0x4f, 0x41, 0x5d, 0x47, 0x46, 0x49, 0x51, 0x45, 0x41, 0x4a, 0x56,
+ 0x4f, 0x4e, 0x4d, 0x4a, 0x3e, 0x55, 0x47, 0x65, 0x48, 0x51, 0x4d, 0x4e,
+ 0x46, 0x43, 0x48, 0x5b, 0x48, 0x4f, 0x4f, 0x48, 0x4b, 0x4d, 0x4e, 0x5c,
+ 0x4f, 0x4c, 0x54, 0x48, 0x4a, 0x4d, 0x4e, 0x4e, 0x44, 0x48, 0x43, 0x52,
+ 0x41, 0x52, 0x48, 0x4f, 0x46, 0x4f, 0x51, 0x41, 0x44, 0x45, 0x41, 0x4b,
+ 0x43, 0x4e, 0x4e, 0x42, 0x48, 0x41, 0x45, 0x43, 0x44, 0x43, 0x4c, 0x4c,
+ 0x51, 0x54, 0x4c, 0x32, 0x46, 0x52, 0x4e, 0x49, 0x40, 0x4d, 0x43, 0x4f,
+ 0x4a, 0x4d, 0x4d, 0x49, 0x46, 0x4c, 0x41, 0x4d, 0x41, 0x3a, 0x50, 0x4c,
+ 0x5a, 0x4e, 0x49, 0x53, 0x4d, 0x53, 0x53, 0x3d, 0x52, 0x64, 0x55, 0x2a,
+ 0x47, 0x5d, 0x61, 0x51, 0x5b, 0x5d, 0x66, 0x52, 0x3f, 0xfd, 0x55, 0x5a,
+ 0x4b, 0x54, 0x5b, 0x60, 0x49, 0x5d, 0x43, 0x57, 0x47, 0x41, 0x45, 0x5e,
+ 0x4c, 0x28, 0x3e, 0x40, 0x49, 0x4e, 0x40, 0x69, 0x4a, 0x44, 0x45, 0x43,
+ 0x45, 0x3d, 0x39, 0x40, 0x4c, 0x53, 0x4b, 0x3d, 0x4e, 0x43, 0x48, 0x55,
+ 0x4d, 0x50, 0x4d, 0x49, 0x4f, 0x48, 0x3e, 0x46, 0x47, 0x56, 0x40, 0x48,
+ 0x46, 0x53, 0x50, 0x5d, 0x43, 0x54, 0x49, 0x47, 0x49, 0x4c, 0x48, 0x5d,
+ 0x49, 0x51, 0x50, 0x3d, 0x41, 0x47, 0x48, 0x64, 0x4b, 0x44, 0x49, 0x41,
+ 0x54, 0x48, 0x3d, 0x6b, 0x4c, 0x5a, 0x48, 0x4e, 0x40, 0x4c, 0x52, 0x5f,
+ 0x54, 0x4a, 0x3f, 0x48, 0x43, 0x43, 0x44, 0x66, 0x49, 0x47, 0x43, 0x46,
+ 0x47, 0x54, 0x42, 0x54, 0x4b, 0x4e, 0x49, 0x49, 0x49, 0x4b, 0x52, 0x4f,
+ 0x43, 0x46, 0x4b, 0x49, 0x54, 0x4b, 0x40, 0x48, 0x47, 0x4a, 0x46, 0x47,
+ 0x44, 0x47, 0x4c, 0x37, 0x3f, 0x49, 0x45, 0x44, 0x50, 0x49, 0x44, 0x36,
+ 0x4d, 0x40, 0x45, 0x49, 0x53, 0x55, 0x44, 0x42, 0x47, 0x48, 0x46, 0x40,
+ 0x4f, 0x4c, 0x41, 0x42, 0x52, 0x3a, 0x43, 0x46, 0x55, 0x51, 0x4e, 0x4f,
+ 0x48, 0x51, 0x55, 0x48, 0x52, 0x66, 0x4e, 0x33, 0x49, 0x5b, 0x5f, 0x4b,
+ 0x5f, 0x5b, 0x66, 0x52, 0x41, 0x7c, 0x4a, 0x59, 0x47, 0x59, 0x58, 0x67,
+ 0x49, 0x5e, 0x44, 0x57, 0x49, 0x4c, 0x43, 0x56, 0x41, 0x27, 0x4c, 0x44,
+ 0x51, 0x44, 0x42, 0x65, 0x49, 0x44, 0x40, 0x3d, 0x4d, 0x3e, 0x4c, 0x3c,
+ 0x4f, 0x4b, 0x45, 0x44, 0x4d, 0x48, 0x47, 0x54, 0x4d, 0x4e, 0x44, 0x42,
+ 0x47, 0x44, 0x3d, 0x49, 0x4e, 0x50, 0x49, 0x45, 0x58, 0x4a, 0x54, 0x5c,
+ 0x41, 0x49, 0x4f, 0x42, 0x44, 0x4f, 0x4a, 0x62, 0x48, 0x50, 0x48, 0x43,
+ 0x51, 0x53, 0x47, 0x6c, 0x40, 0x46, 0x3d, 0x46, 0x4a, 0x50, 0x43, 0x69,
+ 0x49, 0x4f, 0x4a, 0x4c, 0x49, 0x46, 0x43, 0x6a, 0x48, 0x50, 0x49, 0x48,
+ 0x48, 0x51, 0x4b, 0x65, 0x42, 0x4b, 0x4d, 0x48, 0x44, 0x4e, 0x49, 0x60,
+ 0x44, 0x52, 0x42, 0x42, 0x47, 0x48, 0x4b, 0x51, 0x50, 0x4b, 0x3c, 0x4d,
+ 0x4c, 0x44, 0x48, 0x55, 0x51, 0x4c, 0x55, 0x4e, 0x52, 0x4c, 0x4b, 0x39,
+ 0x48, 0x42, 0x49, 0x49, 0x49, 0x50, 0x49, 0x32, 0x4e, 0x4b, 0x45, 0x4f,
+ 0x42, 0x4b, 0x47, 0x50, 0x48, 0x45, 0x54, 0x49, 0x4c, 0x46, 0x40, 0x46,
+ 0x43, 0x3d, 0x51, 0x44, 0x53, 0x4f, 0x54, 0x55, 0x43, 0x4f, 0x5b, 0x47,
+ 0x53, 0x6c, 0x57, 0x2e, 0x50, 0x55, 0x5a, 0x4d, 0x57, 0x5d, 0x70, 0x50,
+ 0x3f, 0x79, 0x4a, 0x5a, 0x4c, 0x58, 0x59, 0x63, 0x45, 0x69, 0x48, 0x58,
+ 0x42, 0x4b, 0x43, 0x5c, 0x46, 0x28, 0x48, 0x49, 0x4c, 0x3f, 0x45, 0x58,
+ 0x45, 0x44, 0x47, 0x40, 0x4c, 0x42, 0x3e, 0x37, 0x45, 0x54, 0x48, 0x3b,
+ 0x4e, 0x48, 0x43, 0x4a, 0x50, 0x4a, 0x49, 0x46, 0x4c, 0x54, 0x3f, 0x4b,
+ 0x4e, 0x56, 0x48, 0x49, 0x49, 0x4c, 0x51, 0x5f, 0x4d, 0x4b, 0x43, 0x4d,
+ 0x47, 0x51, 0x43, 0x59, 0x45, 0x4e, 0x4f, 0x45, 0x44, 0x54, 0x44, 0x6d,
+ 0x47, 0x51, 0x43, 0x4e, 0x4c, 0x4f, 0x43, 0x6d, 0x48, 0x53, 0x4b, 0x47,
+ 0x49, 0x48, 0x46, 0x6a, 0x51, 0x4c, 0x4d, 0x45, 0x4e, 0x47, 0x46, 0x62,
+ 0x4a, 0x54, 0x51, 0x4c, 0x47, 0x4d, 0x4a, 0x61, 0x3d, 0x50, 0x4c, 0x4c,
+ 0x45, 0x3f, 0x3e, 0x54, 0x3d, 0x53, 0x48, 0x47, 0x52, 0x4b, 0x47, 0x51,
+ 0x4f, 0x45, 0x4b, 0x4a, 0x4c, 0x46, 0x44, 0x37, 0x42, 0x50, 0x49, 0x4f,
+ 0x51, 0x41, 0x44, 0x38, 0x54, 0x40, 0x51, 0x52, 0x3e, 0x43, 0x44, 0x47,
+ 0x49, 0x4b, 0x4b, 0x46, 0x53, 0x54, 0x55, 0x4b, 0x4a, 0x37, 0x43, 0x4a,
+ 0x51, 0x47, 0x51, 0x54, 0x43, 0x46, 0x56, 0x3d, 0x54, 0x66, 0x4f, 0x30,
+ 0x45, 0x52, 0x5a, 0x43, 0x5c, 0x65, 0x5d, 0x52, 0x32, 0x77, 0x53, 0x5f,
+ 0x4a, 0x5a, 0x4f, 0x5e, 0x4e, 0x61, 0x4b, 0x5b, 0x4a, 0x53, 0x3e, 0x61,
+ 0x47, 0x24, 0x3e, 0x48, 0x4d, 0x43, 0x40, 0x53, 0x4e, 0x41, 0x43, 0x3d,
+ 0x50, 0x49, 0x41, 0x3a, 0x4e, 0x4b, 0x48, 0x49, 0x48, 0x49, 0x46, 0x50,
+ 0x4f, 0x4b, 0x47, 0x4b, 0x48, 0x52, 0x3e, 0x4d, 0x4d, 0x59, 0x4c, 0x3e,
+ 0x52, 0x49, 0x4f, 0x5e, 0x54, 0x59, 0x47, 0x4d, 0x40, 0x4c, 0x4b, 0x64,
+ 0x42, 0x4c, 0x53, 0x46, 0x4e, 0x50, 0x46, 0x6a, 0x41, 0x59, 0x44, 0x4b,
+ 0x4f, 0x44, 0x52, 0x6c, 0x54, 0x4e, 0x46, 0x48, 0x42, 0x3d, 0x44, 0x67,
+ 0x44, 0x4f, 0x47, 0x54, 0x4c, 0x4f, 0x43, 0x61, 0x4c, 0x54, 0x4f, 0x43,
+ 0x49, 0x40, 0x4a, 0x5f, 0x4a, 0x52, 0x47, 0x43, 0x4c, 0x43, 0x49, 0x53,
+ 0x4c, 0x4b, 0x43, 0x3d, 0x4e, 0x45, 0x49, 0x50, 0x44, 0x53, 0x4f, 0x48,
+ 0x4b, 0x46, 0x44, 0x3c, 0x50, 0x42, 0x43, 0x40, 0x47, 0x43, 0x42, 0x34,
+ 0x47, 0x42, 0x3f, 0x4a, 0x48, 0x42, 0x48, 0x4c, 0x42, 0x4c, 0x4e, 0x47,
+ 0x48, 0x47, 0x51, 0x51, 0x4d, 0x3d, 0x3e, 0x4b, 0x54, 0x4c, 0x4c, 0x59,
+ 0x4f, 0x50, 0x57, 0x3c, 0x54, 0x62, 0x54, 0x35, 0x3d, 0x5a, 0x5b, 0x47,
+ 0x59, 0x63, 0x66, 0x4d, 0x3c, 0x79, 0x50, 0x5f, 0x45, 0x58, 0x4e, 0x5d,
+ 0x48, 0x61, 0x43, 0x54, 0x47, 0x54, 0x4d, 0x54, 0x4b, 0x25, 0x41, 0x44,
+ 0x4c, 0x4a, 0x3b, 0x52, 0x47, 0x3c, 0x45, 0x3c, 0x53, 0x44, 0x44, 0x40,
+ 0x50, 0x4c, 0x45, 0x3a, 0x4c, 0x51, 0x44, 0x49, 0x4d, 0x52, 0x4d, 0x4b,
+ 0x45, 0x52, 0x3d, 0x50, 0x4a, 0x58, 0x4a, 0x47, 0x4d, 0x47, 0x4e, 0x52,
+ 0x4f, 0x4d, 0x4f, 0x49, 0x52, 0x52, 0x4c, 0x5e, 0x47, 0x4d, 0x46, 0x4d,
+ 0x4c, 0x48, 0x50, 0x70, 0x41, 0x4a, 0x48, 0x3d, 0x45, 0x48, 0x45, 0x74,
+ 0x47, 0x4c, 0x43, 0x4f, 0x4a, 0x4a, 0x40, 0x68, 0x52, 0x49, 0x3e, 0x3e,
+ 0x4e, 0x4b, 0x4b, 0x69, 0x42, 0x4f, 0x45, 0x47, 0x3f, 0x45, 0x46, 0x56,
+ 0x45, 0x4a, 0x47, 0x44, 0x52, 0x4b, 0x53, 0x4e, 0x4e, 0x46, 0x45, 0x40,
+ 0x47, 0x4b, 0x53, 0x52, 0x53, 0x51, 0x4f, 0x46, 0x42, 0x43, 0x50, 0x3e,
+ 0x48, 0x4e, 0x41, 0x53, 0x4d, 0x48, 0x48, 0x33, 0x40, 0x43, 0x4b, 0x42,
+ 0x52, 0x4c, 0x42, 0x4e, 0x41, 0x4e, 0x4f, 0x50, 0x43, 0x49, 0x4d, 0x47,
+ 0x4a, 0x3a, 0x3f, 0x51, 0x51, 0x44, 0x4e, 0x54, 0x40, 0x55, 0x59, 0x3c,
+ 0x57, 0x67, 0x4e, 0x2e, 0x4c, 0x5b, 0x5b, 0x51, 0x58, 0x63, 0x62, 0x52,
+ 0x3c, 0x72, 0x51, 0x5a, 0x4e, 0x53, 0x4a, 0x5c, 0x51, 0x69, 0x42, 0x51,
+ 0x48, 0x54, 0x48, 0x57, 0x3e, 0x37, 0x3f, 0x4d, 0x4d, 0x4a, 0x35, 0x57,
+ 0x4e, 0x40, 0x45, 0x4a, 0x45, 0x4e, 0x49, 0x40, 0x49, 0x53, 0x51, 0x44,
+ 0x4a, 0x50, 0x4b, 0x4b, 0x50, 0x4f, 0x3e, 0x44, 0x45, 0x44, 0x4c, 0x51,
+ 0x47, 0x51, 0x46, 0x42, 0x48, 0x50, 0x49, 0x4d, 0x43, 0x54, 0x52, 0x4d,
+ 0x4e, 0x4f, 0x3f, 0x63, 0x54, 0x57, 0x41, 0x44, 0x4e, 0x50, 0x4e, 0x66,
+ 0x41, 0x53, 0x4b, 0x4d, 0x4e, 0x4f, 0x43, 0x6d, 0x4e, 0x51, 0x49, 0x4f,
+ 0x49, 0x4a, 0x4a, 0x6c, 0x4b, 0x4f, 0x3d, 0x47, 0x4d, 0x51, 0x3c, 0x66,
+ 0x4b, 0x56, 0x3e, 0x4c, 0x41, 0x46, 0x45, 0x68, 0x47, 0x4b, 0x4a, 0x54,
+ 0x53, 0x48, 0x51, 0x59, 0x45, 0x43, 0x50, 0x45, 0x4f, 0x45, 0x42, 0x55,
+ 0x48, 0x52, 0x4c, 0x46, 0x52, 0x49, 0x47, 0x3d, 0x55, 0x48, 0x52, 0x52,
+ 0x40, 0x4e, 0x47, 0x31, 0x45, 0x4f, 0x42, 0x4a, 0x4e, 0x50, 0x42, 0x4a,
+ 0x49, 0x57, 0x46, 0x4b, 0x45, 0x4e, 0x4d, 0x46, 0x47, 0x43, 0x50, 0x4e,
+ 0x4f, 0x4c, 0x53, 0x55, 0x45, 0x51, 0x5b, 0x3a, 0x52, 0x64, 0x54, 0x2d,
+ 0x42, 0x59, 0x59, 0x45, 0x59, 0x67, 0x69, 0x53, 0x3f, 0x78, 0x50, 0x60,
+ 0x4c, 0x4c, 0x5b, 0x53, 0x45, 0x63, 0x49, 0x63, 0x51, 0x4c, 0x41, 0x4e,
+ 0x4b, 0x37, 0x45, 0x4e, 0x48, 0x4c, 0x39, 0x55, 0x44, 0x37, 0x3c, 0x49,
+ 0x44, 0x56, 0x3e, 0x40, 0x4d, 0x45, 0x4c, 0x43, 0x42, 0x41, 0x40, 0x42,
+ 0x57, 0x4f, 0x43, 0x3f, 0x52, 0x53, 0x51, 0x4b, 0x4b, 0x55, 0x46, 0x40,
+ 0x49, 0x45, 0x40, 0x4f, 0x47, 0x58, 0x4b, 0x53, 0x4e, 0x52, 0x54, 0x5e,
+ 0x4b, 0x51, 0x50, 0x44, 0x50, 0x4b, 0x4f, 0x70, 0x49, 0x4f, 0x4c, 0x50,
+ 0x45, 0x56, 0x4b, 0x6b, 0x49, 0x52, 0x4a, 0x3f, 0x44, 0x4b, 0x48, 0x72,
+ 0x4c, 0x47, 0x4e, 0x43, 0x46, 0x4c, 0x4f, 0x61, 0x4a, 0x52, 0x52, 0x46,
+ 0x4a, 0x4d, 0x46, 0x65, 0x48, 0x4e, 0x4d, 0x4e, 0x46, 0x4e, 0x53, 0x59,
+ 0x43, 0x49, 0x43, 0x47, 0x45, 0x47, 0x53, 0x50, 0x3e, 0x4d, 0x41, 0x46,
+ 0x4c, 0x4a, 0x4c, 0x35, 0x3f, 0x4f, 0x50, 0x48, 0x47, 0x4d, 0x4c, 0x32,
+ 0x45, 0x53, 0x43, 0x4d, 0x4e, 0x4a, 0x3e, 0x4b, 0x55, 0x4f, 0x53, 0x4c,
+ 0x4a, 0x4d, 0x48, 0x53, 0x4f, 0x3a, 0x47, 0x4b, 0x4e, 0x4e, 0x51, 0x59,
+ 0x41, 0x50, 0x57, 0x38, 0x5d, 0x63, 0x59, 0x2b, 0x45, 0x53, 0x5a, 0x4e,
+ 0x5c, 0x60, 0x5e, 0x4c, 0x41, 0x6f, 0x53, 0x5c, 0x48, 0x53, 0x56, 0x54,
+ 0x4b, 0x62, 0x46, 0x63, 0x47, 0x4e, 0x40, 0x51, 0x43, 0x36, 0x44, 0x42,
+ 0x46, 0x51, 0x41, 0x54, 0x4e, 0x36, 0x40, 0x4b, 0x55, 0x49, 0x40, 0x3f,
+ 0x4b, 0x42, 0x4a, 0x4a, 0x48, 0x47, 0x40, 0x43, 0x4d, 0x4f, 0x55, 0x3f,
+ 0x53, 0x42, 0x4d, 0x56, 0x49, 0x51, 0x4f, 0x41, 0x3b, 0x48, 0x43, 0x4e,
+ 0x4b, 0x5c, 0x4f, 0x45, 0x4a, 0x4c, 0x46, 0x66, 0x43, 0x45, 0x46, 0x48,
+ 0x4f, 0x4e, 0x40, 0x71, 0x4b, 0x4e, 0x3e, 0x42, 0x4d, 0x52, 0x42, 0x71,
+ 0x4c, 0x54, 0x4f, 0x3f, 0x4c, 0x43, 0x4a, 0x73, 0x48, 0x48, 0x4c, 0x4b,
+ 0x4c, 0x4d, 0x40, 0x72, 0x3e, 0x51, 0x49, 0x48, 0x52, 0x53, 0x45, 0x65,
+ 0x52, 0x4e, 0x4f, 0x44, 0x4c, 0x43, 0x4a, 0x5e, 0x3e, 0x56, 0x46, 0x55,
+ 0x55, 0x43, 0x49, 0x51, 0x4f, 0x52, 0x49, 0x4d, 0x46, 0x47, 0x49, 0x3e,
+ 0x51, 0x49, 0x41, 0x53, 0x42, 0x47, 0x46, 0x3b, 0x4d, 0x4e, 0x48, 0x44,
+ 0x42, 0x48, 0x4c, 0x47, 0x42, 0x4e, 0x4a, 0x3e, 0x44, 0x54, 0x4a, 0x4d,
+ 0x49, 0x41, 0x41, 0x53, 0x52, 0x4c, 0x4c, 0x56, 0x49, 0x4a, 0x5a, 0x3f,
+ 0x5b, 0x5c, 0x59, 0x2f, 0x49, 0x52, 0x5a, 0x4e, 0x5a, 0x61, 0x67, 0x4c,
+ 0x41, 0x6f, 0x5a, 0x5a, 0x40, 0x5a, 0x54, 0x4e, 0x49, 0x66, 0x45, 0x5a,
+ 0x4a, 0x45, 0x44, 0x4b, 0x44, 0x36, 0x41, 0x4c, 0x45, 0x44, 0x3d, 0x51,
+ 0x3f, 0x35, 0x3c, 0x46, 0x53, 0x5c, 0x3f, 0x3e, 0x50, 0x43, 0x46, 0x4b,
+ 0x40, 0x54, 0x41, 0x47, 0x4b, 0x51, 0x41, 0x46, 0x4a, 0x4d, 0x51, 0x52,
+ 0x43, 0x58, 0x45, 0x46, 0x4e, 0x46, 0x4a, 0x4b, 0x44, 0x54, 0x4c, 0x4c,
+ 0x43, 0x59, 0x48, 0x61, 0x4e, 0x4f, 0x4d, 0x4d, 0x4a, 0x52, 0x4c, 0x6e,
+ 0x49, 0x57, 0x48, 0x4d, 0x46, 0x46, 0x4d, 0x72, 0x4a, 0x4e, 0x47, 0x44,
+ 0x49, 0x4f, 0x48, 0x73, 0x42, 0x40, 0x4d, 0x44, 0x4d, 0x57, 0x3e, 0x69,
+ 0x50, 0x52, 0x4c, 0x55, 0x46, 0x4c, 0x44, 0x5f, 0x4b, 0x4d, 0x55, 0x4c,
+ 0x48, 0x49, 0x4a, 0x5e, 0x47, 0x4b, 0x45, 0x53, 0x55, 0x53, 0x4d, 0x53,
+ 0x47, 0x5c, 0x45, 0x4e, 0x4e, 0x52, 0x4c, 0x39, 0x4b, 0x4c, 0x49, 0x46,
+ 0x4a, 0x4e, 0x4b, 0x33, 0x46, 0x47, 0x52, 0x41, 0x49, 0x4b, 0x4c, 0x48,
+ 0x51, 0x53, 0x44, 0x4c, 0x4a, 0x45, 0x46, 0x49, 0x49, 0x4b, 0x50, 0x47,
+ 0x4d, 0x4b, 0x4c, 0x4f, 0x44, 0x45, 0x58, 0x3c, 0x56, 0x5a, 0x56, 0x23,
+ 0x4f, 0x4d, 0x5c, 0x4e, 0x59, 0x5a, 0x65, 0x43, 0x45, 0x66, 0x54, 0x5f,
+ 0x45, 0x5e, 0x54, 0x4f, 0x48, 0x5f, 0x44, 0x59, 0x48, 0x46, 0x47, 0x49,
+ 0x4d, 0x3c, 0x49, 0x54, 0x3e, 0x48, 0x43, 0x5b, 0x4a, 0x35, 0x41, 0x43,
+ 0x4b, 0x55, 0x43, 0x38, 0x46, 0x42, 0x4a, 0x4e, 0x54, 0x4b, 0x4d, 0x46,
+ 0x43, 0x4e, 0x44, 0x47, 0x56, 0x4c, 0x51, 0x57, 0x41, 0x4d, 0x43, 0x41,
+ 0x51, 0x47, 0x41, 0x51, 0x51, 0x4f, 0x46, 0x50, 0x52, 0x4e, 0x4d, 0x60,
+ 0x41, 0x49, 0x46, 0x50, 0x48, 0x56, 0x42, 0x6d, 0x40, 0x45, 0x44, 0x55,
+ 0x40, 0x4e, 0x40, 0x7c, 0x47, 0x5a, 0x44, 0x44, 0x45, 0x56, 0x55, 0x71,
+ 0x47, 0x4b, 0x4b, 0x45, 0x4f, 0x54, 0x4c, 0x73, 0x48, 0x55, 0x44, 0x4d,
+ 0x4a, 0x47, 0x49, 0x5e, 0x4d, 0x52, 0x4e, 0x4c, 0x48, 0x52, 0x48, 0x58,
+ 0x4c, 0x5a, 0x49, 0x4b, 0x53, 0x46, 0x4d, 0x4b, 0x48, 0x53, 0x41, 0x49,
+ 0x4a, 0x56, 0x51, 0x3a, 0x4c, 0x4e, 0x4f, 0x51, 0x4c, 0x59, 0x47, 0x45,
+ 0x4f, 0x50, 0x4a, 0x4f, 0x4d, 0x3f, 0x44, 0x4e, 0x42, 0x4a, 0x4a, 0x43,
+ 0x46, 0x4e, 0x4c, 0x4f, 0x47, 0x47, 0x4c, 0x4b, 0x52, 0x50, 0x50, 0x4b,
+ 0x42, 0x45, 0x54, 0x44, 0x54, 0x59, 0x4c, 0x2b, 0x4d, 0x4c, 0x55, 0x4e,
+ 0x5c, 0x5b, 0x5a, 0x42, 0x47, 0x5e, 0x56, 0x59, 0x47, 0x65, 0x55, 0x4c,
+ 0x4c, 0x59, 0x42, 0x5a, 0x4e, 0x46, 0x4e, 0x4b, 0x53, 0x46, 0x49, 0x56,
+ 0x48, 0x58, 0x4b, 0x4f, 0x45, 0x38, 0x40, 0x44, 0x49, 0x51, 0x4a, 0x3b,
+ 0x53, 0x40, 0x40, 0x48, 0x51, 0x49, 0x44, 0x46, 0x52, 0x4b, 0x4e, 0x45,
+ 0x48, 0x5a, 0x4e, 0x57, 0x44, 0x53, 0x49, 0x40, 0x4c, 0x47, 0x41, 0x4f,
+ 0x49, 0x55, 0x46, 0x50, 0x57, 0x5b, 0x48, 0x66, 0x50, 0x49, 0x51, 0x55,
+ 0x55, 0x4f, 0x47, 0x72, 0x49, 0x4f, 0x41, 0x4c, 0x49, 0x42, 0x48, 0x75,
+ 0x4a, 0x55, 0x45, 0x4a, 0x41, 0x51, 0x41, 0x70, 0x47, 0x49, 0x42, 0x52,
+ 0x4f, 0x47, 0x46, 0x63, 0x4f, 0x53, 0x46, 0x4f, 0x49, 0x53, 0x52, 0x63,
+ 0x4c, 0x59, 0x46, 0x41, 0x49, 0x51, 0x3e, 0x53, 0x45, 0x52, 0x51, 0x40,
+ 0x4f, 0x4c, 0x41, 0x4c, 0x47, 0x4a, 0x46, 0x47, 0x53, 0x47, 0x48, 0x39,
+ 0x53, 0x4b, 0x46, 0x4b, 0x50, 0x4c, 0x41, 0x40, 0x48, 0x4e, 0x49, 0x4e,
+ 0x44, 0x53, 0x44, 0x4e, 0x53, 0x49, 0x49, 0x4e, 0x46, 0x3f, 0x45, 0x42,
+ 0x4c, 0x47, 0x42, 0x4e, 0x49, 0x4a, 0x49, 0x44, 0x51, 0x48, 0x57, 0x4c,
+ 0x4d, 0x60, 0x4e, 0x2d, 0x46, 0x4d, 0x58, 0x53, 0x5c, 0x56, 0x5e, 0x41,
+ 0x3e, 0x66, 0x53, 0x5b, 0x49, 0x59, 0x5a, 0x55, 0x4e, 0x59, 0x46, 0x4a,
+ 0x44, 0x42, 0x45, 0x3d, 0x4d, 0x45, 0x44, 0x4f, 0x4d, 0x53, 0x42, 0x5a,
+ 0x43, 0x3c, 0x48, 0x4f, 0x44, 0x59, 0x3f, 0x33, 0x45, 0x48, 0x43, 0x45,
+ 0x4d, 0x56, 0x48, 0x44, 0x3e, 0x48, 0x46, 0x4d, 0x44, 0x53, 0x46, 0x4e,
+ 0x45, 0x52, 0x40, 0x46, 0x4c, 0x50, 0x4e, 0x4b, 0x4d, 0x46, 0x48, 0x46,
+ 0x50, 0x52, 0x4e, 0x57, 0x3f, 0x4a, 0x49, 0x50, 0x53, 0x4e, 0x41, 0x66,
+ 0x49, 0x4f, 0x40, 0x4b, 0x50, 0x4c, 0x4a, 0x70, 0x42, 0x51, 0x41, 0x4c,
+ 0x50, 0x4f, 0x46, 0x60, 0x45, 0x47, 0x54, 0x4c, 0x49, 0x59, 0x52, 0x61,
+ 0x4a, 0x53, 0x52, 0x4f, 0x4b, 0x4c, 0x46, 0x56, 0x4b, 0x54, 0x4f, 0x47,
+ 0x53, 0x49, 0x4f, 0x50, 0x4a, 0x54, 0x45, 0x4e, 0x47, 0x48, 0x47, 0x42,
+ 0x49, 0x44, 0x46, 0x46, 0x55, 0x4c, 0x4f, 0x36, 0x4c, 0x49, 0x3f, 0x4e,
+ 0x45, 0x4b, 0x4b, 0x36, 0x48, 0x4f, 0x4b, 0x50, 0x45, 0x47, 0x49, 0x3f,
+ 0x50, 0x4b, 0x52, 0x48, 0x4c, 0x41, 0x49, 0x43, 0x4e, 0x3c, 0x43, 0x45,
+ 0x3e, 0x45, 0x48, 0x44, 0x4d, 0x48, 0x56, 0x47, 0x4b, 0x54, 0x52, 0x2b,
+ 0x4d, 0x4e, 0x57, 0x4f, 0x57, 0x4f, 0x56, 0x43, 0x48, 0x5f, 0x4c, 0x51,
+ 0x4d, 0x58, 0x4f, 0x4e, 0x50, 0x50, 0x48, 0x4a, 0x4d, 0x3f, 0x47, 0x40,
+ 0x4b, 0x4a, 0x4e, 0x4b, 0x4a, 0x58, 0x42, 0x49, 0x3f, 0x42, 0x3d, 0x4d,
+ 0x46, 0x53, 0x45, 0x3e, 0x4e, 0x49, 0x4f, 0x4a, 0x47, 0x46, 0x40, 0x3e,
+ 0x4c, 0x4d, 0x4d, 0x45, 0x4a, 0x56, 0x40, 0x4a, 0x47, 0x57, 0x4f, 0x48,
+ 0x4f, 0x48, 0x47, 0x49, 0x4e, 0x52, 0x50, 0x48, 0x42, 0x52, 0x43, 0x5a,
+ 0x49, 0x42, 0x4f, 0x4f, 0x51, 0x51, 0x50, 0x5c, 0x4b, 0x43, 0x4b, 0x48,
+ 0x50, 0x51, 0x4b, 0x6d, 0x53, 0x4e, 0x44, 0x4c, 0x4c, 0x51, 0x46, 0x5b,
+ 0x44, 0x48, 0x4d, 0x4c, 0x46, 0x4f, 0x54, 0x54, 0x4e, 0x54, 0x42, 0x4e,
+ 0x4c, 0x49, 0x49, 0x58, 0x49, 0x53, 0x53, 0x4a, 0x4e, 0x4b, 0x47, 0x53,
+ 0x43, 0x55, 0x46, 0x51, 0x3d, 0x3d, 0x4c, 0x47, 0x4e, 0x51, 0x47, 0x48,
+ 0x4b, 0x4c, 0x42, 0x3b, 0x43, 0x4f, 0x44, 0x4d, 0x54, 0x4b, 0x4a, 0x47,
+ 0x4c, 0x42, 0x4b, 0x43, 0x41, 0x4e, 0x4d, 0x50, 0x45, 0x46, 0x41, 0x4a,
+ 0x49, 0x49, 0x54, 0x47, 0x4c, 0x4b, 0x50, 0x4e, 0x3f, 0x43, 0x40, 0x41,
+ 0x44, 0x54, 0x51, 0x47, 0x4c, 0x4b, 0x4f, 0x34, 0x4d, 0x4c, 0x4f, 0x49,
+ 0x56, 0x4e, 0x4b, 0x3e, 0x48, 0x53, 0x4e, 0x56, 0x49, 0x4e, 0x4c, 0x40,
+ 0x55, 0x4a, 0x46, 0x4f, 0x48, 0x4a, 0x55, 0x41, 0x55, 0x3d, 0x47, 0x51,
+ 0x50, 0x51, 0x45, 0x51, 0x4b, 0x4e, 0x4a, 0x4f, 0x4b, 0x45, 0x42, 0x3c,
+ 0x4e, 0x46, 0x47, 0x49, 0x4a, 0x4c, 0x48, 0x41, 0x4f, 0x4a, 0x44, 0x45,
+ 0x4e, 0x4e, 0x43, 0x41, 0x4c, 0x47, 0x48, 0x49, 0x4c, 0x48, 0x4f, 0x4a,
+ 0x4f, 0x4a, 0x4b, 0x45, 0x42, 0x40, 0x52, 0x55, 0x4f, 0x49, 0x44, 0x54,
+ 0x49, 0x48, 0x51, 0x4d, 0x44, 0x4a, 0x4d, 0x49, 0x4e, 0x4e, 0x51, 0x5d,
+ 0x42, 0x4d, 0x49, 0x3f, 0x48, 0x58, 0x40, 0x5e, 0x48, 0x4f, 0x49, 0x53,
+ 0x45, 0x47, 0x4f, 0x53, 0x4d, 0x4f, 0x4d, 0x4d, 0x46, 0x55, 0x43, 0x51,
+ 0x4f, 0x51, 0x4a, 0x4e, 0x49, 0x42, 0x49, 0x50, 0x47, 0x4d, 0x42, 0x47,
+ 0x46, 0x50, 0x55, 0x47, 0x4d, 0x47, 0x3e, 0x51, 0x4d, 0x43, 0x44, 0x39,
+ 0x4e, 0x4b, 0x41, 0x48, 0x52, 0x53, 0x4d, 0x39, 0x4d, 0x51, 0x4c, 0x46,
+ 0x4e, 0x47, 0x49, 0x41, 0x45, 0x4a, 0x4a, 0x45, 0x50, 0x4a, 0x40, 0x48,
+ 0x43, 0x47, 0x44, 0x50, 0x4d, 0x47, 0x4a, 0x47, 0x45, 0x57, 0x41, 0x34,
+ 0x51, 0x40, 0x45, 0x44, 0x3c, 0x47, 0x46, 0x47, 0x44, 0x48, 0x42, 0x40,
+ 0x37, 0x53, 0x4a, 0x43, 0x49, 0x4b, 0x43, 0x44, 0x4f, 0x4f, 0x48, 0x48,
+ 0x53, 0x49, 0x4b, 0x48, 0x4e, 0x4c, 0x42, 0x45, 0x4c, 0x4a, 0x4a, 0x46,
+ 0x47, 0x57, 0x3e, 0x46, 0x46, 0x45, 0x4a, 0x43, 0x46, 0x49, 0x43, 0x52,
+ 0x3e, 0x48, 0x4a, 0x4b, 0x47, 0x47, 0x48, 0x4a, 0x4b, 0x4b, 0x4e, 0x44,
+ 0x42, 0x44, 0x50, 0x41, 0x49, 0x49, 0x4d, 0x4b, 0x44, 0x46, 0x4a, 0x52,
+ 0x4d, 0x47, 0x49, 0x4b, 0x4d, 0x49, 0x41, 0x48, 0x4b, 0x3f, 0x45, 0x4f,
+ 0x51, 0x41, 0x55, 0x42, 0x49, 0x4b, 0x4b, 0x51, 0x4f, 0x4f, 0x42, 0x4e,
+ 0x4e, 0x4a, 0x52, 0x41, 0x4f, 0x42, 0x48, 0x3d, 0x4a, 0x44, 0x50, 0x4b,
+ 0x49, 0x45, 0x51, 0x46, 0x51, 0x44, 0x4d, 0x47, 0x4a, 0x4a, 0x4d, 0x49,
+ 0x4d, 0x48, 0x4d, 0x4f, 0x4d, 0x44, 0x48, 0x4e, 0x4a, 0x4b, 0x40, 0x4f,
+ 0x47, 0x3a, 0x41, 0x47, 0x4a, 0x4a, 0x4a, 0x48, 0x42, 0x41, 0x4d, 0x56,
+ 0x3f, 0x52, 0x4d, 0x4c, 0x44, 0x48, 0x47, 0x4e, 0x51, 0x4c, 0x49, 0x47,
+ 0x44, 0x4c, 0x4b, 0x47, 0x48, 0x46, 0x47, 0x4f, 0x43, 0x41, 0x3e, 0x47,
+ 0x53, 0x4a, 0x46, 0x42, 0x46, 0x61, 0x43, 0x30, 0x4e, 0x52, 0x43, 0x45,
+ 0x32, 0x4a, 0x45, 0x48, 0x51, 0x3e, 0x44, 0x3b, 0x3a, 0x63, 0x4c, 0x46,
+ 0x4c, 0x49, 0x3d, 0x41, 0x52, 0x53, 0x43, 0x43, 0x45, 0x3d, 0x48, 0x40,
+ 0x4b, 0x4a, 0x49, 0x48, 0x4d, 0x49, 0x4b, 0x4c, 0x3f, 0x4e, 0x4b, 0x47,
+ 0x45, 0x4d, 0x3f, 0x4d, 0x43, 0x50, 0x48, 0x4b, 0x54, 0x3e, 0x44, 0x4e,
+ 0x3e, 0x4c, 0x43, 0x4b, 0x4c, 0x4b, 0x3e, 0x49, 0x50, 0x52, 0x4a, 0x4a,
+ 0x50, 0x50, 0x43, 0x4e, 0x49, 0x48, 0x51, 0x50, 0x47, 0x3d, 0x45, 0x4b,
+ 0x47, 0x46, 0x4d, 0x4c, 0x45, 0x4d, 0x4a, 0x4d, 0x42, 0x4d, 0x47, 0x4f,
+ 0x40, 0x43, 0x46, 0x51, 0x47, 0x4b, 0x43, 0x49, 0x49, 0x50, 0x4b, 0x4b,
+ 0x46, 0x4a, 0x4c, 0x48, 0x49, 0x47, 0x4b, 0x56, 0x55, 0x4f, 0x49, 0x4f,
+ 0x4f, 0x4e, 0x4b, 0x49, 0x4a, 0x4a, 0x49, 0x47, 0x44, 0x4b, 0x47, 0x50,
+ 0x46, 0x4c, 0x46, 0x4c, 0x4b, 0x4e, 0x49, 0x57, 0x4d, 0x3e, 0x46, 0x47,
+ 0x50, 0x45, 0x4f, 0x52, 0x3e, 0x4d, 0x49, 0x4a, 0x40, 0x49, 0x4f, 0x5c,
+ 0x3e, 0x4a, 0x47, 0x45, 0x47, 0x41, 0x44, 0x3f, 0x4b, 0x4a, 0x52, 0x43,
+ 0x41, 0x43, 0x43, 0x47, 0x55, 0x49, 0x42, 0x4c, 0x58, 0x4b, 0x42, 0x48,
+ 0x4b, 0x5a, 0x36, 0x33, 0x53, 0x57, 0x4d, 0x4a, 0x37, 0x4c, 0x3e, 0x48,
+ 0x43, 0x46, 0x39, 0x3c, 0x34, 0x65, 0x47, 0x3d, 0x47, 0x42, 0x3c, 0x3e,
+ 0x45, 0x5b, 0x44, 0x3e, 0x45, 0x43, 0x46, 0x43, 0x59, 0x4e, 0x48, 0x46,
+ 0x43, 0x3f, 0x46, 0x47, 0x4e, 0x53, 0x50, 0x4b, 0x4a, 0x3f, 0x4a, 0x54,
+ 0x4c, 0x4a, 0x43, 0x50, 0x4c, 0x42, 0x4d, 0x55, 0x4d, 0x51, 0x51, 0x46,
+ 0x49, 0x41, 0x50, 0x44, 0x4a, 0x4b, 0x4b, 0x43, 0x4b, 0x4e, 0x47, 0x4b,
+ 0x3e, 0x4e, 0x44, 0x4d, 0x49, 0x41, 0x49, 0x44, 0x50, 0x4d, 0x45, 0x4e,
+ 0x4b, 0x50, 0x45, 0x4c, 0x46, 0x4a, 0x46, 0x42, 0x50, 0x45, 0x48, 0x53,
+ 0x4d, 0x44, 0x42, 0x50, 0x4c, 0x49, 0x45, 0x55, 0x4d, 0x42, 0x43, 0x41,
+ 0x4c, 0x41, 0x4e, 0x4d, 0x42, 0x4e, 0x3f, 0x44, 0x4d, 0x4c, 0x4b, 0x4a,
+ 0x47, 0x47, 0x4e, 0x54, 0x43, 0x40, 0x41, 0x55, 0x49, 0x49, 0x4e, 0x49,
+ 0x52, 0x4e, 0x46, 0x58, 0x4b, 0x3d, 0x4a, 0x44, 0x4e, 0x47, 0x53, 0x58,
+ 0x47, 0x42, 0x52, 0x46, 0x49, 0x4b, 0x47, 0x5a, 0x4c, 0x46, 0x46, 0x49,
+ 0x4b, 0x4d, 0x3d, 0x48, 0x40, 0x54, 0x48, 0x4c, 0x4c, 0x44, 0x4c, 0x46,
+ 0x47, 0x4b, 0x4d, 0x44, 0x5a, 0x4a, 0x3e, 0x46, 0x48, 0x53, 0x39, 0x30,
+ 0x51, 0x60, 0x4d, 0x47, 0x35, 0x4f, 0x45, 0x45, 0x4a, 0x4b, 0x42, 0x3f,
+ 0x38, 0x6c, 0x3d, 0x40, 0x44, 0x48, 0x3a, 0x3b, 0x46, 0x5e, 0x45, 0x3b,
+ 0x47, 0x47, 0x45, 0x42, 0x53, 0x55, 0x44, 0x45, 0x46, 0x43, 0x48, 0x48,
+ 0x52, 0x5d, 0x3e, 0x41, 0x53, 0x42, 0x48, 0x55, 0x49, 0x4d, 0x4a, 0x46,
+ 0x52, 0x46, 0x51, 0x48, 0x44, 0x46, 0x48, 0x41, 0x49, 0x49, 0x49, 0x49,
+ 0x41, 0x4d, 0x40, 0x4f, 0x45, 0x46, 0x45, 0x3f, 0x53, 0x40, 0x46, 0x43,
+ 0x47, 0x4d, 0x50, 0x4c, 0x55, 0x48, 0x45, 0x47, 0x4f, 0x46, 0x42, 0x4d,
+ 0x41, 0x48, 0x46, 0x4e, 0x42, 0x48, 0x48, 0x45, 0x41, 0x45, 0x48, 0x4a,
+ 0x40, 0x49, 0x43, 0x4b, 0x48, 0x4a, 0x4c, 0x45, 0x4b, 0x48, 0x48, 0x4f,
+ 0x40, 0x4b, 0x4a, 0x44, 0x50, 0x4a, 0x43, 0x50, 0x4c, 0x44, 0x46, 0x4c,
+ 0x42, 0x44, 0x4e, 0x55, 0x47, 0x49, 0x48, 0x47, 0x52, 0x4e, 0x44, 0x59,
+ 0x4e, 0x44, 0x4a, 0x48, 0x49, 0x4a, 0x42, 0x4e, 0x3e, 0x39, 0x51, 0x45,
+ 0x4d, 0x49, 0x4f, 0x54, 0x51, 0x4b, 0x50, 0x44, 0x53, 0x4f, 0x4d, 0x48,
+ 0x42, 0x45, 0x4e, 0x40, 0x4a, 0x48, 0x43, 0x48, 0x52, 0x54, 0x4d, 0x49,
+ 0x5f, 0x53, 0x46, 0x4e, 0x3f, 0x5a, 0x36, 0x31, 0x52, 0x60, 0x4b, 0x4a,
+ 0x32, 0x51, 0x40, 0x44, 0x46, 0x52, 0x44, 0x41, 0x3a, 0x6e, 0x41, 0x3e,
+ 0x47, 0x3e, 0x3a, 0x2a, 0x44, 0x5a, 0x40, 0x3c, 0x4d, 0x48, 0x46, 0x3b,
+ 0x5e, 0x58, 0x4d, 0x47, 0x51, 0x3a, 0x4b, 0x48, 0x5b, 0x5a, 0x54, 0x43,
+ 0x50, 0x4c, 0x54, 0x54, 0x49, 0x47, 0x4f, 0x48, 0x50, 0x40, 0x4f, 0x4a,
+ 0x42, 0x42, 0x3c, 0x41, 0x43, 0x4e, 0x53, 0x49, 0x4b, 0x4d, 0x49, 0x41,
+ 0x4c, 0x3e, 0x40, 0x49, 0x40, 0x44, 0x49, 0x4f, 0x50, 0x4a, 0x42, 0x3a,
+ 0x49, 0x4b, 0x47, 0x50, 0x49, 0x41, 0x52, 0x46, 0x3d, 0x44, 0x46, 0x43,
+ 0x4b, 0x4b, 0x4d, 0x4b, 0x4e, 0x40, 0x45, 0x43, 0x48, 0x44, 0x55, 0x51,
+ 0x4a, 0x46, 0x4e, 0x40, 0x53, 0x4a, 0x45, 0x41, 0x48, 0x48, 0x45, 0x4e,
+ 0x4a, 0x48, 0x40, 0x4c, 0x54, 0x44, 0x42, 0x4d, 0x49, 0x43, 0x45, 0x4c,
+ 0x43, 0x4f, 0x46, 0x3f, 0x46, 0x4f, 0x4b, 0x59, 0x46, 0x49, 0x54, 0x47,
+ 0x49, 0x46, 0x45, 0x53, 0x4a, 0x49, 0x54, 0x45, 0x41, 0x45, 0x4c, 0x5e,
+ 0x50, 0x3d, 0x4d, 0x49, 0x55, 0x4b, 0x49, 0x47, 0x4c, 0x4f, 0x43, 0x3d,
+ 0x41, 0x4b, 0x43, 0x46, 0x4f, 0x4a, 0x4c, 0x54, 0x5e, 0x4e, 0x40, 0x4d,
+ 0x3d, 0x59, 0x40, 0x28, 0x54, 0x5f, 0x4d, 0x4b, 0x36, 0x51, 0x3a, 0x47,
+ 0x4a, 0x55, 0x42, 0x43, 0x3b, 0x72, 0x3b, 0x3d, 0x51, 0x42, 0x3f, 0x2d,
+ 0x4b, 0x5a, 0x48, 0x44, 0x49, 0x49, 0x3d, 0x39, 0x56, 0x55, 0x46, 0x46,
+ 0x4b, 0x43, 0x40, 0x4a, 0x52, 0x56, 0x4d, 0x45, 0x4b, 0x48, 0x40, 0x5a,
+ 0x4e, 0x3a, 0x53, 0x48, 0x4c, 0x44, 0x49, 0x4e, 0x42, 0x47, 0x46, 0x40,
+ 0x51, 0x42, 0x50, 0x4b, 0x43, 0x53, 0x44, 0x44, 0x46, 0x4c, 0x4c, 0x3c,
+ 0x42, 0x45, 0x42, 0x45, 0x44, 0x4b, 0x52, 0x3d, 0x47, 0x4b, 0x4c, 0x4e,
+ 0x52, 0x4a, 0x4e, 0x41, 0x3f, 0x46, 0x43, 0x54, 0x44, 0x53, 0x4e, 0x48,
+ 0x40, 0x41, 0x4f, 0x45, 0x43, 0x3c, 0x52, 0x49, 0x40, 0x44, 0x4a, 0x3f,
+ 0x4d, 0x4c, 0x4f, 0x47, 0x44, 0x47, 0x55, 0x47, 0x50, 0x4d, 0x4a, 0x4c,
+ 0x50, 0x48, 0x47, 0x55, 0x4b, 0x4a, 0x52, 0x49, 0x3d, 0x3f, 0x4f, 0x51,
+ 0x48, 0x4e, 0x42, 0x4e, 0x42, 0x48, 0x4e, 0x49, 0x4a, 0x50, 0x45, 0x54,
+ 0x41, 0x43, 0x45, 0x4d, 0x48, 0x48, 0x48, 0x51, 0x53, 0x3e, 0x55, 0x44,
+ 0x52, 0x56, 0x44, 0x4d, 0x4e, 0x48, 0x4b, 0x43, 0x48, 0x53, 0x48, 0x44,
+ 0x49, 0x45, 0x4e, 0x50, 0x5d, 0x4a, 0x45, 0x4c, 0x45, 0x55, 0x43, 0x2e,
+ 0x59, 0x60, 0x4e, 0x4d, 0x32, 0x53, 0x3e, 0x3f, 0x40, 0x63, 0x41, 0x48,
+ 0x38, 0x73, 0x38, 0x46, 0x50, 0x3e, 0x3c, 0x23, 0x48, 0x61, 0x45, 0x3c,
+ 0x41, 0x41, 0x36, 0x3b, 0x58, 0x56, 0x4a, 0x40, 0x4f, 0x44, 0x45, 0x4c,
+ 0x5a, 0x56, 0x47, 0x3f, 0x4d, 0x4b, 0x46, 0x5d, 0x52, 0x47, 0x45, 0x4c,
+ 0x4a, 0x52, 0x4f, 0x4f, 0x4f, 0x43, 0x4f, 0x47, 0x43, 0x46, 0x3c, 0x4c,
+ 0x46, 0x55, 0x40, 0x53, 0x43, 0x3e, 0x42, 0x35, 0x51, 0x41, 0x42, 0x3f,
+ 0x45, 0x3d, 0x41, 0x31, 0x4e, 0x47, 0x48, 0x42, 0x41, 0x45, 0x43, 0x38,
+ 0x42, 0x40, 0x4a, 0x47, 0x4e, 0x43, 0x40, 0x43, 0x48, 0x49, 0x45, 0x4f,
+ 0x44, 0x42, 0x4d, 0x42, 0x42, 0x3f, 0x46, 0x52, 0x3c, 0x3c, 0x47, 0x43,
+ 0x46, 0x47, 0x45, 0x40, 0x4c, 0x44, 0x43, 0x4a, 0x4b, 0x4d, 0x4e, 0x46,
+ 0x51, 0x45, 0x47, 0x4b, 0x45, 0x50, 0x40, 0x42, 0x4c, 0x4c, 0x4c, 0x4f,
+ 0x44, 0x3c, 0x49, 0x3c, 0x3f, 0x45, 0x3f, 0x5c, 0x42, 0x3e, 0x4b, 0x4e,
+ 0x50, 0x45, 0x42, 0x5c, 0x4c, 0x48, 0x50, 0x52, 0x50, 0x47, 0x4b, 0x44,
+ 0x3d, 0x50, 0x55, 0x4c, 0x48, 0x3f, 0x4b, 0x44, 0x4a, 0x51, 0x42, 0x4c,
+ 0x60, 0x51, 0x41, 0x4b, 0x46, 0x5c, 0x42, 0x2c, 0x55, 0x61, 0x50, 0x52,
+ 0x37, 0x5a, 0x3f, 0x43, 0x43, 0x58, 0x3a, 0x4d, 0x3e, 0x72, 0x35, 0x3f,
+ 0x58, 0x41, 0x40, 0x1f, 0x55, 0x63, 0x3f, 0x49, 0x41, 0x3e, 0x35, 0x41,
+ 0x65, 0x54, 0x42, 0x45, 0x45, 0x3c, 0x44, 0x45, 0x59, 0x5a, 0x4d, 0x41,
+ 0x51, 0x46, 0x49, 0x59, 0x4c, 0x41, 0x42, 0x44, 0x4a, 0x45, 0x3f, 0x4a,
+ 0x4a, 0x44, 0x48, 0x48, 0x52, 0x40, 0x4a, 0x4a, 0x4d, 0x54, 0x44, 0x48,
+ 0x54, 0x46, 0x49, 0x3b, 0x42, 0x4a, 0x4e, 0x46, 0x4a, 0x45, 0x4f, 0x30,
+ 0x46, 0x41, 0x47, 0x46, 0x4b, 0x47, 0x46, 0x38, 0x4c, 0x3a, 0x4b, 0x46,
+ 0x52, 0x48, 0x4f, 0x3e, 0x48, 0x4a, 0x48, 0x4b, 0x44, 0x45, 0x4a, 0x46,
+ 0x3f, 0x4f, 0x40, 0x44, 0x43, 0x43, 0x4b, 0x39, 0x46, 0x43, 0x49, 0x49,
+ 0x49, 0x4a, 0x44, 0x48, 0x4c, 0x41, 0x4d, 0x52, 0x4c, 0x4a, 0x46, 0x3d,
+ 0x41, 0x4b, 0x41, 0x48, 0x45, 0x3b, 0x51, 0x54, 0x4a, 0x39, 0x4d, 0x41,
+ 0x54, 0x46, 0x4c, 0x53, 0x48, 0x3e, 0x4a, 0x3d, 0x41, 0x52, 0x54, 0x63,
+ 0x44, 0x4d, 0x4a, 0x43, 0x52, 0x4b, 0x52, 0x52, 0x4e, 0x41, 0x48, 0x42,
+ 0x48, 0x4d, 0x49, 0x45, 0x51, 0x48, 0x3e, 0x47, 0x5a, 0x52, 0x4a, 0x4e,
+ 0x3e, 0x59, 0x3c, 0x2e, 0x5c, 0x5b, 0x4c, 0x56, 0x30, 0x59, 0x3a, 0x48,
+ 0x3d, 0x5c, 0x44, 0x49, 0x40, 0x7c, 0x3a, 0x48, 0x54, 0x40, 0x41, 0x28,
+ 0x4d, 0x64, 0x46, 0x47, 0x49, 0x40, 0x30, 0x3a, 0x5f, 0x5b, 0x42, 0x37,
+ 0x49, 0x45, 0x40, 0x43, 0x5b, 0x54, 0x48, 0x4d, 0x4a, 0x47, 0x51, 0x58,
+ 0x4b, 0x3c, 0x4d, 0x46, 0x4b, 0x52, 0x4c, 0x58, 0x53, 0x46, 0x42, 0x45,
+ 0x4c, 0x4a, 0x4d, 0x4e, 0x52, 0x4d, 0x46, 0x44, 0x46, 0x3f, 0x46, 0x34,
+ 0x4f, 0x42, 0x44, 0x46, 0x44, 0x50, 0x47, 0x30, 0x44, 0x3c, 0x42, 0x46,
+ 0x4f, 0x4a, 0x52, 0x30, 0x55, 0x4f, 0x45, 0x4a, 0x48, 0x4c, 0x4e, 0x35,
+ 0x4e, 0x3c, 0x45, 0x4a, 0x45, 0x4a, 0x44, 0x3c, 0x4e, 0x4a, 0x51, 0x44,
+ 0x49, 0x40, 0x4a, 0x40, 0x41, 0x44, 0x4f, 0x4c, 0x43, 0x45, 0x4b, 0x43,
+ 0x3e, 0x3e, 0x4c, 0x44, 0x48, 0x48, 0x42, 0x42, 0x4d, 0x43, 0x50, 0x4d,
+ 0x49, 0x3c, 0x45, 0x4f, 0x4c, 0x46, 0x4b, 0x48, 0x4d, 0x4d, 0x49, 0x55,
+ 0x49, 0x3b, 0x40, 0x44, 0x4a, 0x4b, 0x4e, 0x5e, 0x43, 0x47, 0x45, 0x43,
+ 0x4d, 0x4d, 0x49, 0x46, 0x4a, 0x44, 0x4e, 0x3e, 0x52, 0x41, 0x47, 0x47,
+ 0x4a, 0x50, 0x48, 0x43, 0x5d, 0x4f, 0x49, 0x48, 0x43, 0x4f, 0x45, 0x3e,
+ 0x5a, 0x69, 0x4d, 0x5a, 0x3a, 0x5d, 0x3a, 0x48, 0x42, 0x55, 0x3e, 0x48,
+ 0x48, 0x7b, 0x37, 0x40, 0x57, 0x45, 0x48, 0x24, 0x50, 0x61, 0x4c, 0x4a,
+ 0x44, 0x41, 0x34, 0x38, 0x65, 0x5b, 0x4f, 0x3c, 0x4d, 0x3a, 0x4a, 0x4c,
+ 0x66, 0x55, 0x50, 0x47, 0x4d, 0x46, 0x47, 0x58, 0x4c, 0x48, 0x48, 0x48,
+ 0x4e, 0x59, 0x4f, 0x4b, 0x45, 0x45, 0x4b, 0x54, 0x46, 0x51, 0x4f, 0x44,
+ 0x42, 0x55, 0x48, 0x44, 0x48, 0x41, 0x53, 0x2e, 0x4d, 0x45, 0x44, 0x54,
+ 0x4a, 0x44, 0x53, 0x34, 0x4c, 0x46, 0x47, 0x3f, 0x4c, 0x4b, 0x47, 0x36,
+ 0x47, 0x41, 0x43, 0x40, 0x51, 0x46, 0x45, 0x33, 0x46, 0x3e, 0x47, 0x50,
+ 0x3f, 0x48, 0x48, 0x37, 0x41, 0x41, 0x42, 0x3e, 0x45, 0x3d, 0x49, 0x3e,
+ 0x4f, 0x42, 0x49, 0x4a, 0x46, 0x46, 0x48, 0x44, 0x49, 0x45, 0x46, 0x4a,
+ 0x4a, 0x47, 0x48, 0x43, 0x44, 0x45, 0x3f, 0x4c, 0x4c, 0x49, 0x4d, 0x51,
+ 0x4a, 0x4a, 0x49, 0x4c, 0x42, 0x4d, 0x4b, 0x4b, 0x4a, 0x42, 0x47, 0x4d,
+ 0x3e, 0x4b, 0x47, 0x5c, 0x49, 0x3d, 0x4e, 0x41, 0x44, 0x49, 0x3e, 0x3e,
+ 0x4b, 0x47, 0x4e, 0x45, 0x44, 0x4a, 0x4d, 0x4a, 0x4f, 0x46, 0x45, 0x52,
+ 0x60, 0x53, 0x49, 0x50, 0x3d, 0x4f, 0x43, 0x3d, 0x52, 0x64, 0x52, 0x58,
+ 0x39, 0x5f, 0x36, 0x4c, 0x45, 0x57, 0x42, 0x4b, 0x3f, 0x80, 0x34, 0x47,
+ 0x58, 0x41, 0x45, 0x1b, 0x4b, 0x5e, 0x4c, 0x40, 0x44, 0x42, 0x39, 0x3a,
+ 0x5e, 0x5b, 0x4b, 0x3a, 0x4b, 0x3f, 0x45, 0x3e, 0x69, 0x57, 0x4b, 0x45,
+ 0x4b, 0x3f, 0x45, 0x55, 0x49, 0x49, 0x48, 0x47, 0x41, 0x4f, 0x42, 0x53,
+ 0x49, 0x40, 0x42, 0x3e, 0x49, 0x47, 0x53, 0x47, 0x45, 0x51, 0x4a, 0x44,
+ 0x44, 0x45, 0x4e, 0x2a, 0x45, 0x42, 0x4a, 0x4b, 0x46, 0x4d, 0x41, 0x30,
+ 0x3d, 0x43, 0x3f, 0x48, 0x49, 0x44, 0x4d, 0x2e, 0x48, 0x4a, 0x4c, 0x51,
+ 0x50, 0x46, 0x3e, 0x2c, 0x4d, 0x3f, 0x47, 0x46, 0x3c, 0x40, 0x4c, 0x38,
+ 0x4f, 0x46, 0x47, 0x53, 0x3b, 0x3c, 0x4e, 0x3e, 0x49, 0x40, 0x43, 0x4c,
+ 0x4d, 0x48, 0x45, 0x3c, 0x4d, 0x4c, 0x4d, 0x45, 0x3f, 0x49, 0x4a, 0x43,
+ 0x4d, 0x41, 0x4b, 0x50, 0x4e, 0x46, 0x50, 0x44, 0x49, 0x44, 0x4e, 0x42,
+ 0x4a, 0x43, 0x4c, 0x4c, 0x49, 0x49, 0x44, 0x4e, 0x4b, 0x3f, 0x4b, 0x5d,
+ 0x41, 0x49, 0x4b, 0x46, 0x4e, 0x48, 0x45, 0x51, 0x4d, 0x45, 0x46, 0x45,
+ 0x4b, 0x4e, 0x3c, 0x4d, 0x3d, 0x41, 0x47, 0x47, 0x64, 0x54, 0x41, 0x55,
+ 0x47, 0x56, 0x44, 0x3b, 0x53, 0x66, 0x4f, 0x5e, 0x40, 0x5d, 0x38, 0x4a,
+ 0x41, 0x59, 0x42, 0x48, 0x47, 0xff, 0x36, 0x49, 0x59, 0x41, 0x43, 0x1d,
+ 0x4d, 0x5e, 0x44, 0x44, 0x50, 0x3f, 0x39, 0x40, 0x68, 0x5e, 0x4a, 0x41,
+ 0x52, 0x41, 0x43, 0x41, 0x68, 0x51, 0x45, 0x48, 0x4c, 0x46, 0x4a, 0x5e,
+ 0x4e, 0x40, 0x4d, 0x41, 0x41, 0x5c, 0x3f, 0x4e, 0x4c, 0x37, 0x48, 0x40,
+ 0x46, 0x47, 0x4f, 0x43, 0x53, 0x52, 0x3d, 0x44, 0x47, 0x44, 0x3d, 0x34,
+ 0x44, 0x42, 0x4a, 0x43, 0x4d, 0x3f, 0x53, 0x2e, 0x42, 0x47, 0x43, 0x4d,
+ 0x45, 0x45, 0x47, 0x31, 0x4d, 0x39, 0x41, 0x4a, 0x4a, 0x4d, 0x4b, 0x35,
+ 0x47, 0x4e, 0x4c, 0x40, 0x4a, 0x44, 0x44, 0x36, 0x3e, 0x49, 0x3f, 0x45,
+ 0x46, 0x43, 0x4e, 0x3c, 0x4d, 0x47, 0x4c, 0x48, 0x4a, 0x4b, 0x48, 0x39,
+ 0x46, 0x50, 0x4a, 0x4f, 0x46, 0x41, 0x44, 0x4a, 0x41, 0x4f, 0x4c, 0x4e,
+ 0x55, 0x46, 0x43, 0x46, 0x4a, 0x48, 0x4e, 0x46, 0x42, 0x40, 0x4f, 0x56,
+ 0x4c, 0x45, 0x4b, 0x46, 0x4a, 0x47, 0x42, 0x5e, 0x49, 0x4e, 0x46, 0x43,
+ 0x4e, 0x42, 0x45, 0x48, 0x47, 0x48, 0x4f, 0x45, 0x47, 0x51, 0x4b, 0x4c,
+ 0x51, 0x39, 0x4d, 0x48, 0x60, 0x57, 0x49, 0x52, 0x3d, 0x57, 0x46, 0x3d,
+ 0x53, 0x68, 0x4b, 0x60, 0x40, 0x5a, 0x41, 0x4b, 0x46, 0x56, 0x46, 0x4c,
+ 0x49, 0x7e, 0x2f, 0x48, 0x51, 0x42, 0x40, 0x20, 0x4b, 0x62, 0x4d, 0x41,
+ 0x4f, 0x43, 0x3d, 0x35, 0x63, 0x63, 0x46, 0x3e, 0x4e, 0x47, 0x40, 0x40,
+ 0x60, 0x52, 0x4c, 0x46, 0x49, 0x48, 0x4f, 0x56, 0x51, 0x47, 0x52, 0x4e,
+ 0x4b, 0x59, 0x55, 0x4f, 0x48, 0x3d, 0x48, 0x4a, 0x4d, 0x50, 0x47, 0x47,
+ 0x51, 0x52, 0x4d, 0x51, 0x45, 0x45, 0x47, 0x2d, 0x4d, 0x41, 0x43, 0x49,
+ 0x4d, 0x40, 0x4a, 0x2f, 0x4f, 0x43, 0x46, 0x4a, 0x3e, 0x4a, 0x4a, 0x2b,
+ 0x49, 0x4c, 0x4c, 0x3e, 0x41, 0x4c, 0x4a, 0x2b, 0x40, 0x44, 0x46, 0x4a,
+ 0x40, 0x44, 0x42, 0x38, 0x52, 0x42, 0x46, 0x51, 0x53, 0x4e, 0x45, 0x31,
+ 0x45, 0x47, 0x4f, 0x46, 0x49, 0x43, 0x45, 0x3b, 0x4b, 0x4b, 0x4b, 0x4c,
+ 0x43, 0x4a, 0x4c, 0x43, 0x4e, 0x40, 0x52, 0x44, 0x48, 0x49, 0x47, 0x4b,
+ 0x4e, 0x3d, 0x4e, 0x44, 0x48, 0x4d, 0x4f, 0x4f, 0x50, 0x36, 0x47, 0x41,
+ 0x4a, 0x44, 0x45, 0x56, 0x4f, 0x4c, 0x50, 0x4b, 0x45, 0x3e, 0x45, 0x4e,
+ 0x45, 0x45, 0x43, 0x40, 0x47, 0x4e, 0x45, 0x3e, 0x4a, 0x3f, 0x49, 0x50,
+ 0x62, 0x55, 0x48, 0x56, 0x3e, 0x57, 0x4f, 0x3b, 0x55, 0x6c, 0x50, 0x5c,
+ 0x3d, 0x54, 0x3d, 0x46, 0x43, 0x59, 0x3e, 0x51, 0x4d, 0x7b, 0x33, 0x47,
+ 0x52, 0x43, 0x3f, 0x25, 0x4a, 0x6f, 0x49, 0x3e, 0x50, 0x40, 0x41, 0x30,
+ 0x5e, 0x5c, 0x4a, 0x43, 0x4d, 0x42, 0x46, 0x3b, 0x63, 0x53, 0x4f, 0x43,
+ 0x58, 0x48, 0x4b, 0x59, 0x50, 0x4e, 0x4b, 0x51, 0x4a, 0x55, 0x44, 0x46,
+ 0x4c, 0x3d, 0x4c, 0x52, 0x44, 0x52, 0x4c, 0x41, 0x4f, 0x44, 0x4a, 0x47,
+ 0x4e, 0x48, 0x49, 0x2e, 0x3e, 0x45, 0x4c, 0x48, 0x41, 0x47, 0x4d, 0x2e,
+ 0x40, 0x4b, 0x4c, 0x42, 0x4d, 0x40, 0x4e, 0x2e, 0x43, 0x45, 0x4b, 0x43,
+ 0x3e, 0x49, 0x55, 0x35, 0x43, 0x42, 0x42, 0x40, 0x4e, 0x46, 0x44, 0x37,
+ 0x49, 0x41, 0x3f, 0x52, 0x47, 0x4b, 0x43, 0x33, 0x4b, 0x47, 0x4b, 0x4c,
+ 0x4d, 0x4b, 0x3f, 0x42, 0x44, 0x40, 0x49, 0x41, 0x42, 0x49, 0x4b, 0x46,
+ 0x4e, 0x4e, 0x47, 0x4e, 0x48, 0x48, 0x4b, 0x46, 0x51, 0x4b, 0x46, 0x4d,
+ 0x47, 0x4f, 0x3e, 0x51, 0x46, 0x4e, 0x46, 0x4b, 0x47, 0x48, 0x4e, 0x55,
+ 0x4c, 0x3d, 0x47, 0x51, 0x42, 0x45, 0x4f, 0x42, 0x52, 0x50, 0x44, 0x4c,
+ 0x44, 0x44, 0x43, 0x4d, 0x40, 0x42, 0x4d, 0x4b, 0x5d, 0x4e, 0x47, 0x54,
+ 0x47, 0x51, 0x43, 0x39, 0x58, 0x66, 0x4e, 0x5a, 0x41, 0x52, 0x36, 0x47,
+ 0x45, 0x5f, 0x34, 0x50, 0x46, 0x79, 0x30, 0x48, 0x50, 0x45, 0x32, 0x22,
+ 0x54, 0x64, 0x49, 0x46, 0x45, 0x3c, 0x42, 0x36, 0x65, 0x5c, 0x48, 0x3a,
+ 0x4d, 0x4b, 0x47, 0x3e, 0x63, 0x56, 0x4a, 0x48, 0x51, 0x42, 0x4f, 0x5e,
+ 0x4c, 0x44, 0x4b, 0x4c, 0x3d, 0x5a, 0x43, 0x4d, 0x42, 0x40, 0x4f, 0x4d,
+ 0x3f, 0x3e, 0x46, 0x40, 0x49, 0x42, 0x49, 0x40, 0x49, 0x4c, 0x4a, 0x2e,
+ 0x4b, 0x3f, 0x53, 0x4b, 0x48, 0x49, 0x3e, 0x34, 0x47, 0x4a, 0x4b, 0x46,
+ 0x3b, 0x49, 0x46, 0x34, 0x4b, 0x48, 0x4c, 0x49, 0x49, 0x43, 0x4f, 0x2e,
+ 0x44, 0x46, 0x48, 0x50, 0x46, 0x4e, 0x4a, 0x37, 0x4b, 0x4c, 0x4a, 0x50,
+ 0x45, 0x4a, 0x48, 0x3b, 0x48, 0x44, 0x48, 0x4a, 0x41, 0x44, 0x52, 0x3f,
+ 0x4c, 0x46, 0x4a, 0x45, 0x46, 0x49, 0x49, 0x36, 0x53, 0x3e, 0x48, 0x47,
+ 0x3f, 0x42, 0x41, 0x4c, 0x42, 0x4a, 0x52, 0x46, 0x49, 0x3f, 0x48, 0x5a,
+ 0x43, 0x42, 0x3d, 0x43, 0x4f, 0x44, 0x43, 0x65, 0x41, 0x41, 0x44, 0x4b,
+ 0x50, 0x44, 0x53, 0x49, 0x41, 0x45, 0x4a, 0x4d, 0x40, 0x45, 0x4a, 0x4e,
+ 0x50, 0x40, 0x51, 0x40, 0x5e, 0x50, 0x43, 0x5c, 0x47, 0x5a, 0x44, 0x4c,
+ 0x54, 0x64, 0x4f, 0x63, 0x39, 0x58, 0x3c, 0x4a, 0x42, 0x5e, 0x3c, 0x4a,
+ 0x48, 0x7b, 0x34, 0x4c, 0x4f, 0x44, 0x30, 0x24, 0x50, 0x65, 0x47, 0x39,
+ 0x46, 0x3e, 0x3f, 0x33, 0x65, 0x5a, 0x44, 0x38, 0x50, 0x47, 0x4b, 0x3e,
+ 0x5b, 0x53, 0x4a, 0x4d, 0x51, 0x40, 0x47, 0x59, 0x51, 0x42, 0x4f, 0x50,
+ 0x45, 0x57, 0x46, 0x50, 0x3f, 0x3c, 0x4c, 0x4f, 0x46, 0x41, 0x4a, 0x3e,
+ 0x4d, 0x45, 0x51, 0x48, 0x4e, 0x44, 0x4e, 0x35, 0x44, 0x3f, 0x44, 0x48,
+ 0x3c, 0x4c, 0x49, 0x2c, 0x4a, 0x46, 0x48, 0x44, 0x4b, 0x42, 0x4b, 0x2f,
+ 0x4e, 0x50, 0x4c, 0x4d, 0x44, 0x46, 0x3f, 0x39, 0x4d, 0x47, 0x45, 0x41,
+ 0x42, 0x47, 0x4a, 0x3a, 0x40, 0x3e, 0x4a, 0x51, 0x3f, 0x47, 0x44, 0x37,
+ 0x47, 0x4e, 0x47, 0x52, 0x45, 0x42, 0x4a, 0x3d, 0x43, 0x4d, 0x4d, 0x47,
+ 0x48, 0x43, 0x44, 0x44, 0x47, 0x4e, 0x52, 0x4b, 0x4e, 0x50, 0x42, 0x47,
+ 0x4b, 0x4b, 0x4e, 0x4c, 0x4e, 0x47, 0x50, 0x56, 0x46, 0x47, 0x4d, 0x49,
+ 0x4d, 0x46, 0x49, 0x5f, 0x49, 0x42, 0x4d, 0x44, 0x40, 0x4b, 0x52, 0x45,
+ 0x46, 0x4a, 0x4b, 0x49, 0x47, 0x4b, 0x42, 0x45, 0x42, 0x44, 0x46, 0x4c,
+ 0x62, 0x4a, 0x44, 0x53, 0x43, 0x5a, 0x48, 0x49, 0x59, 0x68, 0x46, 0x61,
+ 0x40, 0x5a, 0x3a, 0x4d, 0x45, 0x5e, 0x33, 0x4f, 0x4e, 0x74, 0x3e, 0x3e,
+ 0x5a, 0x4b, 0x34, 0x31, 0x52, 0x6c, 0x44, 0x39, 0x4c, 0x3b, 0x39, 0x3a,
+ 0x63, 0x65, 0x4b, 0x40, 0x50, 0x4d, 0x53, 0x4a, 0x69, 0x56, 0x54, 0x45,
+ 0x4c, 0x4c, 0x50, 0x5b, 0x4d, 0x4f, 0x3d, 0x4b, 0x44, 0x47, 0x43, 0x47,
+ 0x49, 0x3c, 0x49, 0x41, 0x41, 0x3f, 0x47, 0x43, 0x48, 0x47, 0x4c, 0x43,
+ 0x4a, 0x40, 0x4d, 0x32, 0x4b, 0x4d, 0x44, 0x48, 0x46, 0x44, 0x50, 0x2f,
+ 0x4e, 0x49, 0x53, 0x4b, 0x52, 0x47, 0x4b, 0x2b, 0x48, 0x4b, 0x4a, 0x4c,
+ 0x4d, 0x4c, 0x43, 0x37, 0x48, 0x3c, 0x4b, 0x42, 0x51, 0x3f, 0x45, 0x3c,
+ 0x49, 0x40, 0x42, 0x43, 0x4d, 0x4c, 0x3f, 0x3f, 0x4d, 0x43, 0x45, 0x42,
+ 0x48, 0x42, 0x48, 0x39, 0x51, 0x4e, 0x46, 0x4f, 0x3e, 0x4c, 0x45, 0x3e,
+ 0x3f, 0x3f, 0x43, 0x41, 0x4b, 0x4b, 0x43, 0x4d, 0x44, 0x3b, 0x48, 0x45,
+ 0x3c, 0x4a, 0x48, 0x5b, 0x3c, 0x4b, 0x4c, 0x44, 0x46, 0x3e, 0x45, 0x57,
+ 0x43, 0x42, 0x51, 0x4a, 0x46, 0x47, 0x43, 0x49, 0x42, 0x43, 0x50, 0x4e,
+ 0x4e, 0x44, 0x41, 0x4e, 0x4e, 0x41, 0x48, 0x47, 0x5c, 0x53, 0x44, 0x54,
+ 0x44, 0x5b, 0x45, 0x46, 0x55, 0x67, 0x4d, 0x5d, 0x40, 0x5a, 0x43, 0x4b,
+ 0x43, 0x60, 0x3c, 0x4b, 0x41, 0x79, 0x41, 0x41, 0x58, 0x48, 0x40, 0x3b,
+ 0x4f, 0x6c, 0x46, 0x3f, 0x53, 0x3a, 0x3d, 0x36, 0x5a, 0x57, 0x44, 0x41,
+ 0x4c, 0x47, 0x4e, 0x48, 0x62, 0x60, 0x4a, 0x46, 0x51, 0x3e, 0x52, 0x5f,
+ 0x4b, 0x46, 0x48, 0x4c, 0x4c, 0x55, 0x43, 0x46, 0x49, 0x3e, 0x41, 0x40,
+ 0x4d, 0x47, 0x46, 0x3b, 0x51, 0x3a, 0x4a, 0x45, 0x50, 0x47, 0x51, 0x38,
+ 0x44, 0x41, 0x40, 0x4b, 0x4d, 0x44, 0x4d, 0x28, 0x47, 0x3e, 0x44, 0x40,
+ 0x49, 0x49, 0x40, 0x3c, 0x44, 0x4c, 0x48, 0x51, 0x46, 0x3e, 0x47, 0x2a,
+ 0x41, 0x44, 0x49, 0x4c, 0x4e, 0x4e, 0x42, 0x3c, 0x49, 0x42, 0x43, 0x45,
+ 0x4e, 0x4d, 0x50, 0x39, 0x42, 0x43, 0x48, 0x41, 0x3f, 0x40, 0x4e, 0x3a,
+ 0x44, 0x3d, 0x49, 0x4d, 0x47, 0x45, 0x4b, 0x42, 0x4c, 0x4d, 0x3f, 0x3f,
+ 0x4e, 0x4d, 0x4d, 0x4d, 0x4d, 0x45, 0x47, 0x43, 0x4c, 0x46, 0x47, 0x57,
+ 0x4b, 0x42, 0x4d, 0x46, 0x4b, 0x4b, 0x43, 0x58, 0x48, 0x49, 0x4d, 0x47,
+ 0x43, 0x49, 0x4b, 0x48, 0x46, 0x4f, 0x4f, 0x42, 0x4a, 0x43, 0x49, 0x4e,
+ 0x4a, 0x47, 0x4c, 0x48, 0x5a, 0x57, 0x4a, 0x58, 0x49, 0x4f, 0x45, 0x47,
+ 0x63, 0x66, 0x4d, 0x5e, 0x4b, 0x51, 0x45, 0x4a, 0x43, 0x5d, 0x33, 0x4b,
+ 0x4e, 0x70, 0x42, 0x39, 0x57, 0x4a, 0x40, 0x3a, 0x51, 0x68, 0x45, 0x45,
+ 0x4c, 0x44, 0x3a, 0x3a, 0x4f, 0x62, 0x49, 0x45, 0x53, 0x4c, 0x4e, 0x41,
+ 0x63, 0x5e, 0x44, 0x44, 0x47, 0x43, 0x47, 0x59, 0x4c, 0x4b, 0x4c, 0x49,
+ 0x3e, 0x43, 0x4c, 0x46, 0x4c, 0x38, 0x47, 0x46, 0x46, 0x47, 0x40, 0x44,
+ 0x51, 0x3e, 0x40, 0x47, 0x3f, 0x45, 0x48, 0x2a, 0x42, 0x3e, 0x43, 0x46,
+ 0x50, 0x4c, 0x4a, 0x2c, 0x49, 0x4b, 0x48, 0x48, 0x40, 0x4a, 0x4a, 0x37,
+ 0x4e, 0x42, 0x4f, 0x4c, 0x41, 0x43, 0x45, 0x38, 0x4e, 0x3d, 0x41, 0x47,
+ 0x42, 0x42, 0x43, 0x3b, 0x4a, 0x40, 0x48, 0x4a, 0x53, 0x44, 0x4d, 0x35,
+ 0x51, 0x3c, 0x4e, 0x4e, 0x3e, 0x3f, 0x4b, 0x3c, 0x3e, 0x47, 0x41, 0x48,
+ 0x40, 0x46, 0x4e, 0x44, 0x49, 0x42, 0x49, 0x44, 0x4b, 0x46, 0x46, 0x43,
+ 0x4c, 0x4b, 0x49, 0x4d, 0x3d, 0x47, 0x43, 0x5c, 0x4a, 0x42, 0x47, 0x4e,
+ 0x47, 0x40, 0x4c, 0x55, 0x3f, 0x45, 0x46, 0x49, 0x46, 0x48, 0x49, 0x4d,
+ 0x4c, 0x41, 0x49, 0x40, 0x4a, 0x44, 0x42, 0x49, 0x52, 0x41, 0x49, 0x4a,
+ 0x5c, 0x53, 0x47, 0x58, 0x49, 0x55, 0x4a, 0x4a, 0x62, 0x61, 0x4b, 0x57,
+ 0x3c, 0x50, 0x42, 0x4c, 0x49, 0x5f, 0x3f, 0x4a, 0x42, 0x70, 0x40, 0x40,
+ 0x4f, 0x46, 0x43, 0x43, 0x4d, 0x6c, 0x41, 0x3e, 0x4e, 0x49, 0x43, 0x38,
+ 0x50, 0x57, 0x43, 0x39, 0x4a, 0x4f, 0x51, 0x3e, 0x5c, 0x57, 0x46, 0x49,
+ 0x41, 0x40, 0x42, 0x4f, 0x4c, 0x45, 0x46, 0x4a, 0x4c, 0x4b, 0x43, 0x42,
+ 0x4c, 0x3c, 0x47, 0x47, 0x4f, 0x44, 0x45, 0x3a, 0x4d, 0x3d, 0x4d, 0x3f,
+ 0x46, 0x4f, 0x41, 0x37, 0x46, 0x45, 0x54, 0x47, 0x4e, 0x46, 0x47, 0x23,
+ 0x48, 0x4e, 0x4a, 0x47, 0x45, 0x45, 0x4e, 0x33, 0x49, 0x4a, 0x4d, 0x4e,
+ 0x49, 0x46, 0x49, 0x36, 0x48, 0x44, 0x53, 0x44, 0x4a, 0x45, 0x4a, 0x37,
+ 0x45, 0x36, 0x4b, 0x4e, 0x50, 0x3f, 0x49, 0x38, 0x40, 0x43, 0x46, 0x4c,
+ 0x43, 0x46, 0x4a, 0x3f, 0x45, 0x3d, 0x44, 0x47, 0x44, 0x42, 0x4a, 0x45,
+ 0x47, 0x43, 0x4d, 0x4d, 0x44, 0x44, 0x4f, 0x4a, 0x4a, 0x41, 0x50, 0x50,
+ 0x4b, 0x44, 0x54, 0x5c, 0x4b, 0x3a, 0x46, 0x4a, 0x4a, 0x43, 0x48, 0x5c,
+ 0x4b, 0x43, 0x47, 0x3d, 0x3e, 0x54, 0x42, 0x47, 0x42, 0x4f, 0x4b, 0x4b,
+ 0x46, 0x46, 0x46, 0x42, 0x42, 0x4b, 0x48, 0x45, 0x51, 0x4e, 0x49, 0x4d,
+ 0x43, 0x56, 0x45, 0x40, 0x5a, 0x58, 0x4c, 0x55, 0x40, 0x4b, 0x4c, 0x51,
+ 0x42, 0x59, 0x43, 0x46, 0x46, 0x69, 0x43, 0x3c, 0x54, 0x47, 0x3d, 0x41,
+ 0x52, 0x64, 0x44, 0x38, 0x4f, 0x49, 0x3a, 0x3a, 0x55, 0x54, 0x45, 0x3e,
+ 0x49, 0x44, 0x4e, 0x3f, 0x57, 0x50, 0x47, 0x43, 0x45, 0x48, 0x53, 0x5b,
+ 0x53, 0x4d, 0x48, 0x4e, 0x48, 0x3a, 0x3e, 0x46, 0x42, 0x36, 0x50, 0x4d,
+ 0x49, 0x4b, 0x4b, 0x45, 0x4c, 0x44, 0x50, 0x47, 0x3e, 0x49, 0x50, 0x37,
+ 0x4c, 0x4b, 0x4a, 0x54, 0x4e, 0x43, 0x40, 0x25, 0x46, 0x42, 0x52, 0x3d,
+ 0x44, 0x45, 0x51, 0x2e, 0x4a, 0x3d, 0x46, 0x46, 0x4c, 0x42, 0x48, 0x34,
+ 0x44, 0x44, 0x44, 0x4c, 0x4f, 0x4b, 0x42, 0x3d, 0x45, 0x40, 0x47, 0x49,
+ 0x43, 0x41, 0x3e, 0x39, 0x47, 0x4b, 0x50, 0x4a, 0x46, 0x47, 0x4e, 0x3b,
+ 0x4e, 0x3e, 0x49, 0x4a, 0x50, 0x40, 0x43, 0x49, 0x48, 0x3c, 0x4f, 0x45,
+ 0x4a, 0x41, 0x42, 0x48, 0x4b, 0x46, 0x4a, 0x50, 0x40, 0x49, 0x44, 0x54,
+ 0x45, 0x45, 0x4a, 0x4b, 0x51, 0x51, 0x48, 0x53, 0x50, 0x3f, 0x50, 0x46,
+ 0x44, 0x45, 0x51, 0x43, 0x4f, 0x3e, 0x41, 0x41, 0x46, 0x45, 0x45, 0x4c,
+ 0x54, 0x3c, 0x4a, 0x4c, 0x5a, 0x4f, 0x46, 0x4b, 0x47, 0x4a, 0x43, 0x4c,
+ 0x56, 0x5a, 0x4a, 0x53, 0x4c, 0x49, 0x46, 0x4c, 0x45, 0x59, 0x40, 0x4b,
+ 0x48, 0x60, 0x3d, 0x42, 0x52, 0x3f, 0x42, 0x3d, 0x52, 0x5f, 0x46, 0x42,
+ 0x4b, 0x4e, 0x4a, 0x3d, 0x52, 0x55, 0x53, 0x37, 0x47, 0x3e, 0x4a, 0x42,
+ 0x51, 0x54, 0x48, 0x48, 0x4b, 0x48, 0x3e, 0x52, 0x41, 0x4e, 0x4c, 0x4f,
+ 0x43, 0x3b, 0x4b, 0x4b, 0x4c, 0x40, 0x48, 0x49, 0x4d, 0x3a, 0x45, 0x3c,
+ 0x53, 0x44, 0x48, 0x4d, 0x4b, 0x49, 0x46, 0x3c, 0x4d, 0x40, 0x51, 0x3f,
+ 0x4c, 0x45, 0x44, 0x2f, 0x49, 0x51, 0x3f, 0x4d, 0x3e, 0x4e, 0x3c, 0x30,
+ 0x3d, 0x48, 0x4f, 0x3f, 0x45, 0x45, 0x46, 0x3b, 0x4c, 0x46, 0x4d, 0x50,
+ 0x4c, 0x3d, 0x41, 0x37, 0x3e, 0x3e, 0x4f, 0x4b, 0x4d, 0x4f, 0x45, 0x45,
+ 0x4a, 0x47, 0x4a, 0x44, 0x43, 0x46, 0x51, 0x41, 0x4e, 0x39, 0x44, 0x4a,
+ 0x4e, 0x49, 0x4a, 0x42, 0x49, 0x4b, 0x4e, 0x48, 0x49, 0x4a, 0x45, 0x4a,
+ 0x45, 0x41, 0x4a, 0x4b, 0x42, 0x41, 0x48, 0x4a, 0x44, 0x3a, 0x46, 0x49,
+ 0x54, 0x45, 0x44, 0x60, 0x4a, 0x4e, 0x45, 0x4a, 0x4a, 0x45, 0x4b, 0x49,
+ 0x42, 0x44, 0x46, 0x50, 0x4b, 0x4b, 0x4e, 0x45, 0x48, 0x3e, 0x55, 0x42,
+ 0x51, 0x49, 0x49, 0x44, 0x4e, 0x54, 0x53, 0x49, 0x4c, 0x63, 0x48, 0x5a,
+ 0x50, 0x4b, 0x45, 0x49, 0x43, 0x57, 0x4c, 0x3f, 0x4d, 0x67, 0x3f, 0x47,
+ 0x53, 0x49, 0x43, 0x44, 0x49, 0x61, 0x50, 0x47, 0x49, 0x49, 0x4a, 0x42,
+ 0x4a, 0x51, 0x46, 0x43, 0x3f, 0x34, 0x40, 0x3a, 0x45, 0x54, 0x4c, 0x55,
+ 0x40, 0x3c, 0x4a, 0x4d, 0x3e, 0x4d, 0x48, 0x51, 0x4c, 0x3e, 0x4c, 0x4f,
+ 0x50, 0x47, 0x4d, 0x49, 0x4d, 0x4e, 0x45, 0x43, 0x41, 0x41, 0x40, 0x47,
+ 0x43, 0x4a, 0x4a, 0x3c, 0x4c, 0x3d, 0x4e, 0x43, 0x41, 0x42, 0x4a, 0x30,
+ 0x45, 0x4c, 0x45, 0x55, 0x46, 0x39, 0x43, 0x39, 0x45, 0x47, 0x48, 0x53,
+ 0x4a, 0x48, 0x43, 0x38, 0x4f, 0x51, 0x4d, 0x4c, 0x41, 0x46, 0x40, 0x3d,
+ 0x43, 0x4b, 0x40, 0x46, 0x47, 0x50, 0x4a, 0x43, 0x50, 0x4e, 0x45, 0x4f,
+ 0x4d, 0x44, 0x4d, 0x3f, 0x4e, 0x48, 0x4a, 0x49, 0x44, 0x3d, 0x4a, 0x44,
+ 0x40, 0x45, 0x49, 0x40, 0x4a, 0x44, 0x4f, 0x4a, 0x43, 0x4a, 0x4e, 0x52,
+ 0x4d, 0x50, 0x48, 0x4c, 0x43, 0x45, 0x4d, 0x54, 0x4a, 0x49, 0x4c, 0x58,
+ 0x4c, 0x48, 0x4c, 0x44, 0x4b, 0x4e, 0x52, 0x44, 0x49, 0x44, 0x47, 0x4e,
+ 0x4b, 0x45, 0x49, 0x3e, 0x4c, 0x3b, 0x53, 0x3f, 0x51, 0x41, 0x3f, 0x44,
+ 0x43, 0x4a, 0x4b, 0x43, 0x53, 0x57, 0x50, 0x53, 0x4f, 0x4b, 0x48, 0x51,
+ 0x47, 0x49, 0x46, 0x4d, 0x4d, 0x5e, 0x44, 0x46, 0x56, 0x3d, 0x3c, 0x3e,
+ 0x47, 0x55, 0x54, 0x46, 0x42, 0x49, 0x4f, 0x43, 0x48, 0x54, 0x51, 0x40,
+ 0x44, 0x44, 0x47, 0x45, 0x4b, 0x59, 0x4d, 0x47, 0x40, 0x39, 0x48, 0x54,
+ 0x43, 0x45, 0x44, 0x42, 0x4c, 0x3c, 0x4d, 0x42, 0x4b, 0x45, 0x42, 0x48,
+ 0x51, 0x44, 0x45, 0x3f, 0x3d, 0x49, 0x4b, 0x4a, 0x41, 0x43, 0x4f, 0x3f,
+ 0x51, 0x4b, 0x44, 0x46, 0x46, 0x44, 0x53, 0x3d, 0x47, 0x47, 0x43, 0x4b,
+ 0x41, 0x43, 0x3c, 0x3b, 0x49, 0x47, 0x47, 0x49, 0x4b, 0x3d, 0x43, 0x43,
+ 0x4b, 0x47, 0x45, 0x4e, 0x42, 0x4a, 0x4c, 0x3e, 0x51, 0x3e, 0x46, 0x44,
+ 0x46, 0x43, 0x42, 0x42, 0x47, 0x4d, 0x51, 0x4b, 0x49, 0x44, 0x4d, 0x40,
+ 0x50, 0x43, 0x41, 0x4c, 0x42, 0x49, 0x49, 0x4c, 0x42, 0x50, 0x48, 0x3f,
+ 0x46, 0x42, 0x48, 0x57, 0x49, 0x4d, 0x47, 0x4e, 0x48, 0x4b, 0x46, 0x50,
+ 0x47, 0x45, 0x52, 0x45, 0x4b, 0x48, 0x40, 0x5b, 0x4e, 0x43, 0x51, 0x48,
+ 0x48, 0x4a, 0x4a, 0x4a, 0x52, 0x51, 0x4c, 0x4b, 0x42, 0x55, 0x4d, 0x46,
+ 0x50, 0x40, 0x4a, 0x50, 0x51, 0x3e, 0x42, 0x4c, 0x43, 0x46, 0x4d, 0x46,
+ 0x46, 0x4d, 0x4d, 0x52, 0x4e, 0x44, 0x45, 0x47, 0x49, 0x4c, 0x41, 0x44,
+ 0x4d, 0x54, 0x4c, 0x4a, 0x54, 0x3e, 0x44, 0x43, 0x53, 0x55, 0x4b, 0x4a,
+ 0x47, 0x47, 0x4f, 0x46, 0x4f, 0x4b, 0x51, 0x3f, 0x41, 0x4c, 0x43, 0x46,
+ 0x55, 0x51, 0x40, 0x4b, 0x4f, 0x40, 0x47, 0x50, 0x4e, 0x4a, 0x46, 0x4e,
+ 0x42, 0x4d, 0x48, 0x49, 0x48, 0x4a, 0x4a, 0x43, 0x49, 0x48, 0x44, 0x3b,
+ 0x51, 0x46, 0x3d, 0x43, 0x47, 0x4a, 0x4f, 0x42, 0x4a, 0x50, 0x4f, 0x41,
+ 0x45, 0x45, 0x43, 0x3c, 0x4c, 0x4c, 0x46, 0x4b, 0x3e, 0x44, 0x4b, 0x3a,
+ 0x45, 0x50, 0x42, 0x48, 0x46, 0x47, 0x44, 0x3a, 0x53, 0x46, 0x4e, 0x4f,
+ 0x43, 0x40, 0x46, 0x48, 0x4e, 0x45, 0x3f, 0x47, 0x48, 0x3f, 0x44, 0x4f,
+ 0x44, 0x47, 0x4e, 0x47, 0x47, 0x49, 0x42, 0x43, 0x3f, 0x49, 0x4a, 0x53,
+ 0x53, 0x4a, 0x4e, 0x4a, 0x49, 0x4d, 0x49, 0x41, 0x48, 0x4d, 0x4d, 0x4e,
+ 0x4b, 0x45, 0x4d, 0x4a, 0x46, 0x4a, 0x46, 0x51, 0x4b, 0x47, 0x49, 0x45,
+ 0x49, 0x49, 0x4b, 0x5c, 0x48, 0x42, 0x51, 0x4c, 0x41, 0x3f, 0x4c, 0x42,
+ 0x4f, 0x45, 0x4b, 0x4a, 0x52, 0x48, 0x53, 0x4f, 0x40, 0x47, 0x41, 0x47,
+ 0x68, 0xfb, 0xff, 0xff, 0x4c, 0xfc, 0xff, 0xff, 0x20, 0x00, 0x00, 0x00,
+ 0x14, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0xe8, 0x03, 0x00, 0x00,
+ 0x01, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x03, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x90, 0x00, 0x00, 0x00,
+ 0x58, 0x01, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0xd8, 0x00, 0x00, 0x00,
+ 0x38, 0x02, 0x00, 0x00, 0x9c, 0x02, 0x00, 0x00, 0xa0, 0x01, 0x00, 0x00,
+ 0x14, 0x03, 0x00, 0x00, 0xfe, 0xfc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x03,
+ 0x10, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00,
+ 0x24, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x19, 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00,
+ 0x04, 0x00, 0x00, 0x00, 0x52, 0x65, 0x6c, 0x75, 0x00, 0x00, 0x00, 0x00,
+ 0xcc, 0xfc, 0xff, 0xff, 0x2c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00,
+ 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x17, 0xbf, 0xd2, 0x3f, 0x01, 0x00, 0x00, 0x00, 0x58, 0xec, 0xd1, 0x43,
+ 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x6e, 0xfd, 0xff, 0xff,
+ 0x00, 0x00, 0x00, 0x02, 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00,
+ 0x10, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x08, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x43, 0x6f, 0x6e, 0x76,
+ 0x32, 0x44, 0x5f, 0x62, 0x69, 0x61, 0x73, 0x00, 0x34, 0xff, 0xff, 0xff,
+ 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x01, 0x00, 0x00, 0x00, 0xf5, 0xf7, 0x84, 0x3a, 0xc2, 0xfd, 0xff, 0xff,
+ 0x00, 0x00, 0x00, 0x03, 0x10, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00,
+ 0x1c, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
+ 0x01, 0x00, 0x00, 0x00, 0x31, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x00, 0x00,
+ 0x01, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x52, 0x65, 0x73, 0x68,
+ 0x61, 0x70, 0x65, 0x5f, 0x31, 0x00, 0x00, 0x00, 0x94, 0xfd, 0xff, 0xff,
+ 0x30, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00,
+ 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x80, 0x3f, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7f, 0x43,
+ 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x3a, 0xfe, 0xff, 0xff,
+ 0x00, 0x00, 0x00, 0x02, 0x10, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00,
+ 0x10, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x04, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x4d, 0x61, 0x74, 0x4d,
+ 0x75, 0x6c, 0x5f, 0x62, 0x69, 0x61, 0x73, 0x00, 0x0c, 0x00, 0x0c, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x00, 0x00,
+ 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0xc5, 0x01, 0x2a, 0x3b, 0x96, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x03,
+ 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00,
+ 0x44, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x0a, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00,
+ 0x25, 0x00, 0x00, 0x00, 0x77, 0x65, 0x69, 0x67, 0x68, 0x74, 0x73, 0x5f,
+ 0x71, 0x75, 0x61, 0x6e, 0x74, 0x2f, 0x46, 0x61, 0x6b, 0x65, 0x51, 0x75,
+ 0x61, 0x6e, 0x74, 0x57, 0x69, 0x74, 0x68, 0x4d, 0x69, 0x6e, 0x4d, 0x61,
+ 0x78, 0x56, 0x61, 0x72, 0x73, 0x00, 0x00, 0x00, 0x84, 0xfe, 0xff, 0xff,
+ 0x2c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00,
+ 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0xab, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0xf5, 0xf7, 0x84, 0x3a,
+ 0x01, 0x00, 0x00, 0x00, 0x6e, 0x88, 0xae, 0x3d, 0x01, 0x00, 0x00, 0x00,
+ 0xd4, 0x97, 0x30, 0xbe, 0x26, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x03,
+ 0x10, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00,
+ 0x1c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x04, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x61, 0x64, 0x64, 0x5f,
+ 0x31, 0x00, 0x00, 0x00, 0xec, 0xfe, 0xff, 0xff, 0x2c, 0x00, 0x00, 0x00,
+ 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
+ 0x01, 0x00, 0x00, 0x00, 0x77, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x01, 0x00, 0x00, 0x00, 0x2f, 0xad, 0x18, 0x40, 0x01, 0x00, 0x00, 0x00,
+ 0x02, 0x38, 0xa2, 0x43, 0x01, 0x00, 0x00, 0x00, 0x02, 0xf1, 0x8d, 0xc3,
+ 0x8e, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x03, 0x10, 0x00, 0x00, 0x00,
+ 0x02, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00,
+ 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
+ 0x0e, 0x00, 0x00, 0x00, 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x73, 0x5f, 0x73,
+ 0x6f, 0x66, 0x74, 0x6d, 0x61, 0x78, 0x00, 0x00, 0x5c, 0xff, 0xff, 0xff,
+ 0x2c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00,
+ 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3b,
+ 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7f, 0x3f, 0x01, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x08, 0x00,
+ 0x07, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x14, 0x00, 0x0e, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x03, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x14, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
+ 0x04, 0x00, 0x00, 0x00, 0x30, 0x11, 0x00, 0x00, 0x31, 0x00, 0x00, 0x00,
+ 0x77, 0x65, 0x69, 0x67, 0x68, 0x74, 0x73, 0x5f, 0x71, 0x75, 0x61, 0x6e,
+ 0x74, 0x5f, 0x31, 0x2f, 0x46, 0x61, 0x6b, 0x65, 0x51, 0x75, 0x61, 0x6e,
+ 0x74, 0x57, 0x69, 0x74, 0x68, 0x4d, 0x69, 0x6e, 0x4d, 0x61, 0x78, 0x56,
+ 0x61, 0x72, 0x73, 0x2f, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x73,
+ 0x65, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x14, 0x00, 0x04, 0x00, 0x08, 0x00,
+ 0x0c, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x00, 0x00,
+ 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
+ 0x01, 0x00, 0x00, 0x00, 0x49, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x01, 0x00, 0x00, 0x00, 0x31, 0x83, 0xce, 0x3a, 0x01, 0x00, 0x00, 0x00,
+ 0x4d, 0x97, 0x92, 0x3e, 0x01, 0x00, 0x00, 0x00, 0x84, 0x75, 0xec, 0xbd,
+ 0x03, 0x00, 0x00, 0x00, 0xb4, 0x00, 0x00, 0x00, 0x5c, 0x00, 0x00, 0x00,
+ 0x04, 0x00, 0x00, 0x00, 0xc0, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x09,
+ 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00,
+ 0x24, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00,
+ 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x14, 0x00, 0x1c, 0x00,
+ 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x07, 0x00, 0x14, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x18, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08,
+ 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00,
+ 0x28, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00,
+ 0x02, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, 0x14, 0x00, 0x18, 0x00,
+ 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x07, 0x00, 0x10, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x14, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02,
+ 0x10, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00,
+ 0x1c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00,
+ 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x10, 0x00,
+ 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x0c, 0x00, 0x00, 0x00,
+ 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00,
+ 0x03, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00,
+ 0x04, 0x00, 0x00, 0x00, 0xfa, 0xff, 0xff, 0xff, 0x00, 0x19, 0x06, 0x00,
+ 0x06, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x09, 0x06, 0x00,
+ 0x08, 0x00, 0x07, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04};
+const int g_tiny_conv_model_data_len = 19800;
diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h
new file mode 100644
index 0000000000..2953cc852d
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h
@@ -0,0 +1,27 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This is a standard TensorFlow Lite model file that has been converted into a
+// C data array, so it can be easily compiled into a binary for devices that
+// don't have a file system. It was created using the command:
+// xxd -i tiny_conv.tflite > tiny_conv_model_data.cc
+
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_TINY_CONV_MODEL_DATA_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_TINY_CONV_MODEL_DATA_H_
+
+extern const unsigned char g_tiny_conv_model_data[];
+extern const int g_tiny_conv_model_data_len;
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_TINY_CONV_MODEL_DATA_H_
diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/BUILD b/tensorflow/contrib/lite/experimental/micro/kernels/BUILD
new file mode 100644
index 0000000000..a012f950e6
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/kernels/BUILD
@@ -0,0 +1,107 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
+load(
+ "//tensorflow/contrib/lite/experimental/micro/testing:micro_test.bzl",
+ "tflite_micro_cc_test",
+)
+
+cc_library(
+ name = "micro_ops",
+ srcs = [
+ "depthwise_conv.cc",
+ "fully_connected.cc",
+ "softmax.cc",
+ ],
+ hdrs = [
+ ],
+ copts = tflite_copts(),
+ deps = [
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/experimental/micro:micro_framework",
+ "//tensorflow/contrib/lite/kernels:kernel_util",
+ "//tensorflow/contrib/lite/kernels:op_macros",
+ "//tensorflow/contrib/lite/kernels:padding",
+ "//tensorflow/contrib/lite/kernels/internal:quantization_util",
+ "//tensorflow/contrib/lite/kernels/internal:reference_base",
+ "//tensorflow/contrib/lite/kernels/internal:tensor",
+ ],
+)
+
+cc_library(
+ name = "all_ops_resolver",
+ srcs = [
+ "all_ops_resolver.cc",
+ ],
+ hdrs = [
+ "all_ops_resolver.h",
+ ],
+ copts = tflite_copts(),
+ deps = [
+ ":micro_ops",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/experimental/micro:micro_framework",
+ ],
+)
+
+cc_library(
+ name = "test_utils",
+ srcs = [
+ ],
+ hdrs = [
+ "test_utils.h",
+ ],
+ copts = tflite_copts(),
+ deps = [
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/core/api",
+ "//tensorflow/contrib/lite/experimental/micro:micro_framework",
+ "//tensorflow/contrib/lite/experimental/micro/testing:micro_test",
+ ],
+)
+
+tflite_micro_cc_test(
+ name = "depthwise_conv_test",
+ srcs = [
+ "depthwise_conv_test.cc",
+ ],
+ deps = [
+ ":all_ops_resolver",
+ ":test_utils",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/experimental/micro:micro_framework",
+ "//tensorflow/contrib/lite/experimental/micro/testing:micro_test",
+ ],
+)
+
+tflite_micro_cc_test(
+ name = "fully_connected_test",
+ srcs = [
+ "fully_connected_test.cc",
+ ],
+ deps = [
+ ":all_ops_resolver",
+ ":test_utils",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/experimental/micro:micro_framework",
+ "//tensorflow/contrib/lite/experimental/micro/testing:micro_test",
+ ],
+)
+
+tflite_micro_cc_test(
+ name = "softmax_test",
+ srcs = [
+ "softmax_test.cc",
+ ],
+ deps = [
+ ":all_ops_resolver",
+ ":test_utils",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/experimental/micro:micro_framework",
+ "//tensorflow/contrib/lite/experimental/micro/testing:micro_test",
+ ],
+)
diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.cc b/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.cc
new file mode 100644
index 0000000000..bd0a37badb
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.cc
@@ -0,0 +1,43 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h"
+
+namespace tflite {
+namespace ops {
+namespace micro {
+
+TfLiteRegistration* Register_DEPTHWISE_CONV_2D();
+TfLiteRegistration* Micro_Register_DEPTHWISE_CONV_2D() {
+ return Register_DEPTHWISE_CONV_2D();
+}
+
+TfLiteRegistration* Register_FULLY_CONNECTED();
+TfLiteRegistration* Micro_Register_FULLY_CONNECTED() {
+ return Register_FULLY_CONNECTED();
+}
+
+TfLiteRegistration* Register_SOFTMAX();
+TfLiteRegistration* Micro_Register_SOFTMAX() { return Register_SOFTMAX(); }
+
+AllOpsResolver::AllOpsResolver() {
+ AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D,
+ Micro_Register_DEPTHWISE_CONV_2D());
+ AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Micro_Register_FULLY_CONNECTED(),
+ /* min_version */ 1,
+ /* max_version */ 2);
+ AddBuiltin(BuiltinOperator_SOFTMAX, Micro_Register_SOFTMAX());
+}
+
+} // namespace micro
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h b/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h
new file mode 100644
index 0000000000..f836064a3f
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h
@@ -0,0 +1,34 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_ALL_OPS_RESOLVER_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_ALL_OPS_RESOLVER_H_
+
+#include "tensorflow/contrib/lite/experimental/micro/compatibility.h"
+#include "tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h"
+
+namespace tflite {
+namespace ops {
+namespace micro {
+
+class AllOpsResolver : public MicroMutableOpResolver {
+ public:
+ AllOpsResolver();
+
+ private:
+ TF_LITE_REMOVE_VIRTUAL_DELETE
+};
+
+} // namespace micro
+} // namespace ops
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_ALL_OPS_RESOLVER_H_
diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv.cc
new file mode 100644
index 0000000000..4f17263181
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv.cc
@@ -0,0 +1,208 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/padding.h"
+
+#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h"
+
+namespace tflite {
+namespace ops {
+namespace micro {
+namespace depthwise_conv {
+namespace {
+
+constexpr int kInputTensor = 0;
+constexpr int kFilterTensor = 1;
+constexpr int kBiasTensor = 2;
+constexpr int kOutputTensor = 0;
+
+struct OpData {
+ TfLitePaddingValues padding;
+ // The scaling factor from input to output (aka the 'real multiplier') can
+ // be represented as a fixed point multiplier plus a left shift.
+ int32_t output_multiplier;
+ int output_shift;
+ // The range of the fused activation layer. For example for kNone and
+ // uint8_t these would be 0 and 255.
+ int32_t output_activation_min;
+ int32_t output_activation_max;
+};
+
+TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
+ TfLiteDepthwiseConvParams* params, int width,
+ int height, int filter_width, int filter_height,
+ int out_width, int out_height,
+ const TfLiteType data_type, OpData* data) {
+ data->padding.height = ComputePadding(params->stride_height, 1, height,
+ filter_height, out_height);
+ data->padding.width =
+ ComputePadding(params->stride_width, 1, width, filter_width, out_width);
+
+ // Note that quantized inference requires that all tensors have their
+ // parameters set. This is usually done during quantized training.
+ if (data_type != kTfLiteFloat32) {
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
+ const TfLiteTensor* bias =
+ GetOptionalInputTensor(context, node, kBiasTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ double real_multiplier = 0.0;
+ TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
+ context, input, filter, bias, output, &real_multiplier));
+ int exponent;
+ QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent);
+ data->output_shift = -exponent;
+ CalculateActivationRangeUint8(params->activation, output,
+ &data->output_activation_min,
+ &data->output_activation_max);
+ }
+ return kTfLiteOk;
+}
+
+} // namespace
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ return nullptr;
+}
+
+void Free(TfLiteContext* context, void* buffer) {}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ return kTfLiteOk;
+}
+
+void EvalFloat(TfLiteContext* context, TfLiteNode* node,
+ TfLiteDepthwiseConvParams* params, OpData* data,
+ const TfLiteTensor* input, const TfLiteTensor* filter,
+ const TfLiteTensor* bias, TfLiteTensor* output) {
+ float output_activation_min, output_activation_max;
+ CalculateActivationRange(params->activation, &output_activation_min,
+ &output_activation_max);
+
+ tflite::DepthwiseParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = data->padding.width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.stride_width = params->stride_width;
+ op_params.stride_height = params->stride_height;
+ op_params.dilation_width_factor = 1;
+ op_params.dilation_height_factor = 1;
+ op_params.depth_multiplier = params->depth_multiplier;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ tflite::reference_ops::DepthwiseConv(
+ op_params, GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(filter), GetTensorData<float>(filter),
+ GetTensorShape(bias), GetTensorData<float>(bias), GetTensorShape(output),
+ GetTensorData<float>(output));
+}
+
+void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLiteDepthwiseConvParams* params, OpData* data,
+ const TfLiteTensor* input, const TfLiteTensor* filter,
+ const TfLiteTensor* bias, TfLiteTensor* output) {
+ const int32_t input_offset = -input->params.zero_point;
+ const int32_t filter_offset = -filter->params.zero_point;
+ const int32_t output_offset = output->params.zero_point;
+
+ tflite::DepthwiseParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = data->padding.width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.stride_width = params->stride_width;
+ op_params.stride_height = params->stride_height;
+ op_params.dilation_width_factor = 1;
+ op_params.dilation_height_factor = 1;
+ op_params.depth_multiplier = params->depth_multiplier;
+ op_params.quantized_activation_min = data->output_activation_min;
+ op_params.quantized_activation_max = data->output_activation_max;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = data->output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = -data->output_shift;
+
+ tflite::reference_ops::DepthwiseConv(
+ op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
+ GetTensorShape(filter), GetTensorData<uint8_t>(filter),
+ GetTensorShape(bias), GetTensorData<int32_t>(bias),
+ GetTensorShape(output), GetTensorData<uint8_t>(output));
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data);
+
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
+ const TfLiteTensor* bias =
+ (NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr;
+
+ const TfLiteType data_type = input->type;
+ int width = SizeOfDimension(input, 2);
+ int height = SizeOfDimension(input, 1);
+ int filter_width = SizeOfDimension(filter, 2);
+ int filter_height = SizeOfDimension(filter, 1);
+ int out_width = ComputeOutSize(params->padding, width, filter_width,
+ params->stride_width);
+ int out_height = ComputeOutSize(params->padding, height, filter_height,
+ params->stride_height);
+ OpData local_data_object;
+ OpData* data = &local_data_object;
+ TF_LITE_ENSURE_STATUS(CalculateOpData(context, node, params, width, height,
+ filter_width, filter_height, out_width,
+ out_height, data_type, data));
+
+ // TODO(aselle): Consider whether float conv and quantized conv should be
+ // separate ops to avoid dispatch overhead here.
+ switch (input->type) { // Already know in/out types are same.
+ case kTfLiteFloat32:
+ EvalFloat(context, node, params, data, input, filter, bias, output);
+ break;
+ case kTfLiteUInt8:
+ EvalQuantized(context, node, params, data, input, filter, bias, output);
+ break;
+ default:
+ context->ReportError(context, "Type %d not currently supported.",
+ input->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+} // namespace depthwise_conv
+
+TfLiteRegistration* Register_DEPTHWISE_CONV_2D() {
+ static TfLiteRegistration r = {depthwise_conv::Init, depthwise_conv::Free,
+ depthwise_conv::Prepare, depthwise_conv::Eval};
+ return &r;
+}
+
+} // namespace micro
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test.cc b/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test.cc
new file mode 100644
index 0000000000..169899c471
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test.cc
@@ -0,0 +1,406 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h"
+#include "tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h"
+#include "tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h"
+#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h"
+
+namespace tflite {
+namespace testing {
+namespace {
+
+void TestDepthwiseConvFloat(std::initializer_list<int> input_dims_data,
+ std::initializer_list<float> input_data,
+ std::initializer_list<int> filter_dims_data,
+ std::initializer_list<float> filter_data,
+ std::initializer_list<int> bias_dims_data,
+ std::initializer_list<float> bias_data,
+ std::initializer_list<float> expected_output_data,
+ std::initializer_list<int> output_dims_data,
+ TfLiteFusedActivation activation,
+ float* output_data) {
+ TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
+ TfLiteIntArray* filter_dims = IntArrayFromInitializer(filter_dims_data);
+ TfLiteIntArray* bias_dims = IntArrayFromInitializer(bias_dims_data);
+ TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
+ const int output_dims_count = ElementCount(*output_dims);
+
+ constexpr int inputs_size = 3;
+ constexpr int outputs_size = 1;
+ constexpr int tensors_size = inputs_size + outputs_size;
+ TfLiteTensor tensors[tensors_size] = {
+ CreateFloatTensor(input_data, input_dims, "input_tensor"),
+ CreateFloatTensor(filter_data, filter_dims, "filter_tensor"),
+ CreateFloatTensor(bias_data, bias_dims, "bias_tensor"),
+ CreateFloatTensor(output_data, output_dims, "output_tensor"),
+ };
+
+ TfLiteContext context;
+ PopulateContext(tensors, tensors_size, &context);
+
+ ::tflite::ops::micro::AllOpsResolver resolver;
+ const TfLiteRegistration* registration =
+ resolver.FindOp(tflite::BuiltinOperator_DEPTHWISE_CONV_2D, 1);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
+
+ int input_depth = input_dims->data[3];
+ int output_depth = filter_dims->data[3];
+ int depth_mul = output_depth / input_depth;
+ TfLiteDepthwiseConvParams builtin_data = {
+ kTfLitePaddingValid, 1, 1, depth_mul, activation,
+ };
+ const char* init_data = reinterpret_cast<const char*>(&builtin_data);
+ size_t init_data_size = 0;
+ void* user_data = nullptr;
+ if (registration->init) {
+ user_data = registration->init(&context, init_data, init_data_size);
+ }
+ int inputs_array_data[] = {3, 0, 1, 2};
+ TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
+ int outputs_array_data[] = {1, 3};
+ TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
+ int temporaries_array_data[] = {0};
+ TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
+
+ TfLiteNode node;
+ node.inputs = inputs_array;
+ node.outputs = outputs_array;
+ node.temporaries = temporaries_array;
+ node.user_data = user_data;
+ node.builtin_data = reinterpret_cast<void*>(&builtin_data);
+ node.custom_initial_data = nullptr;
+ node.custom_initial_data_size = 0;
+ node.delegate = nullptr;
+ if (registration->prepare) {
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
+ }
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
+ if (registration->free) {
+ registration->free(&context, user_data);
+ }
+ for (int i = 0; i < output_dims_count; ++i) {
+ TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i], output_data[i],
+ 1e-5f);
+ }
+}
+
+void TestDepthwiseConvQuantized(
+ std::initializer_list<int> input_dims_data,
+ std::initializer_list<uint8_t> input_data, float input_min, float input_max,
+ std::initializer_list<int> filter_dims_data,
+ std::initializer_list<uint8_t> filter_data, float filter_min,
+ float filter_max, std::initializer_list<int> bias_dims_data,
+ std::initializer_list<int32_t> bias_data, float bias_min, float bias_max,
+ std::initializer_list<uint8_t> expected_output_data,
+ std::initializer_list<int> output_dims_data, float output_min,
+ float output_max, TfLiteFusedActivation activation, uint8_t* output_data) {
+ TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
+ TfLiteIntArray* filter_dims = IntArrayFromInitializer(filter_dims_data);
+ TfLiteIntArray* bias_dims = IntArrayFromInitializer(bias_dims_data);
+ TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
+ const int output_dims_count = ElementCount(*output_dims);
+
+ constexpr int inputs_size = 3;
+ constexpr int outputs_size = 1;
+ constexpr int tensors_size = inputs_size + outputs_size;
+ TfLiteTensor tensors[tensors_size] = {
+ CreateQuantizedTensor(input_data, input_dims, "input_tensor", input_min,
+ input_max),
+ CreateQuantizedTensor(filter_data, filter_dims, "filter_tensor",
+ filter_min, filter_max),
+ CreateQuantized32Tensor(bias_data, bias_dims, "bias_tensor", bias_min,
+ bias_max),
+ CreateQuantizedTensor(output_data, output_dims, "output_tensor",
+ output_min, output_max),
+ };
+
+ TfLiteContext context;
+ PopulateContext(tensors, tensors_size, &context);
+
+ ::tflite::ops::micro::AllOpsResolver resolver;
+ const TfLiteRegistration* registration =
+ resolver.FindOp(tflite::BuiltinOperator_DEPTHWISE_CONV_2D, 1);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
+
+ int input_depth = input_dims->data[3];
+ int output_depth = filter_dims->data[3];
+ int depth_mul = output_depth / input_depth;
+ TfLiteDepthwiseConvParams builtin_data = {
+ kTfLitePaddingValid, 1, 1, depth_mul, activation,
+ };
+ const char* init_data = reinterpret_cast<const char*>(&builtin_data);
+ size_t init_data_size = 0;
+ void* user_data = nullptr;
+ if (registration->init) {
+ user_data = registration->init(&context, init_data, init_data_size);
+ }
+
+ int inputs_array_data[] = {3, 0, 1, 2};
+ TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
+ int outputs_array_data[] = {1, 3};
+ TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
+ int temporaries_array_data[] = {0};
+ TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
+
+ TfLiteNode node;
+ node.inputs = inputs_array;
+ node.outputs = outputs_array;
+ node.temporaries = temporaries_array;
+ node.user_data = user_data;
+ node.builtin_data = reinterpret_cast<void*>(&builtin_data);
+ node.custom_initial_data = nullptr;
+ node.custom_initial_data_size = 0;
+ node.delegate = nullptr;
+
+ if (registration->prepare) {
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
+ }
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
+ if (registration->free) {
+ registration->free(&context, user_data);
+ }
+ for (int i = 0; i < output_dims_count; ++i) {
+ TF_LITE_MICRO_EXPECT_EQ(expected_output_data.begin()[i], output_data[i]);
+ }
+}
+
+} // namespace
+} // namespace testing
+} // namespace tflite
+
+TF_LITE_MICRO_TESTS_BEGIN
+
+TF_LITE_MICRO_TEST(SimpleTest) {
+ const int output_dims_count = 8;
+ float output_data[output_dims_count];
+ tflite::testing::TestDepthwiseConvFloat( //
+ {4, 1, 3, 2, 2}, // Input shape.
+ {
+ 1, 2, 7, 8, // Input values.
+ 3, 4, 9, 10, //
+ 5, 6, 11, 12, //
+ },
+ {4, 1, 2, 2, 4}, // Filters shape.
+ {
+ 1, 2, 3, 4, // Filters values.
+ -9, 10, -11, 12, //
+ 5, 6, 7, 8, //
+ 13, -14, 15, -16, //
+ },
+ {1, 4}, // Bias shape.
+ {
+ 1, 2, 3, 4, // Bias values.
+ },
+ {
+ 71, -34, 99, -20, // Expected results.
+ 91, -26, 127, -4, //
+ },
+ {4, 1, 2, 1, 4}, // Output shape.
+ kTfLiteActNone, output_data);
+}
+
+TF_LITE_MICRO_TEST(SimpleTestQuantized) {
+ using tflite::testing::F2Q;
+ using tflite::testing::F2Q32;
+
+ const float input_min = -63.5f;
+ const float input_max = 64.0f;
+ const float filter_min = -63.5f;
+ const float filter_max = 64.0f;
+ const float bias_min = 0.0f;
+ const float bias_max = 64.0f * (1 << 24);
+ const float output_min = -127.0f;
+ const float output_max = 128.0f;
+ const int output_dims_count = 8;
+ uint8_t output_data[output_dims_count];
+
+ tflite::testing::TestDepthwiseConvQuantized( //
+ {4, 1, 3, 2, 2}, // Input shape.
+ {
+ // Input values.
+ F2Q(1, input_min, input_max),
+ F2Q(2, input_min, input_max),
+ F2Q(7, input_min, input_max),
+ F2Q(8, input_min, input_max),
+ F2Q(3, input_min, input_max),
+ F2Q(4, input_min, input_max),
+ F2Q(9, input_min, input_max),
+ F2Q(10, input_min, input_max),
+ F2Q(5, input_min, input_max),
+ F2Q(6, input_min, input_max),
+ F2Q(11, input_min, input_max),
+ F2Q(12, input_min, input_max),
+ },
+ input_min, input_max, // Input quantization range.
+ {4, 1, 2, 2, 4}, // Filter shape.
+ {
+ // Filter values.
+ F2Q(1, filter_min, filter_max),
+ F2Q(2, filter_min, filter_max),
+ F2Q(3, filter_min, filter_max),
+ F2Q(4, filter_min, filter_max),
+ F2Q(-9, filter_min, filter_max),
+ F2Q(10, filter_min, filter_max),
+ F2Q(-11, filter_min, filter_max),
+ F2Q(12, filter_min, filter_max),
+ F2Q(5, filter_min, filter_max),
+ F2Q(6, filter_min, filter_max),
+ F2Q(7, filter_min, filter_max),
+ F2Q(8, filter_min, filter_max),
+ F2Q(13, filter_min, filter_max),
+ F2Q(-14, filter_min, filter_max),
+ F2Q(15, filter_min, filter_max),
+ F2Q(-16, filter_min, filter_max),
+ },
+ filter_min, filter_max, // Filter quantization range.
+ {1, 4}, // Bias shape.
+ {
+ // Bias values.
+ F2Q32(1, bias_min, bias_max),
+ F2Q32(2, bias_min, bias_max),
+ F2Q32(3, bias_min, bias_max),
+ F2Q32(4, bias_min, bias_max),
+ },
+ bias_min, bias_max, // Bias quantization range.
+ {
+ // Expected results.
+ F2Q(71, output_min, output_max),
+ F2Q(-34, output_min, output_max),
+ F2Q(99, output_min, output_max),
+ F2Q(-20, output_min, output_max),
+ F2Q(91, output_min, output_max),
+ F2Q(-26, output_min, output_max),
+ F2Q(127, output_min, output_max),
+ F2Q(-4, output_min, output_max),
+ },
+ {4, 1, 2, 1, 4}, // Output shape.
+ output_min, output_max, // Output quantization range.
+ kTfLiteActNone, output_data);
+}
+
+TF_LITE_MICRO_TEST(SimpleTestRelu) {
+ const int output_dims_count = 8;
+ float output_data[output_dims_count];
+ tflite::testing::TestDepthwiseConvFloat( //
+ {4, 1, 3, 2, 2}, // Input shape.
+ {
+ 1, 2, 7, 8, // Input values.
+ 3, 4, 9, 10, //
+ 5, 6, 11, 12, //
+ },
+ {4, 1, 2, 2, 4}, // Filters shape.
+ {
+ 1, 2, 3, 4, // Filters values.
+ -9, 10, -11, 12, //
+ 5, 6, 7, 8, //
+ 13, -14, 15, -16, //
+ },
+ {1, 4}, // Bias shape.
+ {
+ 1, 2, 3, 4, // Bias values.
+ },
+ {
+ 71, 0, 99, 0, // Expected results.
+ 91, 0, 127, 0, //
+ },
+ {4, 1, 2, 1, 4}, // Output shape.
+ kTfLiteActRelu, output_data);
+}
+
+TF_LITE_MICRO_TEST(SimpleTestReluQuantized) {
+ using tflite::testing::F2Q;
+ using tflite::testing::F2Q32;
+
+ const float input_min = -63.5f;
+ const float input_max = 64.0f;
+ const float filter_min = -63.5f;
+ const float filter_max = 64.0f;
+ const float bias_min = 0.0f;
+ const float bias_max = 64.0f * (1 << 24);
+ const float output_min = -127.0f;
+ const float output_max = 128.0f;
+ const int output_dims_count = 8;
+ uint8_t output_data[output_dims_count];
+
+ tflite::testing::TestDepthwiseConvQuantized( //
+ {4, 1, 3, 2, 2}, // Input shape.
+ {
+ // Input values.
+ F2Q(1, input_min, input_max),
+ F2Q(2, input_min, input_max),
+ F2Q(7, input_min, input_max),
+ F2Q(8, input_min, input_max),
+ F2Q(3, input_min, input_max),
+ F2Q(4, input_min, input_max),
+ F2Q(9, input_min, input_max),
+ F2Q(10, input_min, input_max),
+ F2Q(5, input_min, input_max),
+ F2Q(6, input_min, input_max),
+ F2Q(11, input_min, input_max),
+ F2Q(12, input_min, input_max),
+ },
+ input_min, input_max, // Input quantization range.
+ {4, 1, 2, 2, 4}, // Filter shape.
+ {
+ // Filter values.
+ F2Q(1, filter_min, filter_max),
+ F2Q(2, filter_min, filter_max),
+ F2Q(3, filter_min, filter_max),
+ F2Q(4, filter_min, filter_max),
+ F2Q(-9, filter_min, filter_max),
+ F2Q(10, filter_min, filter_max),
+ F2Q(-11, filter_min, filter_max),
+ F2Q(12, filter_min, filter_max),
+ F2Q(5, filter_min, filter_max),
+ F2Q(6, filter_min, filter_max),
+ F2Q(7, filter_min, filter_max),
+ F2Q(8, filter_min, filter_max),
+ F2Q(13, filter_min, filter_max),
+ F2Q(-14, filter_min, filter_max),
+ F2Q(15, filter_min, filter_max),
+ F2Q(-16, filter_min, filter_max),
+ },
+ filter_min, filter_max, // Filter quantization range.
+ {1, 4}, // Bias shape.
+ {
+ // Bias values.
+ F2Q32(1, bias_min, bias_max),
+ F2Q32(2, bias_min, bias_max),
+ F2Q32(3, bias_min, bias_max),
+ F2Q32(4, bias_min, bias_max),
+ },
+ bias_min, bias_max, // Bias quantization range.
+ {
+ // Expected results.
+ F2Q(71, output_min, output_max),
+ F2Q(0, output_min, output_max),
+ F2Q(99, output_min, output_max),
+ F2Q(0, output_min, output_max),
+ F2Q(91, output_min, output_max),
+ F2Q(0, output_min, output_max),
+ F2Q(127, output_min, output_max),
+ F2Q(0, output_min, output_max),
+ },
+ {4, 1, 2, 1, 4}, // Output shape.
+ output_min, output_max, // Output quantization range.
+ kTfLiteActRelu, output_data);
+}
+
+TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected.cc b/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected.cc
new file mode 100644
index 0000000000..1e9e54cafb
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected.cc
@@ -0,0 +1,184 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace micro {
+namespace fully_connected {
+namespace {
+
+struct OpData {
+ // The scaling factor from input to output (aka the 'real multiplier') can
+ // be represented as a fixed point multiplier plus a left shift.
+ int32_t output_multiplier;
+ int output_shift;
+ // The range of the fused activation layer. For example for kNone and
+ // uint8_t these would be 0 and 255.
+ int32_t output_activation_min;
+ int32_t output_activation_max;
+ // The index of the temporary tensor where the quantized inputs are cached.
+ int input_quantized_index;
+};
+
+constexpr int kInputTensor = 0;
+constexpr int kWeightsTensor = 1;
+constexpr int kBiasTensor = 2;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus CalculateOpData(TfLiteContext* context,
+ TfLiteFullyConnectedParams* params,
+ TfLiteType data_type, const TfLiteTensor* input,
+ const TfLiteTensor* filter,
+ const TfLiteTensor* bias, TfLiteTensor* output,
+ OpData* data) {
+ TfLiteStatus status = kTfLiteOk;
+ if (data_type != kTfLiteFloat32) {
+ double real_multiplier = 0.0;
+ TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
+ context, input, filter, bias, output, &real_multiplier));
+ int exponent;
+ QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent);
+ data->output_shift = -exponent;
+ TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
+ context, params->activation, output, &data->output_activation_min,
+ &data->output_activation_max));
+ }
+ return status;
+}
+
+} // namespace
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ return nullptr;
+}
+
+void Free(TfLiteContext* context, void* buffer) {}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLiteFullyConnectedParams* params, OpData* data,
+ const TfLiteTensor* input,
+ const TfLiteTensor* filter, const TfLiteTensor* bias,
+ TfLiteTensor* output) {
+ const int32_t input_offset = -input->params.zero_point;
+ const int32_t filter_offset = -filter->params.zero_point;
+ const int32_t output_offset = output->params.zero_point;
+
+ tflite::FullyConnectedParams op_params;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = data->output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = -data->output_shift;
+ op_params.quantized_activation_min = data->output_activation_min;
+ op_params.quantized_activation_max = data->output_activation_max;
+
+#define TF_LITE_FULLY_CONNECTED(output_data_type) \
+ reference_ops::FullyConnected( \
+ op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
+ GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
+ GetTensorShape(bias), GetTensorData<int32_t>(bias), \
+ GetTensorShape(output), GetTensorData<output_data_type>(output), \
+ nullptr)
+ switch (output->type) {
+ case kTfLiteUInt8:
+ TF_LITE_FULLY_CONNECTED(uint8_t);
+ break;
+ case kTfLiteInt16:
+ TF_LITE_FULLY_CONNECTED(int16_t);
+ break;
+ default:
+ context->ReportError(
+ context,
+ "Quantized FullyConnected expects output data type uint8 or int16");
+ return kTfLiteError;
+ }
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
+ TfLiteFullyConnectedParams* params, OpData* data,
+ const TfLiteTensor* input, const TfLiteTensor* filter,
+ const TfLiteTensor* bias, TfLiteTensor* output) {
+ float output_activation_min, output_activation_max;
+ CalculateActivationRange(params->activation, &output_activation_min,
+ &output_activation_max);
+ tflite::FullyConnectedParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ tflite::reference_ops::FullyConnected(
+ op_params, GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(filter), GetTensorData<float>(filter),
+ GetTensorShape(bias), GetTensorData<float>(bias), GetTensorShape(output),
+ GetTensorData<float>(output));
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
+ const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ TfLiteType data_type = input->type;
+ OpData local_data_object;
+ OpData* data = &local_data_object;
+ TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input,
+ filter, bias, output, data));
+
+ switch (filter->type) { // Already know in/out types are same.
+ case kTfLiteFloat32:
+ return EvalFloat(context, node, params, data, input, filter, bias,
+ output);
+ case kTfLiteUInt8:
+ return EvalQuantized(context, node, params, data, input, filter, bias,
+ output);
+
+ default:
+ context->ReportError(context, "Type %d not currently supported.",
+ filter->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+} // namespace fully_connected
+
+TfLiteRegistration* Register_FULLY_CONNECTED() {
+ static TfLiteRegistration r = {fully_connected::Init, fully_connected::Free,
+ fully_connected::Prepare,
+ fully_connected::Eval};
+ return &r;
+}
+
+} // namespace micro
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected_test.cc b/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected_test.cc
new file mode 100644
index 0000000000..b42bf4c3bc
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected_test.cc
@@ -0,0 +1,643 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h"
+#include "tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h"
+#include "tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h"
+#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h"
+
+namespace tflite {
+namespace testing {
+namespace {
+
+void TestFullyConnectedFloat(std::initializer_list<int> input_dims_data,
+ std::initializer_list<float> input_data,
+ std::initializer_list<int> weights_dims_data,
+ std::initializer_list<float> weights_data,
+ std::initializer_list<int> bias_dims_data,
+ std::initializer_list<float> bias_data,
+ std::initializer_list<float> expected_output_data,
+ std::initializer_list<int> output_dims_data,
+ TfLiteFusedActivation activation,
+ float* output_data) {
+ TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
+ TfLiteIntArray* weights_dims = IntArrayFromInitializer(weights_dims_data);
+ TfLiteIntArray* bias_dims = IntArrayFromInitializer(bias_dims_data);
+ TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
+ const int output_dims_count = ElementCount(*output_dims);
+
+ constexpr int inputs_size = 3;
+ constexpr int outputs_size = 1;
+ constexpr int tensors_size = inputs_size + outputs_size;
+ TfLiteTensor tensors[tensors_size] = {
+ CreateFloatTensor(input_data, input_dims, "input_tensor"),
+ CreateFloatTensor(weights_data, weights_dims, "weights_tensor"),
+ CreateFloatTensor(bias_data, bias_dims, "bias_tensor"),
+ CreateFloatTensor(output_data, output_dims, "output_tensor"),
+ };
+
+ TfLiteContext context;
+ PopulateContext(tensors, tensors_size, &context);
+
+ ::tflite::ops::micro::AllOpsResolver resolver;
+ const TfLiteRegistration* registration =
+ resolver.FindOp(tflite::BuiltinOperator_FULLY_CONNECTED, 1);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
+
+ TfLiteFullyConnectedParams builtin_data = {
+ activation,
+ kTfLiteFullyConnectedWeightsFormatDefault,
+ };
+ const char* init_data = reinterpret_cast<const char*>(&builtin_data);
+ size_t init_data_size = 0;
+ void* user_data = nullptr;
+ if (registration->init) {
+ user_data = registration->init(&context, init_data, init_data_size);
+ }
+ int inputs_array_data[] = {3, 0, 1, 2};
+ TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
+ int outputs_array_data[] = {1, 3};
+ TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
+ int temporaries_array_data[] = {0};
+ TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
+
+ TfLiteNode node;
+ node.inputs = inputs_array;
+ node.outputs = outputs_array;
+ node.temporaries = temporaries_array;
+ node.user_data = user_data;
+ node.builtin_data = reinterpret_cast<void*>(&builtin_data);
+ node.custom_initial_data = nullptr;
+ node.custom_initial_data_size = 0;
+ node.delegate = nullptr;
+ if (registration->prepare) {
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
+ }
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
+ if (registration->free) {
+ registration->free(&context, user_data);
+ }
+ for (int i = 0; i < output_dims_count; ++i) {
+ TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i], output_data[i],
+ 1e-5f);
+ }
+}
+
+void TestFullyConnectedQuantized(
+ std::initializer_list<int> input_dims_data,
+ std::initializer_list<uint8_t> input_data, float input_min, float input_max,
+ std::initializer_list<int> weights_dims_data,
+ std::initializer_list<uint8_t> weights_data, float weights_min,
+ float weights_max, std::initializer_list<int> bias_dims_data,
+ std::initializer_list<int32_t> bias_data, float bias_min, float bias_max,
+ std::initializer_list<uint8_t> expected_output_data,
+ std::initializer_list<int> output_dims_data, float output_min,
+ float output_max, TfLiteFusedActivation activation, uint8_t* output_data) {
+ TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
+ TfLiteIntArray* weights_dims = IntArrayFromInitializer(weights_dims_data);
+ TfLiteIntArray* bias_dims = IntArrayFromInitializer(bias_dims_data);
+ TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
+ const int output_dims_count = ElementCount(*output_dims);
+
+ constexpr int inputs_size = 3;
+ constexpr int outputs_size = 1;
+ constexpr int tensors_size = inputs_size + outputs_size;
+ TfLiteTensor tensors[tensors_size] = {
+ CreateQuantizedTensor(input_data, input_dims, "input_tensor", input_min,
+ input_max),
+ CreateQuantizedTensor(weights_data, weights_dims, "weights_tensor",
+ weights_min, weights_max),
+ CreateQuantized32Tensor(bias_data, bias_dims, "bias_tensor", bias_min,
+ bias_max),
+ CreateQuantizedTensor(output_data, output_dims, "output_tensor",
+ output_min, output_max),
+ };
+
+ TfLiteContext context;
+ PopulateContext(tensors, tensors_size, &context);
+
+ ::tflite::ops::micro::AllOpsResolver resolver;
+ const TfLiteRegistration* registration =
+ resolver.FindOp(tflite::BuiltinOperator_FULLY_CONNECTED, 1);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
+
+ TfLiteFullyConnectedParams builtin_data = {
+ activation,
+ kTfLiteFullyConnectedWeightsFormatDefault,
+ };
+ const char* init_data = reinterpret_cast<const char*>(&builtin_data);
+ size_t init_data_size = 0;
+ void* user_data = nullptr;
+ if (registration->init) {
+ user_data = registration->init(&context, init_data, init_data_size);
+ }
+
+ int inputs_array_data[] = {3, 0, 1, 2};
+ TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
+ int outputs_array_data[] = {1, 3};
+ TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
+ int temporaries_array_data[] = {0};
+ TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
+
+ TfLiteNode node;
+ node.inputs = inputs_array;
+ node.outputs = outputs_array;
+ node.temporaries = temporaries_array;
+ node.user_data = user_data;
+ node.builtin_data = reinterpret_cast<void*>(&builtin_data);
+ node.custom_initial_data = nullptr;
+ node.custom_initial_data_size = 0;
+ node.delegate = nullptr;
+
+ if (registration->prepare) {
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
+ }
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
+ if (registration->free) {
+ registration->free(&context, user_data);
+ }
+ for (int i = 0; i < output_dims_count; ++i) {
+ TF_LITE_MICRO_EXPECT_EQ(expected_output_data.begin()[i], output_data[i]);
+ }
+}
+
+} // namespace
+} // namespace testing
+} // namespace tflite
+
+TF_LITE_MICRO_TESTS_BEGIN
+
+TF_LITE_MICRO_TEST(SimpleTest) {
+ const int output_dims_count = 6;
+ float output_data[output_dims_count];
+ tflite::testing::TestFullyConnectedFloat( //
+ {2, 2, 10}, // Input shape.
+ {
+ 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
+ 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
+ },
+ {2, 3, 10}, // Weights shape.
+ {
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
+ },
+ {1, 3}, // Bias shape.
+ {
+ 1, 2, 3, // Bias values.
+ },
+ {
+ 24, 25, 26, 58, 59, 60, // Expected results.
+ },
+ {2, 2, 3}, // Output shape.
+ kTfLiteActNone, output_data);
+}
+
+TF_LITE_MICRO_TEST(SimpleTest2) {
+ const int output_dims_count = 6;
+ float output_data[output_dims_count];
+ tflite::testing::TestFullyConnectedFloat( //
+ {2, 2, 2}, // Input shape.
+ {
+ 1, 2, // b = 0
+ 2, 1, // b = 1
+ },
+ {2, 1, 2}, // Weights shape.
+ {
+ 2, 4, // u = 0
+ },
+ {1, 1}, // Bias shape.
+ {
+ 1, // Bias values.
+ },
+ {
+ 11, 9, // Expected results.
+ },
+ {2, 2, 1}, // Output shape.
+ kTfLiteActNone, output_data);
+}
+
+TF_LITE_MICRO_TEST(SimpleTestRelu) {
+ const int output_dims_count = 6;
+ float output_data[output_dims_count];
+ tflite::testing::TestFullyConnectedFloat( //
+ {2, 2, 10}, // Input shape.
+ {
+ 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
+ 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
+ },
+ {2, 3, 10}, // Weights shape.
+ {
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
+ -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, // u = 1
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
+ },
+ {1, 3}, // Bias shape.
+ {
+ 1, -2, 3, // Bias values.
+ },
+ {
+ 24, 0, 26, 58, 0, 60, // Expected results.
+ },
+ {2, 2, 3}, // Output shape.
+ kTfLiteActRelu, output_data);
+}
+
+TF_LITE_MICRO_TEST(SimpleTestQuantized) {
+ using tflite::testing::F2Q;
+ using tflite::testing::F2Q32;
+
+ const float input_min = -63.5f;
+ const float input_max = 64.0f;
+ const float weights_min = -63.5f;
+ const float weights_max = 64.0f;
+ const float bias_min = 0.0f;
+ const float bias_max = 64.0f * (1 << 24);
+ const float output_min = -127.0f;
+ const float output_max = 128.0f;
+ const int output_dims_count = 6;
+ uint8_t output_data[output_dims_count];
+ tflite::testing::TestFullyConnectedQuantized( //
+ {2, 2, 10}, // Input shape.
+ {
+ // Input values.
+ F2Q(1, input_min, input_max), F2Q(2, input_min, input_max),
+ F2Q(3, input_min, input_max), F2Q(4, input_min, input_max),
+ F2Q(5, input_min, input_max), F2Q(6, input_min, input_max),
+ F2Q(7, input_min, input_max), F2Q(8, input_min, input_max),
+ F2Q(-9, input_min, input_max), F2Q(-10, input_min, input_max),
+ F2Q(1, input_min, input_max), F2Q(2, input_min, input_max),
+ F2Q(3, input_min, input_max), F2Q(4, input_min, input_max),
+ F2Q(5, input_min, input_max), F2Q(6, input_min, input_max),
+ F2Q(7, input_min, input_max), F2Q(-8, input_min, input_max),
+ F2Q(9, input_min, input_max), F2Q(-10, input_min, input_max),
+ },
+ input_min, input_max, // Input quantization range.
+ {2, 3, 10}, // Weights shape.
+ {
+ // Weight values.
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ },
+ weights_min, weights_max, // Weights quantization range.
+ {1, 3}, // Bias shape.
+ {
+ F2Q32(1, bias_min, bias_max),
+ F2Q32(2, bias_min, bias_max),
+ F2Q32(3, bias_min, bias_max),
+ },
+ bias_min, bias_max, // Bias quantization range.
+ {
+ // Expected results.
+ F2Q(24, output_min, output_max),
+ F2Q(25, output_min, output_max),
+ F2Q(26, output_min, output_max),
+ F2Q(58, output_min, output_max),
+ F2Q(59, output_min, output_max),
+ F2Q(60, output_min, output_max),
+ },
+ {2, 2, 3}, // Output shape.
+ output_min, output_max, // Output quantization range.
+ kTfLiteActNone, output_data);
+}
+
+TF_LITE_MICRO_TEST(SimpleTestQuantizedRelu) {
+ using tflite::testing::F2Q;
+ using tflite::testing::F2Q32;
+
+ const float input_min = -63.5f;
+ const float input_max = 64.0f;
+ const float weights_min = -63.5f;
+ const float weights_max = 64.0f;
+ const float bias_min = 0.0f;
+ const float bias_max = 64.0f * (1 << 24);
+ const float output_min = -127.0f;
+ const float output_max = 128.0f;
+ const int output_dims_count = 6;
+ uint8_t output_data[output_dims_count];
+ tflite::testing::TestFullyConnectedQuantized( //
+ {2, 2, 10}, // Input shape.
+ {
+ // Input values.
+ F2Q(1, input_min, input_max), F2Q(2, input_min, input_max),
+ F2Q(3, input_min, input_max), F2Q(4, input_min, input_max),
+ F2Q(5, input_min, input_max), F2Q(6, input_min, input_max),
+ F2Q(7, input_min, input_max), F2Q(8, input_min, input_max),
+ F2Q(-9, input_min, input_max), F2Q(-10, input_min, input_max),
+ F2Q(1, input_min, input_max), F2Q(2, input_min, input_max),
+ F2Q(3, input_min, input_max), F2Q(4, input_min, input_max),
+ F2Q(5, input_min, input_max), F2Q(6, input_min, input_max),
+ F2Q(7, input_min, input_max), F2Q(-8, input_min, input_max),
+ F2Q(9, input_min, input_max), F2Q(-10, input_min, input_max),
+ },
+ input_min, input_max, // Input quantization range.
+ {2, 3, 10}, // Weights shape.
+ {
+ // Weight values.
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ F2Q(-1, weights_min, weights_max), F2Q(-2, weights_min, weights_max),
+ F2Q(-3, weights_min, weights_max), F2Q(-4, weights_min, weights_max),
+ F2Q(-5, weights_min, weights_max), F2Q(-6, weights_min, weights_max),
+ F2Q(-7, weights_min, weights_max), F2Q(-8, weights_min, weights_max),
+ F2Q(-9, weights_min, weights_max), F2Q(-10, weights_min, weights_max),
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ },
+ weights_min, weights_max, // Weights quantization range.
+ {1, 3}, // Bias shape.
+ {
+ F2Q32(1, bias_min, bias_max),
+ F2Q32(0, bias_min, bias_max),
+ F2Q32(3, bias_min, bias_max),
+ },
+ bias_min, bias_max, // Bias quantization range.
+ {
+ // Expected results.
+ F2Q(24, output_min, output_max),
+ F2Q(0, output_min, output_max),
+ F2Q(26, output_min, output_max),
+ F2Q(58, output_min, output_max),
+ F2Q(0, output_min, output_max),
+ F2Q(60, output_min, output_max),
+ },
+ {2, 2, 3}, // Output shape.
+ output_min, output_max, // Output quantization range.
+ kTfLiteActRelu, output_data);
+}
+
+TF_LITE_MICRO_TEST(SimpleTestQuantizedOutputMultiplierGreaterThan1) {
+ using tflite::testing::F2Q;
+ using tflite::testing::F2Q32;
+
+ const float input_min = -127.0f;
+ const float input_max = 128.0f;
+ const float weights_min = -127.0f;
+ const float weights_max = 128.0f;
+ const float bias_min = 0.0f;
+ const float bias_max = 256.0f * (1 << 24);
+ const float output_min = -63.5f;
+ const float output_max = 64.0f;
+ const int output_dims_count = 6;
+ uint8_t output_data[output_dims_count];
+ tflite::testing::TestFullyConnectedQuantized( //
+ {2, 2, 10}, // Input shape.
+ {
+ // Input values.
+ F2Q(1, input_min, input_max), F2Q(2, input_min, input_max),
+ F2Q(3, input_min, input_max), F2Q(4, input_min, input_max),
+ F2Q(5, input_min, input_max), F2Q(6, input_min, input_max),
+ F2Q(7, input_min, input_max), F2Q(8, input_min, input_max),
+ F2Q(-9, input_min, input_max), F2Q(-10, input_min, input_max),
+ F2Q(1, input_min, input_max), F2Q(2, input_min, input_max),
+ F2Q(3, input_min, input_max), F2Q(4, input_min, input_max),
+ F2Q(5, input_min, input_max), F2Q(6, input_min, input_max),
+ F2Q(7, input_min, input_max), F2Q(-8, input_min, input_max),
+ F2Q(9, input_min, input_max), F2Q(-10, input_min, input_max),
+ },
+ input_min, input_max, // Input quantization range.
+ {2, 3, 10}, // Weights shape.
+ {
+ // Weight values.
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ },
+ weights_min, weights_max, // Weights quantization range.
+ {1, 3}, // Bias shape.
+ {
+ F2Q32(1, bias_min, bias_max),
+ F2Q32(2, bias_min, bias_max),
+ F2Q32(3, bias_min, bias_max),
+ },
+ bias_min, bias_max, // Bias quantization range.
+ {
+ // Expected results.
+ F2Q(24, output_min, output_max),
+ F2Q(25, output_min, output_max),
+ F2Q(26, output_min, output_max),
+ F2Q(58, output_min, output_max),
+ F2Q(59, output_min, output_max),
+ F2Q(60, output_min, output_max),
+ },
+ {2, 2, 3}, // Output shape.
+ output_min, output_max, // Output quantization range.
+ kTfLiteActNone, output_data);
+}
+
+TF_LITE_MICRO_TEST(SimpleTest4DInput) {
+ const int output_dims_count = 6;
+ float output_data[output_dims_count];
+ tflite::testing::TestFullyConnectedFloat( //
+ {4, 1, 1, 5, 1}, // Input shape.
+ {
+ 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
+ 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
+ },
+ {2, 3, 10}, // Weights shape.
+ {
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
+ },
+ {1, 3}, // Bias shape.
+ {
+ 1, 2, 3, // Bias values.
+ },
+ {
+ 24, 25, 26, 58, 59, 60, // Expected results.
+ },
+ {2, 2, 3}, // Output shape.
+ kTfLiteActNone, output_data);
+}
+
+TF_LITE_MICRO_TEST(SimpleTest4DInputQuantized) {
+ using tflite::testing::F2Q;
+ using tflite::testing::F2Q32;
+
+ const float input_min = -63.5f;
+ const float input_max = 64.0f;
+ const float weights_min = -63.5f;
+ const float weights_max = 64.0f;
+ const float bias_min = 0.0f;
+ const float bias_max = 64.0f * (1 << 24);
+ const float output_min = -127.0f;
+ const float output_max = 128.0f;
+ const int output_dims_count = 6;
+ uint8_t output_data[output_dims_count];
+ tflite::testing::TestFullyConnectedQuantized( //
+ {4, 1, 1, 5, 1}, // Input shape.
+ {
+ // Input values.
+ F2Q(1, input_min, input_max), F2Q(2, input_min, input_max),
+ F2Q(3, input_min, input_max), F2Q(4, input_min, input_max),
+ F2Q(5, input_min, input_max), F2Q(6, input_min, input_max),
+ F2Q(7, input_min, input_max), F2Q(8, input_min, input_max),
+ F2Q(-9, input_min, input_max), F2Q(-10, input_min, input_max),
+ F2Q(1, input_min, input_max), F2Q(2, input_min, input_max),
+ F2Q(3, input_min, input_max), F2Q(4, input_min, input_max),
+ F2Q(5, input_min, input_max), F2Q(6, input_min, input_max),
+ F2Q(7, input_min, input_max), F2Q(-8, input_min, input_max),
+ F2Q(9, input_min, input_max), F2Q(-10, input_min, input_max),
+ },
+ input_min, input_max, // Input quantization range.
+ {2, 3, 10}, // Weights shape.
+ {
+ // Weight values.
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ },
+ weights_min, weights_max, // Weights quantization range.
+ {1, 3}, // Bias shape.
+ {
+ F2Q32(1, bias_min, bias_max),
+ F2Q32(2, bias_min, bias_max),
+ F2Q32(3, bias_min, bias_max),
+ },
+ bias_min, bias_max, // Bias quantization range.
+ {
+ // Expected results.
+ F2Q(24, output_min, output_max),
+ F2Q(25, output_min, output_max),
+ F2Q(26, output_min, output_max),
+ F2Q(58, output_min, output_max),
+ F2Q(59, output_min, output_max),
+ F2Q(60, output_min, output_max),
+ },
+ {2, 2, 3}, // Output shape.
+ output_min, output_max, // Output quantization range.
+ kTfLiteActNone, output_data);
+}
+
+TF_LITE_MICRO_TEST(SimpleTest4DInputQuantizedOutputMultiplierGreaterThan1) {
+ using tflite::testing::F2Q;
+ using tflite::testing::F2Q32;
+
+ const float input_min = -127.0f;
+ const float input_max = 128.0f;
+ const float weights_min = -127.0f;
+ const float weights_max = 128.0f;
+ const float bias_min = 0.0f;
+ const float bias_max = 256.0f * (1 << 24);
+ const float output_min = -63.5f;
+ const float output_max = 64.0f;
+ const int output_dims_count = 6;
+ uint8_t output_data[output_dims_count];
+ tflite::testing::TestFullyConnectedQuantized( //
+ {4, 1, 1, 5, 1}, // Input shape.
+ {
+ // Input values.
+ F2Q(1, input_min, input_max), F2Q(2, input_min, input_max),
+ F2Q(3, input_min, input_max), F2Q(4, input_min, input_max),
+ F2Q(5, input_min, input_max), F2Q(6, input_min, input_max),
+ F2Q(7, input_min, input_max), F2Q(8, input_min, input_max),
+ F2Q(-9, input_min, input_max), F2Q(-10, input_min, input_max),
+ F2Q(1, input_min, input_max), F2Q(2, input_min, input_max),
+ F2Q(3, input_min, input_max), F2Q(4, input_min, input_max),
+ F2Q(5, input_min, input_max), F2Q(6, input_min, input_max),
+ F2Q(7, input_min, input_max), F2Q(-8, input_min, input_max),
+ F2Q(9, input_min, input_max), F2Q(-10, input_min, input_max),
+ },
+ input_min, input_max, // Input quantization range.
+ {2, 3, 10}, // Weights shape.
+ {
+ // Weight values.
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ },
+ weights_min, weights_max, // Weights quantization range.
+ {1, 3}, // Bias shape.
+ {
+ F2Q32(1, bias_min, bias_max),
+ F2Q32(2, bias_min, bias_max),
+ F2Q32(3, bias_min, bias_max),
+ },
+ bias_min, bias_max, // Bias quantization range.
+ {
+ // Expected results.
+ F2Q(24, output_min, output_max),
+ F2Q(25, output_min, output_max),
+ F2Q(26, output_min, output_max),
+ F2Q(58, output_min, output_max),
+ F2Q(59, output_min, output_max),
+ F2Q(60, output_min, output_max),
+ },
+ {2, 2, 3}, // Output shape.
+ output_min, output_max, // Output quantization range.
+ kTfLiteActNone, output_data);
+}
+
+TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/softmax.cc b/tensorflow/contrib/lite/experimental/micro/kernels/softmax.cc
new file mode 100644
index 0000000000..a4019a067c
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/kernels/softmax.cc
@@ -0,0 +1,213 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/kernels/internal/reference/softmax.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace micro {
+namespace activations {
+namespace {
+
+struct OpData {
+ int32_t input_multiplier = 0;
+ int input_left_shift = 0;
+ int32_t input_range_radius = 0;
+ int diff_min = 0;
+};
+
+TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context,
+ const TfLiteTensor* input,
+ TfLiteTensor* output,
+ const TfLiteSoftmaxParams* params,
+ OpData* data) {
+ if (input->type == kTfLiteUInt8) {
+ TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
+ TF_LITE_ENSURE(context, output->params.scale == 1. / 256);
+
+ static const int kScaledDiffIntegerBits = 5;
+
+ tflite::PreprocessSoftmaxScaling(
+ params->beta, input->params.scale, kScaledDiffIntegerBits,
+ &data->input_multiplier, &data->input_left_shift);
+ data->diff_min = -1.0 * tflite::CalculateInputRadius(
+ kScaledDiffIntegerBits, data->input_left_shift);
+ }
+ return kTfLiteOk;
+}
+
+} // namespace
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ return nullptr;
+}
+
+void Free(TfLiteContext* context, void* buffer) {}
+
+TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
+ return kTfLiteOk;
+}
+
+// Takes a 1D tensor and performs softmax along it.
+void Softmax1DFloat(const TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params) {
+ const int input_size = input->dims->data[0];
+ tflite::reference_ops::Softmax(input->data.f, input_size, 1, params->beta,
+ output->data.f);
+}
+
+// Takes a 2D tensor and perform softmax along the last dimension.
+void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params) {
+ const int batch_size = input->dims->data[0];
+ const int input_size = input->dims->data[1];
+ tflite::reference_ops::Softmax(input->data.f, input_size, batch_size,
+ params->beta, output->data.f);
+}
+
+void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params, OpData* data) {
+ // TODO(ahentz): this is arguably a dirty trick. Since the implementation
+ // always traverses the last dimension of a 4D tensor, we will pretend our 1D
+ // tensor is 4D in a special way. We will convert a (Y) shape into a (1,
+ // 1, 1, Y) shape.
+ const int input_size = input->dims->data[0];
+ const int32_t shape_data[4] = {1, 1, 1, input_size};
+ RuntimeShape shape(4, shape_data);
+ SoftmaxParams op_params;
+ op_params.input_multiplier = data->input_multiplier;
+ op_params.input_left_shift = data->input_left_shift;
+ op_params.diff_min = data->diff_min;
+ tflite::reference_ops::Softmax(op_params, shape,
+ GetTensorData<uint8_t>(input), shape,
+ GetTensorData<uint8_t>(output));
+}
+
+void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params, OpData* data) {
+ // TODO(ahentz): this is arguably a dirty trick. Since the implementation
+ // always traverses the last dimension of a 4D tensor, we will pretend our 2D
+ // tensor is 4D in a special way. We will convert a (X, Y) shape into a (X,
+ // 1, 1, Y) shape.
+ const int batch_size = input->dims->data[0];
+ const int input_size = input->dims->data[1];
+ const int32_t shape_data[4] = {batch_size, 1, 1, input_size};
+ RuntimeShape shape(4, shape_data);
+ SoftmaxParams op_params;
+ op_params.input_multiplier = data->input_multiplier;
+ op_params.input_left_shift = data->input_left_shift;
+ op_params.diff_min = data->diff_min;
+ tflite::reference_ops::Softmax(op_params, shape,
+ GetTensorData<uint8_t>(input), shape,
+ GetTensorData<uint8_t>(output));
+}
+
+// Takes a 4D tensor and perform softmax along the forth dimension.
+void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params) {
+ SoftmaxParams op_params;
+ op_params.beta = params->beta;
+ tflite::reference_ops::Softmax(
+ op_params, GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(output), GetTensorData<float>(output));
+}
+
+void Softmax4DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params, OpData* data) {
+ SoftmaxParams op_params;
+ op_params.input_multiplier = data->input_multiplier;
+ op_params.input_left_shift = data->input_left_shift;
+ op_params.diff_min = data->diff_min;
+ tflite::reference_ops::Softmax(
+ op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
+ GetTensorShape(output), GetTensorData<uint8_t>(output));
+}
+
+TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
+
+ const TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+
+ OpData local_data_object;
+ OpData* data = &local_data_object;
+ TF_LITE_ENSURE_STATUS(
+ CalculateSoftmaxOpData(context, input, output, params, data));
+
+ // TODO(ahentz): consider an implementation that works for many (all?)
+ // dimensions.
+ switch (input->type) {
+ case kTfLiteFloat32: {
+ if (NumDimensions(input) == 1) {
+ Softmax1DFloat(input, output, params);
+ return kTfLiteOk;
+ }
+ if (NumDimensions(input) == 2) {
+ Softmax2DFloat(input, output, params);
+ return kTfLiteOk;
+ }
+ if (NumDimensions(input) == 4) {
+ Softmax4DFloat(input, output, params);
+ return kTfLiteOk;
+ }
+ context->ReportError(
+ context, "Only 1D, 2D and 4D tensors supported currently, got %dD.",
+ NumDimensions(input));
+ return kTfLiteError;
+ }
+ case kTfLiteUInt8: {
+ if (NumDimensions(input) == 1) {
+ Softmax1DQuantized(input, output, params, data);
+ return kTfLiteOk;
+ }
+ if (NumDimensions(input) == 2) {
+ Softmax2DQuantized(input, output, params, data);
+ return kTfLiteOk;
+ }
+ if (NumDimensions(input) == 4) {
+ Softmax4DQuantized(input, output, params, data);
+ return kTfLiteOk;
+ }
+ context->ReportError(
+ context, "Only 2D and 4D tensors supported currently, got %dD.",
+ NumDimensions(input));
+ return kTfLiteError;
+ }
+ default:
+ context->ReportError(
+ context, "Only float32 and uint8_t supported currently, got %d.",
+ input->type);
+ return kTfLiteError;
+ }
+}
+} // namespace activations
+
+TfLiteRegistration* Register_SOFTMAX() {
+ static TfLiteRegistration r = {activations::Init, activations::Free,
+ activations::SoftmaxPrepare,
+ activations::SoftmaxEval};
+ return &r;
+}
+
+} // namespace micro
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/softmax_test.cc b/tensorflow/contrib/lite/experimental/micro/kernels/softmax_test.cc
new file mode 100644
index 0000000000..694456d8ac
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/kernels/softmax_test.cc
@@ -0,0 +1,220 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h"
+#include "tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h"
+#include "tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h"
+#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h"
+
+namespace tflite {
+namespace testing {
+namespace {
+
+void TestSoftmaxFloat(std::initializer_list<int> input_dims_data,
+ std::initializer_list<float> input_data,
+ std::initializer_list<float> expected_output_data,
+ std::initializer_list<int> output_dims_data,
+ float* output_data) {
+ TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
+ TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
+ const int output_dims_count = ElementCount(*output_dims);
+
+ constexpr int inputs_size = 2;
+ constexpr int outputs_size = 1;
+ constexpr int tensors_size = inputs_size + outputs_size;
+ TfLiteTensor tensors[tensors_size] = {
+ CreateFloatTensor(input_data, input_dims, "input_tensor"),
+ CreateFloatTensor(output_data, output_dims, "output_tensor"),
+ };
+
+ TfLiteContext context;
+ PopulateContext(tensors, tensors_size, &context);
+
+ ::tflite::ops::micro::AllOpsResolver resolver;
+ const TfLiteRegistration* registration =
+ resolver.FindOp(tflite::BuiltinOperator_SOFTMAX, 1);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
+
+ TfLiteSoftmaxParams builtin_data = {1.0f};
+ const char* init_data = reinterpret_cast<const char*>(&builtin_data);
+ size_t init_data_size = 0;
+ void* user_data = nullptr;
+ if (registration->init) {
+ user_data = registration->init(&context, init_data, init_data_size);
+ }
+ int inputs_array_data[] = {1, 0};
+ TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
+ int outputs_array_data[] = {1, 1};
+ TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
+ int temporaries_array_data[] = {0};
+ TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
+
+ TfLiteNode node;
+ node.inputs = inputs_array;
+ node.outputs = outputs_array;
+ node.temporaries = temporaries_array;
+ node.user_data = user_data;
+ node.builtin_data = reinterpret_cast<void*>(&builtin_data);
+ node.custom_initial_data = nullptr;
+ node.custom_initial_data_size = 0;
+ node.delegate = nullptr;
+ if (registration->prepare) {
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
+ }
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
+ if (registration->free) {
+ registration->free(&context, user_data);
+ }
+ for (int i = 0; i < output_dims_count; ++i) {
+ TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i], output_data[i],
+ 1e-5f);
+ }
+}
+
+void TestSoftmaxQuantized(std::initializer_list<int> input_dims_data,
+ std::initializer_list<uint8_t> input_data,
+ float input_min, float input_max,
+ std::initializer_list<uint8_t> expected_output_data,
+ std::initializer_list<int> output_dims_data,
+ float output_min, float output_max,
+ uint8_t* output_data) {
+ TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
+ TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
+ const int output_dims_count = ElementCount(*output_dims);
+
+ constexpr int inputs_size = 1;
+ constexpr int outputs_size = 1;
+ constexpr int tensors_size = inputs_size + outputs_size;
+ TfLiteTensor tensors[tensors_size] = {
+ CreateQuantizedTensor(input_data, input_dims, "input_tensor", input_min,
+ input_max),
+ CreateQuantizedTensor(output_data, output_dims, "output_tensor",
+ output_min, output_max),
+ };
+
+ TfLiteContext context;
+ PopulateContext(tensors, tensors_size, &context);
+
+ ::tflite::ops::micro::AllOpsResolver resolver;
+ const TfLiteRegistration* registration =
+ resolver.FindOp(tflite::BuiltinOperator_SOFTMAX, 1);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
+
+ TfLiteSoftmaxParams builtin_data = {1.0f};
+ const char* init_data = reinterpret_cast<const char*>(&builtin_data);
+ size_t init_data_size = 0;
+ void* user_data = nullptr;
+ if (registration->init) {
+ user_data = registration->init(&context, init_data, init_data_size);
+ }
+
+ int inputs_array_data[] = {1, 0};
+ TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
+ int outputs_array_data[] = {1, 1};
+ TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
+ int temporaries_array_data[] = {0};
+ TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
+
+ TfLiteNode node;
+ node.inputs = inputs_array;
+ node.outputs = outputs_array;
+ node.temporaries = temporaries_array;
+ node.user_data = user_data;
+ node.builtin_data = reinterpret_cast<void*>(&builtin_data);
+ node.custom_initial_data = nullptr;
+ node.custom_initial_data_size = 0;
+ node.delegate = nullptr;
+
+ if (registration->prepare) {
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
+ }
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
+ if (registration->free) {
+ registration->free(&context, user_data);
+ }
+ for (int i = 0; i < output_dims_count; ++i) {
+ TF_LITE_MICRO_EXPECT_EQ(expected_output_data.begin()[i], output_data[i]);
+ }
+}
+
+} // namespace
+} // namespace testing
+} // namespace tflite
+
+TF_LITE_MICRO_TESTS_BEGIN
+
+TF_LITE_MICRO_TEST(SimpleTest) {
+ const int output_dims_count = 10;
+ float output_data[output_dims_count];
+ tflite::testing::TestSoftmaxFloat( //
+ {2, 2, 5}, // Input shape.
+ {
+ 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0
+ -1.0, -2.0, -3.0, -4.0, -5.0, // b = 0
+ },
+ {
+ // Expected results.
+ 0.011656231,
+ 0.031684921,
+ 0.086128544,
+ 0.234121657,
+ 0.636408647,
+ 0.636408647,
+ 0.234121657,
+ 0.086128544,
+ 0.031684921,
+ 0.011656231,
+ },
+ {2, 2, 5}, // Output shape.
+ output_data);
+}
+
+TF_LITE_MICRO_TEST(SimpleTestQuantized) {
+ using tflite::testing::F2Q;
+
+ const float input_min = -63.5f;
+ const float input_max = 64.0f;
+ const float output_min = 0.0f;
+ const float output_max = (255.0f / 256.0f);
+ const int output_dims_count = 5;
+ uint8_t output_data[output_dims_count];
+ tflite::testing::TestSoftmaxQuantized( //
+ {2, 1, 5}, // Input shape.
+ {
+ F2Q(1.0, input_min, input_max),
+ F2Q(2.0, input_min, input_max),
+ F2Q(3.0, input_min, input_max),
+ F2Q(4.0, input_min, input_max),
+ F2Q(5.0, input_min, input_max),
+ },
+ input_min, input_max, // Input quantized range.
+ {
+ // Expected results.
+ F2Q(0.011656231, output_min, output_max),
+ F2Q(0.031684921, output_min, output_max),
+ F2Q(0.086128544, output_min, output_max),
+ F2Q(0.234121657, output_min, output_max),
+ F2Q(0.636408647, output_min, output_max),
+ },
+ {2, 1, 5}, // Output shape.
+ output_min, output_max, // Output quantized range.
+ output_data);
+}
+
+TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h b/tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h
new file mode 100644
index 0000000000..789a48ece8
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h
@@ -0,0 +1,170 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_
+
+#include <cstdarg>
+#include <initializer_list>
+#include <limits>
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h"
+#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h"
+
+namespace tflite {
+namespace testing {
+
+// How many elements are in the array with this shape.
+inline int ElementCount(const TfLiteIntArray& dims) {
+ int result = 1;
+ for (int i = 0; i < dims.size; ++i) {
+ result *= dims.data[i];
+ }
+ return result;
+}
+
+// Wrapper to forward kernel errors to the interpreter's error reporter.
+inline void ReportOpError(struct TfLiteContext* context, const char* format,
+ ...) {
+ ErrorReporter* error_reporter = static_cast<ErrorReporter*>(context->impl_);
+ va_list args;
+ va_start(args, format);
+ error_reporter->Report(format, args);
+ va_end(args);
+}
+
+// Derives the quantization scaling factor from a min and max range.
+template <typename T>
+inline float ScaleFromMinMax(const float min, const float max) {
+ return (max - min) / ((std::numeric_limits<T>::max() * 1.0) -
+ std::numeric_limits<T>::min());
+}
+
+// Derives the quantization zero point from a min and max range.
+template <typename T>
+inline int ZeroPointFromMinMax(const float min, const float max) {
+ return static_cast<int>((-min / ScaleFromMinMax<T>(min, max)) + 0.5f);
+}
+
+// Converts a float value into an unsigned eight-bit quantized value.
+inline uint8_t F2Q(const float value, const float min, const float max) {
+ int32_t result = ZeroPointFromMinMax<uint8_t>(min, max) +
+ (value / ScaleFromMinMax<uint8_t>(min, max)) + 0.5f;
+ if (result < 0) {
+ result = 0;
+ }
+ if (result > 256) {
+ result = 256;
+ }
+ return result;
+}
+
+// Converts a float value into a signed thirty-two-bit quantized value.
+inline uint8_t F2Q32(const float value, const float min, const float max) {
+ return static_cast<int32_t>((value - ZeroPointFromMinMax<int32_t>(min, max)) /
+ ScaleFromMinMax<int32_t>(min, max));
+}
+
+inline void PopulateContext(TfLiteTensor* tensors, int tensors_size,
+ TfLiteContext* context) {
+ context->tensors_size = tensors_size;
+ context->tensors = tensors;
+ context->impl_ = static_cast<void*>(micro_test::reporter);
+ context->GetExecutionPlan = nullptr;
+ context->ResizeTensor = nullptr;
+ context->ReportError = ReportOpError;
+ context->AddTensors = nullptr;
+ context->GetNodeAndRegistration = nullptr;
+ context->ReplaceSubgraphsWithDelegateKernels = nullptr;
+ context->recommended_num_threads = 1;
+ context->GetExternalContext = nullptr;
+ context->SetExternalContext = nullptr;
+}
+
+inline TfLiteIntArray* IntArrayFromInts(const int* int_array) {
+ return const_cast<TfLiteIntArray*>(
+ reinterpret_cast<const TfLiteIntArray*>(int_array));
+}
+
+inline TfLiteIntArray* IntArrayFromInitializer(
+ std::initializer_list<int> int_initializer) {
+ return IntArrayFromInts(int_initializer.begin());
+}
+
+inline TfLiteTensor CreateFloatTensor(const float* data, TfLiteIntArray* dims,
+ const char* name) {
+ const size_t bytes = ElementCount(*dims) * sizeof(float);
+ return {
+ kTfLiteFloat32, {const_cast<int*>(reinterpret_cast<const int*>(data))},
+ dims, {},
+ kTfLiteMemNone, bytes,
+ nullptr, name};
+}
+
+inline TfLiteTensor CreateFloatTensor(std::initializer_list<float> data,
+ TfLiteIntArray* dims, const char* name) {
+ return CreateFloatTensor(data.begin(), dims, name);
+}
+
+inline TfLiteTensor CreateQuantizedTensor(const uint8_t* data,
+ TfLiteIntArray* dims,
+ const char* name, float min,
+ float max) {
+ const size_t bytes = ElementCount(*dims) * sizeof(uint8_t);
+ const TfLiteQuantizationParams q_params = {
+ ScaleFromMinMax<uint8_t>(min, max),
+ ZeroPointFromMinMax<uint8_t>(min, max)};
+ return {
+ kTfLiteUInt8, {const_cast<int*>(reinterpret_cast<const int*>(data))},
+ dims, q_params,
+ kTfLiteMemNone, bytes,
+ nullptr, name};
+}
+
+inline TfLiteTensor CreateQuantizedTensor(std::initializer_list<uint8_t> data,
+ TfLiteIntArray* dims,
+ const char* name, float min,
+ float max) {
+ return CreateQuantizedTensor(data.begin(), dims, name, min, max);
+}
+
+inline TfLiteTensor CreateQuantized32Tensor(const int32_t* data,
+ TfLiteIntArray* dims,
+ const char* name, float min,
+ float max) {
+ const size_t bytes = ElementCount(*dims) * sizeof(int32_t);
+ const TfLiteQuantizationParams q_params = {
+ ScaleFromMinMax<int32_t>(min, max),
+ ZeroPointFromMinMax<int32_t>(min, max)};
+ return {
+ kTfLiteUInt8, {const_cast<int*>(reinterpret_cast<const int*>(data))},
+ dims, q_params,
+ kTfLiteMemNone, bytes,
+ nullptr, name};
+}
+
+inline TfLiteTensor CreateQuantized32Tensor(std::initializer_list<int32_t> data,
+ TfLiteIntArray* dims,
+ const char* name, float min,
+ float max) {
+ return CreateQuantized32Tensor(data.begin(), dims, name, min, max);
+}
+
+} // namespace testing
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_
diff --git a/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.cc b/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.cc
new file mode 100644
index 0000000000..99dd883661
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.cc
@@ -0,0 +1,78 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h"
+
+#ifdef TF_LITE_MCU_DEBUG_LOG
+#include <debug_log.h>
+#else // TF_LITE_MCU_DEBUG_LOG
+#include <cstdint>
+#include <cstdio>
+void DebugLog(const char* s) { fprintf(stderr, "%s", s); }
+void DebugLogInt32(int32_t i) { fprintf(stderr, "%d", i); }
+void DebugLogUInt32(uint32_t i) { fprintf(stderr, "%d", i); }
+void DebugLogHex(uint32_t i) { fprintf(stderr, "0x%8x", i); }
+void DebugLogFloat(float i) { fprintf(stderr, "%f", i); }
+#endif // TF_LITE_MCU_DEBUG_LOG
+
+namespace tflite {
+namespace {
+void DebugLogPrintf(const char* format, va_list args) {
+ const int output_cache_size = 64;
+ char output_cache[output_cache_size + 1];
+ int output_cache_index = 0;
+ const char* current = format;
+ while (*current != 0) {
+ if (*current == '%') {
+ const char next = *(current + 1);
+ if ((next == 'd') || (next == 's')) {
+ current += 1;
+ if (output_cache_index > 0) {
+ output_cache[output_cache_index] = 0;
+ DebugLog(output_cache);
+ output_cache_index = 0;
+ }
+ if (next == 'd') {
+ DebugLogInt32(va_arg(args, int));
+ } else if (next == 's') {
+ DebugLog(va_arg(args, char*));
+ }
+ }
+ } else {
+ output_cache[output_cache_index] = *current;
+ output_cache_index += 1;
+ }
+ if (output_cache_index >= output_cache_size) {
+ output_cache[output_cache_index] = 0;
+ DebugLog(output_cache);
+ output_cache_index = 0;
+ }
+ current += 1;
+ }
+ if (output_cache_index > 0) {
+ output_cache[output_cache_index] = 0;
+ DebugLog(output_cache);
+ output_cache_index = 0;
+ }
+ DebugLog("\n");
+}
+} // namespace
+
+int MicroErrorReporter::Report(const char* format, va_list args) {
+ DebugLogPrintf(format, args);
+ return 0;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h b/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h
new file mode 100644
index 0000000000..33e54f7990
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h
@@ -0,0 +1,34 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_ERROR_REPORTER_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_ERROR_REPORTER_H_
+
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/experimental/micro/compatibility.h"
+
+namespace tflite {
+
+class MicroErrorReporter : public ErrorReporter {
+ public:
+ ~MicroErrorReporter() {}
+ int Report(const char* format, va_list args) override;
+
+ private:
+ TF_LITE_REMOVE_VIRTUAL_DELETE
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_ERROR_REPORTER_H_
diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h b/tensorflow/contrib/lite/experimental/micro/micro_error_reporter_test.cc
index 86250e6692..ef3c32050c 100644
--- a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h
+++ b/tensorflow/contrib/lite/experimental/micro/micro_error_reporter_test.cc
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -13,9 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
-#define TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
+#include "tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h"
-#include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h"
-
-#endif // TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
+int main(int argc, char** argv) {
+ tflite::MicroErrorReporter micro_error_reporter;
+ tflite::ErrorReporter* error_reporter = &micro_error_reporter;
+ error_reporter->Report("Number: %d", 42);
+ error_reporter->Report("Badly-formed format string %");
+ error_reporter->Report("Another % badly-formed %% format string");
+ error_reporter->Report("~~~%s~~~", "ALL TESTS PASSED");
+}
diff --git a/tensorflow/contrib/lite/experimental/micro/micro_interpreter.cc b/tensorflow/contrib/lite/experimental/micro/micro_interpreter.cc
new file mode 100644
index 0000000000..0f38991bb0
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/micro_interpreter.cc
@@ -0,0 +1,310 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/experimental/micro/micro_interpreter.h"
+
+#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h"
+#include "tensorflow/contrib/lite/experimental/micro/compatibility.h"
+
+namespace tflite {
+namespace {
+const int kStackDataAllocatorSize = 128;
+class StackDataAllocator : public BuiltinDataAllocator {
+ public:
+ void* Allocate(size_t size) override {
+ if (size > kStackDataAllocatorSize) {
+ return nullptr;
+ } else {
+ return data_;
+ }
+ }
+ void Deallocate(void* data) override {
+ // Do nothing.
+ }
+
+ private:
+ uint8_t data_[kStackDataAllocatorSize];
+
+ TF_LITE_REMOVE_VIRTUAL_DELETE
+};
+
+const char* OpNameFromRegistration(const TfLiteRegistration* registration) {
+ if (registration->builtin_code == BuiltinOperator_CUSTOM) {
+ return registration->custom_name;
+ } else {
+ return EnumNameBuiltinOperator(BuiltinOperator(registration->builtin_code));
+ }
+}
+
+void ReportOpError(struct TfLiteContext* context, const char* format, ...) {
+ MicroInterpreter* interpreter =
+ static_cast<MicroInterpreter*>(context->impl_);
+ va_list args;
+ va_start(args, format);
+ interpreter->error_reporter()->Report(format, args);
+ va_end(args);
+}
+
+} // namespace
+
+MicroInterpreter::MicroInterpreter(const Model* model,
+ const OpResolver& op_resolver,
+ SimpleTensorAllocator* tensor_allocator,
+ ErrorReporter* error_reporter)
+ : model_(model),
+ op_resolver_(op_resolver),
+ tensor_allocator_(tensor_allocator),
+ error_reporter_(error_reporter),
+ initialization_status_(kTfLiteOk) {
+ const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers =
+ model->buffers();
+ auto* subgraphs = model->subgraphs();
+ if (subgraphs->size() != 1) {
+ error_reporter->Report("Only 1 subgraph is currently supported.\n");
+ initialization_status_ = kTfLiteError;
+ return;
+ }
+ subgraph_ = (*subgraphs)[0];
+ tensors_ = subgraph_->tensors();
+ operators_ = subgraph_->operators();
+
+ context_.tensors_size = tensors_->Length();
+ context_.tensors =
+ reinterpret_cast<TfLiteTensor*>(tensor_allocator_->AllocateMemory(
+ sizeof(TfLiteTensor) * context_.tensors_size));
+ for (int i = 0; i < subgraph_->inputs()->Length(); ++i) {
+ const int tensor_index = subgraph_->inputs()->Get(i);
+ const auto* tensor = tensors_->Get(tensor_index);
+ initialization_status_ = tensor_allocator_->AllocateTensor(
+ *tensor, 0, operators_->Length(), buffers, error_reporter,
+ &context_.tensors[tensor_index]);
+ if (initialization_status_ != kTfLiteOk) {
+ return;
+ }
+ }
+
+ int* first_created = reinterpret_cast<int*>(
+ tensor_allocator_->AllocateMemory(sizeof(int) * tensors_->Length()));
+ int* last_used = reinterpret_cast<int*>(
+ tensor_allocator_->AllocateMemory(sizeof(int) * tensors_->Length()));
+ for (int i = 0; i < tensors_->Length(); ++i) {
+ first_created[i] = -1;
+ last_used[i] = -1;
+ }
+
+ for (int i = (operators_->Length() - 1); i >= 0; --i) {
+ const auto* op = operators_->Get(i);
+ for (int n = 0; n < op->inputs()->Length(); ++n) {
+ const int tensor_index = op->inputs()->Get(n);
+ if ((last_used[tensor_index] == -1) || (last_used[tensor_index] < i)) {
+ last_used[tensor_index] = i;
+ }
+ }
+ for (int n = 0; n < op->outputs()->Length(); ++n) {
+ const int tensor_index = op->outputs()->Get(n);
+ const int create_before = i;
+ int destroy_after = last_used[tensor_index];
+ if (destroy_after == -1) {
+ destroy_after = operators_->Length();
+ }
+ const auto* tensor = tensors_->Get(tensor_index);
+ if (!tensor->is_variable()) {
+ initialization_status_ = tensor_allocator_->AllocateTensor(
+ *tensor, create_before, destroy_after, buffers, error_reporter,
+ &context_.tensors[tensor_index]);
+ if (initialization_status_ != kTfLiteOk) {
+ return;
+ }
+ first_created[tensor_index] = i;
+ }
+ }
+ }
+
+ for (int i = 0; i < tensors_->Length(); ++i) {
+ const auto* tensor = tensors_->Get(i);
+ const bool is_read_only = (first_created[i] == -1) && (last_used[i] != -1);
+ if (tensor->is_variable() || is_read_only) {
+ initialization_status_ = tensor_allocator_->AllocateTensor(
+ *tensor, 0, operators_->Length(), buffers, error_reporter,
+ &context_.tensors[i]);
+ if (initialization_status_ != kTfLiteOk) {
+ return;
+ }
+ }
+ }
+ context_.impl_ = static_cast<void*>(this);
+ context_.GetExecutionPlan = nullptr;
+ context_.ResizeTensor = nullptr;
+ context_.ReportError = ReportOpError;
+ context_.AddTensors = nullptr;
+ context_.GetNodeAndRegistration = nullptr;
+ context_.ReplaceSubgraphsWithDelegateKernels = nullptr;
+ context_.recommended_num_threads = 1;
+ context_.GetExternalContext = nullptr;
+ context_.SetExternalContext = nullptr;
+}
+
+TfLiteStatus MicroInterpreter::Invoke() {
+ if (initialization_status_ != kTfLiteOk) {
+ error_reporter_->Report("Invoke() called after initialization failed\n");
+ return kTfLiteError;
+ }
+ TfLiteStatus status = kTfLiteOk;
+ auto opcodes = model_->operator_codes();
+ for (int i = 0; i < operators_->Length(); ++i) {
+ const auto* op = operators_->Get(i);
+ int index = op->opcode_index();
+ if (index < 0 || index >= opcodes->size()) {
+ error_reporter_->Report("Missing registration for opcode_index %d\n",
+ index);
+ return kTfLiteError;
+ }
+ auto opcode = (*opcodes)[index];
+ const TfLiteRegistration* registration = nullptr;
+ status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_,
+ &registration);
+ if (status != kTfLiteOk) {
+ return status;
+ }
+ if (registration == nullptr) {
+ error_reporter_->Report("Skipping op for opcode_index %d\n", index);
+ return kTfLiteError;
+ }
+ BuiltinOperator op_type =
+ static_cast<BuiltinOperator>(registration->builtin_code);
+
+ if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) {
+ error_reporter_->Report(
+ "Found builtin operator %s with custom options.\n",
+ EnumNameBuiltinOperator(op_type));
+ }
+ StackDataAllocator stack_data_allocator;
+ const char* custom_data = nullptr;
+ size_t custom_data_size = 0;
+ unsigned char* builtin_data = nullptr;
+ if (op->custom_options()) {
+ custom_data = reinterpret_cast<const char*>(op->custom_options()->data());
+ custom_data_size = op->custom_options()->size();
+ } else {
+ TF_LITE_ENSURE_STATUS(ParseOpData(op, op_type, error_reporter_,
+ &stack_data_allocator,
+ (void**)(&builtin_data)));
+ }
+
+ const char* init_data;
+ size_t init_data_size;
+ if (registration->builtin_code == BuiltinOperator_CUSTOM) {
+ init_data = custom_data;
+ init_data_size = custom_data_size;
+ } else {
+ init_data = reinterpret_cast<const char*>(builtin_data);
+ init_data_size = 0;
+ }
+ void* user_data = nullptr;
+ if (registration->init) {
+ user_data = registration->init(&context_, init_data, init_data_size);
+ }
+
+ const int kMaxInputs = 16;
+ int inputs_data[kMaxInputs + 1];
+ TfLiteIntArray* inputs_array =
+ reinterpret_cast<TfLiteIntArray*>(inputs_data);
+ if (op->inputs()->Length() >= kMaxInputs) {
+ error_reporter_->Report("Too many inputs (%d)\n", op->inputs()->Length());
+ return kTfLiteError;
+ }
+ inputs_array->size = op->inputs()->Length();
+ for (int n = 0; n < op->inputs()->Length(); ++n) {
+ inputs_array->data[n] = op->inputs()->Get(n);
+ }
+
+ const int kMaxOutputs = 16;
+ int outputs_data[kMaxOutputs + 1];
+ TfLiteIntArray* outputs_array =
+ reinterpret_cast<TfLiteIntArray*>(outputs_data);
+ if (op->outputs()->Length() >= kMaxOutputs) {
+ error_reporter_->Report("Too many outputs (%d)\n",
+ op->outputs()->Length());
+ return kTfLiteError;
+ }
+ outputs_array->size = op->outputs()->Length();
+ for (int n = 0; n < op->outputs()->Length(); ++n) {
+ outputs_array->data[n] = op->outputs()->Get(n);
+ }
+
+ const int kMaxTemporaries = 16;
+ int temporaries_data[kMaxTemporaries + 1];
+ TfLiteIntArray* temporaries_array =
+ reinterpret_cast<TfLiteIntArray*>(temporaries_data);
+ temporaries_array->size = 0;
+
+ TfLiteNode node;
+ node.inputs = inputs_array;
+ node.outputs = outputs_array;
+ node.temporaries = temporaries_array;
+ node.user_data = user_data;
+ node.builtin_data = reinterpret_cast<void*>(builtin_data);
+ node.custom_initial_data = custom_data;
+ node.custom_initial_data_size = custom_data_size;
+ node.delegate = nullptr;
+ if (registration->prepare) {
+ TfLiteStatus prepare_status = registration->prepare(&context_, &node);
+ if (prepare_status != kTfLiteOk) {
+ error_reporter_->Report(
+ "Node %s (number %d) failed to prepare with status %d",
+ OpNameFromRegistration(registration), i, prepare_status);
+ return kTfLiteError;
+ }
+ }
+
+ if (registration->invoke) {
+ TfLiteStatus invoke_status = registration->invoke(&context_, &node);
+ if (invoke_status != kTfLiteOk) {
+ error_reporter_->Report(
+ "Node %s (number %d) failed to invoke with status %d",
+ OpNameFromRegistration(registration), i, invoke_status);
+ return kTfLiteError;
+ }
+ }
+
+ if (registration->free) {
+ registration->free(&context_, user_data);
+ }
+ }
+ return status;
+}
+
+TfLiteTensor* MicroInterpreter::input(int index) {
+ const flatbuffers::Vector<int32_t>* inputs = subgraph_->inputs();
+ const size_t length = inputs->Length();
+ if ((index < 0) || (index >= length)) {
+ error_reporter_->Report("Input index %d out of range (length is %d)", index,
+ length);
+ return nullptr;
+ }
+ return &(context_.tensors[inputs->Get(index)]);
+}
+
+TfLiteTensor* MicroInterpreter::output(int index) {
+ const flatbuffers::Vector<int32_t>* outputs = subgraph_->outputs();
+ const size_t length = outputs->Length();
+ if ((index < 0) || (index >= outputs->Length())) {
+ error_reporter_->Report("Output index %d out of range (length is %d)",
+ index, length);
+ return nullptr;
+ }
+ return &(context_.tensors[outputs->Get(index)]);
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/experimental/micro/micro_interpreter.h b/tensorflow/contrib/lite/experimental/micro/micro_interpreter.h
new file mode 100644
index 0000000000..a88514cde8
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/micro_interpreter.h
@@ -0,0 +1,71 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_INTERPRETER_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_INTERPRETER_H_
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+#include "tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+
+namespace tflite {
+
+class MicroInterpreter {
+ public:
+ // The lifetime of the model, op resolver, allocator, and error reporter must
+ // be at least as long as that of the interpreter object, since the
+ // interpreter may need to access them at any time. This means that you should
+ // usually create them with the same scope as each other, for example having
+ // them all allocated on the stack as local variables through a top-level
+ // function.
+ // The interpreter doesn't do any deallocation of any of the pointed-to
+ // objects, ownership remains with the caller.
+ MicroInterpreter(const Model* model, const OpResolver& op_resolver,
+ SimpleTensorAllocator* tensor_allocator,
+ ErrorReporter* error_reporter);
+
+ TfLiteStatus Invoke();
+
+ size_t tensors_size() const { return context_.tensors_size; }
+ TfLiteTensor* tensor(int tensor_index);
+
+ TfLiteTensor* input(int index);
+ size_t inputs_size() const { return subgraph_->inputs()->Length(); }
+
+ TfLiteTensor* output(int index);
+ size_t outputs_size() const { return subgraph_->outputs()->Length(); }
+
+ TfLiteStatus initialization_status() const { return initialization_status_; }
+
+ ErrorReporter* error_reporter() { return error_reporter_; }
+
+ private:
+ const Model* model_;
+ const OpResolver& op_resolver_;
+ SimpleTensorAllocator* tensor_allocator_;
+ ErrorReporter* error_reporter_;
+
+ TfLiteStatus initialization_status_;
+ const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors_;
+ const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators_;
+ TfLiteContext context_;
+
+ const SubGraph* subgraph_;
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_INTERPRETER_H_
diff --git a/tensorflow/contrib/lite/experimental/micro/micro_interpreter_test.cc b/tensorflow/contrib/lite/experimental/micro/micro_interpreter_test.cc
new file mode 100644
index 0000000000..251e5f7203
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/micro_interpreter_test.cc
@@ -0,0 +1,197 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/experimental/micro/micro_interpreter.h"
+
+#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h"
+
+namespace tflite {
+namespace {
+void* MockInit(TfLiteContext* context, const char* buffer, size_t length) {
+ // Do nothing.
+ return nullptr;
+}
+
+void MockFree(TfLiteContext* context, void* buffer) {
+ // Do nothing.
+}
+
+TfLiteStatus MockPrepare(TfLiteContext* context, TfLiteNode* node) {
+ return kTfLiteOk;
+}
+
+TfLiteStatus MockInvoke(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
+ const int32_t* input_data = input->data.i32;
+ const TfLiteTensor* weight = &context->tensors[node->inputs->data[1]];
+ const uint8_t* weight_data = weight->data.uint8;
+ TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
+ int32_t* output_data = output->data.i32;
+ output_data[0] = input_data[0] + weight_data[0];
+ return kTfLiteOk;
+}
+
+class MockOpResolver : public OpResolver {
+ public:
+ const TfLiteRegistration* FindOp(BuiltinOperator op,
+ int version) const override {
+ return nullptr;
+ }
+ const TfLiteRegistration* FindOp(const char* op, int version) const override {
+ if (strcmp(op, "mock_custom") == 0) {
+ static TfLiteRegistration r = {MockInit, MockFree, MockPrepare,
+ MockInvoke};
+ return &r;
+ } else {
+ return nullptr;
+ }
+ }
+};
+
+class StackAllocator : public flatbuffers::Allocator {
+ public:
+ StackAllocator() : data_(data_backing_), data_size_(0) {}
+
+ uint8_t* allocate(size_t size) override {
+ if ((data_size_ + size) > kStackAllocatorSize) {
+ // TODO(petewarden): Add error reporting beyond returning null!
+ return nullptr;
+ }
+ uint8_t* result = data_;
+ data_ += size;
+ data_size_ += size;
+ return result;
+ }
+
+ void deallocate(uint8_t* p, size_t) override {}
+
+ static StackAllocator& instance() {
+ // Avoid using true dynamic memory allocation to be portable to bare metal.
+ static char inst_memory[sizeof(StackAllocator)];
+ static StackAllocator* inst = new (inst_memory) StackAllocator;
+ return *inst;
+ }
+
+ static constexpr int kStackAllocatorSize = 4096;
+
+ private:
+ uint8_t data_backing_[kStackAllocatorSize];
+ uint8_t* data_;
+ int data_size_;
+};
+
+const Model* BuildMockModel() {
+ using flatbuffers::Offset;
+ flatbuffers::FlatBufferBuilder builder(StackAllocator::kStackAllocatorSize,
+ &StackAllocator::instance());
+ constexpr size_t buffer_data_size = 1;
+ const uint8_t buffer_data[buffer_data_size] = {21};
+ constexpr size_t buffers_size = 2;
+ const Offset<Buffer> buffers[buffers_size] = {
+ CreateBuffer(builder),
+ CreateBuffer(builder,
+ builder.CreateVector(buffer_data, buffer_data_size))};
+ constexpr size_t tensor_shape_size = 1;
+ const int32_t tensor_shape[tensor_shape_size] = {1};
+ constexpr size_t tensors_size = 3;
+ const Offset<Tensor> tensors[tensors_size] = {
+ CreateTensor(builder,
+ builder.CreateVector(tensor_shape, tensor_shape_size),
+ TensorType_INT32, 0,
+ builder.CreateString("test_input_tensor"), 0, false),
+ CreateTensor(builder,
+ builder.CreateVector(tensor_shape, tensor_shape_size),
+ TensorType_UINT8, 1,
+ builder.CreateString("test_weight_tensor"), 0, false),
+ CreateTensor(builder,
+ builder.CreateVector(tensor_shape, tensor_shape_size),
+ TensorType_INT32, 0,
+ builder.CreateString("test_output_tensor"), 0, false),
+ };
+ constexpr size_t inputs_size = 1;
+ const int32_t inputs[inputs_size] = {0};
+ constexpr size_t outputs_size = 1;
+ const int32_t outputs[outputs_size] = {2};
+ constexpr size_t operator_inputs_size = 2;
+ const int32_t operator_inputs[operator_inputs_size] = {0, 1};
+ constexpr size_t operator_outputs_size = 1;
+ const int32_t operator_outputs[operator_outputs_size] = {2};
+ constexpr size_t operators_size = 1;
+ const Offset<Operator> operators[operators_size] = {CreateOperator(
+ builder, 0, builder.CreateVector(operator_inputs, operator_inputs_size),
+ builder.CreateVector(operator_outputs, operator_outputs_size),
+ BuiltinOptions_NONE)};
+ constexpr size_t subgraphs_size = 1;
+ const Offset<SubGraph> subgraphs[subgraphs_size] = {
+ CreateSubGraph(builder, builder.CreateVector(tensors, tensors_size),
+ builder.CreateVector(inputs, inputs_size),
+ builder.CreateVector(outputs, outputs_size),
+ builder.CreateVector(operators, operators_size),
+ builder.CreateString("test_subgraph"))};
+ constexpr size_t operator_codes_size = 1;
+ const Offset<OperatorCode> operator_codes[operator_codes_size] = {
+ CreateOperatorCodeDirect(builder, BuiltinOperator_CUSTOM, "mock_custom",
+ 0)};
+ const Offset<Model> model_offset = CreateModel(
+ builder, 0, builder.CreateVector(operator_codes, operator_codes_size),
+ builder.CreateVector(subgraphs, subgraphs_size),
+ builder.CreateString("test_model"),
+ builder.CreateVector(buffers, buffers_size));
+ FinishModelBuffer(builder, model_offset);
+ void* model_pointer = builder.GetBufferPointer();
+ const Model* model = flatbuffers::GetRoot<Model>(model_pointer);
+ return model;
+}
+
+} // namespace
+} // namespace tflite
+
+TF_LITE_MICRO_TESTS_BEGIN
+
+TF_LITE_MICRO_TEST(TestInterpreter) {
+ const tflite::Model* model = tflite::BuildMockModel();
+ TF_LITE_MICRO_EXPECT_NE(nullptr, model);
+ tflite::MockOpResolver mock_resolver;
+ constexpr size_t allocator_buffer_size = 1024;
+ uint8_t allocator_buffer[allocator_buffer_size];
+ tflite::SimpleTensorAllocator simple_tensor_allocator(allocator_buffer,
+ allocator_buffer_size);
+ tflite::MicroInterpreter interpreter(
+ model, mock_resolver, &simple_tensor_allocator, micro_test::reporter);
+ TF_LITE_MICRO_EXPECT_EQ(1, interpreter.inputs_size());
+ TF_LITE_MICRO_EXPECT_EQ(1, interpreter.outputs_size());
+
+ TfLiteTensor* input = interpreter.input(0);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, input);
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt32, input->type);
+ TF_LITE_MICRO_EXPECT_EQ(1, input->dims->size);
+ TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
+ TF_LITE_MICRO_EXPECT_EQ(4, input->bytes);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, input->data.i32);
+ input->data.i32[0] = 21;
+
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, interpreter.Invoke());
+
+ TfLiteTensor* output = interpreter.output(0);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, output);
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt32, output->type);
+ TF_LITE_MICRO_EXPECT_EQ(1, output->dims->size);
+ TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]);
+ TF_LITE_MICRO_EXPECT_EQ(4, output->bytes);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, output->data.i32);
+ TF_LITE_MICRO_EXPECT_EQ(42, output->data.i32[0]);
+}
+
+TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.cc b/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.cc
new file mode 100644
index 0000000000..40c21c6448
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.cc
@@ -0,0 +1,80 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h"
+
+namespace tflite {
+
+const TfLiteRegistration* MicroMutableOpResolver::FindOp(
+ tflite::BuiltinOperator op, int version) const {
+ for (int i = 0; i < registrations_len_; ++i) {
+ const TfLiteRegistration& registration = registrations_[i];
+ if ((registration.builtin_code == op) &&
+ (registration.version == version)) {
+ return &registration;
+ }
+ }
+ return nullptr;
+}
+
+const TfLiteRegistration* MicroMutableOpResolver::FindOp(const char* op,
+ int version) const {
+ for (int i = 0; i < registrations_len_; ++i) {
+ const TfLiteRegistration& registration = registrations_[i];
+ if ((registration.builtin_code == -1) &&
+ (strcmp(registration.custom_name, op) == 0) &&
+ (registration.version == version)) {
+ return &registration;
+ }
+ }
+ return nullptr;
+}
+
+void MicroMutableOpResolver::AddBuiltin(tflite::BuiltinOperator op,
+ TfLiteRegistration* registration,
+ int min_version, int max_version) {
+ for (int version = min_version; version <= max_version; ++version) {
+ if (registrations_len_ >= TFLITE_REGISTRATIONS_MAX) {
+ // TODO(petewarden) - Add error reporting hooks so we can report this!
+ return;
+ }
+ TfLiteRegistration* new_registration = &registrations_[registrations_len_];
+ registrations_len_ += 1;
+
+ *new_registration = *registration;
+ new_registration->builtin_code = op;
+ new_registration->version = version;
+ }
+}
+
+void MicroMutableOpResolver::AddCustom(const char* name,
+ TfLiteRegistration* registration,
+ int min_version, int max_version) {
+ for (int version = min_version; version <= max_version; ++version) {
+ if (registrations_len_ >= TFLITE_REGISTRATIONS_MAX) {
+ // TODO(petewarden) - Add error reporting hooks so we can report this!
+ return;
+ }
+ TfLiteRegistration* new_registration = &registrations_[registrations_len_];
+ registrations_len_ += 1;
+
+ *new_registration = *registration;
+ new_registration->builtin_code = -1;
+ new_registration->custom_name = name;
+ new_registration->version = version;
+ }
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h b/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h
new file mode 100644
index 0000000000..f3750a2484
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
+
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+#include "tensorflow/contrib/lite/experimental/micro/compatibility.h"
+
+#ifndef TFLITE_REGISTRATIONS_MAX
+#define TFLITE_REGISTRATIONS_MAX (128)
+#endif
+
+namespace tflite {
+
+class MicroMutableOpResolver : public OpResolver {
+ public:
+ const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
+ int version) const override;
+ const TfLiteRegistration* FindOp(const char* op, int version) const override;
+ void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration,
+ int min_version = 1, int max_version = 1);
+ void AddCustom(const char* name, TfLiteRegistration* registration,
+ int min_version = 1, int max_version = 1);
+
+ private:
+ TfLiteRegistration registrations_[TFLITE_REGISTRATIONS_MAX];
+ int registrations_len_ = 0;
+
+ TF_LITE_REMOVE_VIRTUAL_DELETE
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
diff --git a/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver_test.cc b/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver_test.cc
new file mode 100644
index 0000000000..5420a33e87
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver_test.cc
@@ -0,0 +1,83 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h"
+
+#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h"
+
+namespace tflite {
+namespace {
+void* MockInit(TfLiteContext* context, const char* buffer, size_t length) {
+ // Do nothing.
+ return nullptr;
+}
+
+void MockFree(TfLiteContext* context, void* buffer) {
+ // Do nothing.
+}
+
+TfLiteStatus MockPrepare(TfLiteContext* context, TfLiteNode* node) {
+ return kTfLiteOk;
+}
+
+TfLiteStatus MockInvoke(TfLiteContext* context, TfLiteNode* node) {
+ return kTfLiteOk;
+}
+} // namespace
+} // namespace tflite
+
+TF_LITE_MICRO_TESTS_BEGIN
+
+TF_LITE_MICRO_TEST(TestOperations) {
+ using tflite::BuiltinOperator_CONV_2D;
+ using tflite::BuiltinOperator_RELU;
+ using tflite::MicroMutableOpResolver;
+ using tflite::OpResolver;
+
+ static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree,
+ tflite::MockPrepare, tflite::MockInvoke};
+
+ MicroMutableOpResolver micro_mutable_op_resolver;
+ micro_mutable_op_resolver.AddBuiltin(BuiltinOperator_CONV_2D, &r, 0, 2);
+ micro_mutable_op_resolver.AddCustom("mock_custom", &r, 0, 3);
+ OpResolver* resolver = &micro_mutable_op_resolver;
+
+ const TfLiteRegistration* registration =
+ resolver->FindOp(BuiltinOperator_CONV_2D, 0);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
+ TF_LITE_MICRO_EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
+
+ registration = resolver->FindOp(BuiltinOperator_CONV_2D, 10);
+ TF_LITE_MICRO_EXPECT_EQ(nullptr, registration);
+
+ registration = resolver->FindOp(BuiltinOperator_RELU, 0);
+ TF_LITE_MICRO_EXPECT_EQ(nullptr, registration);
+
+ registration = resolver->FindOp("mock_custom", 0);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
+ TF_LITE_MICRO_EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
+
+ registration = resolver->FindOp("mock_custom", 10);
+ TF_LITE_MICRO_EXPECT_EQ(nullptr, registration);
+
+ registration = resolver->FindOp("nonexistent_custom", 0);
+ TF_LITE_MICRO_EXPECT_EQ(nullptr, registration);
+}
+
+TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.cc b/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.cc
new file mode 100644
index 0000000000..8c090a20a5
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.cc
@@ -0,0 +1,149 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h"
+
+#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h"
+
+namespace tflite {
+namespace {
+
+TfLiteStatus TfLiteTypeSizeOf(TfLiteType type, size_t* size,
+ ErrorReporter* reporter) {
+ switch (type) {
+ case kTfLiteFloat32:
+ *size = sizeof(float);
+ break;
+ case kTfLiteInt16:
+ *size = sizeof(int16_t);
+ break;
+ case kTfLiteInt32:
+ *size = sizeof(int32_t);
+ break;
+ case kTfLiteUInt8:
+ *size = sizeof(uint8_t);
+ break;
+ case kTfLiteInt64:
+ *size = sizeof(int64_t);
+ break;
+ case kTfLiteBool:
+ *size = sizeof(bool);
+ break;
+ case kTfLiteComplex64:
+ *size = sizeof(float) * 2;
+ break;
+ default:
+ reporter->Report(
+ "Only float32, int16, int32, int64, uint8, bool, complex64 "
+ "supported currently.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus BytesRequired(const tflite::Tensor& flatbuffer_tensor,
+ size_t dims_size, size_t* bytes,
+ ErrorReporter* error_reporter) {
+ TfLiteType tf_lite_type;
+ TF_LITE_ENSURE_STATUS(ConvertTensorType(flatbuffer_tensor.type(),
+ &tf_lite_type, error_reporter));
+ size_t type_size;
+ TF_LITE_ENSURE_STATUS(
+ TfLiteTypeSizeOf(tf_lite_type, &type_size, error_reporter));
+ *bytes = dims_size * type_size;
+ return kTfLiteOk;
+}
+
+} // namespace
+
+TfLiteStatus SimpleTensorAllocator::AllocateTensor(
+ const tflite::Tensor& flatbuffer_tensor, int create_before,
+ int destroy_after,
+ const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
+ ErrorReporter* error_reporter, TfLiteTensor* result) {
+ TF_LITE_ENSURE_STATUS(ConvertTensorType(flatbuffer_tensor.type(),
+ &result->type, error_reporter));
+ result->is_variable = flatbuffer_tensor.is_variable();
+
+ result->data.raw = nullptr;
+ result->bytes = 0;
+ if (auto* buffer = (*buffers)[flatbuffer_tensor.buffer()]) {
+ if (auto* array = buffer->data()) {
+ if (size_t array_size = array->size()) {
+ result->data.raw =
+ const_cast<char*>(reinterpret_cast<const char*>(array->data()));
+ TF_LITE_ENSURE_STATUS(BytesRequired(flatbuffer_tensor, array_size,
+ &result->bytes, error_reporter));
+ }
+ }
+ }
+ if (result->data.raw) {
+ result->allocation_type = kTfLiteMmapRo;
+ } else {
+ int data_size = 1;
+ for (int n = 0; n < flatbuffer_tensor.shape()->Length(); ++n) {
+ data_size *= flatbuffer_tensor.shape()->Get(n);
+ }
+ TF_LITE_ENSURE_STATUS(BytesRequired(flatbuffer_tensor, data_size,
+ &result->bytes, error_reporter));
+ result->data.raw = reinterpret_cast<char*>(AllocateMemory(result->bytes));
+ if (result->data.raw == nullptr) {
+ const char* tensor_name = flatbuffer_tensor.name()->c_str();
+ if (tensor_name == nullptr) {
+ tensor_name = "<None>";
+ }
+ error_reporter->Report(
+ "Couldn't allocate memory for tensor '%s', wanted %d bytes but only "
+ "%d were available",
+ tensor_name, result->bytes, (data_size_max_ - data_size_));
+ return kTfLiteError;
+ }
+ result->allocation_type = kTfLiteArenaRw;
+ }
+ result->dims = reinterpret_cast<TfLiteIntArray*>(
+ AllocateMemory(sizeof(int) * (flatbuffer_tensor.shape()->Length() + 1)));
+ result->dims->size = flatbuffer_tensor.shape()->Length();
+ for (int n = 0; n < flatbuffer_tensor.shape()->Length(); ++n) {
+ result->dims->data[n] = flatbuffer_tensor.shape()->Get(n);
+ }
+ if (flatbuffer_tensor.quantization()) {
+ result->params.scale = flatbuffer_tensor.quantization()->scale()->Get(0);
+ result->params.zero_point =
+ flatbuffer_tensor.quantization()->zero_point()->Get(0);
+ }
+ result->allocation = nullptr;
+ if (flatbuffer_tensor.name()) {
+ result->name = flatbuffer_tensor.name()->c_str();
+ } else {
+ result->name = "<No name>";
+ }
+ result->delegate = nullptr;
+ result->buffer_handle = 0;
+ result->data_is_stale = false;
+ return kTfLiteOk;
+}
+
+uint8_t* SimpleTensorAllocator::AllocateMemory(size_t size) {
+ if ((data_size_ + size) > data_size_max_) {
+ // TODO(petewarden): Add error reporting beyond returning null!
+ return nullptr;
+ }
+ uint8_t* result = data_;
+ data_ += size;
+ data_size_ += size;
+ return result;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h b/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h
new file mode 100644
index 0000000000..4f16a9d0e5
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h
@@ -0,0 +1,51 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_SIMPLE_TENSOR_ALLOCATOR_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_SIMPLE_TENSOR_ALLOCATOR_H_
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+
+namespace tflite {
+
+// TODO(petewarden): This allocator never frees up or reuses any memory, even
+// though we have enough information about lifetimes of the tensors to do so.
+// This makes it pretty wasteful, so we should use a more intelligent method.
+class SimpleTensorAllocator {
+ public:
+ SimpleTensorAllocator(uint8_t* buffer, int buffer_size)
+ : data_size_(0), data_size_max_(buffer_size), data_(buffer) {}
+
+ TfLiteStatus AllocateTensor(
+ const tflite::Tensor& flatbuffer_tensor, int create_before,
+ int destroy_after,
+ const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
+ ErrorReporter* error_reporter, TfLiteTensor* result);
+
+ uint8_t* AllocateMemory(size_t size);
+
+ int GetDataSize() const { return data_size_; }
+
+ private:
+ int data_size_;
+ int data_size_max_;
+ uint8_t* data_;
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_SIMPLE_TENSOR_ALLOCATOR_H_
diff --git a/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator_test.cc b/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator_test.cc
new file mode 100644
index 0000000000..c835427243
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator_test.cc
@@ -0,0 +1,144 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/experimental/micro/micro_interpreter.h"
+
+#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h"
+
+namespace tflite {
+namespace {
+class StackAllocator : public flatbuffers::Allocator {
+ public:
+ StackAllocator() : data_(data_backing_), data_size_(0) {}
+
+ uint8_t* allocate(size_t size) override {
+ if ((data_size_ + size) > kStackAllocatorSize) {
+ // TODO(petewarden): Add error reporting beyond returning null!
+ return nullptr;
+ }
+ uint8_t* result = data_;
+ data_ += size;
+ data_size_ += size;
+ return result;
+ }
+
+ void deallocate(uint8_t* p, size_t) override {}
+
+ static StackAllocator& instance() {
+ // Avoid using true dynamic memory allocation to be portable to bare metal.
+ static char inst_memory[sizeof(StackAllocator)];
+ static StackAllocator* inst = new (inst_memory) StackAllocator;
+ return *inst;
+ }
+
+ static constexpr int kStackAllocatorSize = 4096;
+
+ private:
+ uint8_t data_backing_[kStackAllocatorSize];
+ uint8_t* data_;
+ int data_size_;
+};
+
+flatbuffers::FlatBufferBuilder* BuilderInstance() {
+ static char inst_memory[sizeof(flatbuffers::FlatBufferBuilder)];
+ static flatbuffers::FlatBufferBuilder* inst =
+ new (inst_memory) flatbuffers::FlatBufferBuilder(
+ StackAllocator::kStackAllocatorSize, &StackAllocator::instance());
+ return inst;
+}
+
+const Tensor* Create1dTensor(int size) {
+ using flatbuffers::Offset;
+ flatbuffers::FlatBufferBuilder* builder = BuilderInstance();
+ constexpr size_t tensor_shape_size = 1;
+ const int32_t tensor_shape[tensor_shape_size] = {size};
+ const Offset<Tensor> tensor_offset = CreateTensor(
+ *builder, builder->CreateVector(tensor_shape, tensor_shape_size),
+ TensorType_INT32, 0, builder->CreateString("test_tensor"), 0, false);
+ builder->Finish(tensor_offset);
+ void* tensor_pointer = builder->GetBufferPointer();
+ const Tensor* tensor = flatbuffers::GetRoot<Tensor>(tensor_pointer);
+ return tensor;
+}
+
+const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* CreateBuffers() {
+ using flatbuffers::Offset;
+ flatbuffers::FlatBufferBuilder* builder = BuilderInstance();
+ constexpr size_t buffers_size = 1;
+ const Offset<Buffer> buffers[buffers_size] = {
+ CreateBuffer(*builder),
+ };
+ const flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
+ buffers_offset = builder->CreateVector(buffers, buffers_size);
+ builder->Finish(buffers_offset);
+ void* buffers_pointer = builder->GetBufferPointer();
+ const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* result =
+ flatbuffers::GetRoot<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>(
+ buffers_pointer);
+ return result;
+}
+
+} // namespace
+} // namespace tflite
+
+TF_LITE_MICRO_TESTS_BEGIN
+
+TF_LITE_MICRO_TEST(TestAllocateTensor) {
+ constexpr size_t arena_size = 1024;
+ uint8_t arena[arena_size];
+ tflite::SimpleTensorAllocator allocator(arena, arena_size);
+
+ const tflite::Tensor* tensor = tflite::Create1dTensor(100);
+ const flatbuffers::Vector<flatbuffers::Offset<tflite::Buffer>>* buffers =
+ tflite::CreateBuffers();
+
+ TfLiteTensor allocated_tensor;
+ TF_LITE_MICRO_EXPECT_EQ(
+ kTfLiteOk,
+ allocator.AllocateTensor(*tensor, 0, 1, buffers, micro_test::reporter,
+ &allocated_tensor));
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt32, allocated_tensor.type);
+ TF_LITE_MICRO_EXPECT_EQ(1, allocated_tensor.dims->size);
+ TF_LITE_MICRO_EXPECT_EQ(100, allocated_tensor.dims->data[0]);
+ TF_LITE_MICRO_EXPECT_EQ(400, allocated_tensor.bytes);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, allocated_tensor.data.i32);
+}
+
+TF_LITE_MICRO_TEST(TestTooLarge) {
+ constexpr size_t arena_size = 1024;
+ uint8_t arena[arena_size];
+ tflite::SimpleTensorAllocator allocator(arena, arena_size);
+
+ const tflite::Tensor* tensor = tflite::Create1dTensor(10000);
+ const flatbuffers::Vector<flatbuffers::Offset<tflite::Buffer>>* buffers =
+ tflite::CreateBuffers();
+
+ TfLiteTensor allocated_tensor;
+ TF_LITE_MICRO_EXPECT_NE(
+ kTfLiteOk,
+ allocator.AllocateTensor(*tensor, 0, 1, buffers, micro_test::reporter,
+ &allocated_tensor));
+}
+
+TF_LITE_MICRO_TEST(TestJustFits) {
+ constexpr size_t arena_size = 1024;
+ uint8_t arena[arena_size];
+ tflite::SimpleTensorAllocator allocator(arena, arena_size);
+
+ uint8_t* result = allocator.AllocateMemory(arena_size);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, result);
+}
+
+TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/contrib/lite/experimental/micro/testing/BUILD b/tensorflow/contrib/lite/experimental/micro/testing/BUILD
new file mode 100644
index 0000000000..0d23be5712
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/testing/BUILD
@@ -0,0 +1,17 @@
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["test_linux_binary.sh"])
+
+cc_library(
+ name = "micro_test",
+ hdrs = [
+ "micro_test.h",
+ ],
+ deps = [
+ "//tensorflow/contrib/lite/experimental/micro:micro_framework",
+ ],
+)
diff --git a/tensorflow/contrib/lite/experimental/micro/testing/Dockerfile.bluepill b/tensorflow/contrib/lite/experimental/micro/testing/Dockerfile.bluepill
new file mode 100644
index 0000000000..7d6d81af0f
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/testing/Dockerfile.bluepill
@@ -0,0 +1,21 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# This docker configuration file lets you emulate a Blue Pill board
+# on an x86 desktop or laptop, which can be useful for debugging and
+# automated testing.
+FROM antmicro/renode:latest
+
+LABEL maintainer="Pete Warden <petewarden@google.com>" \ No newline at end of file
diff --git a/tensorflow/contrib/lite/experimental/micro/testing/bluepill.resc b/tensorflow/contrib/lite/experimental/micro/testing/bluepill.resc
new file mode 100644
index 0000000000..9333dc42bf
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/testing/bluepill.resc
@@ -0,0 +1,36 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+using sysbus
+
+mach create
+machine LoadPlatformDescription @platforms/cpus/stm32f103.repl
+
+# These lines are needed to show the results of DebugLog calls in the output.
+machine LoadPlatformDescriptionFromString "uartSemihosting: UART.SemihostingUart @ cpu"
+showAnalyzer cpu.uartSemihosting Antmicro.Renode.Analyzers.LoggingUartAnalyzer
+
+logFile @/tmp/renode_bluepill_log.txt
+
+macro reset
+"""
+ sysbus LoadELF $bin
+"""
+
+runMacro $reset
+
+emulation RunFor @1
+
+quit \ No newline at end of file
diff --git a/tensorflow/contrib/lite/experimental/micro/testing/micro_test.bzl b/tensorflow/contrib/lite/experimental/micro/testing/micro_test.bzl
new file mode 100644
index 0000000000..916e3eeac3
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/testing/micro_test.bzl
@@ -0,0 +1,67 @@
+"""Rules for simple testing without dependencies by parsing output logs."""
+
+def tflite_micro_cc_test(
+ name,
+ expected_in_logs = "~~~ALL TESTS PASSED~~~",
+ srcs = [],
+ includes = [],
+ defines = [],
+ copts = [],
+ nocopts = "",
+ linkopts = [],
+ deps = [],
+ tags = [],
+ visibility = None):
+ """Tests a C/C++ binary without testing framework dependencies`.
+
+ Runs a C++ binary, and tests that the output logs contain the
+ expected value. This is a deliberately spartan way of testing, to match
+ what's available when testing microcontroller binaries.
+
+ Args:
+ name: a unique name for this rule.
+ expected_in_logs: A regular expression that is required to be
+ present in the binary's logs for the test to pass.
+ srcs: sources to compile (C, C++, ld scripts).
+ includes: include paths to add to this rule and its dependents.
+ defines: list of `VAR` or `VAR=VAL` to pass to CPP for this rule and
+ its dependents.
+ copts: gcc compilation flags for this rule only.
+ nocopts: list of gcc compilation flags to remove for this rule
+ only. No regexp like for `cc_library`.
+ linkopts: `gcc` flags to add to the linking phase. For "pure" ld flags,
+ prefix them with the `-Wl,` prefix here.
+ deps: dependencies. only `tflite_bare_metal_cc_library()` dependencies
+ allowed.
+ visibility: visibility.
+ """
+ native.cc_binary(
+ name = name + "_binary",
+ srcs = srcs,
+ includes = includes,
+ defines = defines,
+ copts = copts,
+ nocopts = nocopts,
+ linkopts = linkopts,
+ deps = deps,
+ tags = tags,
+ visibility = visibility,
+ )
+ native.sh_test(
+ name = name,
+ size = "medium",
+ srcs = [
+ "//tensorflow/contrib/lite/experimental/micro/testing:test_linux_binary.sh",
+ ],
+ args = [
+ native.package_name() + "/" + name + "_binary",
+ "'" + expected_in_logs + "'",
+ ],
+ data = [
+ name + "_binary",
+ # Internal test dependency placeholder
+ ],
+ deps = [
+ ],
+ tags = tags,
+ )
diff --git a/tensorflow/contrib/lite/experimental/micro/testing/micro_test.h b/tensorflow/contrib/lite/experimental/micro/testing/micro_test.h
new file mode 100644
index 0000000000..104509c9dc
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/testing/micro_test.h
@@ -0,0 +1,138 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// An ultra-lightweight testing framework designed for use with microcontroller
+// applications. Its only dependency is on TensorFlow Lite's ErrorReporter
+// interface, where log messages are output. This is designed to be usable even
+// when no standard C or C++ libraries are available, and without any dynamic
+// memory allocation or reliance on global constructors.
+//
+// To build a test, you use syntax similar to gunit, but with some extra
+// decoration to create a hidden 'main' function containing each of the tests to
+// be run. Your code should look something like:
+// ----------------------------------------------------------------------------
+// #include "path/to/this/header"
+//
+// TF_LITE_MICRO_TESTS_BEGIN
+//
+// TF_LITE_MICRO_TEST(SomeTest) {
+// TF_LITE_LOG_EXPECT_EQ(true, true);
+// }
+//
+// TF_LITE_MICRO_TESTS_END
+// ----------------------------------------------------------------------------
+// If you compile this for your platform, you'll get a normal binary that you
+// should be able to run. Executing it will output logging information like this
+// to stderr (or whatever equivalent is available and written to by
+// ErrorReporter):
+// ----------------------------------------------------------------------------
+// Testing SomeTest
+// 1/1 tests passed
+// ~~~ALL TESTS PASSED~~~
+// ----------------------------------------------------------------------------
+// This is designed to be human-readable, so you can just run tests manually,
+// but the string "~~~ALL TESTS PASSED~~~" should only appear if all of the
+// tests do pass. This makes it possible to integrate with automated test
+// systems by scanning the output logs and looking for that magic value.
+//
+// This framework is intended to be a rudimentary alternative to no testing at
+// all on systems that struggle to run more conventional approaches, so use with
+// caution!
+
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_TESTING_MICRO_TEST_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_TESTING_MICRO_TEST_H_
+
+#include "tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h"
+
+namespace micro_test {
+extern int tests_passed;
+extern int tests_failed;
+extern bool is_test_complete;
+extern bool did_test_fail;
+extern tflite::ErrorReporter* reporter;
+} // namespace micro_test
+
+#define TF_LITE_MICRO_TESTS_BEGIN \
+ namespace micro_test { \
+ int tests_passed; \
+ int tests_failed; \
+ bool is_test_complete; \
+ bool did_test_fail; \
+ tflite::ErrorReporter* reporter; \
+ } \
+ \
+ int main(int argc, char** argv) { \
+ micro_test::tests_passed = 0; \
+ micro_test::tests_failed = 0; \
+ tflite::MicroErrorReporter error_reporter; \
+ micro_test::reporter = &error_reporter;
+
+#define TF_LITE_MICRO_TESTS_END \
+ micro_test::reporter->Report( \
+ "%d/%d tests passed", micro_test::tests_passed, \
+ (micro_test::tests_failed + micro_test::tests_passed)); \
+ if (micro_test::tests_failed == 0) { \
+ micro_test::reporter->Report("~~~ALL TESTS PASSED~~~\n"); \
+ } else { \
+ micro_test::reporter->Report("~~~SOME TESTS FAILED~~~\n"); \
+ } \
+ }
+
+// TODO(petewarden): I'm going to hell for what I'm doing to this poor for loop.
+#define TF_LITE_MICRO_TEST(name) \
+ micro_test::reporter->Report("Testing %s", #name); \
+ for (micro_test::is_test_complete = false, \
+ micro_test::did_test_fail = false; \
+ !micro_test::is_test_complete; micro_test::is_test_complete = true, \
+ micro_test::tests_passed += (micro_test::did_test_fail) ? 0 : 1, \
+ micro_test::tests_failed += (micro_test::did_test_fail) ? 1 : 0)
+
+#define TF_LITE_MICRO_EXPECT(x) \
+ do { \
+ if (!(x)) { \
+ micro_test::reporter->Report(#x " failed at %s:%d", __FILE__, __LINE__); \
+ micro_test::did_test_fail = true; \
+ } \
+ } while (false)
+
+#define TF_LITE_MICRO_EXPECT_EQ(x, y) \
+ do { \
+ if ((x) != (y)) { \
+ micro_test::reporter->Report(#x " == " #y " failed at %s:%d", __FILE__, \
+ __LINE__); \
+ micro_test::did_test_fail = true; \
+ } \
+ } while (false)
+
+#define TF_LITE_MICRO_EXPECT_NE(x, y) \
+ do { \
+ if ((x) == (y)) { \
+ micro_test::reporter->Report(#x " != " #y " failed at %s:%d", __FILE__, \
+ __LINE__); \
+ micro_test::did_test_fail = true; \
+ } \
+ } while (false)
+
+#define TF_LITE_MICRO_EXPECT_NEAR(x, y, epsilon) \
+ do { \
+ auto delta = ((x) > (y)) ? ((x) - (y)) : ((y) - (x)); \
+ if (delta > epsilon) { \
+ micro_test::reporter->Report(#x " near " #y " failed at %s:%d", \
+ __FILE__, __LINE__); \
+ micro_test::did_test_fail = true; \
+ } \
+ } while (false)
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_TESTING_MICRO_TEST_H_
diff --git a/tensorflow/contrib/lite/experimental/micro/testing/test_bluepill_binary.sh b/tensorflow/contrib/lite/experimental/micro/testing/test_bluepill_binary.sh
new file mode 100755
index 0000000000..07742a8262
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/testing/test_bluepill_binary.sh
@@ -0,0 +1,54 @@
+#!/bin/bash -e
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Tests a 'bluepill' STM32F103 ELF by parsing the log output of Renode emulation.
+#
+# First argument is the ELF location.
+# Second argument is a regular expression that's required to be in the output logs
+# for the test to pass.
+
+declare -r ROOT_DIR=`pwd`
+declare -r TEST_TMPDIR=/tmp/test_bluepill_binary/
+declare -r MICRO_LOG_PATH=${TEST_TMPDIR}
+declare -r MICRO_LOG_FILENAME=${MICRO_LOG_PATH}/logs.txt
+mkdir -p ${MICRO_LOG_PATH}
+
+docker build -t renode_bluepill \
+ -f ${ROOT_DIR}/tensorflow/contrib/lite/experimental/micro/testing/Dockerfile.bluepill \
+ ${ROOT_DIR}/tensorflow/contrib/lite/experimental/micro/testing/
+
+docker run \
+ --log-driver=none -a stdout -a stderr \
+ -v ${ROOT_DIR}:/workspace \
+ -v /tmp:/tmp \
+ -it renode_bluepill \
+ /bin/bash -c "renode -P 5000 --disable-xwt -e '
+\$bin?=@/workspace/$1
+s @/workspace/tensorflow/contrib/lite/experimental/micro/testing/bluepill.resc
+' 2>&1 >${MICRO_LOG_FILENAME}"
+
+echo "LOGS:"
+cat ${MICRO_LOG_FILENAME}
+
+if grep -q "$2" ${MICRO_LOG_FILENAME}
+then
+ echo "$1: PASS"
+ exit 0
+else
+ echo "$1: FAIL - '$2' not found in logs."
+ exit 1
+fi
+
diff --git a/tensorflow/contrib/lite/experimental/micro/testing/test_linux_binary.sh b/tensorflow/contrib/lite/experimental/micro/testing/test_linux_binary.sh
new file mode 100755
index 0000000000..24131a6d2d
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/testing/test_linux_binary.sh
@@ -0,0 +1,39 @@
+#!/bin/bash -e
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Tests a Linux binary by parsing the log output.
+#
+# First argument is the binary location.
+# Second argument is a regular expression that's required to be in the output logs
+# for the test to pass.
+
+declare -r ROOT_DIR=`pwd`
+declare -r TEST_TMPDIR=/tmp/test_bluepill_binary/
+declare -r MICRO_LOG_PATH=${TEST_TMPDIR}/$1
+declare -r MICRO_LOG_FILENAME=${MICRO_LOG_PATH}/logs.txt
+mkdir -p ${MICRO_LOG_PATH}
+
+$1 2>&1 | tee ${MICRO_LOG_FILENAME}
+
+if grep -q "$2" ${MICRO_LOG_FILENAME}
+then
+ echo "$1: PASS"
+ exit 0
+else
+ echo "$1: FAIL - '$2' not found in logs."
+ exit 1
+fi
+
diff --git a/tensorflow/contrib/lite/experimental/micro/tools/make/Makefile b/tensorflow/contrib/lite/experimental/micro/tools/make/Makefile
new file mode 100644
index 0000000000..880bb4763c
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/tools/make/Makefile
@@ -0,0 +1,166 @@
+MAKEFILE_DIR := tensorflow/contrib/lite/experimental/micro/tools/make
+
+# Try to figure out the host system
+HOST_OS :=
+ifeq ($(OS),Windows_NT)
+ HOST_OS = windows
+else
+ UNAME_S := $(shell uname -s)
+ ifeq ($(UNAME_S),Linux)
+ HOST_OS := linux
+ endif
+ ifeq ($(UNAME_S),Darwin)
+ HOST_OS := osx
+ endif
+endif
+
+HOST_ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32; else echo $(shell uname -m); fi)
+
+# Override these on the make command line to target a specific architecture. For example:
+# make -f tensorflow/contrib/lite/Makefile TARGET=rpi TARGET_ARCH=armv7l
+TARGET := $(HOST_OS)
+TARGET_ARCH := $(HOST_ARCH)
+
+INCLUDES := \
+-I. \
+-I$(MAKEFILE_DIR)/../../../../../ \
+-I$(MAKEFILE_DIR)/../../../../../../ \
+-I$(MAKEFILE_DIR)/downloads/ \
+-I$(MAKEFILE_DIR)/downloads/gemmlowp \
+-I$(MAKEFILE_DIR)/downloads/flatbuffers/include \
+-I$(OBJDIR)
+# This is at the end so any globally-installed frameworks like protobuf don't
+# override local versions in the source tree.
+INCLUDES += -I/usr/local/include
+
+TEST_SCRIPT := tensorflow/contrib/lite/experimental/micro/testing/test_linux_binary.sh
+
+MICROLITE_LIBS := -lm
+
+# There are no rules for compiling objects for the host system (since we don't
+# generate things like the protobuf compiler that require that), so all of
+# these settings are for the target compiler.
+CXXFLAGS := -O3 -DNDEBUG
+CXXFLAGS += --std=c++11 -g -DTF_LITE_STATIC_MEMORY
+CCFLAGS := -DNDEBUG -g -DTF_LITE_STATIC_MEMORY
+LDOPTS := -L/usr/local/lib
+ARFLAGS := -r
+TARGET_TOOLCHAIN_PREFIX :=
+CC_PREFIX :=
+
+# This library is the main target for this makefile. It will contain a minimal
+# runtime that can be linked in to other programs.
+MICROLITE_LIB_NAME := libtensorflow-microlite.a
+
+# Test binary for the microcontroller speech model.
+MICRO_SPEECH_TEST_SRCS := \
+tensorflow/contrib/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc \
+tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc
+
+MICROLITE_TEST_SRCS := \
+$(wildcard tensorflow/contrib/lite/experimental/micro/*test.cc) \
+$(wildcard tensorflow/contrib/lite/experimental/micro/kernels/*test.cc)
+
+MICROLITE_CC_BASE_SRCS := \
+$(wildcard tensorflow/contrib/lite/experimental/micro/*.cc) \
+$(wildcard tensorflow/contrib/lite/experimental/micro/kernels/*.cc) \
+tensorflow/contrib/lite/c/c_api_internal.c \
+tensorflow/contrib/lite/core/api/error_reporter.cc \
+tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc \
+tensorflow/contrib/lite/core/api/op_resolver.cc \
+tensorflow/contrib/lite/kernels/kernel_util.cc \
+tensorflow/contrib/lite/kernels/internal/quantization_util.cc
+MICROLITE_CC_SRCS := $(filter-out $(MICROLITE_TEST_SRCS), $(MICROLITE_CC_BASE_SRCS))
+
+# These target-specific makefiles should modify or replace options like
+# CXXFLAGS or LIBS to work for a specific targetted architecture. All logic
+# based on platforms or architectures should happen within these files, to
+# keep this main makefile focused on the sources and dependencies.
+include $(wildcard $(MAKEFILE_DIR)/targets/*_makefile.inc)
+
+ALL_SRCS := \
+ $(MICRO_SPEECH_TEST_SRCS) \
+ $(MICROLITE_CC_SRCS) \
+ $(MICROLITE_TEST_SRCS)
+
+# Where compiled objects are stored.
+GENDIR := $(MAKEFILE_DIR)/gen/$(TARGET)_$(TARGET_ARCH)/
+OBJDIR := $(GENDIR)obj/
+BINDIR := $(GENDIR)bin/
+LIBDIR := $(GENDIR)lib/
+
+MICROLITE_LIB_PATH := $(LIBDIR)$(MICROLITE_LIB_NAME)
+
+MICRO_SPEECH_TEST_BINARY := $(BINDIR)micro_speech_test
+
+CXX := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}g++
+CC := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}gcc
+AR := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}ar
+
+MICRO_SPEECH_TEST_OBJS := $(addprefix $(OBJDIR), \
+$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MICRO_SPEECH_TEST_SRCS))))
+
+MICROLITE_LIB_OBJS := $(addprefix $(OBJDIR), \
+$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MICROLITE_CC_SRCS))))
+
+MICROLITE_TEST_TARGETS := $(addprefix $(BINDIR), \
+$(patsubst %_test.cc,%.test_target,$(MICROLITE_TEST_SRCS)))
+
+# For normal manually-created TensorFlow C++ source files.
+$(OBJDIR)%.o: %.cc
+ @mkdir -p $(dir $@)
+ $(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@
+
+# For normal manually-created TensorFlow C source files.
+$(OBJDIR)%.o: %.c
+ @mkdir -p $(dir $@)
+ $(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@
+
+# The target that's compiled if there's no command-line arguments.
+all: $(MICROLITE_LIB_PATH) $(MICRO_SPEECH_TEST_BINARY)
+
+microlite: $(MICROLITE_LIB_PATH)
+
+# Hack for generating schema file bypassing flatbuffer parsing
+tensorflow/contrib/lite/schema/schema_generated.h:
+ @cp -u tensorflow/contrib/lite/schema/schema_generated.h.OPENSOURCE tensorflow/contrib/lite/schema/schema_generated.h
+
+# Gathers together all the objects we've compiled into a single '.a' archive.
+$(MICROLITE_LIB_PATH): tensorflow/contrib/lite/schema/schema_generated.h $(MICROLITE_LIB_OBJS)
+ @mkdir -p $(dir $@)
+ $(AR) $(ARFLAGS) $(MICROLITE_LIB_PATH) $(MICROLITE_LIB_OBJS)
+
+$(MICRO_SPEECH_TEST_BINARY): $(MICRO_SPEECH_TEST_OBJS) $(MICROLITE_LIB_PATH)
+ @mkdir -p $(dir $@)
+ $(CXX) $(CXXFLAGS) $(INCLUDES) \
+ -o $(MICRO_SPEECH_TEST_BINARY) $(MICRO_SPEECH_TEST_OBJS) \
+ $(LIBFLAGS) $(MICROLITE_LIB_PATH) $(LDFLAGS) $(MICROLITE_LIBS)
+
+micro_speech_test: $(MICRO_SPEECH_TEST_BINARY)
+micro_speech_test_bin: $(MICRO_SPEECH_TEST_BINARY).bin
+
+test_micro_speech: $(MICRO_SPEECH_TEST_BINARY)
+ $(TEST_SCRIPT) $(MICRO_SPEECH_TEST_BINARY) '~~~ALL TESTS PASSED~~~'
+
+$(BINDIR)%_test : $(OBJDIR)%_test.o $(MICROLITE_LIB_PATH)
+ @mkdir -p $(dir $@)
+ $(CXX) $(CXXFLAGS) $(INCLUDES) \
+ -o $@ $< \
+ $(LIBFLAGS) $(MICROLITE_LIB_PATH) $(LDFLAGS) $(MICROLITE_LIBS)
+
+$(BINDIR)%.test_target: $(BINDIR)%_test
+ $(TEST_SCRIPT) $< '~~~ALL TESTS PASSED~~~'
+
+$(info $(MICROLITE_TEST_TARGETS))
+
+test: test_micro_speech $(MICROLITE_TEST_TARGETS)
+
+# Gets rid of all generated files.
+clean:
+ rm -rf $(MAKEFILE_DIR)/gen
+
+$(DEPDIR)/%.d: ;
+.PRECIOUS: $(DEPDIR)/%.d
+.PRECIOUS: $(BINDIR)%_test
+
+-include $(patsubst %,$(DEPDIR)/%.d,$(basename $(ALL_SRCS)))
diff --git a/tensorflow/contrib/lite/experimental/micro/tools/make/download_dependencies.sh b/tensorflow/contrib/lite/experimental/micro/tools/make/download_dependencies.sh
new file mode 100755
index 0000000000..4c2ff8545d
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/tools/make/download_dependencies.sh
@@ -0,0 +1,73 @@
+#!/bin/bash
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+set -e
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+cd "$SCRIPT_DIR/../../../../../../.."
+
+DOWNLOADS_DIR=tensorflow/contrib/lite/experimental/micro/tools/make/downloads
+BZL_FILE_PATH=tensorflow/workspace.bzl
+
+# Ensure it is being run from repo root
+if [ ! -f $BZL_FILE_PATH ]; then
+ echo "Could not find ${BZL_FILE_PATH}":
+ echo "Likely you are not running this from the root directory of the repository.";
+ exit 1;
+fi
+
+GEMMLOWP_URL="https://github.com/google/gemmlowp/archive/719139ce755a0f31cbf1c37f7f98adcc7fc9f425.zip"
+FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/1f5eae5d6a135ff6811724f6c57f911d1f46bb15.tar.gz"
+CMSIS_URL="https://github.com/ARM-software/CMSIS_5/archive/5.4.0.zip"
+STM32_BARE_LIB_URL="https://github.com/google/stm32_bare_lib/archive/50e0da307a2821bb54af1f57b969e6b76cb89d32.zip"
+
+download_and_extract() {
+ local usage="Usage: download_and_extract URL DIR"
+ local url="${1:?${usage}}"
+ local dir="${2:?${usage}}"
+ echo "downloading ${url}" >&2
+ mkdir -p "${dir}"
+ if [[ "${url}" == *gz ]]; then
+ curl -Ls "${url}" | tar -C "${dir}" --strip-components=1 -xz
+ elif [[ "${url}" == *zip ]]; then
+ tempdir=$(mktemp -d)
+ tempdir2=$(mktemp -d)
+
+ curl -L ${url} > ${tempdir}/zipped.zip
+ unzip ${tempdir}/zipped.zip -d ${tempdir2}
+
+ # If the zip file contains nested directories, extract the files from the
+ # inner directory.
+ if ls ${tempdir2}/*/* 1> /dev/null 2>&1; then
+ # unzip has no strip components, so unzip to a temp dir, and move the
+ # files we want from the tempdir to destination.
+ cp -R ${tempdir2}/*/* ${dir}/
+ else
+ cp -R ${tempdir2}/* ${dir}/
+ fi
+ rm -rf ${tempdir2} ${tempdir}
+ fi
+
+ # Delete any potential BUILD files, which would interfere with Bazel builds.
+ find "${dir}" -type f -name '*BUILD' -delete
+}
+
+download_and_extract "${GEMMLOWP_URL}" "${DOWNLOADS_DIR}/gemmlowp"
+download_and_extract "${FLATBUFFERS_URL}" "${DOWNLOADS_DIR}/flatbuffers"
+download_and_extract "${CMSIS_URL}" "${DOWNLOADS_DIR}/cmsis"
+download_and_extract "${STM32_BARE_LIB_URL}" "${DOWNLOADS_DIR}/stm32_bare_lib"
+
+echo "download_dependencies.sh completed successfully." >&2
diff --git a/tensorflow/contrib/lite/experimental/micro/tools/make/targets/bluepill_makefile.inc b/tensorflow/contrib/lite/experimental/micro/tools/make/targets/bluepill_makefile.inc
new file mode 100644
index 0000000000..022a8422dc
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/tools/make/targets/bluepill_makefile.inc
@@ -0,0 +1,65 @@
+# Settings for Blue Pill platforms.
+ifeq ($(TARGET), bluepill)
+ TARGET_ARCH := cortex-m3
+ TARGET_TOOLCHAIN_PREFIX := arm-none-eabi-
+
+ PLATFORM_FLAGS = \
+ -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \
+ -DTF_LITE_STATIC_MEMORY \
+ -DTF_LITE_MCU_DEBUG_LOG \
+ -fno-rtti \
+ -fmessage-length=0 \
+ -fno-exceptions \
+ -fno-unwind-tables \
+ -fno-builtin \
+ -ffunction-sections \
+ -fdata-sections \
+ -funsigned-char \
+ -MMD \
+ -mcpu=cortex-m3 \
+ -mthumb \
+ -std=gnu++11 \
+ -Wvla \
+ -Wall \
+ -Wextra \
+ -Wno-unused-parameter \
+ -Wno-missing-field-initializers \
+ -Wno-write-strings \
+ -Wno-sign-compare \
+ -fno-delete-null-pointer-checks \
+ -fomit-frame-pointer \
+ -fpermissive \
+ -nostdlib \
+ -g \
+ -Os
+ CXXFLAGS += $(PLATFORM_FLAGS)
+ CCFLAGS += $(PLATFORM_FLAGS)
+ LDFLAGS += \
+ -T $(MAKEFILE_DIR)/downloads/stm32_bare_lib/stm32_linker_layout.lds \
+ -Wl,-Map=$(MAKEFILE_DIR)/gen/$(TARGET).map,--cref \
+ -Wl,--gc-sections
+ BUILD_TYPE := micro
+ MICROLITE_LIBS := \
+ -lm
+ INCLUDES += \
+ -isystem$(MAKEFILE_DIR)/downloads/cmsis/CMSIS/Core/Include/ \
+ -I$(MAKEFILE_DIR)/downloads/stm32_bare_lib/include
+ MICROLITE_CC_SRCS += \
+ $(wildcard $(MAKEFILE_DIR)/downloads/stm32_bare_lib/source/*.c) \
+ $(wildcard $(MAKEFILE_DIR)/downloads/stm32_bare_lib/source/*.cc)
+ TEST_SCRIPT := tensorflow/contrib/lite/experimental/micro/testing/test_bluepill_binary.sh
+ # These are tests that don't currently work on the blue pill.
+ EXCLUDED_TESTS := \
+ tensorflow/contrib/lite/experimental/micro/micro_interpreter_test.cc \
+ tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator_test.cc
+ MICROLITE_TEST_SRCS := $(filter-out $(EXCLUDED_TESTS), $(MICROLITE_TEST_SRCS))
+
+# These are microcontroller-specific rules for converting the ELF output
+# of the linker into a binary image that can be loaded directly.
+OBJCOPY := $(TARGET_TOOLCHAIN_PREFIX)objcopy
+
+$(BINDIR)/%.bin: $(BINDIR)/%
+ @mkdir -p $(dir $@)
+ $(OBJCOPY) $< $@ -O binary
+
+endif \ No newline at end of file
diff --git a/tensorflow/contrib/lite/g3doc/performance.md b/tensorflow/contrib/lite/g3doc/performance.md
index 0ae9400068..6b7943caf8 100644
--- a/tensorflow/contrib/lite/g3doc/performance.md
+++ b/tensorflow/contrib/lite/g3doc/performance.md
@@ -7,12 +7,12 @@ Mobile and embedded devices have limited computational resources and it is impor
Some models may be too large to run on embedded devices. Instead of large models it is better to use a slightly less precise but smaller model for embedded devices. Smaller models not only use less disk space and memory but are generally faster and more energy efficient. One example of models optimized for mobile devices are [MobileNets](https://arxiv.org/abs/1704.04861), which are optimized for mobile vision applications. Tensorflow Lite [models page](models.md) lists several other models that have been optimized specifically for mobile and embedded devices.
You can retrain the listed models on your own dataset by using transfer learning. Check out our transfer learning tutorial for
-[image classification] (https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/#0) and
+[image classification](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/#0) and
[object detection](https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193).
## Profile your model
-Before starting any optimization, it is a good practice to profile and benchmark your model. Tensorflow Lite [benchmarking tool](../tools/benchmark) has a built-in profiler that shows per operator profiling statistics. This can help in understanding performance bottlenecks and which operators dominate the computation time.
+Before starting any optimization, it is a good practice to profile and benchmark your model. Tensorflow Lite [benchmarking tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark) has a built-in profiler that shows per operator profiling statistics. This can help in understanding performance bottlenecks and which operators dominate the computation time.
## Profile and optimize operators in the graph
If a particular operator appears frequently in the model and based on profiling you find the operator consuming the most amount of time, you can look into optimizing the operator.
@@ -22,7 +22,7 @@ If a particular operator appears frequently in the model and based on profiling
If your model uses floating point weights or activations then it may be possible to reduce the size of model up to ~4x by using quantization and other model optimizations. Check out our [model optimization toolkit](https://www.tensorflow.org/performance/model_optimization) for details about optimizing your model. Fully quantized models can be remarkably power efficient as well.
## Tweak the number of threads
-Tensorflow Lite supports multi-threaded kernels for many operators. You can increase the number of threads and speed up execution of operators. Increasing the number of threads will however make your model use more resources and power. For some applications latency may be more important than energy efficiency. You can increase the number of threads by setting the number of [interpreter](../interpreter.h) threads.
+Tensorflow Lite supports multi-threaded kernels for many operators. You can increase the number of threads and speed up execution of operators. Increasing the number of threads will however make your model use more resources and power. For some applications latency may be more important than energy efficiency. You can increase the number of threads by setting the number of [interpreter](https://github.com/tensorflow/tensorflow/blob/1084594657a5d139102ac794f84d1427a710e39a/tensorflow/contrib/lite/interpreter.h#L337) threads.
## Eliminate redundant copies
Tensorflow Lite is optimized to reduce redundant copies. The APIs allow user to [mmap a model file](https://github.com/tensorflow/tensorflow/blob/9982fd6c8831cbd2f58954f79ea71f26660393bc/tensorflow/contrib/lite/model.h#L152) and avoid copies. If your application is not careful, there can be redundant copies when feeding the input to the model and reading output from the model. Make sure to eliminate redundant copies. If you are using higher level APIs like Java API, make sure to carefully check the documentation for performance caveats. For example, the Java API is a lot faster if ByteBuffers are used as [inputs](https://github.com/tensorflow/tensorflow/blob/6305a6d83552ba6a472cd72398b60d9241467f1f/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java#L151).
@@ -31,8 +31,8 @@ Tensorflow Lite is optimized to reduce redundant copies. The APIs allow user to
Platform specific tools like [Android profiler](https://developer.android.com/studio/profile/android-profiler) and [Instruments](https://help.apple.com/instruments/mac/current/) provide a wealth of profiling information that can be used to debug your app. Sometimes the performance bug may be not in the model but in parts of application code that interact with the model. Make sure to familiarize yourself with platform specific profiling tools and best practices for your platform.
## Use hardware accelerators available on the device
-Tensorflow Lite is working on adding support for accelerators like GPU and provides acceleration through [NNAPI](https://developer.android.com/ndk/guides/neuralnetworks/) on Android.
-You can utilize these hardware accelerator backends to improve the speed and efficiency of your model. To enable NNAPI call [UseNNAPI](https://github.com/tensorflow/tensorflow/blob/6305a6d83552ba6a472cd72398b60d9241467f1f/tensorflow/contrib/lite/interpreter.h#L334) on the interpreter instance.
+Tensorflow Lite is working on adding support for accelerators like GPU and provides acceleration through [Neural Networks API](https://developer.android.com/ndk/guides/neuralnetworks/) on Android.
+You can utilize these hardware accelerator backends to improve the speed and efficiency of your model. To enable Neural Networks API call [UseNNAPI](https://github.com/tensorflow/tensorflow/blob/6305a6d83552ba6a472cd72398b60d9241467f1f/tensorflow/contrib/lite/interpreter.h#L334) on the interpreter instance.
## Need more help
The Tensorflow team is happy to help diagnose and address specific performance issues you may be facing. Please file a bug on [github](https://github.com/tensorflow/tensorflow/issues) with details of the issue.
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
index b0f32a8d6c..2eb776d10c 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
@@ -1,6 +1,22 @@
-
# Building TensorFlow on Android
+Warning: We expect to deprecate TensorFlow Mobile in early 2019
+
+<div class="caution">
+ <p>
+ <a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
+ working hard to close the feature gap between TensorFlow Mobile and
+ TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
+ will give ample notice to our users when we get to that point and will
+ provide help and support to ensure easy migrations.
+ </p>
+ <p>
+ In the meantime, please use TensorFlow Lite. If you have a feature request,
+ such as a missing op, please post to our <a
+ href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
+ </p>
+</div>
+
To get you started working with TensorFlow on Android, we'll walk through two
ways to build our TensorFlow mobile demos and deploying them on an Android
device. The first is Android Studio, which lets you build and deploy in an
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/index.md b/tensorflow/contrib/lite/g3doc/tfmobile/index.md
index 49ad35d4e6..15f0fd3961 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/index.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/index.md
@@ -1,6 +1,22 @@
-
# Overview
+Warning: We expect to deprecate TensorFlow Mobile in early 2019
+
+<div class="caution">
+ <p>
+ <a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
+ working hard to close the feature gap between TensorFlow Mobile and
+ TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
+ will give ample notice to our users when we get to that point and will
+ provide help and support to ensure easy migrations.
+ </p>
+ <p>
+ In the meantime, please use TensorFlow Lite. If you have a feature request,
+ such as a missing op, please post to our <a
+ href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
+ </p>
+</div>
+
TensorFlow was designed to be a good deep learning solution for mobile
platforms. Currently we have two solutions for deploying machine learning
applications on mobile and embedded devices: TensorFlow for Mobile and
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md
index be8b4100c8..d922907cdc 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md
@@ -1,6 +1,22 @@
-
# Building TensorFlow on iOS
+Warning: We expect to deprecate TensorFlow Mobile in early 2019
+
+<div class="caution">
+ <p>
+ <a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
+ working hard to close the feature gap between TensorFlow Mobile and
+ TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
+ will give ample notice to our users when we get to that point and will
+ provide help and support to ensure easy migrations.
+ </p>
+ <p>
+ In the meantime, please use TensorFlow Lite. If you have a feature request,
+ such as a missing op, please post to our <a
+ href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
+ </p>
+</div>
+
## Using CocoaPods
The simplest way to get started with TensorFlow on iOS is using the CocoaPods
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md b/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md
index 4d4bb3bc08..fd0e322c93 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md
@@ -1,6 +1,22 @@
-
# Integrating TensorFlow libraries
+Warning: We expect to deprecate TensorFlow Mobile in early 2019
+
+<div class="caution">
+ <p>
+ <a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
+ working hard to close the feature gap between TensorFlow Mobile and
+ TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
+ will give ample notice to our users when we get to that point and will
+ provide help and support to ensure easy migrations.
+ </p>
+ <p>
+ In the meantime, please use TensorFlow Lite. If you have a feature request,
+ such as a missing op, please post to our <a
+ href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
+ </p>
+</div>
+
Once you have made some progress on a model that addresses the problem you’re
trying to solve, it’s important to test it out inside your application
immediately. There are often unexpected differences between your training data
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md b/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md
index 7436594fd8..59ff8e774c 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md
@@ -1,6 +1,22 @@
-
# Optimizing for mobile
+Warning: We expect to deprecate TensorFlow Mobile in early 2019
+
+<div class="caution">
+ <p>
+ <a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
+ working hard to close the feature gap between TensorFlow Mobile and
+ TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
+ will give ample notice to our users when we get to that point and will
+ provide help and support to ensure easy migrations.
+ </p>
+ <p>
+ In the meantime, please use TensorFlow Lite. If you have a feature request,
+ such as a missing op, please post to our <a
+ href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
+ </p>
+</div>
+
There are some special issues that you have to deal with when you’re trying to
ship on mobile or embedded devices, and you’ll need to think about these as
you’re developing your model.
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md b/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md
index d1c67d4c61..1d373251dd 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md
@@ -1,6 +1,22 @@
-
# Preparing models for mobile deployment
+Warning: We expect to deprecate TensorFlow Mobile in early 2019
+
+<div class="caution">
+ <p>
+ <a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
+ working hard to close the feature gap between TensorFlow Mobile and
+ TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
+ will give ample notice to our users when we get to that point and will
+ provide help and support to ensure easy migrations.
+ </p>
+ <p>
+ In the meantime, please use TensorFlow Lite. If you have a feature request,
+ such as a missing op, please post to our <a
+ href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
+ </p>
+</div>
+
The requirements for storing model information during training are very
different from when you want to release it as part of a mobile app. This section
covers the tools involved in converting from a training model to something
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index 7ef736d01b..651a97e9dc 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -349,6 +349,10 @@ class Interpreter {
return context_.allow_fp32_relax_to_fp16;
}
+ // Owning handle to a TfLiteDelegate instance.
+ using TfLiteDelegatePtr =
+ std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
+
// Allow a delegate to look at the graph and modify the graph to handle
// parts of the graph themselves. After this is called, the graph may
// contain new nodes that replace 1 more nodes.
@@ -574,19 +578,11 @@ class Interpreter {
TfLiteExternalContextType type,
TfLiteExternalContext* ctx);
- using TfLiteDelegatePtr =
- std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
-
// Variant of the public ModifyGraphWithDelegate method that additionally
// Assumes ownership of the provided delegate.
// WARNING: This is an experimental API and subject to change.
- template <typename Delegate>
- TfLiteStatus ModifyGraphWithDelegate(std::unique_ptr<Delegate> typed_delegate,
+ TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegatePtr delegate,
bool allow_dynamic_tensors = false) {
- TfLiteDelegatePtr delegate(typed_delegate.release(),
- [](TfLiteDelegate* delegate) {
- delete static_cast<Delegate*>(delegate);
- });
// Note that we retain ownership of the delegate even if graph modification
// fails, as delegate use will be in an indeterminate state at that point.
owned_delegates_.push_back(std::move(delegate));
@@ -676,6 +672,7 @@ class Interpreter {
// List of delegates that have been installed and are owned by this
// interpreter instance. Useful if client delegate ownership is burdensome.
// WARNING: This is an experimental API and subject to change.
+ // TODO(b/116667551): Use TfLiteExternalContext for storing state.
std::vector<TfLiteDelegatePtr> owned_delegates_;
std::unique_ptr<MemoryPlanner> memory_planner_;
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
index cdede430e2..6c71d5a8d7 100644
--- a/tensorflow/contrib/lite/interpreter_test.cc
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -30,7 +30,11 @@ class InterpreterTest : public ::testing::Test {
template <typename Delegate>
static TfLiteStatus ModifyGraphWithDelegate(
Interpreter* interpreter, std::unique_ptr<Delegate> delegate) {
- return interpreter->ModifyGraphWithDelegate(std::move(delegate));
+ Interpreter::TfLiteDelegatePtr tflite_delegate(
+ delegate.release(), [](TfLiteDelegate* delegate) {
+ delete reinterpret_cast<Delegate*>(delegate);
+ });
+ return interpreter->ModifyGraphWithDelegate(std::move(tflite_delegate));
}
protected:
diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD
index 098ba7e773..e68cd26f81 100644
--- a/tensorflow/contrib/lite/java/BUILD
+++ b/tensorflow/contrib/lite/java/BUILD
@@ -11,6 +11,10 @@ load("//tensorflow/java:build_defs.bzl", "JAVACOPTS")
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_jni_binary")
load("//tensorflow/contrib/lite/java:aar_with_jni.bzl", "aar_with_jni")
+JAVA_SRCS = glob([
+ "src/main/java/org/tensorflow/lite/*.java",
+])
+
# Building tensorflow-lite.aar including 4 variants of .so
# To build an aar for release, run below command:
# bazel build --cxxopt='--std=c++11' -c opt --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \
@@ -20,28 +24,38 @@ aar_with_jni(
android_library = ":tensorflowlite",
)
+# EXPERIMENTAL: AAR target that supports TensorFlow op execution with TFLite.
+aar_with_jni(
+ name = "tensorflow-lite-flex",
+ android_library = ":tensorflowlite_flex",
+)
+
android_library(
name = "tensorflowlite",
- srcs = glob(
- [
- "src/main/java/org/tensorflow/lite/*.java",
- ],
- ),
+ srcs = JAVA_SRCS,
+ manifest = "AndroidManifest.xml",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":tensorflowlite_native",
+ "@org_checkerframework_qual",
+ ],
+)
+
+# EXPERIMENTAL: Android target that supports TensorFlow op execution with TFLite.
+android_library(
+ name = "tensorflowlite_flex",
+ srcs = JAVA_SRCS,
manifest = "AndroidManifest.xml",
visibility = ["//visibility:public"],
deps = [
- ":tflite_runtime",
+ ":tensorflowlite_native_flex",
"@org_checkerframework_qual",
],
)
android_library(
name = "tensorflowlite_java",
- srcs = glob(
- [
- "src/main/java/org/tensorflow/lite/*.java",
- ],
- ),
+ srcs = JAVA_SRCS,
visibility = ["//visibility:public"],
deps = [
"@org_checkerframework_qual",
@@ -50,16 +64,23 @@ android_library(
java_library(
name = "tensorflowlitelib",
- srcs = glob(
- [
- "src/main/java/org/tensorflow/lite/*.java",
- ],
- ),
+ srcs = JAVA_SRCS,
javacopts = JAVACOPTS,
visibility = ["//visibility:public"],
deps = [
":libtensorflowlite_jni.so",
- "//tensorflow/contrib/lite/java/src/main/native",
+ "@org_checkerframework_qual",
+ ],
+)
+
+# EXPERIMENTAL: Java target that supports TensorFlow op execution with TFLite.
+java_library(
+ name = "tensorflowlitelib_flex",
+ srcs = JAVA_SRCS,
+ javacopts = JAVACOPTS,
+ visibility = ["//visibility:public"],
+ deps = [
+ ":libtensorflowlite_flex_jni.so",
"@org_checkerframework_qual",
],
)
@@ -72,7 +93,6 @@ java_test(
tags = ["no_oss"],
test_class = "org.tensorflow.lite.TensorFlowLiteTest",
deps = [
- ":libtensorflowlite_jni.so",
":tensorflowlitelib",
"@com_google_truth",
"@junit",
@@ -87,7 +107,6 @@ java_test(
tags = ["no_oss"],
test_class = "org.tensorflow.lite.DataTypeTest",
deps = [
- ":libtensorflowlite_jni.so",
":tensorflowlitelib",
"@com_google_truth",
"@junit",
@@ -110,7 +129,6 @@ java_test(
tags = ["no_oss"],
test_class = "org.tensorflow.lite.NativeInterpreterWrapperTest",
deps = [
- ":libtensorflowlite_jni.so",
":tensorflowlitelib",
"@com_google_truth",
"@junit",
@@ -125,13 +143,13 @@ java_test(
data = [
"src/testdata/add.bin",
"src/testdata/mobilenet.tflite.bin",
+ "//tensorflow/contrib/lite:testdata/multi_add_flex.bin",
],
javacopts = JAVACOPTS,
tags = ["no_oss"],
test_class = "org.tensorflow.lite.InterpreterTest",
visibility = ["//visibility:private"],
deps = [
- ":libtensorflowlite_jni.so",
":tensorflowlitelib",
"@com_google_truth",
"@junit",
@@ -139,6 +157,24 @@ java_test(
)
java_test(
+ name = "InterpreterFlexTest",
+ size = "small",
+ srcs = ["src/test/java/org/tensorflow/lite/InterpreterFlexTest.java"],
+ data = [
+ "//tensorflow/contrib/lite:testdata/multi_add_flex.bin",
+ ],
+ javacopts = JAVACOPTS,
+ tags = ["no_oss"],
+ test_class = "org.tensorflow.lite.InterpreterFlexTest",
+ visibility = ["//visibility:private"],
+ deps = [
+ ":tensorflowlitelib_flex",
+ "@com_google_truth",
+ "@junit",
+ ],
+)
+
+java_test(
name = "TensorTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/lite/TensorTest.java"],
@@ -164,14 +200,29 @@ filegroup(
)
cc_library(
- name = "tflite_runtime",
+ name = "tensorflowlite_native",
srcs = ["libtensorflowlite_jni.so"],
visibility = ["//visibility:public"],
)
+cc_library(
+ name = "tensorflowlite_native_flex",
+ srcs = ["libtensorflowlite_flex_jni.so"],
+ visibility = ["//visibility:public"],
+)
+
tflite_jni_binary(
name = "libtensorflowlite_jni.so",
deps = [
"//tensorflow/contrib/lite/java/src/main/native",
],
)
+
+# EXPERIMENTAL: Native target that supports TensorFlow op execution with TFLite.
+tflite_jni_binary(
+ name = "libtensorflowlite_flex_jni.so",
+ deps = [
+ "//tensorflow/contrib/lite/delegates/flex:delegate",
+ "//tensorflow/contrib/lite/java/src/main/native",
+ ],
+)
diff --git a/tensorflow/contrib/lite/java/aar_with_jni.bzl b/tensorflow/contrib/lite/java/aar_with_jni.bzl
index 9d2aead266..360d622b1b 100644
--- a/tensorflow/contrib/lite/java/aar_with_jni.bzl
+++ b/tensorflow/contrib/lite/java/aar_with_jni.bzl
@@ -30,7 +30,10 @@ EOF
# In some platforms we don't have an Android SDK/NDK and this target
# can't be built. We need to prevent the build system from trying to
# use the target in that case.
- tags = ["manual"],
+ tags = [
+ "manual",
+ "no_cuda_on_cpu_tap",
+ ],
)
native.genrule(
diff --git a/tensorflow/contrib/lite/java/ovic/BUILD b/tensorflow/contrib/lite/java/ovic/BUILD
index bb0be04ca2..ea9b9ed4b6 100644
--- a/tensorflow/contrib/lite/java/ovic/BUILD
+++ b/tensorflow/contrib/lite/java/ovic/BUILD
@@ -9,6 +9,7 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow/java:build_defs.bzl", "JAVACOPTS")
+# Build targets for OVIC classification.
java_test(
name = "OvicClassifierTest",
size = "medium",
@@ -45,8 +46,9 @@ android_library(
name = "ovicbenchmarkerlib",
srcs = [
"src/main/java/org/tensorflow/ovic/OvicBenchmarker.java",
+ "src/main/java/org/tensorflow/ovic/OvicClassificationResult.java",
"src/main/java/org/tensorflow/ovic/OvicClassifier.java",
- "src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java",
+ "src/main/java/org/tensorflow/ovic/OvicClassifierBenchmarker.java",
],
manifest = "//tensorflow/contrib/lite/java:AndroidManifest.xml",
tags = ["no_oss"],
@@ -60,8 +62,8 @@ android_library(
java_library(
name = "ovicbenchmarkerlib_java",
srcs = [
+ "src/main/java/org/tensorflow/ovic/OvicClassificationResult.java",
"src/main/java/org/tensorflow/ovic/OvicClassifier.java",
- "src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java",
],
javacopts = JAVACOPTS,
tags = ["no_oss"],
@@ -73,3 +75,58 @@ java_library(
"@org_checkerframework_qual",
],
)
+
+# Build targets for OVIC detection.
+java_test(
+ name = "OvicDetectorTest",
+ size = "medium",
+ srcs = ["src/test/java/org/tensorflow/ovic/OvicDetectorTest.java"],
+ data = [
+ "//tensorflow/contrib/lite/java/ovic/src/testdata:coco_labels.txt",
+ "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata",
+ "@tflite_mobilenet_ssd_quant//:detect.tflite",
+ ],
+ javacopts = JAVACOPTS,
+ tags = ["no_oss"],
+ test_class = "org.tensorflow.ovic.OvicDetectorTest",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/contrib/lite/java/ovic:ovicdetectionbenchmarkerlib_java",
+ "@com_google_truth",
+ "@junit",
+ ],
+)
+
+android_library(
+ name = "ovicdetectionbenchmarkerlib",
+ srcs = [
+ "src/main/java/org/tensorflow/ovic/BoundingBox.java",
+ "src/main/java/org/tensorflow/ovic/OvicBenchmarker.java",
+ "src/main/java/org/tensorflow/ovic/OvicDetectionResult.java",
+ "src/main/java/org/tensorflow/ovic/OvicDetector.java",
+ "src/main/java/org/tensorflow/ovic/OvicDetectorBenchmarker.java",
+ ],
+ manifest = "//tensorflow/contrib/lite/java:AndroidManifest.xml",
+ deps = [
+ "//tensorflow/contrib/lite/java:tensorflowlite",
+ "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper",
+ "@org_checkerframework_qual",
+ ],
+)
+
+java_library(
+ name = "ovicdetectionbenchmarkerlib_java",
+ srcs = [
+ "src/main/java/org/tensorflow/ovic/BoundingBox.java",
+ "src/main/java/org/tensorflow/ovic/OvicDetectionResult.java",
+ "src/main/java/org/tensorflow/ovic/OvicDetector.java",
+ ],
+ javacopts = JAVACOPTS,
+ deps = [
+ "//tensorflow/contrib/lite/java:libtensorflowlite_jni.so",
+ "//tensorflow/contrib/lite/java:tensorflowlite_java",
+ "//tensorflow/contrib/lite/java/src/main/native",
+ "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper",
+ "@org_checkerframework_qual",
+ ],
+)
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
index 058240aada..f567358ea3 100644
--- a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
@@ -10,8 +10,10 @@ android_binary(
],
aapt_version = "aapt",
assets = [
- "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata",
+ "//tensorflow/contrib/lite/java/ovic/src/testdata:coco_labels.txt",
"//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt",
+ "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata",
+ "@tflite_mobilenet_ssd_quant//:detect.tflite",
],
assets_dir = "",
custom_package = "ovic.demo.app",
@@ -25,6 +27,7 @@ android_binary(
deps = [
"//tensorflow/contrib/lite/java:tensorflowlite",
"//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib",
+ "//tensorflow/contrib/lite/java/ovic:ovicdetectionbenchmarkerlib",
"@androidsdk//com.android.support:support-v13-25.2.0",
"@androidsdk//com.android.support:support-v4-25.2.0",
],
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java
index 4adf94aeb6..48c29ecebe 100644
--- a/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java
@@ -35,19 +35,18 @@ import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.text.DecimalFormat;
import org.tensorflow.ovic.OvicBenchmarker;
-import org.tensorflow.ovic.OvicSingleImageResult;
-
+import org.tensorflow.ovic.OvicClassifierBenchmarker;
+import org.tensorflow.ovic.OvicDetectorBenchmarker;
/** Class that benchmark image classifier models. */
public class OvicBenchmarkerActivity extends Activity {
/** Tag for the {@link Log}. */
private static final String TAG = "OvicBenchmarkerActivity";
- /** Name of the label file stored in Assets. */
- private static final String LABEL_PATH = "labels.txt";
-
- private static final String TEST_IMAGE_PATH = "test_image_224.jpg";
- private static final String MODEL_PATH = "float_model.lite";
+ /** Name of the task-dependent data files stored in Assets. */
+ private static String labelPath = null;
+ private static String testImagePath = null;
+ private static String modelPath = null;
/**
* Each bottom press will launch a benchmarking experiment. The experiment stops when either the
* total native latency reaches WALL_TIME or the number of iterations reaches MAX_ITERATIONS,
@@ -66,8 +65,6 @@ public class OvicBenchmarkerActivity extends Activity {
private MappedByteBuffer model = null;
private InputStream labelInputStream = null;
private OvicBenchmarker benchmarker;
- /** Inference result of each iteration. */
- OvicSingleImageResult iterResult = null;
private TextView textView = null;
// private Button startButton = null;
@@ -83,21 +80,31 @@ public class OvicBenchmarkerActivity extends Activity {
}
private Bitmap loadTestBitmap() throws IOException {
- InputStream imageStream = getAssets().open(TEST_IMAGE_PATH);
+ InputStream imageStream = getAssets().open(testImagePath);
return BitmapFactory.decodeStream(imageStream);
}
- public void initializeTest() throws IOException {
+ public void initializeTest(boolean benchmarkClassification) throws IOException {
Log.i(TAG, "Initializing benchmarker.");
- benchmarker = new OvicBenchmarker(WALL_TIME);
+ if (benchmarkClassification) {
+ benchmarker = new OvicClassifierBenchmarker(WALL_TIME);
+ labelPath = "labels.txt";
+ testImagePath = "test_image_224.jpg";
+ modelPath = "quantized_model.lite";
+ } else { // Benchmarking detection.
+ benchmarker = new OvicDetectorBenchmarker(WALL_TIME);
+ labelPath = "coco_labels.txt";
+ testImagePath = "test_image_224.jpg";
+ modelPath = "detect.tflite";
+ }
AssetManager am = getAssets();
- AssetFileDescriptor fileDescriptor = am.openFd(MODEL_PATH);
+ AssetFileDescriptor fileDescriptor = am.openFd(modelPath);
FileInputStream modelInputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = modelInputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
model = fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
- labelInputStream = am.open(LABEL_PATH);
+ labelInputStream = am.open(labelPath);
}
public Boolean doTestIteration() throws IOException, InterruptedException {
@@ -117,24 +124,44 @@ public class OvicBenchmarkerActivity extends Activity {
Log.i(TAG, "Going to do test iter.");
// Start testing.
Bitmap testImageBitmap = loadTestBitmap();
- iterResult = benchmarker.doTestIteration(testImageBitmap);
- testImageBitmap.recycle();
- if (iterResult == null) {
+ try {
+ if (!benchmarker.processBitmap(testImageBitmap)) {
+ throw new RuntimeException("Failed to run test.");
+ }
+ } catch (Exception e) {
+ e.printStackTrace();
+ throw e;
+ } finally {
+ testImageBitmap.recycle();
+ }
+ String iterResultString = benchmarker.getLastResultString();
+ if (iterResultString == null) {
throw new RuntimeException("Inference failed to produce a result.");
}
- Log.i(TAG, iterResult.toString());
+ Log.i(TAG, iterResultString);
return true;
}
- public void startPressed(View view) throws IOException {
- Log.i(TAG, "Start pressed");
+ public void detectPressed(View view) throws IOException {
+ benchmarkSession(false);
+ }
+ public void classifyPressed(View view) throws IOException {
+ benchmarkSession(true);
+ }
+
+ private void benchmarkSession(boolean benchmarkClassification) throws IOException {
try {
- initializeTest();
+ initializeTest(benchmarkClassification);
} catch (IOException e) {
Log.e(TAG, "Can't initialize benchmarker.", e);
throw e;
}
String displayText = "";
+ if (benchmarkClassification) {
+ displayText = "Classification benchmark: ";
+ } else {
+ displayText = "Detection benchmark: ";
+ }
try {
setProcessorAffinity(BIG_CORE_MASK);
} catch (IOException e) {
@@ -144,7 +171,6 @@ public class OvicBenchmarkerActivity extends Activity {
Log.i(TAG, "Successfully initialized benchmarker.");
int testIter = 0;
Boolean iterSuccess = false;
- double totalLatency = 0.0f;
while (testIter < MAX_ITERATIONS) {
try {
iterSuccess = doTestIteration();
@@ -153,23 +179,22 @@ public class OvicBenchmarkerActivity extends Activity {
throw e;
} catch (InterruptedException e) {
Log.e(TAG, "Interrupted at iteration " + testIter);
+ displayText += e.getMessage() + "\n";
}
if (!iterSuccess) {
break;
}
testIter++;
- totalLatency += (double) iterResult.latency;
}
- ;
Log.i(TAG, "Benchmarking finished");
if (textView != null) {
if (testIter > 0) {
textView.setText(
displayText
- + MODEL_PATH
+ + modelPath
+ ": Average latency="
- + df2.format(totalLatency / testIter)
+ + df2.format(benchmarker.getTotalRunTime() / testIter)
+ "ms after "
+ testIter
+ " runs.");
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml b/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml
index e9d83bae54..1bce60ff7d 100644
--- a/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml
@@ -30,14 +30,14 @@
android:layout_height="wrap_content"
android:text="@string/initial_status_msg"
android:id="@+id/textView"
- android:layout_above="@+id/button_start"
+ android:layout_above="@+id/button_clf_start"
android:layout_alignParentTop="true"/>
<Button
android:layout_width="wrap_content"
android:layout_height="wrap_content"
- android:text="@string/start_label"
- android:id="@id/button_start"
+ android:text="@string/start_clf_label"
+ android:id="@id/button_clf_start"
android:layout_alignParentBottom="true"
android:layout_alignParentLeft="true"
android:background="@drawable/start_button_color"
@@ -49,6 +49,25 @@
android:textColor="#ffffff"
android:enabled="true"
style="?android:attr/buttonBarButtonStyle"
- android:onClick="startPressed"/>
+ android:onClick="classifyPressed"/>
+
+ <Button
+ android:layout_width="wrap_content"
+ android:layout_height="wrap_content"
+ android:text="@string/start_det_label"
+ android:id="@+id/button_det_start"
+ android:layout_alignParentBottom="true"
+ android:layout_alignParentRight="true"
+ android:layout_toRightOf="@id/button_clf_start"
+ android:background="@drawable/start_button_color"
+ android:padding="10dp"
+ android:layout_marginRight="100dp"
+ android:layout_marginLeft="30dp"
+ android:layout_marginTop="10dp"
+ android:foreground="#000000"
+ android:textColor="#ffffff"
+ android:enabled="true"
+ style="?android:attr/buttonBarButtonStyle"
+ android:onClick="detectPressed"/>
</RelativeLayout>
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/values/strings.xml b/tensorflow/contrib/lite/java/ovic/demo/app/res/values/strings.xml
index d26beb1d27..53525908d3 100644
--- a/tensorflow/contrib/lite/java/ovic/demo/app/res/values/strings.xml
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/res/values/strings.xml
@@ -17,6 +17,7 @@
<resources>
<string name="app_name" translatable="false">Benchmarker</string>
- <string name="start_label" translatable="false">Start</string>
+ <string name="start_clf_label" translatable="false">Clf</string>
+ <string name="start_det_label" translatable="false">Det</string>
<string name="initial_status_msg" translatable="false"> Press start to run the benchmarks.</string>
</resources>
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/BoundingBox.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/BoundingBox.java
new file mode 100644
index 0000000000..9bf7d005d2
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/BoundingBox.java
@@ -0,0 +1,68 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ https://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow.ovic;
+
+/** Class for holding a detection bounding box with category and confidence. */
+public class BoundingBox {
+ // Upper left point.
+ public float x1;
+ public float y1;
+
+ // Lower right point.
+ public float x2;
+ public float y2;
+
+ // The area of the box
+ public float area;
+
+ // The object category
+ public int category;
+
+ // The confidence of the detection
+ public float score;
+
+ public BoundingBox(float x1, float y1, float x2, float y2, int category, float score) {
+ this.x1 = x1;
+ this.y1 = y1;
+ this.x2 = x2;
+ this.y2 = y2;
+ this.category = category;
+ this.score = score;
+ // -1 stands for area not initialized
+ this.area = -1;
+ }
+
+ // The intersection area of two bounding boxes
+ public float intersect(BoundingBox bbx) {
+ return Math.max(0, Math.min(x2, bbx.x2) - Math.max(x1, bbx.x1))
+ * Math.max(0, Math.min(y2, bbx.y2) - Math.max(y1, bbx.y1));
+ }
+
+ // The union area of two bounding boxes
+ public float union(BoundingBox bbx) {
+ return bbx.getArea() + this.getArea() - this.intersect(bbx);
+ }
+
+ public float getArea() {
+ if (area < 0) {
+ area = (x2 - x1) * (y2 - y1);
+ }
+ return area;
+ }
+
+ public float computeIoU(BoundingBox bbx) {
+ return (float) (this.intersect(bbx) * 1.0 / this.union(bbx));
+ }
+}
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java
index 4cda258bee..15d9511f50 100644
--- a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java
@@ -20,11 +20,10 @@ import android.util.Log;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
-import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
/**
- * Class that benchmarks image classifier models.
+ * Base class that benchmarks image models.
*
* <p>===================== General workflow =======================
*
@@ -33,37 +32,40 @@ import java.nio.MappedByteBuffer;
* benchmarker.getReadyToTest(labelInputStream, model);
* while (!benchmarker.shouldStop()) {
* Bitmap bitmap = ...
- * benchmarker.doTestIteration(bitmap);
+ * imgId = ...
+ * benchmarker.processBitmap(bitmap, imgId);
* }
* }</pre>
*/
-public class OvicBenchmarker {
+public abstract class OvicBenchmarker {
/** Tag for the {@link Log}. */
private static final String TAG = "OvicBenchmarker";
- /** Evaluation transformation parameters. */
- private static final float CENTRAL_FRACTION = 0.875f;
-
/** Dimensions of inputs. */
- private static final int DIM_BATCH_SIZE = 1;
- private static final int DIM_PIXEL_SIZE = 3;
- private int imgHeight = 224;
- private int imgWidth = 224;
+ protected static final int DIM_BATCH_SIZE = 1;
+ protected static final int DIM_PIXEL_SIZE = 3;
+ protected int imgHeight = 224;
+ protected int imgWidth = 224;
+
+ /** Preprocess parameters (only used when input is float). */
+ protected static final float IMAGE_MEAN = 127.5f;
+ protected static final float IMAGE_STD = 127.5f;
+
+ /** Whether input is float or quantized. */
+ protected Boolean quantizedInput = null;
/* Preallocated buffers for storing image data in. */
- private int[] intValues = null;
+ protected int[] intValues = null;
/** A ByteBuffer to hold image data, to be feed into classifier as inputs. */
- private ByteBuffer imgData = null;
-
- private OvicClassifier classifier;
+ protected ByteBuffer imgData = null;
/** Total runtime in ms. */
- private double totalRuntime = 0.0;
+ protected double totalRuntime = 0.0;
/** Total allowed runtime in ms. */
- private double wallTime = 20000 * 30.0;
-
- private Boolean benchmarkStarted = null;
+ protected double wallTime = 20000 * 30.0;
+ /** Record whether benchmark has started (used to skip the first image). */
+ protected boolean benchmarkStarted = false;
/**
* Initializes an {@link OvicBenchmarker}
@@ -76,6 +78,11 @@ public class OvicBenchmarker {
this.wallTime = wallTime;
}
+ /** Return the cumulative latency of all runs so far. */
+ public double getTotalRunTime() {
+ return totalRuntime;
+ }
+
/** Check whether the benchmarker should stop. */
public Boolean shouldStop() {
if (totalRuntime >= wallTime) {
@@ -90,105 +97,62 @@ public class OvicBenchmarker {
return false;
}
- /** Check whether the benchmarker is ready to start classifying images. */
- public Boolean readyToTest() {
- return (classifier != null);
- }
+ /** Abstract class for checking whether the benchmarker is ready to start processing images */
+ public abstract boolean readyToTest();
/**
- * Getting the benchmarker ready for classifying images.
+ * Abstract class for getting the benchmarker ready.
*
* @param labelInputStream: an {@link InputStream} specifying where the list of labels should be
* read from.
* @param model: a {@link MappedByteBuffer} model to benchmark.
*/
- public void getReadyToTest(InputStream labelInputStream, MappedByteBuffer model) {
- try {
- Log.i(TAG, "Creating classifier.");
- classifier = new OvicClassifier(labelInputStream, model);
- int [] inputDims = classifier.getInputDims();
- imgHeight = inputDims[1];
- imgWidth = inputDims[2];
- // Only accept QUANTIZED_UINT8 input.
- imgData = ByteBuffer.allocateDirect(DIM_BATCH_SIZE * imgHeight * imgWidth * DIM_PIXEL_SIZE);
- imgData.order(ByteOrder.nativeOrder());
- intValues = new int[imgHeight * imgWidth];
- } catch (Exception e) {
- Log.e(TAG, e.getMessage());
- Log.e(TAG, "Failed to initialize ImageNet classifier for the benchmarker.");
- }
- }
-
- /** Return how many classes are predicted per image. */
- public int getNumPredictions() {
- return classifier.getNumPredictions();
- }
+ public abstract void getReadyToTest(InputStream labelInputStream, MappedByteBuffer model);
/**
* Perform test on a single bitmap image.
*
- * @param bitmap: a {@link Bitmap} image to classify.
+ * @param bitmap: a {@link Bitmap} image to process.
+ * @param imageId: an ID uniquely representing the image.
*/
- public OvicSingleImageResult doTestIteration(Bitmap bitmap)
- throws IOException, InterruptedException {
- if (shouldStop() || !readyToTest()) {
- return null;
- }
- OvicSingleImageResult iterResult = null;
- try {
- Log.i(TAG, "Converting bitmap.");
- convertBitmapToInput(bitmap);
- Log.i(TAG, "Classifying image.");
- iterResult = classifier.classifyByteBuffer(imgData);
- } catch (RuntimeException e) {
- Log.e(TAG, e.getMessage());
- Log.e(TAG, "Failed to classify image.");
- }
- if (iterResult == null || iterResult.latency == null) {
- throw new RuntimeException("Classification result or timing is invalid.");
- }
- Log.d(TAG, "Native inference latency: " + iterResult.latency);
- Log.i(TAG, iterResult.toString());
+ public abstract boolean processBitmap(Bitmap bitmap, int imageId)
+ throws IOException, InterruptedException;
- if (!benchmarkStarted) { // Skip the first image to discount warming-up time.
- benchmarkStarted = true;
- } else {
- totalRuntime += (double) iterResult.latency;
- }
- return iterResult;
+ /** Perform test on a single bitmap image without an image ID. */
+ public boolean processBitmap(Bitmap bitmap) throws IOException, InterruptedException {
+ return processBitmap(bitmap, /* imageId = */ 0);
}
+ /** Returns the last inference results as string. */
+ public abstract String getLastResultString();
+
/**
- * Writes Image data into a {@link ByteBuffer}.
- *
- * @param bitmap: a {@link Bitmap} source image.
- */
- private void convertBitmapToInput(Bitmap bitmap) throws RuntimeException {
- if (imgData == null) {
+ * Loads input buffer from intValues into ByteBuffer for the interpreter.
+ * Input buffer must be loaded in intValues and output will be placed in imgData.
+ */
+ protected void loadsInputToByteBuffer() {
+ if (imgData == null || intValues == null || quantizedInput == null) {
throw new RuntimeException("Benchmarker is not yet ready to test.");
}
- imgData.rewind();
- // Perform transformations corresponding to evaluation mode.
- float width = (float) bitmap.getWidth();
- float height = (float) bitmap.getHeight();
- int stWidth = Math.round((width - width * CENTRAL_FRACTION) / 2);
- int stHeight = Math.round((height - height * CENTRAL_FRACTION) / 2);
- int newWidth = Math.round(width - stWidth * 2);
- int newHeight = Math.round(height - stHeight * 2);
- bitmap = Bitmap.createBitmap(bitmap, stWidth, stHeight, newWidth, newHeight);
- bitmap = Bitmap.createScaledBitmap(bitmap, imgWidth, imgHeight, true);
- bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
-
// Convert the image to ByteBuffer.
+ imgData.rewind();
int pixel = 0;
long startTime = SystemClock.uptimeMillis();
for (int i = 0; i < imgHeight; ++i) {
for (int j = 0; j < imgWidth; ++j) {
- final int val = intValues[pixel++];
- imgData.put((byte) ((val >> 16) & 0xFF));
- imgData.put((byte) ((val >> 8) & 0xFF));
- imgData.put((byte) (val & 0xFF));
+ final int pixelValue = intValues[pixel++];
+ if (quantizedInput) {
+ // Quantized model
+ imgData.put((byte) ((pixelValue >> 16) & 0xFF));
+ imgData.put((byte) ((pixelValue >> 8) & 0xFF));
+ imgData.put((byte) (pixelValue & 0xFF));
+ } else {
+ // Float model
+ imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ }
}
}
long endTime = SystemClock.uptimeMillis();
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassificationResult.java
index 4af9a65c2f..5ab804e6ee 100644
--- a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassificationResult.java
@@ -1,4 +1,4 @@
-/*Copyright 2018 Google LLC
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -17,17 +17,17 @@ package org.tensorflow.ovic;
import java.util.ArrayList;
/** Result class for inference run on a single image. */
-public class OvicSingleImageResult {
+public class OvicClassificationResult {
/** Top K classes and probabilities. */
- public ArrayList<String> topKClasses;
- public ArrayList<Float> topKProbs;
- public ArrayList<Integer> topKIndices;
+ public final ArrayList<String> topKClasses;
+ public final ArrayList<Float> topKProbs;
+ public final ArrayList<Integer> topKIndices;
/** Latency (ms). */
public Long latency;
- OvicSingleImageResult() {
+ OvicClassificationResult() {
topKClasses = new ArrayList<>();
topKProbs = new ArrayList<>();
topKIndices = new ArrayList<>();
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java
index fd610b054f..d8a54c1f3b 100644
--- a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java
@@ -31,7 +31,7 @@ import java.util.PriorityQueue;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.TestHelper;
-/** Benchmark ImageNet Classifier with Tensorflow Lite. */
+/** Class for running ImageNet classification with a TfLite model. */
public class OvicClassifier {
/** Tag for the {@link Log}. */
@@ -106,7 +106,7 @@ public class OvicClassifier {
/** Classifies a {@link ByteBuffer} image. */
// @throws RuntimeException if model is uninitialized.
- public OvicSingleImageResult classifyByteBuffer(ByteBuffer imgData) {
+ public OvicClassificationResult classifyByteBuffer(ByteBuffer imgData) {
if (tflite == null) {
throw new RuntimeException(TAG + ": ImageNet classifier has not been initialized; Failed.");
}
@@ -122,7 +122,7 @@ public class OvicClassifier {
labelProbArray[0][i] = (inferenceOutputArray[0][i] & 0xff) / 255.0f;
}
}
- OvicSingleImageResult iterResult = computeTopKLabels();
+ OvicClassificationResult iterResult = computeTopKLabels();
iterResult.latency = getLastNativeInferenceLatencyMilliseconds();
return iterResult;
}
@@ -174,7 +174,7 @@ public class OvicClassifier {
}
/** Computes top-K labels. */
- private OvicSingleImageResult computeTopKLabels() {
+ private OvicClassificationResult computeTopKLabels() {
if (labelList == null) {
throw new RuntimeException("Label file has not been loaded.");
}
@@ -184,7 +184,7 @@ public class OvicClassifier {
sortedLabels.poll();
}
}
- OvicSingleImageResult singleImageResult = new OvicSingleImageResult();
+ OvicClassificationResult singleImageResult = new OvicClassificationResult();
if (sortedLabels.size() != RESULTS_TO_SHOW) {
throw new RuntimeException(
"Number of returned labels does not match requirement: "
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifierBenchmarker.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifierBenchmarker.java
new file mode 100644
index 0000000000..0cdd0f7bec
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifierBenchmarker.java
@@ -0,0 +1,142 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ https://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow.ovic;
+
+import android.graphics.Bitmap;
+import android.util.Log;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.MappedByteBuffer;
+
+/** Class that benchmarks image classifier models. */
+public final class OvicClassifierBenchmarker extends OvicBenchmarker {
+ /** Tag for the {@link Log}. */
+ private static final String TAG = "OvicClassifierBenchmarker";
+
+ /** ImageNet preprocessing parameters. */
+ private static final float CENTRAL_FRACTION = 0.875f;
+ private OvicClassifier classifier;
+ private OvicClassificationResult iterResult = null;
+
+ public OvicClassifierBenchmarker(double wallTime) {
+ super(wallTime);
+ }
+
+ /** Test if the classifier is ready for benchmarking. */
+ @Override
+ public boolean readyToTest() {
+ return (classifier != null);
+ }
+
+ /**
+ * Getting the benchmarker ready for classifying images.
+ *
+ * @param labelInputStream: an {@link InputStream} specifying where the list of labels should be
+ * read from.
+ * @param model: a {@link MappedByteBuffer} model to benchmark.
+ */
+ @Override
+ public void getReadyToTest(InputStream labelInputStream, MappedByteBuffer model) {
+ try {
+ Log.i(TAG, "Creating classifier.");
+ classifier = new OvicClassifier(labelInputStream, model);
+ int [] inputDims = classifier.getInputDims();
+ imgHeight = inputDims[1];
+ imgWidth = inputDims[2];
+ quantizedInput = true;
+ // Only accept QUANTIZED_UINT8 input.
+ imgData = ByteBuffer.allocateDirect(DIM_BATCH_SIZE * imgHeight * imgWidth * DIM_PIXEL_SIZE);
+ imgData.order(ByteOrder.nativeOrder());
+ intValues = new int[imgHeight * imgWidth];
+ } catch (Exception e) {
+ Log.e(TAG, e.getMessage());
+ Log.e(TAG, "Failed to initialize ImageNet classifier for the benchmarker.");
+ }
+ }
+
+ /**
+ * Perform classification on a single bitmap image.
+ *
+ * @param bitmap: a {@link Bitmap} image to process.
+ * @param imageId: an ID uniquely representing the image.
+ */
+ @Override
+ public boolean processBitmap(Bitmap bitmap, int imageId)
+ throws IOException, InterruptedException {
+ if (shouldStop() || !readyToTest()) {
+ return false;
+ }
+ try {
+ Log.i(TAG, "Converting bitmap.");
+ convertBitmapToInput(bitmap);
+ Log.i(TAG, "Classifying image: " + imageId);
+ iterResult = classifier.classifyByteBuffer(imgData);
+ } catch (RuntimeException e) {
+ Log.e(TAG, e.getMessage());
+ Log.e(TAG, "Failed to classify image.");
+ }
+ if (iterResult == null || iterResult.latency == null) {
+ throw new RuntimeException("Classification result or timing is invalid.");
+ }
+ Log.d(TAG, "Native inference latency: " + iterResult.latency);
+ Log.i(TAG, iterResult.toString());
+
+ if (!benchmarkStarted) { // Skip the first image to discount warming-up time.
+ benchmarkStarted = true;
+ } else {
+ totalRuntime += ((double) iterResult.latency);
+ }
+ return true;
+ }
+
+ /** Return how many classes are predicted per image. */
+ public int getNumPredictions() {
+ return classifier.getNumPredictions();
+ }
+
+ public OvicClassificationResult getLastClassificationResult() {
+ return iterResult;
+ }
+
+ @Override
+ public String getLastResultString() {
+ if (iterResult == null) {
+ return null;
+ } else {
+ return iterResult.toString();
+ }
+ }
+
+ /**
+ * Preprocess bitmap according to ImageNet protocol then writes result into a {@link ByteBuffer}.
+ *
+ * @param bitmap: a {@link Bitmap} source image.
+ */
+ private void convertBitmapToInput(Bitmap bitmap) {
+ // Perform transformations corresponding to evaluation mode.
+ float width = (float) bitmap.getWidth();
+ float height = (float) bitmap.getHeight();
+ int stWidth = Math.round((width - width * CENTRAL_FRACTION) / 2);
+ int stHeight = Math.round((height - height * CENTRAL_FRACTION) / 2);
+ int newWidth = Math.round(width - stWidth * 2);
+ int newHeight = Math.round(height - stHeight * 2);
+ bitmap = Bitmap.createBitmap(bitmap, stWidth, stHeight, newWidth, newHeight);
+ bitmap = Bitmap.createScaledBitmap(bitmap, imgWidth, imgHeight, true);
+ bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
+ loadsInputToByteBuffer();
+ }
+}
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetectionResult.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetectionResult.java
new file mode 100644
index 0000000000..cf2902a5cb
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetectionResult.java
@@ -0,0 +1,91 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ https://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow.ovic;
+
+import java.util.ArrayList;
+
+/** Result class for inference run on a single image. */
+public class OvicDetectionResult {
+
+ // Top K classes and probabilities.
+ public final ArrayList<BoundingBox> detections;
+ // Latency (ms).
+ public Long latency = -1L;
+ // id of the image.
+ public int id = -1;
+ // Number of valid detections (separately maintained, maybe different from detections.size()).
+ public int count = 0;
+
+ // Create OvicDetectionResult object with pre-filled capacity. Note that detections.size() will
+ // be equal to capacity after this call.
+ OvicDetectionResult(int capacity) {
+ detections = new ArrayList<BoundingBox>(capacity);
+ for (int i = 0; i < capacity; i++) {
+ detections.add(new BoundingBox(-1.0f, -1.0f, -1.0f, -1.0f, -1, -1.0f));
+ }
+ }
+
+ public void resetTo(Long latency, int id) {
+ count = 0;
+ this.latency = latency;
+ this.id = id;
+ }
+
+ public void addBox(float x1, float y1, float x2, float y2, int category, float score) {
+ detections.get(count).x1 = x1;
+ detections.get(count).y1 = y1;
+ detections.get(count).x2 = x2;
+ detections.get(count).y2 = y2;
+ detections.get(count).category = category;
+ detections.get(count).score = score;
+ count += 1;
+ }
+
+ public void scaleUp(double scaleFactorWidth, double scaleFactorHeight) {
+ for (BoundingBox box : detections) {
+ box.x1 = (float) (box.x1 * scaleFactorWidth);
+ box.y1 = (float) (box.y1 * scaleFactorHeight);
+ box.x2 = (float) (box.x2 * scaleFactorWidth);
+ box.y2 = (float) (box.y2 * scaleFactorHeight);
+ }
+ }
+
+ @Override
+ public String toString() {
+ String textToShow = latency + "ms";
+ int k = 0;
+ for (BoundingBox box : detections) {
+ textToShow +=
+ "\nPrediction ["
+ + k
+ + "] = Class "
+ + box.category
+ + " ("
+ + box.x1
+ + ", "
+ + box.y1
+ + ", "
+ + box.x2
+ + ", "
+ + box.y2
+ + ") : "
+ + box.score;
+ k++;
+ }
+
+
+ return textToShow;
+ }
+}
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetector.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetector.java
new file mode 100644
index 0000000000..56836a79e5
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetector.java
@@ -0,0 +1,184 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ https://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow.ovic;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.nio.ByteBuffer;
+import java.nio.MappedByteBuffer;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import org.tensorflow.lite.Interpreter;
+import org.tensorflow.lite.TestHelper;
+
+/** Class for running COCO detection with a TfLite model. */
+public class OvicDetector implements AutoCloseable {
+
+ /** Tag for the {@link Log}. */
+ private static final String TAG = "OvicDetector";
+
+ /** An instance of the driver class to run model inference with Tensorflow Lite. */
+ private Interpreter tflite;
+
+ /** Labels corresponding to the output of the vision model. */
+ private final List<String> labelList;
+
+ /** Define the output format. */
+ private final Boolean inputIsFloat;
+
+ /** Number of detections per image. 10 for demo, 100 for the actual competition. */
+ private static final int NUM_RESULTS = 10;
+
+ /** The output arrays for the mobilenet SSD. */
+ private float[][][] outputLocations;
+ private float[][] outputClasses;
+ private float[][] outputScores;
+ private float[] numDetections;
+ private Map<Integer, Object> outputMap;
+
+ /** Input resolution. */
+ private final int[] inputDims;
+
+ /** Final result. */
+ public OvicDetectionResult result = null;
+
+ OvicDetector(InputStream labelInputStream, MappedByteBuffer model) throws IOException {
+ // Load the label list.
+ labelList = loadLabelList(labelInputStream);
+
+ // Create the TfLite interpreter.
+ tflite = new Interpreter(model, new Interpreter.Options().setNumThreads(1));
+ inputDims = TestHelper.getInputDims(tflite, 0);
+ inputIsFloat = TestHelper.getInputDataType(tflite, 0).equals("float");
+ if (inputDims.length != 4) {
+ throw new RuntimeException("The model's input dimensions must be 4 (BWHC).");
+ }
+ if (inputDims[0] != 1) {
+ throw new RuntimeException(
+ "The model must have a batch size of 1, got " + inputDims[0] + " instead.");
+ }
+ if (inputDims[3] != 3) {
+ throw new RuntimeException(
+ "The model must have three color channels, got " + inputDims[3] + " instead.");
+ }
+ // Check the resolution.
+ int minSide = Math.min(inputDims[1], inputDims[2]);
+ int maxSide = Math.max(inputDims[1], inputDims[2]);
+ if (minSide <= 0 || maxSide > 1000) {
+ throw new RuntimeException("The model's resolution must be between (0, 1000].");
+ }
+
+ // Initialize the input array and result arrays. The input images are stored in a list of
+ // Object. Since this function anaylzed one image per time, there is only 1 item.
+ // The output is fomulated as a map of int -> Object. The output arrays are added to the map.
+ outputLocations = new float[1][NUM_RESULTS][4];
+ outputClasses = new float[1][NUM_RESULTS];
+ outputScores = new float[1][NUM_RESULTS];
+ numDetections = new float[1];
+ outputMap = new HashMap<>();
+ outputMap.put(0, outputLocations);
+ outputMap.put(1, outputClasses);
+ outputMap.put(2, outputScores);
+ outputMap.put(3, numDetections);
+ // Preallocate the result. This will be where inference result is stored after each
+ // detectByteBuffer call.
+ result = new OvicDetectionResult(NUM_RESULTS);
+ }
+
+ public Boolean quantizedInput() {
+ return !inputIsFloat;
+ }
+
+ /** Reads label list from Assets. */
+ private static List<String> loadLabelList(InputStream labelInputStream) throws IOException {
+ List<String> labelList = new ArrayList<>();
+ try (BufferedReader reader =
+ new BufferedReader(new InputStreamReader(labelInputStream, StandardCharsets.UTF_8))) {
+ String line;
+ while ((line = reader.readLine()) != null) {
+ labelList.add(line);
+ }
+ }
+ return labelList;
+ }
+
+ /**
+ * The interface to run the detection. This method currently only support float mobilenet_ssd
+ * model. The quantized models will be added in the future.
+ *
+ * @param imgData The image buffer in ByteBuffer format.
+ * @return boolean indicator of whether detection was a success. If success, the detection results
+ * is available in the result member variable.
+ * See OvicDetectionResult.java for details.
+ */
+ boolean detectByteBuffer(ByteBuffer imgData, int imageId) {
+ if (tflite == null) {
+ throw new RuntimeException(TAG + ": Detector has not been initialized; Failed.");
+ }
+ if (inputIsFloat == null) {
+ throw new RuntimeException(TAG + ": Detector input type has not been resolved.");
+ }
+
+ Object[] inputArray = {imgData};
+ tflite.runForMultipleInputsOutputs(inputArray, outputMap);
+
+ Long latency = getLastNativeInferenceLatencyMilliseconds();
+
+ // Update the results.
+ result.resetTo(latency, imageId);
+ for (int i = 0; i < NUM_RESULTS; i++) {
+ result.addBox(outputLocations[0][i][1] * inputDims[1],
+ outputLocations[0][i][0] * inputDims[1],
+ outputLocations[0][i][3] * inputDims[2],
+ outputLocations[0][i][2] * inputDims[2],
+ Math.round(outputClasses[0][i] + 1 /* Label offset */),
+ outputScores[0][i]);
+ }
+ return true; // Marks that the result is available.
+ }
+
+ /*
+ * Get native inference latency of last image detection run.
+ * @throws RuntimeException if model is uninitialized.
+ * @return The inference latency in millisecond.
+ */
+ public Long getLastNativeInferenceLatencyMilliseconds() {
+ if (tflite == null) {
+ throw new RuntimeException(TAG + ": ImageNet classifier has not been initialized; Failed.");
+ }
+ Long latency = tflite.getLastNativeInferenceDurationNanoseconds();
+ return (latency == null) ? null : (Long) (latency / 1000000);
+ }
+
+ public int[] getInputDims() {
+ return inputDims;
+ }
+
+ public List<String> getLabels() {
+ return labelList;
+ }
+
+ /** Closes tflite to release resources. */
+ @Override
+ public void close() {
+ tflite.close();
+ tflite = null;
+ }
+}
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetectorBenchmarker.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetectorBenchmarker.java
new file mode 100644
index 0000000000..1a4e193ff2
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetectorBenchmarker.java
@@ -0,0 +1,160 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ https://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow.ovic;
+
+import android.graphics.Bitmap;
+import android.util.Log;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.MappedByteBuffer;
+
+/**
+ * Class that benchmarks object detection models.
+ */
+public final class OvicDetectorBenchmarker extends OvicBenchmarker {
+ /** Tag for the {@link Log}. */
+ private static final String TAG = "OvicDetectorBenchmarker";
+
+ public double scaleFactorWidth = 1.0f;
+ public double scaleFactorHeight = 1.0f;
+ private Bitmap scaledBitmap = null; // Preallocate bitmap for scaling.
+
+ private OvicDetector detector;
+
+ /**
+ * Initializes an {@link OvicDetectionBenchmarker}
+ *
+ * @param wallTime: a double number specifying the total amount of time to benchmark.
+ */
+ public OvicDetectorBenchmarker(double wallTime) {
+ super(wallTime);
+ }
+
+ /** Check to see if the detector is ready to test. */
+ @Override
+ public boolean readyToTest() {
+ return (detector != null);
+ }
+
+ /**
+ * Getting the benchmarker ready for detecting images.
+ *
+ * @param labelInputStream: an {@link InputStream} specifying where the list of labels should be
+ * read from.
+ * @param model: a {@link MappedByteBuffer} model to benchmark.
+ */
+ @Override
+ public void getReadyToTest(InputStream labelInputStream, MappedByteBuffer model) {
+ try {
+ Log.i(TAG, "Creating detector.");
+ detector = new OvicDetector(labelInputStream, model);
+ quantizedInput = detector.quantizedInput();
+ int[] inputDims = detector.getInputDims();
+ imgHeight = inputDims[1];
+ imgWidth = inputDims[2];
+ if (quantizedInput) {
+ imgData = ByteBuffer.allocateDirect(DIM_BATCH_SIZE * imgHeight * imgWidth * DIM_PIXEL_SIZE);
+ } else {
+ imgData =
+ ByteBuffer.allocateDirect(DIM_BATCH_SIZE * imgHeight * imgWidth * DIM_PIXEL_SIZE * 4);
+ }
+ imgData.order(ByteOrder.nativeOrder());
+ intValues = new int[imgHeight * imgWidth];
+ benchmarkStarted = false;
+ } catch (Exception e) {
+ Log.e(TAG, e.getMessage());
+ Log.e(TAG, "Failed to initialize COCO detector for the benchmarker.", e);
+ }
+ }
+
+ /**
+ * Perform detection on a single ByteBuffer {@link ByteBuffer} image. The image must have the
+ * same dimension that the model expects.
+ *
+ * @param image: a {@link ByteBuffer} image to process.
+ * @param imageId: an ID uniquely representing the image.
+ */
+ public boolean processBuffer(ByteBuffer image, int imageId) {
+ if (!readyToTest()) {
+ return false;
+ }
+ try {
+ if (!detector.detectByteBuffer(image, imageId)) {
+ return false;
+ }
+ } catch (RuntimeException e) {
+ Log.e(TAG, e.getMessage());
+ return false;
+ }
+
+ if (!benchmarkStarted) { // Skip the first image to discount warming-up time.
+ benchmarkStarted = true;
+ } else {
+ totalRuntime += ((double) detector.result.latency);
+ }
+ return true; // Indicating that result is ready.
+ }
+
+ /**
+ * Perform detection on a single bitmap image.
+ *
+ * @param bitmap: a {@link Bitmap} image to process.
+ * @param imageId: an ID uniquely representing the image.
+ */
+ @Override
+ public boolean processBitmap(Bitmap bitmap, int imageId)
+ throws IOException, InterruptedException {
+ if (shouldStop() || !readyToTest()) {
+ return false;
+ }
+ convertBitmapToInput(bitmap); // Scale bitmap if needed, store result in imgData.
+ if (!processBuffer(imgData, imageId)) {
+ return false;
+ }
+ // Scale results back to original image coordinates.
+ detector.result.scaleUp(scaleFactorWidth, scaleFactorHeight);
+ return true; // Indicating that result is ready.
+ }
+
+ public OvicDetectionResult getLastDetectionResult() {
+ return detector.result;
+ }
+
+ @Override
+ public String getLastResultString() {
+ if (detector.result == null) {
+ return null;
+ }
+ return detector.result.toString();
+ }
+
+ /**
+ * Preprocess bitmap image into {@link ByteBuffer} format for the detector.
+ *
+ * @param bitmap: a {@link Bitmap} source image.
+ */
+ private void convertBitmapToInput(Bitmap bitmap) {
+ int originalWidth = bitmap.getWidth();
+ int originalHeight = bitmap.getHeight();
+ scaledBitmap = Bitmap.createScaledBitmap(bitmap, imgWidth, imgHeight, true);
+ scaleFactorWidth = originalWidth * 1.0 / imgWidth;
+ scaleFactorHeight = originalHeight * 1.0 / imgHeight;
+ scaledBitmap.getPixels(intValues, 0, imgWidth, 0, 0, imgWidth, imgHeight);
+ scaledBitmap.recycle();
+ loadsInputToByteBuffer();
+ }
+}
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicValidator.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicValidator.java
index a504ec74a9..baa14baf92 100644
--- a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicValidator.java
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicValidator.java
@@ -51,7 +51,7 @@ public class OvicValidator {
MappedByteBuffer model = loadModelFile(modelFile);
OvicClassifier classifier = new OvicClassifier(labelsInputStream, model);
ByteBuffer imgData = createByteBufferForClassifier(classifier);
- OvicSingleImageResult testResult = classifier.classifyByteBuffer(imgData);
+ OvicClassificationResult testResult = classifier.classifyByteBuffer(imgData);
if (testResult.topKClasses.isEmpty()) {
throw new RuntimeException("Failed to return top K predictions.");
}
diff --git a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java
index 1587c3c56f..99e874ca78 100644
--- a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java
+++ b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java
@@ -1,4 +1,4 @@
-/*Copyright 2018 Google LLC
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -43,7 +43,7 @@ public final class OvicClassifierTest {
private MappedByteBuffer lowResModel = null;
private ByteBuffer testImage = null;
private ByteBuffer lowResTestImage = null;
- private OvicSingleImageResult testResult = null;
+ private OvicClassificationResult testResult = null;
private static final String LABELS_PATH =
"tensorflow/contrib/lite/java/ovic/src/testdata/labels.txt";
private static final String QUANTIZED_MODEL_PATH =
@@ -147,7 +147,7 @@ public final class OvicClassifierTest {
return imgData;
}
- private static void assertCorrectTopK(OvicSingleImageResult testResult) {
+ private static void assertCorrectTopK(OvicClassificationResult testResult) {
assertThat(testResult.topKClasses.size() > 0).isTrue();
Boolean topKAccurate = false;
// Assert that the correct class is in the top K.
diff --git a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicDetectorTest.java b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicDetectorTest.java
new file mode 100644
index 0000000000..4681e26052
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicDetectorTest.java
@@ -0,0 +1,149 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ https://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow.ovic;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import java.awt.Graphics2D;
+import java.awt.image.BufferedImage;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.MappedByteBuffer;
+import java.nio.channels.FileChannel;
+import javax.imageio.ImageIO;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Unit test for {@link org.tensorflow.ovic.OvicDetector}. */
+@RunWith(JUnit4.class)
+public final class OvicDetectorTest {
+ private OvicDetector detector = null;
+ private InputStream labelsInputStream = null;
+ private MappedByteBuffer model = null;
+ private ByteBuffer testImage = null;
+
+ private static final float IMAGE_MEAN = 128f;
+ private static final float IMAGE_STD = 128f;
+
+ private Boolean quantizedInput = null;
+ private static final String LABELS_PATH =
+ "tensorflow/contrib/lite/java/ovic/src/testdata/coco_labels.txt";
+ private static final String MODEL_PATH =
+ "external/tflite_mobilenet_ssd_quant/detect.tflite";
+ private static final String TEST_IMAGE_PATH =
+ "external/tflite_ovic_testdata/test_image_224.jpg";
+ private static final int GROUNDTRUTH = 1 /* Person */;
+
+ @Before
+ public void setUp() {
+ try {
+ // load models.
+ model = loadModelFile(MODEL_PATH);
+
+ // Load label files;
+ File labelsfile = new File(LABELS_PATH);
+ labelsInputStream = new FileInputStream(labelsfile);
+
+ // Create detector.
+ detector = new OvicDetector(labelsInputStream, model);
+ quantizedInput = detector.quantizedInput();
+
+ // Load test image and convert into byte buffer.
+ File imageFile = new File(TEST_IMAGE_PATH);
+ BufferedImage rawimg = ImageIO.read(imageFile);
+ int[] inputDims = detector.getInputDims();
+ BufferedImage img = new BufferedImage(inputDims[1], inputDims[2], rawimg.getType());
+ Graphics2D g = img.createGraphics();
+ g.drawImage(rawimg, 0, 0, inputDims[1], inputDims[2], null);
+ g.dispose();
+ testImage = toByteBuffer(img);
+ } catch (IOException e) {
+ System.out.println(e.getMessage());
+ }
+
+ System.out.println("Successfully setup");
+ }
+
+ private static MappedByteBuffer loadModelFile(String modelFilePath) throws IOException {
+ File modelfile = new File(modelFilePath);
+ FileInputStream inputStream = new FileInputStream(modelfile);
+ FileChannel fileChannel = inputStream.getChannel();
+ long startOffset = 0L;
+ long declaredLength = fileChannel.size();
+ return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
+ }
+
+ private ByteBuffer toByteBuffer(BufferedImage image) {
+ ByteBuffer imgData;
+ if (quantizedInput) {
+ imgData = ByteBuffer.allocateDirect(image.getHeight() * image.getWidth() * 3);
+ } else {
+ imgData = ByteBuffer.allocateDirect(image.getHeight() * image.getWidth() * 12);
+ }
+ imgData.order(ByteOrder.nativeOrder());
+ for (int y = 0; y < image.getHeight(); y++) {
+ for (int x = 0; x < image.getWidth(); x++) {
+ int pixelValue = image.getRGB(x, y);
+ if (quantizedInput) {
+ // Quantized model
+ imgData.put((byte) ((pixelValue >> 16) & 0xFF));
+ imgData.put((byte) ((pixelValue >> 8) & 0xFF));
+ imgData.put((byte) (pixelValue & 0xFF));
+ } else {
+ // Float model
+ imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ }
+ }
+ }
+ return imgData;
+ }
+
+ @Test
+ public void ovicDetector_detectSuccess() throws Exception {
+ assertThat(detector.detectByteBuffer(testImage, 1)).isTrue();
+ assertThat(detector.result != null).isTrue();
+ }
+
+ @Test
+ public void ovicDetector_simpleBatchTest() throws Exception {
+ final int numRepeats = 5;
+ for (int i = 0; i < numRepeats; i++) {
+ assertThat(detector.detectByteBuffer(testImage, 1)).isTrue();
+ OvicDetectionResult result = detector.result;
+ Boolean detectWithinTop5 = false;
+ for (int j = 0; j < Math.min(5, result.count); j++) {
+ if (result.detections.get(j).category == GROUNDTRUTH) {
+ detectWithinTop5 = true;
+ break;
+ }
+ }
+ if (!detectWithinTop5) {
+ System.out.println("---------------- Image " + i + " ---------------------");
+ System.out.println("Expect category " + GROUNDTRUTH);
+ System.out.println("Detection results: ");
+ System.out.println(result.toString());
+ }
+ assertThat(detectWithinTop5).isTrue();
+ }
+ }
+}
diff --git a/tensorflow/contrib/lite/java/ovic/src/testdata/BUILD b/tensorflow/contrib/lite/java/ovic/src/testdata/BUILD
index 1021ea30dd..051aa2204e 100644
--- a/tensorflow/contrib/lite/java/ovic/src/testdata/BUILD
+++ b/tensorflow/contrib/lite/java/ovic/src/testdata/BUILD
@@ -14,6 +14,9 @@ filegroup(
)
exports_files(
- ["labels.txt"],
+ [
+ "labels.txt",
+ "coco_labels.txt",
+ ],
visibility = ["//visibility:public"],
)
diff --git a/tensorflow/contrib/lite/java/ovic/src/testdata/coco_labels.txt b/tensorflow/contrib/lite/java/ovic/src/testdata/coco_labels.txt
new file mode 100644
index 0000000000..d91f535b1a
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/src/testdata/coco_labels.txt
@@ -0,0 +1,91 @@
+person
+bicycle
+car
+motorcycle
+airplane
+bus
+train
+truck
+boat
+traffic light
+fire hydrant
+empty
+stop sign
+parking meter
+bench
+bird
+cat
+dog
+horse
+sheep
+cow
+elephant
+bear
+zebra
+giraffe
+empty
+backpack
+umbrella
+empty
+empty
+handbag
+tie
+suitcase
+frisbee
+skis
+snowboard
+sports ball
+kite
+baseball bat
+baseball glove
+skateboard
+surfboard
+tennis racket
+bottle
+empty
+wine glasses
+cup
+fork
+knife
+spoon
+bowl
+banana
+apple
+sandwich
+orange
+broccoli
+carrot
+hot dog
+pizza
+donut
+cake
+chair
+couch
+potted plant
+bed
+empty
+dining table
+empty
+empty
+toilet
+empty
+tv
+laptop
+mouse
+remote
+keyboard
+cell phone
+microwave
+oven
+toaster
+sink
+refrigerator
+empty
+book
+clock
+vase
+scissors
+teddy bear
+hair drier
+toothbrush
+empty
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
index 9bc44bf797..6f03e7853a 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
@@ -18,7 +18,6 @@ package org.tensorflow.lite;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
-import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
@@ -83,6 +82,19 @@ final class NativeInterpreterWrapper implements AutoCloseable {
/** Releases resources associated with this {@code NativeInterpreterWrapper}. */
@Override
public void close() {
+ // Close the tensors first as they may reference the native interpreter.
+ for (int i = 0; i < inputTensors.length; ++i) {
+ if (inputTensors[i] != null) {
+ inputTensors[i].close();
+ inputTensors[i] = null;
+ }
+ }
+ for (int i = 0; i < outputTensors.length; ++i) {
+ if (outputTensors[i] != null) {
+ outputTensors[i].close();
+ outputTensors[i] = null;
+ }
+ }
delete(errorHandle, modelHandle, interpreterHandle);
errorHandle = 0;
modelHandle = 0;
@@ -91,8 +103,6 @@ final class NativeInterpreterWrapper implements AutoCloseable {
inputsIndexes = null;
outputsIndexes = null;
isMemoryAllocated = false;
- Arrays.fill(inputTensors, null);
- Arrays.fill(outputTensors, null);
}
/** Sets inputs, runs model inference and returns outputs. */
@@ -260,7 +270,8 @@ final class NativeInterpreterWrapper implements AutoCloseable {
Tensor inputTensor = inputTensors[index];
if (inputTensor == null) {
inputTensor =
- inputTensors[index] = Tensor.fromHandle(getInputTensor(interpreterHandle, index));
+ inputTensors[index] =
+ Tensor.fromIndex(interpreterHandle, getInputTensorIndex(interpreterHandle, index));
}
return inputTensor;
}
@@ -282,7 +293,8 @@ final class NativeInterpreterWrapper implements AutoCloseable {
Tensor outputTensor = outputTensors[index];
if (outputTensor == null) {
outputTensor =
- outputTensors[index] = Tensor.fromHandle(getOutputTensor(interpreterHandle, index));
+ outputTensors[index] =
+ Tensor.fromIndex(interpreterHandle, getOutputTensorIndex(interpreterHandle, index));
}
return outputTensor;
}
@@ -317,9 +329,9 @@ final class NativeInterpreterWrapper implements AutoCloseable {
private static native long allocateTensors(long interpreterHandle, long errorHandle);
- private static native long getInputTensor(long interpreterHandle, int inputIdx);
+ private static native int getInputTensorIndex(long interpreterHandle, int inputIdx);
- private static native long getOutputTensor(long interpreterHandle, int outputIdx);
+ private static native int getOutputTensorIndex(long interpreterHandle, int outputIdx);
private static native int getInputCount(long interpreterHandle);
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
index f174178d98..6ca47aa3ed 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
@@ -23,13 +23,26 @@ import java.util.Arrays;
/**
* A typed multi-dimensional array used in Tensorflow Lite.
*
- * <p>The native handle of a {@code Tensor} belongs to {@code NativeInterpreterWrapper}, thus not
- * needed to be closed here.
+ * <p>The native handle of a {@code Tensor} is managed by {@code NativeInterpreterWrapper}, and does
+ * not needed to be closed by the client. However, once the {@code NativeInterpreterWrapper} has
+ * been closed, the tensor handle will be invalidated.
*/
public final class Tensor {
- static Tensor fromHandle(long nativeHandle) {
- return new Tensor(nativeHandle);
+ /**
+ * Creates a Tensor wrapper from the provided interpreter instance and tensor index.
+ *
+ * <p>The caller is responsible for closing the created wrapper, and ensuring the provided
+ * native interpreter is valid until the tensor is closed.
+ */
+ static Tensor fromIndex(long nativeInterpreterHandle, int tensorIndex) {
+ return new Tensor(create(nativeInterpreterHandle, tensorIndex));
+ }
+
+ /** Disposes of any resources used by the Tensor wrapper. */
+ void close() {
+ delete(nativeHandle);
+ nativeHandle = 0;
}
/** Returns the {@link DataType} of elements stored in the Tensor. */
@@ -235,7 +248,7 @@ public final class Tensor {
return o instanceof ByteBuffer;
}
- private final long nativeHandle;
+ private long nativeHandle;
private final DataType dtype;
private int[] shapeCopy;
@@ -249,6 +262,10 @@ public final class Tensor {
return buffer(nativeHandle).order(ByteOrder.nativeOrder());
}
+ private static native long create(long interpreterHandle, int tensorIndex);
+
+ private static native void delete(long handle);
+
private static native ByteBuffer buffer(long handle);
private static native void writeDirectBuffer(long handle, ByteBuffer src);
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java
index 711638a9f9..d5447b3bf8 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java
@@ -18,7 +18,8 @@ package org.tensorflow.lite;
/** Static utility methods loading the TensorFlowLite runtime. */
public final class TensorFlowLite {
- private static final String LIBNAME = "tensorflowlite_jni";
+ private static final String PRIMARY_LIBNAME = "tensorflowlite_jni";
+ private static final String FALLBACK_LIBNAME = "tensorflowlite_flex_jni";
private TensorFlowLite() {}
@@ -29,13 +30,24 @@ public final class TensorFlowLite {
* Load the TensorFlowLite runtime C library.
*/
static boolean init() {
+ Throwable primaryLibException;
try {
- System.loadLibrary(LIBNAME);
+ System.loadLibrary(PRIMARY_LIBNAME);
return true;
} catch (UnsatisfiedLinkError e) {
- System.err.println("TensorFlowLite: failed to load native library: " + e.getMessage());
- return false;
+ primaryLibException = e;
}
+
+ try {
+ System.loadLibrary(FALLBACK_LIBNAME);
+ return true;
+ } catch (UnsatisfiedLinkError e) {
+ // If the fallback fails, log the error for the primary load instead.
+ System.err.println(
+ "TensorFlowLite: failed to load native library: " + primaryLibException.getMessage());
+ }
+
+ return false;
}
static {
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
index abb7320bc5..4dc73fbcf8 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
@@ -159,26 +159,20 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors(
}
}
-JNIEXPORT jlong JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensor(JNIEnv* env,
- jclass clazz,
- jlong handle,
- jint index) {
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndex(
+ JNIEnv* env, jclass clazz, jlong handle, jint input_index) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return 0;
- return reinterpret_cast<jlong>(
- interpreter->tensor(interpreter->inputs()[index]));
+ return interpreter->inputs()[input_index];
}
-JNIEXPORT jlong JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensor(JNIEnv* env,
- jclass clazz,
- jlong handle,
- jint index) {
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndex(
+ JNIEnv* env, jclass clazz, jlong handle, jint output_index) {
tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
if (interpreter == nullptr) return 0;
- return reinterpret_cast<jlong>(
- interpreter->tensor(interpreter->outputs()[index]));
+ return interpreter->outputs()[output_index];
}
JNIEXPORT jint JNICALL
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
index aa809dff8a..f8f3e7028c 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
@@ -46,25 +46,21 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_allocateTensors(
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
- * Method: getInputTensor
- * Signature: (JI)J
+ * Method: getInputTensorIndex
+ * Signature: (JI)I
*/
-JNIEXPORT jlong JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensor(JNIEnv* env,
- jclass clazz,
- jlong handle,
- jint index);
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getInputTensorIndex(
+ JNIEnv* env, jclass clazz, jlong handle, jint input_index);
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
- * Method: getOutputTensor
- * Signature: (JI)J
+ * Method: getOutputTensorIndex
+ * Signature: (JI)I
*/
-JNIEXPORT jlong JNICALL
-Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensor(JNIEnv* env,
- jclass clazz,
- jlong handle,
- jint index);
+JNIEXPORT jint JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputTensorIndex(
+ JNIEnv* env, jclass clazz, jlong handle, jint output_index);
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
index 7ff96a3172..d3378f5f14 100644
--- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
+++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.cc
@@ -16,17 +16,36 @@ limitations under the License.
#include "tensorflow/contrib/lite/java/src/main/native/tensor_jni.h"
#include <cstring>
#include <memory>
+#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/java/src/main/native/exception_jni.h"
namespace {
-TfLiteTensor* convertLongToTensor(JNIEnv* env, jlong handle) {
+// Convenience handle for obtaining a TfLiteTensor given an interpreter and
+// tensor index.
+//
+// Historically, the Java Tensor class used a TfLiteTensor pointer as its native
+// handle. However, this approach isn't generally safe, as the interpreter may
+// invalidate all TfLiteTensor* handles during inference or allocation.
+class TensorHandle {
+ public:
+ TensorHandle(tflite::Interpreter* interpreter, int tensor_index)
+ : interpreter_(interpreter), tensor_index_(tensor_index) {}
+
+ TfLiteTensor* tensor() const { return interpreter_->tensor(tensor_index_); }
+
+ private:
+ tflite::Interpreter* const interpreter_;
+ const int tensor_index_;
+};
+
+TfLiteTensor* GetTensorFromHandle(JNIEnv* env, jlong handle) {
if (handle == 0) {
throwException(env, kIllegalArgumentException,
"Internal error: Invalid handle to TfLiteTensor.");
return nullptr;
}
- return reinterpret_cast<TfLiteTensor*>(handle);
+ return reinterpret_cast<TensorHandle*>(handle)->tensor();
}
size_t elementByteSize(TfLiteType data_type) {
@@ -192,10 +211,23 @@ size_t writeMultiDimensionalArray(JNIEnv* env, jobject src, TfLiteType type,
} // namespace
+JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_Tensor_create(
+ JNIEnv* env, jclass clazz, jlong interpreter_handle, jint tensor_index) {
+ tflite::Interpreter* interpreter =
+ reinterpret_cast<tflite::Interpreter*>(interpreter_handle);
+ return reinterpret_cast<jlong>(new TensorHandle(interpreter, tensor_index));
+}
+
+JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_delete(JNIEnv* env,
+ jclass clazz,
+ jlong handle) {
+ delete reinterpret_cast<TensorHandle*>(handle);
+}
+
JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env,
jclass clazz,
jlong handle) {
- TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ TfLiteTensor* tensor = GetTensorFromHandle(env, handle);
if (tensor == nullptr) return nullptr;
if (tensor->data.raw == nullptr) {
throwException(env, kIllegalArgumentException,
@@ -208,7 +240,7 @@ JNIEXPORT jobject JNICALL Java_org_tensorflow_lite_Tensor_buffer(JNIEnv* env,
JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_writeDirectBuffer(
JNIEnv* env, jclass clazz, jlong handle, jobject src) {
- TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ TfLiteTensor* tensor = GetTensorFromHandle(env, handle);
if (tensor == nullptr) return;
char* src_data_raw = static_cast<char*>(env->GetDirectBufferAddress(src));
@@ -226,7 +258,7 @@ Java_org_tensorflow_lite_Tensor_readMultiDimensionalArray(JNIEnv* env,
jclass clazz,
jlong handle,
jobject value) {
- TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ TfLiteTensor* tensor = GetTensorFromHandle(env, handle);
if (tensor == nullptr) return;
int num_dims = tensor->dims->size;
if (num_dims == 0) {
@@ -243,7 +275,7 @@ Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env,
jclass clazz,
jlong handle,
jobject src) {
- TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ TfLiteTensor* tensor = GetTensorFromHandle(env, handle);
if (tensor == nullptr) return;
if (tensor->data.raw == nullptr) {
throwException(env, kIllegalArgumentException,
@@ -262,14 +294,14 @@ Java_org_tensorflow_lite_Tensor_writeMultiDimensionalArray(JNIEnv* env,
JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_dtype(JNIEnv* env,
jclass clazz,
jlong handle) {
- TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ TfLiteTensor* tensor = GetTensorFromHandle(env, handle);
if (tensor == nullptr) return 0;
return static_cast<jint>(tensor->type);
}
JNIEXPORT jintArray JNICALL
Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, jclass clazz, jlong handle) {
- TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ TfLiteTensor* tensor = GetTensorFromHandle(env, handle);
if (tensor == nullptr) return nullptr;
int num_dims = tensor->dims->size;
jintArray result = env->NewIntArray(num_dims);
@@ -280,7 +312,7 @@ Java_org_tensorflow_lite_Tensor_shape(JNIEnv* env, jclass clazz, jlong handle) {
JNIEXPORT jint JNICALL Java_org_tensorflow_lite_Tensor_numBytes(JNIEnv* env,
jclass clazz,
jlong handle) {
- const TfLiteTensor* tensor = convertLongToTensor(env, handle);
+ const TfLiteTensor* tensor = GetTensorFromHandle(env, handle);
if (tensor == nullptr) return 0;
return static_cast<jint>(tensor->bytes);
}
diff --git a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
index 2f73128bdf..c5e9690e9a 100644
--- a/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/tensor_jni.h
@@ -25,6 +25,23 @@ extern "C" {
/*
* Class: org_tensorflow_lite_Tensor
+ * Method: create
+ * Signature: (JI)J
+ */
+JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_Tensor_create(
+ JNIEnv* env, jclass clazz, jlong interpreter_handle, jint tensor_index);
+
+/*
+ * Class: org_tensorflow_lite_Tensor
+ * Method: delete
+ * Signature: (J)
+ */
+JNIEXPORT void JNICALL Java_org_tensorflow_lite_Tensor_delete(JNIEnv* env,
+ jclass clazz,
+ jlong handle);
+
+/*
+ * Class: org_tensorflow_lite_Tensor
* Method: buffer
* Signature: (J)Ljava/nio/ByteBuffer;
*/
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java
new file mode 100644
index 0000000000..2791c3864b
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java
@@ -0,0 +1,46 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import java.io.File;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Unit tests for {@link org.tensorflow.lite.Interpreter} that validate execution with models that
+ * have TensorFlow ops.
+ */
+@RunWith(JUnit4.class)
+public final class InterpreterFlexTest {
+
+ private static final File FLEX_MODEL_FILE =
+ new File("tensorflow/contrib/lite/testdata/multi_add_flex.bin");
+
+ /** Smoke test validating that flex model loading works when the flex delegate is linked. */
+ @Test
+ public void testFlexModel() throws Exception {
+ try (Interpreter interpreter = new Interpreter(FLEX_MODEL_FILE)) {
+ assertThat(interpreter.getInputTensorCount()).isEqualTo(4);
+ assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
+ assertThat(interpreter.getOutputTensorCount()).isEqualTo(4);
+ assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
+ interpreter.run(new float[1], new float[1]);
+ }
+ }
+}
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
index a98fca0132..f8b73c7cf3 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
@@ -43,6 +43,9 @@ public final class InterpreterTest {
private static final File MOBILENET_MODEL_FILE =
new File("tensorflow/contrib/lite/java/src/testdata/mobilenet.tflite.bin");
+ private static final File FLEX_MODEL_FILE =
+ new File("tensorflow/contrib/lite/testdata/multi_add_flex.bin");
+
@Test
public void testInterpreter() throws Exception {
Interpreter interpreter = new Interpreter(MODEL_FILE);
@@ -345,4 +348,15 @@ public final class InterpreterTest {
interpreter.close();
interpreter.close();
}
+
+ /** Smoke test validating that flex model loading fails when the flex delegate is not linked. */
+ @Test
+ public void testFlexModel() throws Exception {
+ try {
+ new Interpreter(FLEX_MODEL_FILE);
+ fail();
+ } catch (IllegalStateException e) {
+ // Expected failure.
+ }
+ }
}
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
index 85ad393d89..56a38ea3e2 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
@@ -182,7 +182,7 @@ public final class TensorTest {
dataType = Tensor.dataTypeOf(testFloatArray);
assertThat(dataType).isEqualTo(DataType.FLOAT32);
float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray};
- dataType = Tensor.dataTypeOf(testFloatArray);
+ dataType = Tensor.dataTypeOf(testMultiDimArray);
assertThat(dataType).isEqualTo(DataType.FLOAT32);
try {
double[] testDoubleArray = {0.783, 0.251};
@@ -238,4 +238,15 @@ public final class TensorTest {
assertThat(shape[1]).isEqualTo(3);
assertThat(shape[2]).isEqualTo(1);
}
+
+ @Test
+ public void testUseAfterClose() {
+ tensor.close();
+ try {
+ tensor.numBytes();
+ fail();
+ } catch (IllegalArgumentException e) {
+ // Expected failure.
+ }
+ }
}
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index daaf6714cc..d2d8073abd 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -210,6 +210,7 @@ cc_library(
"slice.cc",
"space_to_batch_nd.cc",
"space_to_depth.cc",
+ "sparse_output_fully_connected.cc",
"sparse_to_dense.cc",
"split.cc",
"squeeze.cc",
@@ -233,11 +234,11 @@ cc_library(
":activation_functor",
":eigen_support",
":kernel_util",
+ ":lstm_eval",
":op_macros",
":padding",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string_util",
- "//tensorflow/contrib/lite:util",
"//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:gemm_support",
"//tensorflow/contrib/lite/kernels/internal:audio_utils",
@@ -254,6 +255,18 @@ cc_library(
)
cc_library(
+ name = "lstm_eval",
+ srcs = ["lstm_eval.cc"],
+ hdrs = ["lstm_eval.h"],
+ deps = [
+ ":op_macros",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/kernels/internal:kernel_utils",
+ "//tensorflow/contrib/lite/kernels/internal:tensor_utils",
+ ],
+)
+
+cc_library(
name = "builtin_ops",
srcs = ["register.cc"],
hdrs = ["register.h"],
@@ -334,6 +347,23 @@ tf_cc_test(
)
tf_cc_test(
+ name = "sparse_output_fully_connected_test",
+ size = "small",
+ srcs = ["sparse_output_fully_connected_test.cc"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ "@flatbuffers",
+ ],
+)
+
+tf_cc_test(
name = "activations_test",
size = "small",
srcs = ["activations_test.cc"],
diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc
index cf9441aee3..9aed4f09b8 100644
--- a/tensorflow/contrib/lite/kernels/activations.cc
+++ b/tensorflow/contrib/lite/kernels/activations.cc
@@ -616,13 +616,15 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
switch (input->type) {
- case kTfLiteFloat32:
+ case kTfLiteFloat32: {
SoftmaxParams op_params;
optimized_ops::LogSoftmax(
op_params, GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(output), GetTensorData<float>(output));
return kTfLiteOk;
- case kTfLiteUInt8:
+ }
+ case kTfLiteUInt8: {
+ SoftmaxParams op_params;
op_params.input_multiplier = data->input_multiplier;
op_params.input_left_shift = data->input_left_shift;
op_params.reverse_scaling_divisor = data->reverse_scaling_divisor;
@@ -632,6 +634,7 @@ TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
GetTensorShape(output), GetTensorData<uint8_t>(output));
return kTfLiteOk;
+ }
default:
context->ReportError(context, "Only float32 supported currently., got %d",
input->type);
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
index 66b947771c..a326827b1e 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/lstm_eval.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
@@ -119,7 +120,7 @@ constexpr int kBwAuxInputToOutputWeightsTensor = 47; // Optional
// Output tensors.
constexpr int kFwOutputTensor = 0;
-constexpr int kBwOutputTensor = 1;
+constexpr int kBwOutputTensor = 1; // Ignored if merge_outputs is set.
// Temporary tensors.
enum TemporaryTensor {
@@ -162,7 +163,8 @@ TfLiteStatus CheckLstmTensorDimensions(
int input_gate_bias_tensor, int forget_gate_bias_tensor,
int cell_gate_bias_tensor, int output_gate_bias_tensor,
int projection_weights_tensor, int projection_bias_tensor) {
- const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
+ node->builtin_data);
// Making sure clipping parameters have valid values.
// == 0 means no clipping
@@ -347,10 +349,13 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
// tensors. Also check that the size of the input tensors match each other.
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+ const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
+ node->builtin_data);
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 48);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size,
+ params->merge_outputs ? 1 : 2);
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
@@ -368,6 +373,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->data[1],
n_input);
+ const TfLiteTensor* bw_input_to_output_weights =
+ GetInput(context, node, kBwInputToOutputWeightsTensor);
+ const int n_bw_cell = bw_input_to_output_weights->dims->data[0];
+ TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->data[1],
+ n_input);
+
const TfLiteTensor* fw_recurrent_to_output_weights =
GetInput(context, node, kFwRecurrentToOutputWeightsTensor);
TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->size, 2);
@@ -375,6 +387,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
n_fw_cell);
const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1];
+ const TfLiteTensor* bw_recurrent_to_output_weights =
+ GetInput(context, node, kBwRecurrentToOutputWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->data[0],
+ n_bw_cell);
+ const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1];
+
// Check that input tensor dimensions matches with each other.
TF_LITE_ENSURE_OK(
context, CheckInputTensorDimensions(context, node, n_input, n_fw_output,
@@ -440,7 +459,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteIntArray* fw_output_size = TfLiteIntArrayCreate(3);
fw_output_size->data[0] = max_time;
fw_output_size->data[1] = n_batch;
- fw_output_size->data[2] = n_fw_output;
+ fw_output_size->data[2] =
+ params->merge_outputs ? n_bw_output + n_fw_output : n_fw_output;
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, fw_output, fw_output_size));
@@ -479,39 +499,28 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_scratch_buffer,
fw_scratch_buffer_size));
// Same for the backward cell.
- const TfLiteTensor* bw_input_to_output_weights =
- GetInput(context, node, kBwInputToOutputWeightsTensor);
- const int n_bw_cell = bw_input_to_output_weights->dims->data[0];
- TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->size, 2);
- TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->data[1],
- n_input);
-
- const TfLiteTensor* bw_recurrent_to_output_weights =
- GetInput(context, node, kBwRecurrentToOutputWeightsTensor);
- TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->size, 2);
- TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->data[0],
- n_bw_cell);
- const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1];
// Check that input tensor dimensions matches with each other.
TF_LITE_ENSURE_OK(
context, CheckInputTensorDimensions(context, node, n_input, n_bw_output,
n_bw_cell));
- // Get the pointer to output, activation_state and cell_state buffer tensors.
- TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
+ // Get the pointer to activation_state and cell_state buffer tensors.
TfLiteTensor* bw_activation_state =
GetVariableInput(context, node, kBwInputActivationStateTensor);
TfLiteTensor* bw_cell_state =
GetVariableInput(context, node, kBwInputCellStateTensor);
// Resize the output tensors.
- TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3);
- bw_output_size->data[0] = max_time;
- bw_output_size->data[1] = n_batch;
- bw_output_size->data[2] = n_bw_output;
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, bw_output, bw_output_size));
+ if (!params->merge_outputs) {
+ TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
+ TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3);
+ bw_output_size->data[0] = max_time;
+ bw_output_size->data[1] = n_batch;
+ bw_output_size->data[2] = n_bw_output;
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, bw_output, bw_output_size));
+ }
// Check the shape of input state tensors.
// These tensor may be 1D or 2D. It's fine as long as the total size is
@@ -686,332 +695,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
-TfLiteStatus EvalFloat(
- const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
- const TfLiteTensor* input_to_forget_weights,
- const TfLiteTensor* input_to_cell_weights,
- const TfLiteTensor* input_to_output_weights,
- const TfLiteTensor* recurrent_to_input_weights,
- const TfLiteTensor* recurrent_to_forget_weights,
- const TfLiteTensor* recurrent_to_cell_weights,
- const TfLiteTensor* recurrent_to_output_weights,
- const TfLiteTensor* cell_to_input_weights,
- const TfLiteTensor* cell_to_forget_weights,
- const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
- const TfLiteTensor* aux_input_to_input_weights,
- const TfLiteTensor* aux_input_to_forget_weights,
- const TfLiteTensor* aux_input_to_cell_weights,
- const TfLiteTensor* aux_input_to_output_weights,
- const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
- const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
- const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
- const TfLiteLSTMParams* params, bool forward_sequence,
- TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
- TfLiteTensor* cell_state, TfLiteTensor* output) {
- const int max_time = input->dims->data[0];
- const int n_batch = input->dims->data[1];
- const int n_input = input->dims->data[2];
- const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
-
- // n_cell and n_output will be the same size when there is no projection.
- const int n_cell = input_to_output_weights->dims->data[0];
- const int n_output = recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existense of only one to the get the condition.
- const bool use_cifg = (input_to_input_weights == nullptr);
- const bool use_peephole = (cell_to_output_weights != nullptr);
-
- // Index the scratch buffers pointers to the global scratch buffer.
- float* input_gate_scratch = nullptr;
- float* cell_scratch = nullptr;
- float* forget_gate_scratch = nullptr;
- float* output_gate_scratch = nullptr;
- if (use_cifg) {
- cell_scratch = scratch_buffer->data.f;
- forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- } else {
- input_gate_scratch = scratch_buffer->data.f;
- cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
- forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- const float* input_to_input_weights_ptr =
- (use_cifg) ? nullptr : input_to_input_weights->data.f;
- const float* recurrent_to_input_weights_ptr =
- (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
- const float* input_gate_bias_ptr =
- (use_cifg) ? nullptr : input_gate_bias->data.f;
- const float* cell_to_input_weights_ptr =
- (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
- const float* cell_to_forget_weights_ptr =
- (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
- const float* cell_to_output_weights_ptr =
- (use_peephole) ? cell_to_output_weights->data.f : nullptr;
- const float* projection_weights_ptr =
- (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
- const float* projection_bias_ptr =
- (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
-
- float* aux_input_ptr = nullptr;
- float* aux_input_to_input_weights_ptr = nullptr;
- float* aux_input_to_forget_weights_ptr = nullptr;
- float* aux_input_to_cell_weights_ptr = nullptr;
- float* aux_input_to_output_weights_ptr = nullptr;
- if (aux_input_size > 0) {
- aux_input_ptr = aux_input->data.f;
- aux_input_to_input_weights_ptr = aux_input_to_input_weights->data.f;
- aux_input_to_forget_weights_ptr = aux_input_to_forget_weights->data.f;
- aux_input_to_cell_weights_ptr = aux_input_to_cell_weights->data.f;
- aux_input_to_output_weights_ptr = aux_input_to_output_weights->data.f;
- }
-
- // Loop through the sequence.
- const int input_step = n_batch * n_input;
- const int output_step = n_batch * n_output;
- for (int t = 0; t < max_time; t++) {
- // If this is the forward_sequence, step forward, otherwise step backwards.
- const int t_rel = forward_sequence ? t : max_time - t - 1;
- const float* input_ptr = input->data.f + t_rel * input_step;
- float* output_ptr_time = output->data.f + t_rel * output_step;
-
- kernel_utils::LstmStepWithAuxInput(
- input_ptr, input_to_input_weights_ptr, input_to_forget_weights->data.f,
- input_to_cell_weights->data.f, input_to_output_weights->data.f,
- aux_input_ptr, aux_input_to_input_weights_ptr,
- aux_input_to_forget_weights_ptr, aux_input_to_cell_weights_ptr,
- aux_input_to_output_weights_ptr, recurrent_to_input_weights_ptr,
- recurrent_to_forget_weights->data.f, recurrent_to_cell_weights->data.f,
- recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
- cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
- input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
- output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
- params, n_batch, n_cell, n_input, aux_input_size, n_output,
- activation_state->data.f, cell_state->data.f, input_gate_scratch,
- forget_gate_scratch, cell_scratch, output_gate_scratch,
- output_ptr_time);
- }
- return kTfLiteOk;
-}
-
-TfLiteStatus EvalHybrid(
- const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
- const TfLiteTensor* input_to_forget_weights,
- const TfLiteTensor* input_to_cell_weights,
- const TfLiteTensor* input_to_output_weights,
- const TfLiteTensor* recurrent_to_input_weights,
- const TfLiteTensor* recurrent_to_forget_weights,
- const TfLiteTensor* recurrent_to_cell_weights,
- const TfLiteTensor* recurrent_to_output_weights,
- const TfLiteTensor* cell_to_input_weights,
- const TfLiteTensor* cell_to_forget_weights,
- const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
- const TfLiteTensor* aux_input_to_input_weights,
- const TfLiteTensor* aux_input_to_forget_weights,
- const TfLiteTensor* aux_input_to_cell_weights,
- const TfLiteTensor* aux_input_to_output_weights,
- const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
- const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
- const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
- const TfLiteLSTMParams* params, bool forward_sequence,
- TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
- TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
- TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
- TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
- TfLiteTensor* output_state, TfLiteTensor* cell_state,
- TfLiteTensor* output) {
- const int max_time = input->dims->data[0];
- const int n_batch = input->dims->data[1];
- const int n_input = input->dims->data[2];
- const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
- // n_cell and n_output will be the same size when there is no projection.
- const int n_cell = input_to_output_weights->dims->data[0];
- const int n_output = recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existence of only one to get the condition.
- const bool use_cifg = (input_to_input_weights == nullptr);
- const bool use_peephole = (cell_to_output_weights != nullptr);
-
- float* input_gate_scratch = nullptr;
- float* cell_scratch = nullptr;
- float* forget_gate_scratch = nullptr;
- float* output_gate_scratch = nullptr;
- if (use_cifg) {
- cell_scratch = scratch_buffer->data.f;
- forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- } else {
- input_gate_scratch = scratch_buffer->data.f;
- cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
- forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- int8_t* input_to_input_weights_ptr = nullptr;
- float input_to_input_weights_scale = 1.0f;
- int8_t* recurrent_to_input_weights_ptr = nullptr;
- float recurrent_to_input_weights_scale = 1.0f;
- float* input_gate_bias_ptr = nullptr;
- if (!use_cifg) {
- input_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
- recurrent_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
- input_gate_bias_ptr = input_gate_bias->data.f;
- input_to_input_weights_scale = input_to_input_weights->params.scale;
- recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
- }
-
- int8_t* cell_to_input_weights_ptr = nullptr;
- int8_t* cell_to_forget_weights_ptr = nullptr;
- int8_t* cell_to_output_weights_ptr = nullptr;
- float cell_to_input_weights_scale = 1.0f;
- float cell_to_forget_weights_scale = 1.0f;
- float cell_to_output_weights_scale = 1.0f;
- if (use_peephole) {
- if (!use_cifg) {
- cell_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
- cell_to_input_weights_scale = cell_to_input_weights->params.scale;
- }
- cell_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
- cell_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
- cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
- cell_to_output_weights_scale = cell_to_output_weights->params.scale;
- }
-
- const int8_t* projection_weights_ptr =
- (projection_weights == nullptr)
- ? nullptr
- : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
- const float projection_weights_scale =
- (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
- const float* projection_bias_ptr =
- (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
-
- // Required tensors, pointers are non-null.
- const int8_t* input_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
- const float input_to_forget_weights_scale =
- input_to_forget_weights->params.scale;
- const int8_t* input_to_cell_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
- const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
- const int8_t* input_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
- const float input_to_output_weights_scale =
- input_to_output_weights->params.scale;
- const int8_t* recurrent_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
- const float recurrent_to_forget_weights_scale =
- recurrent_to_forget_weights->params.scale;
- const int8_t* recurrent_to_cell_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
- const float recurrent_to_cell_weights_scale =
- recurrent_to_cell_weights->params.scale;
- const int8_t* recurrent_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
- const float recurrent_to_output_weights_scale =
- recurrent_to_output_weights->params.scale;
- const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
- const float* cell_bias_ptr = cell_bias->data.f;
- const float* output_gate_bias_ptr = output_gate_bias->data.f;
-
- float* output_state_ptr = output_state->data.f;
- float* cell_state_ptr = cell_state->data.f;
-
- // Temporary storage for quantized values and scaling factors.
- int8_t* quantized_input_ptr =
- reinterpret_cast<int8_t*>(input_quantized->data.uint8);
- int8_t* quantized_aux_input_ptr =
- (aux_input_quantized == nullptr)
- ? nullptr
- : reinterpret_cast<int8_t*>(aux_input_quantized->data.uint8);
- int8_t* quantized_output_state_ptr =
- reinterpret_cast<int8_t*>(output_state_quantized->data.uint8);
- int8_t* quantized_cell_state_ptr =
- reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
- float* scaling_factors_ptr = scaling_factors->data.f;
- float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
- float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
-
- // Auxiliary input and weights.
- float* aux_input_ptr = nullptr;
- int8_t* aux_input_to_input_weights_ptr = nullptr;
- int8_t* aux_input_to_forget_weights_ptr = nullptr;
- int8_t* aux_input_to_cell_weights_ptr = nullptr;
- int8_t* aux_input_to_output_weights_ptr = nullptr;
- float aux_input_to_input_weights_scale = 0.0f;
- float aux_input_to_forget_weights_scale = 0.0f;
- float aux_input_to_cell_weights_scale = 0.0f;
- float aux_input_to_output_weights_scale = 0.0f;
- if (aux_input_size > 0) {
- aux_input_ptr = aux_input->data.f;
- aux_input_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(aux_input_to_input_weights->data.uint8);
- aux_input_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(aux_input_to_forget_weights->data.uint8);
- aux_input_to_cell_weights_ptr =
- reinterpret_cast<int8_t*>(aux_input_to_cell_weights->data.uint8);
- aux_input_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(aux_input_to_output_weights->data.uint8);
- aux_input_to_input_weights_scale = aux_input_to_input_weights->params.scale;
- aux_input_to_forget_weights_scale =
- aux_input_to_forget_weights->params.scale;
- aux_input_to_cell_weights_scale = aux_input_to_cell_weights->params.scale;
- aux_input_to_output_weights_scale =
- aux_input_to_output_weights->params.scale;
- }
-
- // Feed the sequence into the LSTM step-by-step.
- const int input_step = n_batch * n_input;
- const int output_step = n_batch * n_output;
- for (int t = 0; t < max_time; t++) {
- // If this is the forward_sequence, step forward, otherwise step backwards.
- const int t_rel = forward_sequence ? t : max_time - t - 1;
- const float* input_ptr = input->data.f + t_rel * input_step;
- float* output_ptr = output->data.f + t_rel * output_step;
-
- kernel_utils::LstmStepWithAuxInput(
- input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
- input_to_forget_weights_ptr, input_to_forget_weights_scale,
- input_to_cell_weights_ptr, input_to_cell_weights_scale,
- input_to_output_weights_ptr, input_to_output_weights_scale,
- aux_input_ptr, aux_input_to_input_weights_ptr,
- aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
- aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
- aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
- aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
- recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
- recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
- recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
- recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
- cell_to_input_weights_scale, cell_to_forget_weights_ptr,
- cell_to_forget_weights_scale, cell_to_output_weights_ptr,
- cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr,
- cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
- projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell,
- n_input, aux_input_size, n_output, input_gate_scratch,
- forget_gate_scratch, cell_scratch, output_gate_scratch,
- scaling_factors_ptr, prod_scaling_factors_ptr,
- recovered_cell_weights_ptr, quantized_input_ptr,
- quantized_aux_input_ptr, quantized_output_state_ptr,
- quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, output_ptr);
- }
-
- return kTfLiteOk;
-}
-
// The LSTM Op engine.
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
+ node->builtin_data);
// Input tensor.
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
@@ -1107,7 +794,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetVariableInput(context, node, kBwInputActivationStateTensor);
TfLiteTensor* bw_cell_state =
GetVariableInput(context, node, kBwInputCellStateTensor);
- TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
+ TfLiteTensor* bw_output = params->merge_outputs
+ ? nullptr
+ : GetOutput(context, node, kBwOutputTensor);
// Temporary tensors.
TfLiteTensor* fw_scratch_buffer =
@@ -1135,9 +824,17 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* bw_aux_input_to_output_weights =
GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor);
+ // Populate a TfLiteLSTMParams struct for the evaluation functions.
+ TfLiteLSTMParams lstm_params = {params->activation, params->cell_clip,
+ params->proj_clip, kTfLiteLSTMFullKernel};
+
+ const int bw_output_offset =
+ params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0;
+ const auto actual_bw_output = params->merge_outputs ? fw_output : bw_output;
+
switch (fw_input_to_output_weights->type) {
case kTfLiteFloat32: {
- TfLiteStatus fw_pass_status = EvalFloat(
+ TfLiteStatus fw_pass_status = lstm_eval::EvalFloat(
input, fw_input_to_input_weights, fw_input_to_forget_weights,
fw_input_to_cell_weights, fw_input_to_output_weights,
fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
@@ -1147,12 +844,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
fw_aux_input_to_output_weights, fw_input_gate_bias,
fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias,
- fw_projection_weights, fw_projection_bias, params,
- /*forward_sequence=*/true, fw_scratch_buffer, fw_activation_state,
- fw_cell_state, fw_output);
+ fw_projection_weights, fw_projection_bias, &lstm_params,
+ /*forward_sequence=*/true, /*output_offset=*/0, fw_scratch_buffer,
+ fw_activation_state, fw_cell_state, fw_output);
TF_LITE_ENSURE_OK(context, fw_pass_status);
- TfLiteStatus bw_pass_status = EvalFloat(
+ TfLiteStatus bw_pass_status = lstm_eval::EvalFloat(
input, bw_input_to_input_weights, bw_input_to_forget_weights,
bw_input_to_cell_weights, bw_input_to_output_weights,
bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
@@ -1162,9 +859,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
bw_aux_input_to_forget_weights, bw_aux_input_to_cell_weights,
bw_aux_input_to_output_weights, bw_input_gate_bias,
bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
- bw_projection_weights, bw_projection_bias, params,
- /*forward_sequence=*/false, bw_scratch_buffer, bw_activation_state,
- bw_cell_state, bw_output);
+ bw_projection_weights, bw_projection_bias, &lstm_params,
+ /*forward_sequence=*/false, bw_output_offset, bw_scratch_buffer,
+ bw_activation_state, bw_cell_state, actual_bw_output);
TF_LITE_ENSURE_OK(context, bw_pass_status);
return kTfLiteOk;
}
@@ -1188,7 +885,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* recovered_cell_weights =
GetTemporary(context, node, kRecoveredCellWeights);
- TfLiteStatus fw_pass_status = EvalHybrid(
+ TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid(
input, fw_input_to_input_weights, fw_input_to_forget_weights,
fw_input_to_cell_weights, fw_input_to_output_weights,
fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
@@ -1198,15 +895,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
fw_aux_input_to_output_weights, fw_input_gate_bias,
fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias,
- fw_projection_weights, fw_projection_bias, params,
- /*forward_sequence=*/true, fw_scratch_buffer, scaling_factors,
- prod_scaling_factors, recovered_cell_weights, input_quantized,
- aux_input_quantized, fw_activation_state_quantized,
+ fw_projection_weights, fw_projection_bias, &lstm_params,
+ /*forward_sequence=*/true, /*output_offset=*/0, fw_scratch_buffer,
+ scaling_factors, prod_scaling_factors, recovered_cell_weights,
+ input_quantized, aux_input_quantized, fw_activation_state_quantized,
fw_cell_state_quantized, fw_activation_state, fw_cell_state,
fw_output);
TF_LITE_ENSURE_OK(context, fw_pass_status);
- TfLiteStatus bw_pass_status = EvalHybrid(
+ TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid(
input, bw_input_to_input_weights, bw_input_to_forget_weights,
bw_input_to_cell_weights, bw_input_to_output_weights,
bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
@@ -1216,12 +913,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
fw_aux_input_to_output_weights, bw_input_gate_bias,
bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
- bw_projection_weights, bw_projection_bias, params,
- /*forward_sequence=*/false, bw_scratch_buffer, scaling_factors,
- prod_scaling_factors, recovered_cell_weights, input_quantized,
- aux_input_quantized, bw_activation_state_quantized,
+ bw_projection_weights, bw_projection_bias, &lstm_params,
+ /*forward_sequence=*/false, bw_output_offset, bw_scratch_buffer,
+ scaling_factors, prod_scaling_factors, recovered_cell_weights,
+ input_quantized, aux_input_quantized, bw_activation_state_quantized,
bw_cell_state_quantized, bw_activation_state, bw_cell_state,
- bw_output);
+ actual_bw_output);
TF_LITE_ENSURE_OK(context, bw_pass_status);
return kTfLiteOk;
}
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
index 74ba8021c2..9cc04907e1 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
@@ -35,8 +35,8 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
BidirectionalLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
int sequence_length, bool use_cifg,
bool use_peephole, bool use_projection_weights,
- bool use_projection_bias, float cell_clip,
- float proj_clip,
+ bool use_projection_bias, bool merge_outputs,
+ float cell_clip, float proj_clip,
const std::vector<std::vector<int>>& input_shapes)
: n_batch_(n_batch),
n_input_(n_input),
@@ -175,7 +175,9 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
fw_output_ = AddOutput(TensorType_FLOAT32);
- bw_output_ = AddOutput(TensorType_FLOAT32);
+ if (!merge_outputs) {
+ bw_output_ = AddOutput(TensorType_FLOAT32);
+ }
aux_input_ = AddNullInput();
fw_aux_input_to_input_weights_ = AddNullInput();
@@ -188,9 +190,10 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
bw_aux_input_to_output_weights_ = AddNullInput();
SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
- BuiltinOptions_LSTMOptions,
- CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
- cell_clip, proj_clip)
+ BuiltinOptions_BidirectionalSequenceLSTMOptions,
+ CreateBidirectionalSequenceLSTMOptions(
+ builder_, ActivationFunctionType_TANH, cell_clip,
+ proj_clip, merge_outputs)
.Union());
BuildInterpreter(input_shapes);
}
@@ -380,7 +383,8 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
BidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
/*use_peephole=*/false, /*use_projection_weights=*/false,
- /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0,
{
{sequence_length, n_batch, n_input}, // input tensor
@@ -526,6 +530,162 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
ElementsAreArray(ArrayFloatNear(bw_expected)));
}
+// Same as the previous test, yet with a single merged output tensor.
+TEST(LSTMOpTest, BlackBoxTestMergedOutput) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+ const int sequence_length = 3;
+
+ BidirectionalLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
+ /*use_peephole=*/false, /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false, /*merge_outputs=*/true, /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0,
+ {
+ {sequence_length, n_batch, n_input}, // input tensor
+
+ // Forward cell
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {0}, // cell_to_forget_weight tensor
+ {0}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+
+ // Backward cell
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {0}, // cell_to_forget_weight tensor
+ {0}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, sequence_length, 0}, // aux_input tensor
+ {n_cell, 0}, // aux_fw_input_to_input tensor
+ {n_cell, 0}, // aux_fw_input_to_forget tensor
+ {n_cell, 0}, // aux_fw_input_to_cell tensor
+ {n_cell, 0}, // aux_fw_input_to_output tensor
+ {n_cell, 0}, // aux_bw_input_to_input tensor
+ {n_cell, 0}, // aux_bw_input_to_forget tensor
+ {n_cell, 0}, // aux_bw_input_to_cell tensor
+ {n_cell, 0}, // aux_bw_input_to_output tensor
+ });
+
+ lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
+ -0.34550029, 0.04266912, -0.15680569,
+ -0.34856534, 0.43890524});
+
+ lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
+ -0.20583314, 0.44344562, 0.22077113,
+ -0.29909778});
+
+ lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
+ -0.31343272, -0.40032279, 0.44781327,
+ 0.01387155, -0.35593212});
+
+ lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
+ 0.40525138, 0.44272184, 0.03897077, -0.1556896,
+ 0.19487578});
+
+ lstm.SetInputGateBias({0., 0., 0., 0.});
+
+ lstm.SetCellBias({0., 0., 0., 0.});
+
+ lstm.SetForgetGateBias({1., 1., 1., 1.});
+
+ lstm.SetOutputGateBias({0., 0., 0., 0.});
+
+ lstm.SetRecurrentToInputWeights(
+ {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
+ -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
+ -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
+
+ lstm.SetRecurrentToCellWeights(
+ {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
+ -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
+ -0.46367589, 0.26016325, -0.03894562, -0.16368064});
+
+ lstm.SetRecurrentToForgetWeights(
+ {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
+ -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
+ 0.28053468, 0.01560611, -0.20127171, -0.01140004});
+
+ lstm.SetRecurrentToOutputWeights(
+ {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
+ 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
+ -0.51818722, -0.15390486, 0.0468148, 0.39922136});
+
+ // Input should have n_input * sequence_length many values.
+ static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
+ static float lstm_fw_golden_output[] = {
+ -0.02973187, 0.1229473, 0.20885126, -0.15358765,
+ -0.03716109, 0.12507336, 0.41193449, -0.20860538,
+ -0.15053082, 0.09120187, 0.24278517, -0.12222792};
+ static float lstm_bw_golden_output[] = {
+ -0.0806187, 0.139077, 0.400476, -0.197842, -0.0332076, 0.123838,
+ 0.309777, -0.17621, -0.0490733, 0.0739237, 0.067706, -0.0208124};
+
+ float* batch0_start = lstm_input;
+ float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
+
+ lstm.SetInput(0, batch0_start, batch0_end);
+
+ lstm.Invoke();
+
+ std::vector<float> merged_expected;
+ for (int k = 0; k < lstm.sequence_length(); k++) {
+ merged_expected.insert(
+ merged_expected.end(),
+ lstm_fw_golden_output + k * lstm.num_fw_outputs(),
+ lstm_fw_golden_output + (k + 1) * lstm.num_fw_outputs());
+ merged_expected.insert(
+ merged_expected.end(),
+ lstm_bw_golden_output + k * lstm.num_bw_outputs(),
+ lstm_bw_golden_output + (k + 1) * lstm.num_bw_outputs());
+ }
+ EXPECT_THAT(lstm.GetFwOutput(),
+ ElementsAreArray(ArrayFloatNear(merged_expected)));
+}
+
TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClippingReverse) {
const int n_batch = 1;
const int n_input = 2;
@@ -537,7 +697,8 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClippingReverse) {
BidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
/*use_peephole=*/false, /*use_projection_weights=*/false,
- /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0,
{
{sequence_length, n_batch, n_input}, // input tensor
@@ -696,7 +857,8 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
BidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true,
/*use_peephole=*/true, /*use_projection_weights=*/false,
- /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0,
{
{sequence_length, n_batch, n_input}, // input tensor
@@ -845,7 +1007,8 @@ TEST(LSTMOpTest,
BidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true,
/*use_peephole=*/true, /*use_projection_weights=*/false,
- /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0,
{
{sequence_length, n_batch, n_input}, // input tensor
@@ -994,7 +1157,8 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
BidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
/*use_peephole=*/true, /*use_projection_weights=*/true,
- /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0,
{
{sequence_length, n_batch, n_input}, // input tensor
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
index 2f896c5289..c22a457a71 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
@@ -47,7 +47,7 @@ constexpr int kFwAuxWeightsTensor = 10; // Optional.
constexpr int kBwAuxWeightsTensor = 11; // Optional.
// Output tensors.
constexpr int kFwOutputTensor = 0;
-constexpr int kBwOutputTensor = 1;
+constexpr int kBwOutputTensor = 1; // Only if merge_outputs is false.
// Temporary tensors.
enum TemporaryTensor {
@@ -70,9 +70,13 @@ void Free(TfLiteContext* context, void* buffer) {
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>(
+ node->builtin_data);
+
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 12);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size,
+ params->merge_outputs ? 1 : 2);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* fw_input_weights =
@@ -109,6 +113,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// input configuration.
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
const int batch_size = input->dims->data[0];
const int max_time = input->dims->data[1];
const int fw_num_units = fw_input_weights->dims->data[0];
@@ -142,9 +147,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
bw_aux_input_weights->dims->data[1]);
}
- TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
- TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
-
const bool is_hybrid_op =
(fw_input_weights->type == kTfLiteUInt8 && input->type == kTfLiteFloat32);
@@ -233,18 +235,23 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
// Resize outputs.
+ TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3);
fw_output_size_array->data[0] = batch_size;
fw_output_size_array->data[1] = max_time;
- fw_output_size_array->data[2] = fw_num_units;
+ fw_output_size_array->data[2] =
+ params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, fw_output, fw_output_size_array));
- TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3);
- bw_output_size_array->data[0] = batch_size;
- bw_output_size_array->data[1] = max_time;
- bw_output_size_array->data[2] = bw_num_units;
- TF_LITE_ENSURE_OK(
- context, context->ResizeTensor(context, bw_output, bw_output_size_array));
+ if (!params->merge_outputs) {
+ TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
+ TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3);
+ bw_output_size_array->data[0] = batch_size;
+ bw_output_size_array->data[1] = max_time;
+ bw_output_size_array->data[2] = bw_num_units;
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_output,
+ bw_output_size_array));
+ }
return kTfLiteOk;
}
@@ -256,9 +263,9 @@ TfLiteStatus EvalFloat(
const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias,
const TfLiteTensor* aux_input, const TfLiteTensor* fw_aux_input_weights,
const TfLiteTensor* bw_aux_input_weights,
- const TfLiteSequenceRNNParams* params, TfLiteTensor* fw_hidden_state,
- TfLiteTensor* fw_output, TfLiteTensor* bw_hidden_state,
- TfLiteTensor* bw_output) {
+ const TfLiteBidirectionalSequenceRNNParams* params,
+ TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
+ TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) {
const int batch_size = input->dims->data[0];
const int max_time = input->dims->data[1];
const int input_size = input->dims->data[2];
@@ -281,10 +288,15 @@ TfLiteStatus EvalFloat(
? bw_aux_input_weights->data.f
: nullptr;
+ const int fw_output_step =
+ params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
+ const int bw_output_step =
+ params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
for (int b = 0; b < batch_size; b++) {
// Forward cell.
float* fw_hidden_state_ptr_batch =
fw_hidden_state->data.f + b * fw_num_units;
+ float* fw_output_offset = fw_output->data.f + b * fw_output_step * max_time;
for (int s = 0; s < max_time; s++) {
const float* input_ptr_batch =
input->data.f + b * input_size * max_time + s * input_size;
@@ -292,8 +304,7 @@ TfLiteStatus EvalFloat(
(aux_input != nullptr)
? aux_input->data.f + b * input_size * max_time + s * input_size
: nullptr;
- float* output_ptr_batch =
- fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units;
+ float* output_ptr_batch = fw_output_offset + s * fw_output_step;
kernel_utils::RnnBatchStep(
input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch,
@@ -304,6 +315,10 @@ TfLiteStatus EvalFloat(
// Backward cell.
float* bw_hidden_state_ptr_batch =
bw_hidden_state->data.f + b * bw_num_units;
+ float* bw_output_offset =
+ params->merge_outputs
+ ? fw_output->data.f + b * bw_output_step * max_time + fw_num_units
+ : bw_output->data.f + b * bw_output_step * max_time;
for (int s = max_time - 1; s >= 0; s--) {
const float* input_ptr_batch =
input->data.f + b * input_size * max_time + s * input_size;
@@ -311,8 +326,7 @@ TfLiteStatus EvalFloat(
(aux_input != nullptr)
? aux_input->data.f + b * input_size * max_time + s * input_size
: nullptr;
- float* output_ptr_batch =
- bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units;
+ float* output_ptr_batch = bw_output_offset + s * bw_output_step;
kernel_utils::RnnBatchStep(
input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch,
@@ -331,11 +345,12 @@ TfLiteStatus EvalHybrid(
const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias,
const TfLiteTensor* aux_input, const TfLiteTensor* aux_fw_input_weights,
const TfLiteTensor* aux_bw_input_weights,
- const TfLiteSequenceRNNParams* params, TfLiteTensor* scaling_factors,
- TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
- TfLiteTensor* fw_hidden_state_quantized, TfLiteTensor* fw_hidden_state,
- TfLiteTensor* fw_output, TfLiteTensor* bw_hidden_state_quantized,
- TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) {
+ const TfLiteBidirectionalSequenceRNNParams* params,
+ TfLiteTensor* scaling_factors, TfLiteTensor* input_quantized,
+ TfLiteTensor* aux_input_quantized, TfLiteTensor* fw_hidden_state_quantized,
+ TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
+ TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state,
+ TfLiteTensor* bw_output) {
const int batch_size = input->dims->data[0];
const int max_time = input->dims->data[1];
const int input_size = input->dims->data[2];
@@ -384,10 +399,15 @@ TfLiteStatus EvalHybrid(
reinterpret_cast<int8_t*>(bw_hidden_state_quantized->data.uint8);
float* scaling_factors_ptr = scaling_factors->data.f;
+ const int fw_output_step =
+ params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
+ const int bw_output_step =
+ params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
for (int b = 0; b < batch_size; b++) {
// Forward cell.
float* fw_hidden_state_ptr_batch =
fw_hidden_state->data.f + b * fw_num_units;
+ float* fw_output_offset = fw_output->data.f + b * fw_output_step * max_time;
for (int s = 0; s < max_time; s++) {
const float* input_ptr_batch =
input->data.f + b * input_size * max_time + s * input_size;
@@ -395,8 +415,7 @@ TfLiteStatus EvalHybrid(
(aux_input != nullptr)
? aux_input->data.f + b * input_size * max_time + s * input_size
: nullptr;
- float* output_ptr_batch =
- fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units;
+ float* output_ptr_batch = fw_output_offset + s * fw_output_step;
kernel_utils::RnnBatchStep(
input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale,
@@ -411,6 +430,10 @@ TfLiteStatus EvalHybrid(
// Backward cell.
float* bw_hidden_state_ptr_batch =
bw_hidden_state->data.f + b * bw_num_units;
+ float* bw_output_offset =
+ params->merge_outputs
+ ? fw_output->data.f + b * bw_output_step * max_time
+ : bw_output->data.f + b * bw_output_step * max_time;
for (int s = max_time - 1; s >= 0; s--) {
const float* input_ptr_batch =
input->data.f + b * input_size * max_time + s * input_size;
@@ -418,8 +441,7 @@ TfLiteStatus EvalHybrid(
(aux_input != nullptr)
? aux_input->data.f + b * input_size * max_time + s * input_size
: nullptr;
- float* output_ptr_batch =
- bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units;
+ float* output_ptr_batch = bw_output_offset + s * bw_output_step;
kernel_utils::RnnBatchStep(
input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale,
@@ -436,8 +458,8 @@ TfLiteStatus EvalHybrid(
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- const auto* params =
- reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
+ const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>(
+ node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* fw_input_weights =
@@ -465,7 +487,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetVariableInput(context, node, kBwHiddenStateTensor);
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
- TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
+ TfLiteTensor* bw_output = params->merge_outputs
+ ? nullptr
+ : GetOutput(context, node, kBwOutputTensor);
switch (fw_input_weights->type) {
case kTfLiteFloat32:
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
index 3e34ba6196..f555c472f5 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
@@ -654,7 +654,7 @@ const std::initializer_list<float> recurrent_weights = {
class BidirectionalRNNOpModel : public SingleOpModel {
public:
BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units,
- int bw_units, int input_size)
+ int bw_units, int input_size, bool merge_outputs)
: batches_(batches),
sequence_len_(sequence_len),
fw_units_(fw_units),
@@ -675,12 +675,15 @@ class BidirectionalRNNOpModel : public SingleOpModel {
aux_bw_weights_ = AddNullInput();
fw_output_ = AddOutput(TensorType_FLOAT32);
- bw_output_ = AddOutput(TensorType_FLOAT32);
+ if (!merge_outputs) {
+ bw_output_ = AddOutput(TensorType_FLOAT32);
+ }
SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
- BuiltinOptions_SequenceRNNOptions,
- CreateSequenceRNNOptions(builder_, /*time_major=*/false,
- ActivationFunctionType_RELU)
+ BuiltinOptions_BidirectionalSequenceRNNOptions,
+ CreateBidirectionalSequenceRNNOptions(
+ builder_, /*time_major=*/false,
+ ActivationFunctionType_RELU, merge_outputs)
.Union());
BuildInterpreter({
{batches_, sequence_len_, input_size_}, // input
@@ -767,7 +770,7 @@ class BidirectionalRNNOpModel : public SingleOpModel {
TEST(BidirectionalRNNOpTest, BlackBoxTest) {
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*fw_units=*/16, /*bw_units=*/16,
- /*input_size=*/8);
+ /*input_size=*/8, /*merge_outputs=*/false);
rnn.SetFwWeights(weights);
rnn.SetBwWeights(weights);
rnn.SetFwBias(biases);
@@ -800,12 +803,49 @@ TEST(BidirectionalRNNOpTest, BlackBoxTest) {
EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
}
+// Same as the previous test, yet with merged outputs.
+TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) {
+ BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
+ /*fw_units=*/16, /*bw_units=*/16,
+ /*input_size=*/8, /*merge_outputs=*/true);
+ rnn.SetFwWeights(weights);
+ rnn.SetBwWeights(weights);
+ rnn.SetFwBias(biases);
+ rnn.SetBwBias(biases);
+ rnn.SetFwRecurrentWeights(recurrent_weights);
+ rnn.SetBwRecurrentWeights(recurrent_weights);
+
+ const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
+ float* batch_start = rnn_input;
+ float* batch_end = batch_start + input_sequence_size;
+ rnn.SetInput(0, batch_start, batch_end);
+ rnn.SetInput(input_sequence_size, batch_start, batch_end);
+
+ rnn.Invoke();
+
+ std::vector<float> merged_expected;
+ for (int bid = 0; bid < rnn.num_batches(); bid++) {
+ for (int step = 0; step < rnn.sequence_len(); step++) {
+ merged_expected.insert(
+ merged_expected.end(),
+ rnn_golden_fw_output + rnn.num_fw_units() * step,
+ rnn_golden_fw_output + rnn.num_fw_units() * (step + 1));
+ merged_expected.insert(
+ merged_expected.end(),
+ rnn_golden_bw_output + rnn.num_bw_units() * step,
+ rnn_golden_bw_output + rnn.num_bw_units() * (step + 1));
+ }
+ }
+ EXPECT_THAT(rnn.GetFwOutput(),
+ ElementsAreArray(ArrayFloatNear(merged_expected)));
+}
+
// Check that if the input sequence is reversed the outputs are the same just
// forward and backward are swapped (and reversed).
TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) {
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*fw_units=*/16, /*bw_units=*/16,
- /*input_size=*/8);
+ /*input_size=*/8, /*merge_outputs=*/false);
rnn.SetFwWeights(weights);
rnn.SetBwWeights(weights);
rnn.SetFwBias(biases);
@@ -851,7 +891,7 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) {
TEST(BidirectionalRNNOpTest, EndToEndTest) {
BidirectionalRNNOpModel rnn(/*batches=*/1, /*sequence_len=*/4,
/*fw_units=*/16, /*bw_units=*/16,
- /*input_size=*/8);
+ /*input_size=*/8, /*merge_outputs=*/false);
const int output_size = 4;
float dnn_weights[] = {
-0.5782342, -0.052212059, 0.73036242, -0.81216097, -0.80088139,
diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc
index f765235e04..3926af5b97 100644
--- a/tensorflow/contrib/lite/kernels/comparisons.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons.cc
@@ -66,31 +66,25 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
if (input1->type == kTfLiteUInt8) { \
auto input1_offset = -input1->params.zero_point; \
auto input2_offset = -input2->params.zero_point; \
- const int left_shift = 20; \
- const double twice_max_input_scale = \
- 2 * std::max(input1->params.scale, input2->params.scale); \
- const double real_input1_multiplier = \
- input1->params.scale / twice_max_input_scale; \
- const double real_input2_multiplier = \
- input2->params.scale / twice_max_input_scale; \
+ const int left_shift = 8; \
\
int32 input1_multiplier; \
int input1_shift; \
- QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, \
+ QuantizeMultiplierSmallerThanOneExp(input1->params.scale, \
&input1_multiplier, &input1_shift); \
int32 input2_multiplier; \
int input2_shift; \
- QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, \
+ QuantizeMultiplierSmallerThanOneExp(input2->params.scale, \
&input2_multiplier, &input2_shift); \
\
ComparisonParams op_params; \
op_params.left_shift = left_shift; \
op_params.input1_offset = input1_offset; \
op_params.input1_multiplier = input1_multiplier; \
- op_params.input1_shift = -input1_shift; \
+ op_params.input1_shift = input1_shift; \
op_params.input2_offset = input2_offset; \
op_params.input2_multiplier = input2_multiplier; \
- op_params.input2_shift = -input2_shift; \
+ op_params.input2_shift = input2_shift; \
if (requires_broadcast) { \
reference_ops::Broadcast4DSlow##opname##WithScaling( \
op_params, GetTensorShape(input1), GetTensorData<uint8_t>(input1), \
diff --git a/tensorflow/contrib/lite/kernels/comparisons_test.cc b/tensorflow/contrib/lite/kernels/comparisons_test.cc
index 67a91c17fd..04c8bf2e30 100644
--- a/tensorflow/contrib/lite/kernels/comparisons_test.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons_test.cc
@@ -402,6 +402,17 @@ TEST(ComparisonsTest, GreaterQuantized) {
EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
}
+TEST(ComparisonsTest, GreaterQuantizedSmallRange) {
+ ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, 0.0, 1.0},
+ {TensorType_UINT8, {1, 2, 2, 1}, 0.0, 2.0},
+ TensorType_UINT8, BuiltinOperator_GREATER);
+ model.QuantizeAndPopulate<uint8_t>(model.input1(), {1.0, 0.5, 0.35, 0.1});
+ model.QuantizeAndPopulate<uint8_t>(model.input2(), {1.01, 0.25, 0.3, 0.4});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
+}
+
TEST(ComparisonsTest, GreaterEqualQuantized) {
const float kMin = -1.f;
const float kMax = 128.f;
diff --git a/tensorflow/contrib/lite/kernels/internal/compatibility.h b/tensorflow/contrib/lite/kernels/internal/compatibility.h
index b87cf2b60d..7c176e0fa1 100644
--- a/tensorflow/contrib/lite/kernels/internal/compatibility.h
+++ b/tensorflow/contrib/lite/kernels/internal/compatibility.h
@@ -84,4 +84,27 @@ using uint16 = std::uint16_t;
using int32 = std::int32_t;
using uint32 = std::uint32_t;
+// TFLITE_DEPRECATED()
+//
+// Duplicated from absl/base/macros.h to avoid pulling in that library.
+// Marks a deprecated class, struct, enum, function, method and variable
+// declarations. The macro argument is used as a custom diagnostic message (e.g.
+// suggestion of a better alternative).
+//
+// Example:
+//
+// class TFLITE_DEPRECATED("Use Bar instead") Foo {...};
+// TFLITE_DEPRECATED("Use Baz instead") void Bar() {...}
+//
+// Every usage of a deprecated entity will trigger a warning when compiled with
+// clang's `-Wdeprecated-declarations` option. This option is turned off by
+// default, but the warnings will be reported by clang-tidy.
+#if defined(__clang__) && __cplusplus >= 201103L
+#define TFLITE_DEPRECATED(message) __attribute__((deprecated(message)))
+#endif
+
+#ifndef TFLITE_DEPRECATED
+#define TFLITE_DEPRECATED(message)
+#endif
+
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
index 56e9367878..083e5839bd 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
@@ -169,603 +169,5 @@ void RnnBatchStep(
hidden_state_ptr_batch);
}
-void LstmStep(
- const float* input_ptr_batch, const float* input_to_input_weights_ptr,
- const float* input_to_forget_weights_ptr,
- const float* input_to_cell_weights_ptr,
- const float* input_to_output_weights_ptr,
- const float* recurrent_to_input_weights_ptr,
- const float* recurrent_to_forget_weights_ptr,
- const float* recurrent_to_cell_weights_ptr,
- const float* recurrent_to_output_weights_ptr,
- const float* cell_to_input_weights_ptr,
- const float* cell_to_forget_weights_ptr,
- const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const float* projection_weights_ptr,
- const float* projection_bias_ptr, const TfLiteLSTMParams* params,
- int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr,
- float* cell_state_ptr, float* input_gate_scratch,
- float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
- float* output_ptr_batch) {
- LstmStepWithAuxInput(
- input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr,
- input_to_cell_weights_ptr, input_to_output_weights_ptr,
- /*aux_input_ptr_batch=*/nullptr,
- /*aux_input_to_input_weights_ptr=*/nullptr,
- /*aux_input_to_forget_weights_ptr=*/nullptr,
- /*aux_input_to_cell_weights_ptr=*/nullptr,
- /*aux_input_to_output_weights_ptr=*/nullptr,
- recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
- recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
- cell_to_input_weights_ptr, cell_to_forget_weights_ptr,
- cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr,
- cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
- projection_bias_ptr, params, n_batch, n_cell, n_input, /*n_aux_input=*/0,
- n_output, output_state_ptr, cell_state_ptr, input_gate_scratch,
- forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr_batch);
-}
-
-void LstmStepWithAuxInput(
- const float* input_ptr_batch, const float* input_to_input_weights_ptr,
- const float* input_to_forget_weights_ptr,
- const float* input_to_cell_weights_ptr,
- const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch,
- const float* aux_input_to_input_weights_ptr,
- const float* aux_input_to_forget_weights_ptr,
- const float* aux_input_to_cell_weights_ptr,
- const float* aux_input_to_output_weights_ptr,
- const float* recurrent_to_input_weights_ptr,
- const float* recurrent_to_forget_weights_ptr,
- const float* recurrent_to_cell_weights_ptr,
- const float* recurrent_to_output_weights_ptr,
- const float* cell_to_input_weights_ptr,
- const float* cell_to_forget_weights_ptr,
- const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const float* projection_weights_ptr,
- const float* projection_bias_ptr, const TfLiteLSTMParams* params,
- int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
- float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
- float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
- float* output_ptr_batch) {
- // Since we have already checked that weights are all there or none, we can
- // check the existense of only one to the get the condition.
- const bool use_cifg = (input_to_input_weights_ptr == nullptr);
- const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
- // Initialize scratch buffers with bias.
- if (!use_cifg) {
- tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch,
- input_gate_scratch);
- }
- tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
- forget_gate_scratch);
- tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
- cell_scratch);
- tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
- output_gate_scratch);
-
- // For each batch and cell: compute input_weight * input.
- if (!use_cifg) {
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- input_gate_scratch, /*result_stride=*/1);
- }
-
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- forget_gate_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- cell_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- output_gate_scratch, /*result_stride=*/1);
-
- // If auxiliary input is available then compute aux_input_weight * aux_input
- if (aux_input_ptr_batch != nullptr) {
- if (!use_cifg) {
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_input_weights_ptr, n_cell, n_aux_input,
- aux_input_ptr_batch, n_batch, input_gate_scratch,
- /*result_stride=*/1);
- }
-
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
- aux_input_ptr_batch, n_batch, forget_gate_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr_batch,
- n_batch, cell_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_output_weights_ptr, n_cell, n_aux_input,
- aux_input_ptr_batch, n_batch, output_gate_scratch, /*result_stride=*/1);
- }
-
- // For each batch and cell: compute recurrent_weight * output_state.
- if (!use_cifg) {
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, input_gate_scratch, /*result_stride=*/1);
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, forget_gate_scratch,
- /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, cell_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, output_gate_scratch,
- /*result_stride=*/1);
-
- // For each batch and cell: update input gate.
- if (!use_cifg) {
- if (use_peephole) {
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
- input_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
- input_gate_scratch);
- }
-
- // For each batch and cell: update forget gate.
- if (use_peephole) {
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
- forget_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
- forget_gate_scratch);
-
- // For each batch and cell: update the cell.
- tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
- n_batch * n_cell, cell_state_ptr);
- tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
- params->activation, cell_scratch);
- if (use_cifg) {
- tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
- forget_gate_scratch);
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
- } else {
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
- }
- if (params->cell_clip > 0.0) {
- tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
- params->cell_clip, cell_state_ptr);
- }
-
- // For each batch and cell: update the output gate.
- if (use_peephole) {
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
- output_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
- output_gate_scratch);
- tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
- params->activation, cell_scratch);
- tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
- n_batch * n_cell, output_gate_scratch);
-
- // For each batch: update the projection and output_state.
- const bool use_projection_weight = (projection_weights_ptr != nullptr);
- const bool use_projection_bias = (projection_bias_ptr != nullptr);
- if (use_projection_weight) {
- if (use_projection_bias) {
- tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
- n_batch, output_ptr_batch);
- } else {
- tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch,
- output_ptr_batch, /*result_stride=*/1);
- if (params->proj_clip > 0.0) {
- tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
- params->proj_clip, output_ptr_batch);
- }
- } else {
- tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
- output_ptr_batch);
- }
- tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
- output_state_ptr);
-}
-
-void LstmStep(
- const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
- float input_to_input_weights_scale,
- const int8_t* input_to_forget_weights_ptr,
- float input_to_forget_weights_scale,
- const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
- const int8_t* input_to_output_weights_ptr,
- float input_to_output_weights_scale,
- const int8_t* recurrent_to_input_weights_ptr,
- float recurrent_to_input_weights_scale,
- const int8_t* recurrent_to_forget_weights_ptr,
- float recurrent_to_forget_weights_scale,
- const int8_t* recurrent_to_cell_weights_ptr,
- float recurrent_to_cell_weights_scale,
- const int8_t* recurrent_to_output_weights_ptr,
- float recurrent_to_output_weights_scale,
- const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
- const int8_t* cell_to_forget_weights_ptr,
- float cell_to_forget_weights_scale,
- const int8_t* cell_to_output_weights_ptr,
- float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
- float projection_weights_scale, const float* projection_bias_ptr,
- const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
- int n_output, float* input_gate_scratch, float* forget_gate_scratch,
- float* cell_scratch, float* output_gate_scratch, float* scaling_factors,
- float* product_scaling_factors, float* recovered_cell_weights,
- int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr,
- int8_t* quantized_cell_state_ptr, float* output_state_ptr,
- float* cell_state_ptr, float* output_ptr_batch) {
- LstmStepWithAuxInput(
- input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale,
- input_to_forget_weights_ptr, input_to_forget_weights_scale,
- input_to_cell_weights_ptr, input_to_cell_weights_scale,
- input_to_output_weights_ptr, input_to_output_weights_scale,
- /*aux_input_ptr_batch=*/nullptr,
- /*aux_input_to_input_weights_ptr=*/nullptr,
- /*aux_input_to_input_weights_scale=*/0.0f,
- /*aux_input_to_forget_weights_ptr=*/nullptr,
- /*aux_input_to_forget_weights_scale=*/0.0f,
- /*aux_input_to_cell_weights_ptr=*/nullptr,
- /*aux_input_to_cell_weights_scale=*/0.0f,
- /*aux_input_to_output_weights_ptr=*/nullptr,
- /*aux_input_to_output_weights_scale=*/0.0f,
- recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale,
- recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale,
- recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale,
- recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale,
- cell_to_input_weights_ptr, cell_to_input_weights_scale,
- cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
- cell_to_output_weights_ptr, cell_to_output_weights_scale,
- input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
- output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale,
- projection_bias_ptr, params, n_batch, n_cell, n_input,
- /*n_aux_input=*/0, n_output, input_gate_scratch, forget_gate_scratch,
- cell_scratch, output_gate_scratch, scaling_factors,
- product_scaling_factors, recovered_cell_weights,
- quantized_input_ptr_batch,
- /*quantized_aux_input_ptr_batch=*/nullptr, quantized_output_state_ptr,
- quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
- output_ptr_batch);
- }
-
- void LstmStepWithAuxInput(
- const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
- float input_to_input_weights_scale,
- const int8_t* input_to_forget_weights_ptr,
- float input_to_forget_weights_scale,
- const int8_t* input_to_cell_weights_ptr,
- float input_to_cell_weights_scale,
- const int8_t* input_to_output_weights_ptr,
- float input_to_output_weights_scale, const float* aux_input_ptr_batch,
- const int8_t* aux_input_to_input_weights_ptr,
- float aux_input_to_input_weights_scale,
- const int8_t* aux_input_to_forget_weights_ptr,
- float aux_input_to_forget_weights_scale,
- const int8_t* aux_input_to_cell_weights_ptr,
- float aux_input_to_cell_weights_scale,
- const int8_t* aux_input_to_output_weights_ptr,
- float aux_input_to_output_weights_scale,
- const int8_t* recurrent_to_input_weights_ptr,
- float recurrent_to_input_weights_scale,
- const int8_t* recurrent_to_forget_weights_ptr,
- float recurrent_to_forget_weights_scale,
- const int8_t* recurrent_to_cell_weights_ptr,
- float recurrent_to_cell_weights_scale,
- const int8_t* recurrent_to_output_weights_ptr,
- float recurrent_to_output_weights_scale,
- const int8_t* cell_to_input_weights_ptr,
- float cell_to_input_weights_scale,
- const int8_t* cell_to_forget_weights_ptr,
- float cell_to_forget_weights_scale,
- const int8_t* cell_to_output_weights_ptr,
- float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
- float projection_weights_scale, const float* projection_bias_ptr,
- const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
- int n_aux_input, int n_output, float* input_gate_scratch,
- float* forget_gate_scratch, float* cell_scratch,
- float* output_gate_scratch, float* scaling_factors,
- float* product_scaling_factors, float* recovered_cell_weights,
- int8_t* quantized_input_ptr_batch,
- int8_t* quantized_aux_input_ptr_batch,
- int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr,
- float* output_state_ptr, float* cell_state_ptr,
- float* output_ptr_batch) {
- // Since we have already checked that weights are all there or none, we
- // can check the existense of only one to the get the condition.
- const bool use_cifg = (input_to_input_weights_ptr == nullptr);
- const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
- // Initialize scratch buffers with bias.
- if (!use_cifg) {
- tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell,
- n_batch, input_gate_scratch);
- }
- tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell,
- n_batch, forget_gate_scratch);
- tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
- cell_scratch);
- tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell,
- n_batch, output_gate_scratch);
-
- if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
- // Save quantization and matmul computation for all zero input.
- float unused_min, unused_max;
- for (int b = 0; b < n_batch; ++b) {
- const int offset = b * n_input;
- tensor_utils::SymmetricQuantizeFloats(
- input_ptr_batch + offset, n_input,
- quantized_input_ptr_batch + offset, &unused_min, &unused_max,
- &scaling_factors[b]);
- }
- // For each batch and cell: compute input_weight * input.
- if (!use_cifg) {
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_input_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_input_weights_ptr, n_cell, n_input,
- quantized_input_ptr_batch, product_scaling_factors, n_batch,
- input_gate_scratch, /*result_stride=*/1);
- }
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_forget_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_forget_weights_ptr, n_cell, n_input,
- quantized_input_ptr_batch, product_scaling_factors, n_batch,
- forget_gate_scratch,
- /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_cell_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_cell_weights_ptr, n_cell, n_input,
- quantized_input_ptr_batch, product_scaling_factors, n_batch,
- cell_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_output_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_output_weights_ptr, n_cell, n_input,
- quantized_input_ptr_batch, product_scaling_factors, n_batch,
- output_gate_scratch,
- /*result_stride=*/1);
- }
-
- if (aux_input_ptr_batch != nullptr &&
- !tensor_utils::IsZeroVector(aux_input_ptr_batch, n_batch * n_input)) {
- // Save quantization and matmul computation for all zero input.
- float unused_min, unused_max;
- for (int b = 0; b < n_batch; ++b) {
- const int offset = b * n_input;
- tensor_utils::SymmetricQuantizeFloats(
- aux_input_ptr_batch + offset, n_input,
- quantized_aux_input_ptr_batch + offset, &unused_min, &unused_max,
- &scaling_factors[b]);
- }
- // For each batch and cell: compute input_weight * input.
- if (!use_cifg) {
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * aux_input_to_input_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_input_weights_ptr, n_cell, n_input,
- quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
- input_gate_scratch, /*result_stride=*/1);
- }
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * aux_input_to_forget_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_forget_weights_ptr, n_cell, n_input,
- quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
- forget_gate_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * aux_input_to_cell_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_cell_weights_ptr, n_cell, n_input,
- quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
- cell_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * aux_input_to_output_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_output_weights_ptr, n_cell, n_input,
- quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
- output_gate_scratch, /*result_stride=*/1);
- }
-
- if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
- // Save quantization and matmul computation for all zero input.
- float unused_min, unused_max;
- for (int b = 0; b < n_batch; ++b) {
- const int offset = b * n_output;
- tensor_utils::SymmetricQuantizeFloats(
- output_state_ptr + offset, n_output,
- quantized_output_state_ptr + offset, &unused_min, &unused_max,
- &scaling_factors[b]);
- }
- // For each batch and cell: compute recurrent_weight * output_state.
- if (!use_cifg) {
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_input_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_input_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- input_gate_scratch, /*result_stride=*/1);
- }
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_forget_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_forget_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- forget_gate_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_cell_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_cell_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- cell_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_output_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_output_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- output_gate_scratch, /*result_stride=*/1);
- }
-
- // Save quantization and matmul computation for all zero input.
- bool is_cell_state_all_zeros =
- tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
-
- // For each batch and cell: update input gate.
- if (!use_cifg) {
- if (use_peephole && !is_cell_state_all_zeros) {
- tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell,
- cell_to_input_weights_scale,
- recovered_cell_weights);
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
- input_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
- input_gate_scratch);
- }
-
- // For each batch and cell: update forget gate.
- if (use_peephole && !is_cell_state_all_zeros) {
- tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell,
- cell_to_forget_weights_scale,
- recovered_cell_weights);
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
- forget_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
- forget_gate_scratch);
-
- // For each batch and cell: update the cell.
- tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch,
- cell_state_ptr, n_batch * n_cell,
- cell_state_ptr);
- tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
- params->activation, cell_scratch);
- if (use_cifg) {
- tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
- forget_gate_scratch);
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, forget_gate_scratch, n_batch * n_cell,
- cell_state_ptr);
- } else {
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
- }
- if (params->cell_clip > 0.0) {
- tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
- params->cell_clip, cell_state_ptr);
- }
-
- is_cell_state_all_zeros =
- tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
- // For each batch and cell: update the output gate.
- if (use_peephole && !is_cell_state_all_zeros) {
- tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell,
- cell_to_output_weights_scale,
- recovered_cell_weights);
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
- output_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
- output_gate_scratch);
- tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
- params->activation, cell_scratch);
- tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
- n_batch * n_cell,
- output_gate_scratch);
-
- // For each batch: update the projection and output_state.
- const bool use_projection_weight = (projection_weights_ptr != nullptr);
- const bool use_projection_bias = (projection_bias_ptr != nullptr);
- if (use_projection_weight) {
- if (use_projection_bias) {
- tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
- n_batch, output_ptr_batch);
- } else {
- tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
- }
- if (!tensor_utils::IsZeroVector(output_gate_scratch,
- n_batch * n_cell)) {
- // Save quantization and matmul computation for all zero input.
- float unused_min, unused_max;
- for (int b = 0; b < n_batch; ++b) {
- const int offset = b * n_cell;
- tensor_utils::SymmetricQuantizeFloats(
- output_gate_scratch + offset, n_cell,
- quantized_cell_state_ptr + offset, &unused_min, &unused_max,
- &scaling_factors[b]);
- }
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * projection_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- projection_weights_ptr, n_output, n_cell,
- quantized_cell_state_ptr, product_scaling_factors, n_batch,
- output_ptr_batch,
- /*result_stride=*/1);
- }
- if (params->proj_clip > 0.0) {
- tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
- params->proj_clip, output_ptr_batch);
- }
- } else {
- tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
- output_ptr_batch);
- }
- tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
- output_state_ptr);
- }
-
} // namespace kernel_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
index b5558cce55..74e0a4a53d 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
@@ -76,190 +76,6 @@ void RnnBatchStep(
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
float* hidden_state_ptr_batch, float* output_ptr_batch);
-// Performs an LSTM batch inference step for input specified by input_ptr_batch.
-// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
-// biases (*_bias_ptr), and buffers (*_scratch), along with additional
-// parameters:
-// - params: various LSTM params including activation, clipping, etc.,
-// - n_batch: size of batch,
-// - n_cell: number of cells (or units),
-// - n_input: the input size,
-// - n_output: the output size.
-//
-// The pointers to the cell and output state and the output are updated.
-//
-// The pointers with the suffix "_batch" point to data aligned in batch_major
-// order, and each step processes batch_size many inputs from input_ptr_batch,
-// and updates batch_size many cell and output states.
-void LstmStep(
- const float* input_ptr_batch, const float* input_to_input_weights_ptr,
- const float* input_to_forget_weights_ptr,
- const float* input_to_cell_weights_ptr,
- const float* input_to_output_weights_ptr,
- const float* recurrent_to_input_weights_ptr,
- const float* recurrent_to_forget_weights_ptr,
- const float* recurrent_to_cell_weights_ptr,
- const float* recurrent_to_output_weights_ptr,
- const float* cell_to_input_weights_ptr,
- const float* cell_to_forget_weights_ptr,
- const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const float* projection_weights_ptr,
- const float* projection_bias_ptr, const TfLiteLSTMParams* params,
- int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr,
- float* cell_state_ptr, float* input_gate_scratch,
- float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
- float* output_ptr_batch);
-
-// Same as above but includes an auxiliary input with the corresponding weights.
-void LstmStepWithAuxInput(
- const float* input_ptr_batch, const float* input_to_input_weights_ptr,
- const float* input_to_forget_weights_ptr,
- const float* input_to_cell_weights_ptr,
- const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch,
- const float* aux_input_to_input_weights_ptr,
- const float* aux_input_to_forget_weights_ptr,
- const float* aux_input_to_cell_weights_ptr,
- const float* aux_input_to_output_weights_ptr,
- const float* recurrent_to_input_weights_ptr,
- const float* recurrent_to_forget_weights_ptr,
- const float* recurrent_to_cell_weights_ptr,
- const float* recurrent_to_output_weights_ptr,
- const float* cell_to_input_weights_ptr,
- const float* cell_to_forget_weights_ptr,
- const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const float* projection_weights_ptr,
- const float* projection_bias_ptr, const TfLiteLSTMParams* params,
- int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
- float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
- float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
- float* output_ptr_batch);
-
-// Same as above but with quantized weight matrices. In detail:
-// Input of size 'n_batch * n_input':
-// input_ptr_batch
-//
-// LSTM weights:
-// Quantized input weights of size 'n_cell * n_input':
-// input_to_input_weights - optional (can be nullptr)
-// input_to_forget_weights
-// input_to_cell_weights
-// input_to_input_weights
-// Quantized recurrent weights of size 'n_cell * n_output':
-// recurrent_to_input_weights - optional
-// recurrent_to_forget_weights
-// recurrent_to_cell_weights
-// recurrent_to_input_weights
-// Quantized peephole weights of size 'n_cell', representing diagonal matrices.
-// cell_to_input_weights - optional
-// cell_to_cell_weights - optional
-// cell_to_output_weights - optional
-// Quantized projection weights of size 'n_output * n_cell'
-// projection_weights_ptr - optional
-// Weight scales (scalars) for each of the weights above.
-// input_to_input_weights_scale - optional
-// input_to_forget_weights_scale
-// input_to_cell_weights_scale
-// input_to_output_weights_scale
-// recurrent_to_input_weights_scale - optional
-// recurrent_to_forget_weights_scale
-// recurrent_to_cell_weights_scale
-// recurrent_to_output_weights_scale
-// cell_to_input_weights_scale,
-// cell_to_forget_weights_scale,
-// cell_to_output_weights_scale,
-// projection_weights_scale - optional
-// Gate biases of size 'n_cell':
-// input_gate_bias_ptr - optional
-// forget_gate_bias_ptr
-// cell_gate_bias_ptr
-// output_gate_bias_ptr
-//
-// Temporary pre-allocated storage for quantized values:
-// quantized_input_ptr_batch (same size as input_ptr_batch)
-// quantized_output_state_ptr (same size as output_state_ptr)
-// quantized_cell_state_ptr (same size as cell_state_ptr)
-// Temporary pre-allocated storage for recovered values:
-// recovered_cell_weights (same size as cell_to_*_weights)
-//
-// Outputs:
-// output_state_ptr - size 'n_batch * n_output'
-// cell_state_ptr - size 'n_batch * n_cell'
-// output_ptr_batch - size 'n_batch * n_output'
-void LstmStep(
- const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
- float input_to_input_weights_scale,
- const int8_t* input_to_forget_weights_ptr,
- float input_to_forget_weights_scale,
- const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
- const int8_t* input_to_output_weights_ptr,
- float input_to_output_weights_scale,
- const int8_t* recurrent_to_input_weights_ptr,
- float recurrent_to_input_weights_scale,
- const int8_t* recurrent_to_forget_weights_ptr,
- float recurrent_to_forget_weights_scale,
- const int8_t* recurrent_to_cell_weights_ptr,
- float recurrent_to_cell_weights_scale,
- const int8_t* recurrent_to_output_weights_ptr,
- float recurrent_to_output_weights_scale,
- const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
- const int8_t* cell_to_forget_weights_ptr,
- float cell_to_forget_weights_scale,
- const int8_t* cell_to_output_weights_ptr,
- float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
- float projection_weights_scale, const float* projection_bias_ptr,
- const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
- int n_output, float* input_gate_scratch, float* forget_gate_scratch,
- float* cell_scratch, float* output_gate_scratch, float* scaling_factors,
- float* product_scaling_factors, float* recovered_cell_weights,
- int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr,
- int8_t* quantized_cell_state_ptr, float* output_state_ptr,
- float* cell_state_ptr, float* output_ptr_batch);
-
-void LstmStepWithAuxInput(
- const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
- float input_to_input_weights_scale,
- const int8_t* input_to_forget_weights_ptr,
- float input_to_forget_weights_scale,
- const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
- const int8_t* input_to_output_weights_ptr,
- float input_to_output_weights_scale, const float* aux_input_ptr_batch,
- const int8_t* aux_input_to_input_weights_ptr,
- float aux_input_to_input_weights_scale,
- const int8_t* aux_input_to_forget_weights_ptr,
- float aux_input_to_forget_weights_scale,
- const int8_t* aux_input_to_cell_weights_ptr,
- float aux_input_to_cell_weights_scale,
- const int8_t* aux_input_to_output_weights_ptr,
- float aux_input_to_output_weights_scale,
- const int8_t* recurrent_to_input_weights_ptr,
- float recurrent_to_input_weights_scale,
- const int8_t* recurrent_to_forget_weights_ptr,
- float recurrent_to_forget_weights_scale,
- const int8_t* recurrent_to_cell_weights_ptr,
- float recurrent_to_cell_weights_scale,
- const int8_t* recurrent_to_output_weights_ptr,
- float recurrent_to_output_weights_scale,
- const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
- const int8_t* cell_to_forget_weights_ptr,
- float cell_to_forget_weights_scale,
- const int8_t* cell_to_output_weights_ptr,
- float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
- float projection_weights_scale, const float* projection_bias_ptr,
- const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
- int n_aux_input, int n_output, float* input_gate_scratch,
- float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
- float* scaling_factors, float* product_scaling_factors,
- float* recovered_cell_weights, int8_t* quantized_input_ptr_batch,
- int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr,
- int8_t* quantized_cell_state_ptr, float* output_state_ptr,
- float* cell_state_ptr, float* output_ptr_batch);
-
} // namespace kernel_utils
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
index 14281f25c6..25ea72b886 100644
--- a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
@@ -259,7 +259,7 @@ TEST(QuantizationUtilTest, IntegerFrExpVersusDouble) {
EXPECT_EQ(double_shift, 1);
result = IntegerFrExp(123.45, &shift);
- EXPECT_NEAR(result, (0.964453 * (1L << 31)), 1000);
+ EXPECT_NEAR(result, (0.964453 * (1LL << 31)), 1000);
EXPECT_EQ(shift, 7);
double_result = std::frexp(123.45, &double_shift);
EXPECT_NEAR(double_result, 0.964453, 1e-5);
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index b39347758a..c6bc6074d4 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -18,7 +18,6 @@ limitations under the License.
#include <algorithm>
#include <cstring>
-#include "absl/base/macros.h"
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
namespace tflite {
@@ -269,8 +268,9 @@ class RuntimeShape {
// This creates a shape padded to the desired size with the specified value.
RuntimeShape(int new_shape_size, const RuntimeShape& shape, int pad_value)
: size_(0) {
+ // If the following check fails, it is likely because a 4D-only kernel is
+ // being used with an array of larger dimension count.
TFLITE_CHECK_GE(new_shape_size, shape.DimensionsCount());
- TFLITE_CHECK_LE(new_shape_size, kMaxSmallSize);
Resize(new_shape_size);
const int size_increase = new_shape_size - shape.DimensionsCount();
for (int i = 0; i < size_increase; ++i) {
@@ -441,7 +441,7 @@ inline int FlatSize(const Dims<N>& dims) {
return flat_size;
}
-ABSL_DEPRECATED("Prefer FlatSize.")
+TFLITE_DEPRECATED("Prefer FlatSize.")
inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
return FlatSize(dims);
}
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
index 5b996d00bc..16d67a1a93 100644
--- a/tensorflow/contrib/lite/kernels/lstm.cc
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/lstm_eval.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
@@ -424,263 +425,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
-// The LSTM Op engine.
-TfLiteStatus EvalFloat(
- const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
- const TfLiteTensor* input_to_forget_weights,
- const TfLiteTensor* input_to_cell_weights,
- const TfLiteTensor* input_to_output_weights,
- const TfLiteTensor* recurrent_to_input_weights,
- const TfLiteTensor* recurrent_to_forget_weights,
- const TfLiteTensor* recurrent_to_cell_weights,
- const TfLiteTensor* recurrent_to_output_weights,
- const TfLiteTensor* cell_to_input_weights,
- const TfLiteTensor* cell_to_forget_weights,
- const TfLiteTensor* cell_to_output_weights,
- const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
- const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
- const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
- const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
- TfLiteTensor* activation_state, TfLiteTensor* cell_state,
- TfLiteTensor* output) {
- const int n_batch = input->dims->data[0];
- const int n_input = input->dims->data[1];
- // n_cell and n_output will be the same size when there is no projection.
- const int n_cell = input_to_output_weights->dims->data[0];
- const int n_output = recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existence of only one to get the condition.
- const bool use_cifg = (input_to_input_weights == nullptr);
- const bool use_peephole = (cell_to_output_weights != nullptr);
-
- float* input_gate_scratch = nullptr;
- float* cell_scratch = nullptr;
- float* forget_gate_scratch = nullptr;
- float* output_gate_scratch = nullptr;
- if (use_cifg) {
- cell_scratch = scratch_buffer->data.f;
- forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- } else {
- input_gate_scratch = scratch_buffer->data.f;
- cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
- forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- const float* input_to_input_weights_ptr =
- (use_cifg) ? nullptr : input_to_input_weights->data.f;
- const float* recurrent_to_input_weights_ptr =
- (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
- const float* input_gate_bias_ptr =
- (use_cifg) ? nullptr : input_gate_bias->data.f;
- const float* cell_to_input_weights_ptr =
- (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
- const float* cell_to_forget_weights_ptr =
- (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
- const float* cell_to_output_weights_ptr =
- (use_peephole) ? cell_to_output_weights->data.f : nullptr;
- const float* projection_weights_ptr =
- (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
- const float* projection_bias_ptr =
- (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
-
- // Required tensors, pointers are non-null.
- const float* input_ptr_batch = input->data.f;
- const float* input_to_forget_weights_ptr = input_to_forget_weights->data.f;
- const float* input_to_cell_weights_ptr = input_to_cell_weights->data.f;
- const float* input_to_output_weights_ptr = input_to_output_weights->data.f;
- const float* recurrent_to_forget_weights_ptr =
- recurrent_to_forget_weights->data.f;
- const float* recurrent_to_cell_weights_ptr =
- recurrent_to_cell_weights->data.f;
- const float* recurrent_to_output_weights_ptr =
- recurrent_to_output_weights->data.f;
- const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
- const float* cell_bias_ptr = cell_bias->data.f;
- const float* output_gate_bias_ptr = output_gate_bias->data.f;
-
- float* activation_state_ptr = activation_state->data.f;
- float* cell_state_ptr = cell_state->data.f;
- float* output_ptr_batch = output->data.f;
-
- kernel_utils::LstmStep(
- input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr,
- input_to_cell_weights_ptr, input_to_output_weights_ptr,
- recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
- recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
- cell_to_input_weights_ptr, cell_to_forget_weights_ptr,
- cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr,
- cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
- projection_bias_ptr, params, n_batch, n_cell, n_input, n_output,
- activation_state_ptr, cell_state_ptr, input_gate_scratch,
- forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr_batch);
-
- return kTfLiteOk;
-}
-
-TfLiteStatus EvalHybrid(
- const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
- const TfLiteTensor* input_to_forget_weights,
- const TfLiteTensor* input_to_cell_weights,
- const TfLiteTensor* input_to_output_weights,
- const TfLiteTensor* recurrent_to_input_weights,
- const TfLiteTensor* recurrent_to_forget_weights,
- const TfLiteTensor* recurrent_to_cell_weights,
- const TfLiteTensor* recurrent_to_output_weights,
- const TfLiteTensor* cell_to_input_weights,
- const TfLiteTensor* cell_to_forget_weights,
- const TfLiteTensor* cell_to_output_weights,
- const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
- const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
- const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
- const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
- TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
- TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
- TfLiteTensor* activation_state_quantized,
- TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state,
- TfLiteTensor* cell_state, TfLiteTensor* output) {
- const int n_batch = input->dims->data[0];
- const int n_input = input->dims->data[1];
- // n_cell and n_output will be the same size when there is no projection.
- const int n_cell = input_to_output_weights->dims->data[0];
- const int n_output = recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existence of only one to get the condition.
- const bool use_cifg = (input_to_input_weights == nullptr);
- const bool use_peephole = (cell_to_output_weights != nullptr);
-
- float* input_gate_scratch = nullptr;
- float* cell_scratch = nullptr;
- float* forget_gate_scratch = nullptr;
- float* output_gate_scratch = nullptr;
- if (use_cifg) {
- cell_scratch = scratch_buffer->data.f;
- forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- } else {
- input_gate_scratch = scratch_buffer->data.f;
- cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
- forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- int8_t* input_to_input_weights_ptr = nullptr;
- float input_to_input_weights_scale = 1.0f;
- int8_t* recurrent_to_input_weights_ptr = nullptr;
- float recurrent_to_input_weights_scale = 1.0f;
- float* input_gate_bias_ptr = nullptr;
- if (!use_cifg) {
- input_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
- recurrent_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
- input_gate_bias_ptr = input_gate_bias->data.f;
- input_to_input_weights_scale = input_to_input_weights->params.scale;
- recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
- }
-
- int8_t* cell_to_input_weights_ptr = nullptr;
- int8_t* cell_to_forget_weights_ptr = nullptr;
- int8_t* cell_to_output_weights_ptr = nullptr;
- float cell_to_input_weights_scale = 1.0f;
- float cell_to_forget_weights_scale = 1.0f;
- float cell_to_output_weights_scale = 1.0f;
- if (use_peephole) {
- if (!use_cifg) {
- cell_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
- cell_to_input_weights_scale = cell_to_input_weights->params.scale;
- }
- cell_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
- cell_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
- cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
- cell_to_output_weights_scale = cell_to_output_weights->params.scale;
- }
-
- const int8_t* projection_weights_ptr =
- (projection_weights == nullptr)
- ? nullptr
- : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
- const float projection_weights_scale =
- (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
- const float* projection_bias_ptr =
- (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
-
- // Required tensors, pointers are non-null.
- const float* input_ptr_batch = input->data.f;
- const int8_t* input_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
- const float input_to_forget_weights_scale =
- input_to_forget_weights->params.scale;
- const int8_t* input_to_cell_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
- const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
- const int8_t* input_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
- const float input_to_output_weights_scale =
- input_to_output_weights->params.scale;
- const int8_t* recurrent_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
- const float recurrent_to_forget_weights_scale =
- recurrent_to_forget_weights->params.scale;
- const int8_t* recurrent_to_cell_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
- const float recurrent_to_cell_weights_scale =
- recurrent_to_cell_weights->params.scale;
- const int8_t* recurrent_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
- const float recurrent_to_output_weights_scale =
- recurrent_to_output_weights->params.scale;
- const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
- const float* cell_bias_ptr = cell_bias->data.f;
- const float* output_gate_bias_ptr = output_gate_bias->data.f;
-
- float* activation_state_ptr = activation_state->data.f;
- float* cell_state_ptr = cell_state->data.f;
- float* output_ptr_batch = output->data.f;
-
- // Temporary storage for quantized values and scaling factors.
- int8_t* quantized_input_ptr =
- reinterpret_cast<int8_t*>(input_quantized->data.uint8);
- int8_t* quantized_activation_state_ptr =
- reinterpret_cast<int8_t*>(activation_state_quantized->data.uint8);
- int8_t* quantized_cell_state_ptr =
- reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
- float* scaling_factors_ptr = scaling_factors->data.f;
- float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
- float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
-
- kernel_utils::LstmStep(
- input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale,
- input_to_forget_weights_ptr, input_to_forget_weights_scale,
- input_to_cell_weights_ptr, input_to_cell_weights_scale,
- input_to_output_weights_ptr, input_to_output_weights_scale,
- recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale,
- recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale,
- recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale,
- recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale,
- cell_to_input_weights_ptr, cell_to_input_weights_scale,
- cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
- cell_to_output_weights_ptr, cell_to_output_weights_scale,
- input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
- output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale,
- projection_bias_ptr, params, n_batch, n_cell, n_input, n_output,
- input_gate_scratch, forget_gate_scratch, cell_scratch,
- output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
- recovered_cell_weights_ptr, quantized_input_ptr,
- quantized_activation_state_ptr, quantized_cell_state_ptr,
- activation_state_ptr, cell_state_ptr, output_ptr_batch);
-
- return kTfLiteOk;
-}
-
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
@@ -738,15 +482,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// TODO(mirkov): add a check that weights are all uint8s or all floats.
switch (input_to_output_weights->type) {
case kTfLiteFloat32: {
- return EvalFloat(input, input_to_input_weights, input_to_forget_weights,
- input_to_cell_weights, input_to_output_weights,
- recurrent_to_input_weights, recurrent_to_forget_weights,
- recurrent_to_cell_weights, recurrent_to_output_weights,
- cell_to_input_weights, cell_to_forget_weights,
- cell_to_output_weights, input_gate_bias,
- forget_gate_bias, cell_bias, output_gate_bias,
- projection_weights, projection_bias, params,
- scratch_buffer, activation_state, cell_state, output);
+ return lstm_eval::EvalFloat(
+ input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights,
+ recurrent_to_input_weights, recurrent_to_forget_weights,
+ recurrent_to_cell_weights, recurrent_to_output_weights,
+ cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
+ /*aux_input=*/nullptr,
+ /*aux_input_to_input_weights=*/nullptr,
+ /*aux_input_to_forget_weights=*/nullptr,
+ /*aux_input_to_cell_weights=*/nullptr,
+ /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
+ projection_bias, params, /*forward_sequence=*/true,
+ /*output_offset=*/0, scratch_buffer, activation_state, cell_state,
+ output);
}
case kTfLiteUInt8: {
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
@@ -759,17 +509,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTemporary(context, node, /*index=*/5);
TfLiteTensor* recovered_cell_weights =
GetTemporary(context, node, /*index=*/6);
- return EvalHybrid(
+ return lstm_eval::EvalHybrid(
input, input_to_input_weights, input_to_forget_weights,
input_to_cell_weights, input_to_output_weights,
recurrent_to_input_weights, recurrent_to_forget_weights,
recurrent_to_cell_weights, recurrent_to_output_weights,
cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
- input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias,
- projection_weights, projection_bias, params, scratch_buffer,
- scaling_factors, prod_scaling_factors, recovered_cell_weights,
- input_quantized, activation_state_quantized, cell_state_quantized,
- activation_state, cell_state, output);
+ /*aux_input=*/nullptr,
+ /*aux_input_to_input_weights=*/nullptr,
+ /*aux_input_to_forget_weights=*/nullptr,
+ /*aux_input_to_cell_weights=*/nullptr,
+ /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
+ projection_bias, params, /*forward_sequence=*/true,
+ /*output_offset=*/0, scratch_buffer, scaling_factors,
+ prod_scaling_factors, recovered_cell_weights, input_quantized,
+ /*aux_input_quantized=*/nullptr, activation_state_quantized,
+ cell_state_quantized, activation_state, cell_state, output);
}
default:
context->ReportError(context, "Type %d is not currently supported.",
diff --git a/tensorflow/contrib/lite/kernels/lstm_eval.cc b/tensorflow/contrib/lite/kernels/lstm_eval.cc
new file mode 100644
index 0000000000..20a4e30009
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/lstm_eval.cc
@@ -0,0 +1,912 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/kernels/lstm_eval.h"
+
+#include <stdint.h>
+
+#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace lstm_eval {
+
+namespace {
+
+// Performs an LSTM batch inference step for input specified by input_ptr_batch.
+// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
+// biases (*_bias_ptr), and buffers (*_scratch), along with additional
+// parameters:
+// - params: various LSTM params including activation, clipping, etc.,
+// - n_batch: size of batch,
+// - n_cell: number of cells (or units),
+// - n_input: the input size,
+// - n_output: the output size.
+//
+// The pointers to the cell and output state and the output are updated.
+//
+// The pointers with the suffix "_batch" point to data aligned in batch_major
+// order, and each step processes batch_size many inputs from input_ptr_batch,
+// and updates batch_size many cell and output states.
+inline void LstmStepWithAuxInput(
+ const float* input_ptr_batch, const float* input_to_input_weights_ptr,
+ const float* input_to_forget_weights_ptr,
+ const float* input_to_cell_weights_ptr,
+ const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch,
+ const float* aux_input_to_input_weights_ptr,
+ const float* aux_input_to_forget_weights_ptr,
+ const float* aux_input_to_cell_weights_ptr,
+ const float* aux_input_to_output_weights_ptr,
+ const float* recurrent_to_input_weights_ptr,
+ const float* recurrent_to_forget_weights_ptr,
+ const float* recurrent_to_cell_weights_ptr,
+ const float* recurrent_to_output_weights_ptr,
+ const float* cell_to_input_weights_ptr,
+ const float* cell_to_forget_weights_ptr,
+ const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
+ const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+ const float* output_gate_bias_ptr, const float* projection_weights_ptr,
+ const float* projection_bias_ptr, const TfLiteLSTMParams* params,
+ int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
+ float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
+ float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
+ float* output_ptr_batch) {
+ // Since we have already checked that weights are all there or none, we can
+ // check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights_ptr == nullptr);
+ const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
+ // Initialize scratch buffers with bias.
+ if (!use_cifg) {
+ tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
+ forget_gate_scratch);
+ tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
+ cell_scratch);
+ tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
+ output_gate_scratch);
+
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ forget_gate_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ output_gate_scratch, /*result_stride=*/1);
+
+ // If auxiliary input is available then compute aux_input_weight * aux_input
+ if (aux_input_ptr_batch != nullptr) {
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_input_weights_ptr, n_cell, n_aux_input,
+ aux_input_ptr_batch, n_batch, input_gate_scratch,
+ /*result_stride=*/1);
+ }
+
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
+ aux_input_ptr_batch, n_batch, forget_gate_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr_batch,
+ n_batch, cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_output_weights_ptr, n_cell, n_aux_input,
+ aux_input_ptr_batch, n_batch, output_gate_scratch, /*result_stride=*/1);
+ }
+
+ // For each batch and cell: compute recurrent_weight * output_state.
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, input_gate_scratch, /*result_stride=*/1);
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, forget_gate_scratch,
+ /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, output_gate_scratch,
+ /*result_stride=*/1);
+
+ // For each batch and cell: update input gate.
+ if (!use_cifg) {
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
+ input_gate_scratch);
+ }
+
+ // For each batch and cell: update forget gate.
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
+ forget_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
+ forget_gate_scratch);
+
+ // For each batch and cell: update the cell.
+ tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
+ n_batch * n_cell, cell_state_ptr);
+ tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
+ params->activation, cell_scratch);
+ if (use_cifg) {
+ tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
+ forget_gate_scratch);
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ } else {
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ }
+ if (params->cell_clip > 0.0) {
+ tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
+ params->cell_clip, cell_state_ptr);
+ }
+
+ // For each batch and cell: update the output gate.
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
+ output_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
+ output_gate_scratch);
+ tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
+ params->activation, cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
+ n_batch * n_cell, output_gate_scratch);
+
+ // For each batch: update the projection and output_state.
+ const bool use_projection_weight = (projection_weights_ptr != nullptr);
+ const bool use_projection_bias = (projection_bias_ptr != nullptr);
+ if (use_projection_weight) {
+ if (use_projection_bias) {
+ tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
+ n_batch, output_ptr_batch);
+ } else {
+ tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch,
+ output_ptr_batch, /*result_stride=*/1);
+ if (params->proj_clip > 0.0) {
+ tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
+ params->proj_clip, output_ptr_batch);
+ }
+ } else {
+ tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
+ output_ptr_batch);
+ }
+ tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
+ output_state_ptr);
+}
+
+// Same as above but with quantized weight matrices. In detail:
+// Input of size 'n_batch * n_input':
+// input_ptr_batch
+//
+// LSTM weights:
+// Quantized input weights of size 'n_cell * n_input':
+// input_to_input_weights - optional (can be nullptr)
+// input_to_forget_weights
+// input_to_cell_weights
+// input_to_input_weights
+// Quantized recurrent weights of size 'n_cell * n_output':
+// recurrent_to_input_weights - optional
+// recurrent_to_forget_weights
+// recurrent_to_cell_weights
+// recurrent_to_input_weights
+// Quantized peephole weights of size 'n_cell', representing diagonal matrices.
+// cell_to_input_weights - optional
+// cell_to_cell_weights - optional
+// cell_to_output_weights - optional
+// Quantized projection weights of size 'n_output * n_cell'
+// projection_weights_ptr - optional
+// Weight scales (scalars) for each of the weights above.
+// input_to_input_weights_scale - optional
+// input_to_forget_weights_scale
+// input_to_cell_weights_scale
+// input_to_output_weights_scale
+// recurrent_to_input_weights_scale - optional
+// recurrent_to_forget_weights_scale
+// recurrent_to_cell_weights_scale
+// recurrent_to_output_weights_scale
+// cell_to_input_weights_scale,
+// cell_to_forget_weights_scale,
+// cell_to_output_weights_scale,
+// projection_weights_scale - optional
+// Gate biases of size 'n_cell':
+// input_gate_bias_ptr - optional
+// forget_gate_bias_ptr
+// cell_gate_bias_ptr
+// output_gate_bias_ptr
+//
+// Temporary pre-allocated storage for quantized values:
+// quantized_input_ptr_batch (same size as input_ptr_batch)
+// quantized_output_state_ptr (same size as output_state_ptr)
+// quantized_cell_state_ptr (same size as cell_state_ptr)
+// Temporary pre-allocated storage for recovered values:
+// recovered_cell_weights (same size as cell_to_*_weights)
+//
+// Outputs:
+// output_state_ptr - size 'n_batch * n_output'
+// cell_state_ptr - size 'n_batch * n_cell'
+// output_ptr_batch - size 'n_batch * n_output'
+inline void LstmStepWithAuxInput(
+ const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
+ float input_to_input_weights_scale,
+ const int8_t* input_to_forget_weights_ptr,
+ float input_to_forget_weights_scale,
+ const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
+ const int8_t* input_to_output_weights_ptr,
+ float input_to_output_weights_scale, const float* aux_input_ptr_batch,
+ const int8_t* aux_input_to_input_weights_ptr,
+ float aux_input_to_input_weights_scale,
+ const int8_t* aux_input_to_forget_weights_ptr,
+ float aux_input_to_forget_weights_scale,
+ const int8_t* aux_input_to_cell_weights_ptr,
+ float aux_input_to_cell_weights_scale,
+ const int8_t* aux_input_to_output_weights_ptr,
+ float aux_input_to_output_weights_scale,
+ const int8_t* recurrent_to_input_weights_ptr,
+ float recurrent_to_input_weights_scale,
+ const int8_t* recurrent_to_forget_weights_ptr,
+ float recurrent_to_forget_weights_scale,
+ const int8_t* recurrent_to_cell_weights_ptr,
+ float recurrent_to_cell_weights_scale,
+ const int8_t* recurrent_to_output_weights_ptr,
+ float recurrent_to_output_weights_scale,
+ const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
+ const int8_t* cell_to_forget_weights_ptr,
+ float cell_to_forget_weights_scale,
+ const int8_t* cell_to_output_weights_ptr,
+ float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
+ const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+ const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
+ float projection_weights_scale, const float* projection_bias_ptr,
+ const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
+ int n_aux_input, int n_output, float* input_gate_scratch,
+ float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
+ float* scaling_factors, float* product_scaling_factors,
+ float* recovered_cell_weights, int8_t* quantized_input_ptr_batch,
+ int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr,
+ int8_t* quantized_cell_state_ptr, float* output_state_ptr,
+ float* cell_state_ptr, float* output_ptr_batch) {
+ // Since we have already checked that weights are all there or none, we
+ // can check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights_ptr == nullptr);
+ const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
+ // Initialize scratch buffers with bias.
+ if (!use_cifg) {
+ tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
+ forget_gate_scratch);
+ tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
+ cell_scratch);
+ tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
+ output_gate_scratch);
+
+ if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_input;
+ tensor_utils::SymmetricQuantizeFloats(
+ input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset,
+ &unused_min, &unused_max, &scaling_factors[b]);
+ }
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_input_weights_ptr, n_cell, n_input,
+ quantized_input_ptr_batch, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+ product_scaling_factors, n_batch, forget_gate_scratch,
+ /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+ product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+ product_scaling_factors, n_batch, output_gate_scratch,
+ /*result_stride=*/1);
+ }
+
+ if (aux_input_ptr_batch != nullptr &&
+ !tensor_utils::IsZeroVector(aux_input_ptr_batch, n_batch * n_input)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_input;
+ tensor_utils::SymmetricQuantizeFloats(
+ aux_input_ptr_batch + offset, n_input,
+ quantized_aux_input_ptr_batch + offset, &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * aux_input_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_input_weights_ptr, n_cell, n_input,
+ quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * aux_input_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_forget_weights_ptr, n_cell, n_input,
+ quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
+ forget_gate_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * aux_input_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_cell_weights_ptr, n_cell, n_input,
+ quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
+ cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * aux_input_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_output_weights_ptr, n_cell, n_input,
+ quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
+ output_gate_scratch, /*result_stride=*/1);
+ }
+
+ if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_output;
+ tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output,
+ quantized_output_state_ptr + offset,
+ &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ // For each batch and cell: compute recurrent_weight * output_state.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_input_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_forget_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ forget_gate_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_cell_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_output_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ output_gate_scratch, /*result_stride=*/1);
+ }
+
+ // Save quantization and matmul computation for all zero input.
+ bool is_cell_state_all_zeros =
+ tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
+
+ // For each batch and cell: update input gate.
+ if (!use_cifg) {
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell,
+ cell_to_input_weights_scale,
+ recovered_cell_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
+ input_gate_scratch);
+ }
+
+ // For each batch and cell: update forget gate.
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell,
+ cell_to_forget_weights_scale,
+ recovered_cell_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
+ forget_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
+ forget_gate_scratch);
+
+ // For each batch and cell: update the cell.
+ tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
+ n_batch * n_cell, cell_state_ptr);
+ tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
+ params->activation, cell_scratch);
+ if (use_cifg) {
+ tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
+ forget_gate_scratch);
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ } else {
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ }
+ if (params->cell_clip > 0.0) {
+ tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
+ params->cell_clip, cell_state_ptr);
+ }
+
+ is_cell_state_all_zeros =
+ tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
+ // For each batch and cell: update the output gate.
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell,
+ cell_to_output_weights_scale,
+ recovered_cell_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
+ output_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
+ output_gate_scratch);
+ tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
+ params->activation, cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
+ n_batch * n_cell, output_gate_scratch);
+
+ // For each batch: update the projection and output_state.
+ const bool use_projection_weight = (projection_weights_ptr != nullptr);
+ const bool use_projection_bias = (projection_bias_ptr != nullptr);
+ if (use_projection_weight) {
+ if (use_projection_bias) {
+ tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
+ n_batch, output_ptr_batch);
+ } else {
+ tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
+ }
+ if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_cell;
+ tensor_utils::SymmetricQuantizeFloats(
+ output_gate_scratch + offset, n_cell,
+ quantized_cell_state_ptr + offset, &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * projection_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr,
+ product_scaling_factors, n_batch, output_ptr_batch,
+ /*result_stride=*/1);
+ }
+ if (params->proj_clip > 0.0) {
+ tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
+ params->proj_clip, output_ptr_batch);
+ }
+ } else {
+ tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
+ output_ptr_batch);
+ }
+ tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
+ output_state_ptr);
+}
+} // namespace
+
+TfLiteStatus EvalFloat(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
+ const TfLiteTensor* aux_input_to_input_weights,
+ const TfLiteTensor* aux_input_to_forget_weights,
+ const TfLiteTensor* aux_input_to_cell_weights,
+ const TfLiteTensor* aux_input_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
+ TfLiteTensor* cell_state, TfLiteTensor* output) {
+ TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
+ const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0];
+ const int n_batch = input->dims->data[input->dims->size - 2];
+ const int n_input = input->dims->data[input->dims->size - 1];
+ const int aux_input_size =
+ (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
+
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ // Index the scratch buffers pointers to the global scratch buffer.
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Check optional tensors, the respective pointers can be null.
+ const float* input_to_input_weights_ptr =
+ (use_cifg) ? nullptr : input_to_input_weights->data.f;
+ const float* recurrent_to_input_weights_ptr =
+ (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
+ const float* input_gate_bias_ptr =
+ (use_cifg) ? nullptr : input_gate_bias->data.f;
+ const float* cell_to_input_weights_ptr =
+ (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
+ const float* cell_to_forget_weights_ptr =
+ (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
+ const float* cell_to_output_weights_ptr =
+ (use_peephole) ? cell_to_output_weights->data.f : nullptr;
+ const float* projection_weights_ptr =
+ (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
+ const float* projection_bias_ptr =
+ (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+ float* aux_input_ptr = nullptr;
+ float* aux_input_to_input_weights_ptr = nullptr;
+ float* aux_input_to_forget_weights_ptr = nullptr;
+ float* aux_input_to_cell_weights_ptr = nullptr;
+ float* aux_input_to_output_weights_ptr = nullptr;
+ if (aux_input_size > 0) {
+ aux_input_ptr = aux_input->data.f;
+ aux_input_to_input_weights_ptr = aux_input_to_input_weights->data.f;
+ aux_input_to_forget_weights_ptr = aux_input_to_forget_weights->data.f;
+ aux_input_to_cell_weights_ptr = aux_input_to_cell_weights->data.f;
+ aux_input_to_output_weights_ptr = aux_input_to_output_weights->data.f;
+ }
+
+ // Loop through the sequence.
+ const int input_step = n_batch * n_input;
+ const int output_step = n_batch * output->dims->data[output->dims->size - 1];
+ for (int t = 0; t < max_time; t++) {
+ // If this is the forward_sequence, step forward, otherwise step backwards.
+ const int t_rel = forward_sequence ? t : max_time - t - 1;
+ const float* input_ptr = input->data.f + t_rel * input_step;
+ float* output_ptr_time =
+ output->data.f + t_rel * output_step + output_offset;
+
+ LstmStepWithAuxInput(
+ input_ptr, input_to_input_weights_ptr, input_to_forget_weights->data.f,
+ input_to_cell_weights->data.f, input_to_output_weights->data.f,
+ aux_input_ptr, aux_input_to_input_weights_ptr,
+ aux_input_to_forget_weights_ptr, aux_input_to_cell_weights_ptr,
+ aux_input_to_output_weights_ptr, recurrent_to_input_weights_ptr,
+ recurrent_to_forget_weights->data.f, recurrent_to_cell_weights->data.f,
+ recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
+ cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
+ input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
+ output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
+ params, n_batch, n_cell, n_input, aux_input_size, n_output,
+ activation_state->data.f, cell_state->data.f, input_gate_scratch,
+ forget_gate_scratch, cell_scratch, output_gate_scratch,
+ output_ptr_time);
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalHybrid(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
+ const TfLiteTensor* aux_input_to_input_weights,
+ const TfLiteTensor* aux_input_to_forget_weights,
+ const TfLiteTensor* aux_input_to_cell_weights,
+ const TfLiteTensor* aux_input_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
+ TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
+ TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
+ TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
+ TfLiteTensor* output_state, TfLiteTensor* cell_state,
+ TfLiteTensor* output) {
+ TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
+ const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0];
+ const int n_batch = input->dims->data[input->dims->size - 2];
+ const int n_input = input->dims->data[input->dims->size - 1];
+ const int aux_input_size =
+ (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existence of only one to get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Check optional tensors, the respective pointers can be null.
+ int8_t* input_to_input_weights_ptr = nullptr;
+ float input_to_input_weights_scale = 1.0f;
+ int8_t* recurrent_to_input_weights_ptr = nullptr;
+ float recurrent_to_input_weights_scale = 1.0f;
+ float* input_gate_bias_ptr = nullptr;
+ if (!use_cifg) {
+ input_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
+ recurrent_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
+ input_gate_bias_ptr = input_gate_bias->data.f;
+ input_to_input_weights_scale = input_to_input_weights->params.scale;
+ recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
+ }
+
+ int8_t* cell_to_input_weights_ptr = nullptr;
+ int8_t* cell_to_forget_weights_ptr = nullptr;
+ int8_t* cell_to_output_weights_ptr = nullptr;
+ float cell_to_input_weights_scale = 1.0f;
+ float cell_to_forget_weights_scale = 1.0f;
+ float cell_to_output_weights_scale = 1.0f;
+ if (use_peephole) {
+ if (!use_cifg) {
+ cell_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
+ cell_to_input_weights_scale = cell_to_input_weights->params.scale;
+ }
+ cell_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
+ cell_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
+ cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
+ cell_to_output_weights_scale = cell_to_output_weights->params.scale;
+ }
+
+ const int8_t* projection_weights_ptr =
+ (projection_weights == nullptr)
+ ? nullptr
+ : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
+ const float projection_weights_scale =
+ (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
+ const float* projection_bias_ptr =
+ (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+ // Required tensors, pointers are non-null.
+ const int8_t* input_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
+ const float input_to_forget_weights_scale =
+ input_to_forget_weights->params.scale;
+ const int8_t* input_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
+ const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
+ const int8_t* input_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
+ const float input_to_output_weights_scale =
+ input_to_output_weights->params.scale;
+ const int8_t* recurrent_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
+ const float recurrent_to_forget_weights_scale =
+ recurrent_to_forget_weights->params.scale;
+ const int8_t* recurrent_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
+ const float recurrent_to_cell_weights_scale =
+ recurrent_to_cell_weights->params.scale;
+ const int8_t* recurrent_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
+ const float recurrent_to_output_weights_scale =
+ recurrent_to_output_weights->params.scale;
+ const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
+ const float* cell_bias_ptr = cell_bias->data.f;
+ const float* output_gate_bias_ptr = output_gate_bias->data.f;
+
+ float* output_state_ptr = output_state->data.f;
+ float* cell_state_ptr = cell_state->data.f;
+
+ // Temporary storage for quantized values and scaling factors.
+ int8_t* quantized_input_ptr =
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+ int8_t* quantized_aux_input_ptr =
+ (aux_input_quantized == nullptr)
+ ? nullptr
+ : reinterpret_cast<int8_t*>(aux_input_quantized->data.uint8);
+ int8_t* quantized_output_state_ptr =
+ reinterpret_cast<int8_t*>(output_state_quantized->data.uint8);
+ int8_t* quantized_cell_state_ptr =
+ reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
+ float* scaling_factors_ptr = scaling_factors->data.f;
+ float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
+ float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
+
+ // Auxiliary input and weights.
+ float* aux_input_ptr = nullptr;
+ int8_t* aux_input_to_input_weights_ptr = nullptr;
+ int8_t* aux_input_to_forget_weights_ptr = nullptr;
+ int8_t* aux_input_to_cell_weights_ptr = nullptr;
+ int8_t* aux_input_to_output_weights_ptr = nullptr;
+ float aux_input_to_input_weights_scale = 0.0f;
+ float aux_input_to_forget_weights_scale = 0.0f;
+ float aux_input_to_cell_weights_scale = 0.0f;
+ float aux_input_to_output_weights_scale = 0.0f;
+ if (aux_input_size > 0) {
+ aux_input_ptr = aux_input->data.f;
+ aux_input_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_input_to_input_weights->data.uint8);
+ aux_input_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_input_to_forget_weights->data.uint8);
+ aux_input_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_input_to_cell_weights->data.uint8);
+ aux_input_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_input_to_output_weights->data.uint8);
+ aux_input_to_input_weights_scale = aux_input_to_input_weights->params.scale;
+ aux_input_to_forget_weights_scale =
+ aux_input_to_forget_weights->params.scale;
+ aux_input_to_cell_weights_scale = aux_input_to_cell_weights->params.scale;
+ aux_input_to_output_weights_scale =
+ aux_input_to_output_weights->params.scale;
+ }
+
+ // Feed the sequence into the LSTM step-by-step.
+ const int input_step = n_batch * n_input;
+ const int output_step = n_batch * output->dims->data[output->dims->size - 1];
+ for (int t = 0; t < max_time; t++) {
+ // If this is the forward_sequence, step forward, otherwise step backwards.
+ const int t_rel = forward_sequence ? t : max_time - t - 1;
+ const float* input_ptr = input->data.f + t_rel * input_step;
+ float* output_ptr = output->data.f + t_rel * output_step + output_offset;
+
+ LstmStepWithAuxInput(
+ input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
+ input_to_forget_weights_ptr, input_to_forget_weights_scale,
+ input_to_cell_weights_ptr, input_to_cell_weights_scale,
+ input_to_output_weights_ptr, input_to_output_weights_scale,
+ aux_input_ptr, aux_input_to_input_weights_ptr,
+ aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
+ aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
+ aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
+ aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
+ recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
+ recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
+ recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
+ recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
+ cell_to_input_weights_scale, cell_to_forget_weights_ptr,
+ cell_to_forget_weights_scale, cell_to_output_weights_ptr,
+ cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr,
+ cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
+ projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell,
+ n_input, aux_input_size, n_output, input_gate_scratch,
+ forget_gate_scratch, cell_scratch, output_gate_scratch,
+ scaling_factors_ptr, prod_scaling_factors_ptr,
+ recovered_cell_weights_ptr, quantized_input_ptr,
+ quantized_aux_input_ptr, quantized_output_state_ptr,
+ quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, output_ptr);
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace lstm_eval
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/lstm_eval.h b/tensorflow/contrib/lite/kernels/lstm_eval.h
new file mode 100644
index 0000000000..adf8cf0f64
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/lstm_eval.h
@@ -0,0 +1,79 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_LSTM_EVAL_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_LSTM_EVAL_H_
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace lstm_eval {
+
+TfLiteStatus EvalFloat(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
+ const TfLiteTensor* aux_input_to_input_weights,
+ const TfLiteTensor* aux_input_to_forget_weights,
+ const TfLiteTensor* aux_input_to_cell_weights,
+ const TfLiteTensor* aux_input_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
+ TfLiteTensor* cell_state, TfLiteTensor* output);
+
+TfLiteStatus EvalHybrid(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
+ const TfLiteTensor* aux_input_to_input_weights,
+ const TfLiteTensor* aux_input_to_forget_weights,
+ const TfLiteTensor* aux_input_to_cell_weights,
+ const TfLiteTensor* aux_input_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
+ TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
+ TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
+ TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
+ TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output);
+
+} // namespace lstm_eval
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_LSTM_EVAL_H_
diff --git a/tensorflow/contrib/lite/kernels/sparse_output_fully_connected.cc b/tensorflow/contrib/lite/kernels/sparse_output_fully_connected.cc
new file mode 100644
index 0000000000..843ed0768c
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/sparse_output_fully_connected.cc
@@ -0,0 +1,235 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// SparseOutputFullyConnected is a fully connected layer that uses a single
+// row in the weights and bias via a lookup.
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace sparse_output_fully_connected {
+
+// Input tensors of size {n_batch, n_input}
+constexpr int kInputTensor = 0;
+// Auxiliary input tensor of size { 1 }
+constexpr int kInputLookupTensor = 1;
+
+// Weights tensor of size { n_embeddings , n_input }
+constexpr int kWeightsTensor = 2;
+// Bias tensor of size { n_embeddings }
+constexpr int kBiasTensor = 3;
+
+// Output tensor.
+constexpr int kOutputTensor = 0;
+
+// Temporary tensors.
+enum TemporaryTensor {
+ kInputQuantized = 0,
+ kScalingFactors = 1,
+ kNumTemporaryTensors = 2
+};
+
+// Struct to hold op data.
+struct OpData {
+ int scratch_tensor_index;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* data = new OpData;
+ context->AddTensors(context, /*tensors_to_add=*/kNumTemporaryTensors,
+ &data->scratch_tensor_index);
+ return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2);
+ const int n_batch = SizeOfDimension(input, 0);
+ const int n_input = SizeOfDimension(input, 1);
+
+ const TfLiteTensor* lookup = GetInput(context, node, kInputLookupTensor);
+ TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1);
+ // Only support single lookup.
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(lookup, 0), 1);
+
+ const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 2);
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(weights, 1), n_input);
+
+ const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
+ TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(weights, 0));
+
+ const bool is_hybrid_op =
+ (weights->type == kTfLiteUInt8 && input->type == kTfLiteFloat32);
+
+ if (is_hybrid_op) {
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
+
+ // Allocate temporary tensors to store quantized values of input.
+ node->temporaries->data[kInputQuantized] = op_data->scratch_tensor_index;
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, /*index=*/kInputQuantized);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
+
+ // Tell interpreter to allocate temporary tensors to store scaling factors.
+ node->temporaries->data[kScalingFactors] =
+ op_data->scratch_tensor_index + kScalingFactors;
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, /*index=*/kScalingFactors);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalFloat(const TfLiteTensor* input, const TfLiteTensor* lookup,
+ const TfLiteTensor* weights, const TfLiteTensor* bias,
+ TfLiteTensor* output) {
+ const int n_batch = SizeOfDimension(input, 0);
+ const int n_input = SizeOfDimension(input, 1);
+
+ const float* input_ptr_batch = input->data.f;
+
+ // Initialize pointer to right row according to lookup value.
+ int32 lookup_index = lookup->data.i32[0];
+ const float* weights_ptr = weights->data.f + lookup_index * n_input;
+
+ // Initialize output to bias.
+ if (bias) {
+ float* bias_ptr = bias->data.f + lookup_index;
+ tensor_utils::VectorBatchVectorAssign(bias_ptr, 1, n_batch, output->data.f);
+ } else {
+ tensor_utils::ZeroVector(output->data.f, n_batch * 1);
+ }
+
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ weights_ptr, /*m_rows=*/1, n_input, input_ptr_batch, n_batch,
+ output->data.f, /*result_stride=*/1);
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalHybrid(const TfLiteTensor* input, const TfLiteTensor* lookup,
+ const TfLiteTensor* weights, const TfLiteTensor* bias,
+ TfLiteTensor* scaling_factors,
+ TfLiteTensor* input_quantized, TfLiteTensor* output) {
+ const int n_batch = SizeOfDimension(input, 0);
+ const int n_input = SizeOfDimension(input, 1);
+
+ const float* input_ptr_batch = input->data.f;
+ // Initialize the pointer to storage for quantized values and
+ // scaling factors.
+ int8_t* quantized_input_ptr_batch =
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+ float* scaling_factors_ptr = scaling_factors->data.f;
+
+ // Initialize pointer to right row according to lookup value.
+ int32 lookup_index = lookup->data.i32[0];
+ int8_t* weights_ptr =
+ reinterpret_cast<int8_t*>(weights->data.uint8) + lookup_index * n_input;
+
+ // Initialize output to bias.
+ if (bias) {
+ float* bias_ptr = bias->data.f + lookup_index;
+ tensor_utils::VectorBatchVectorAssign(bias_ptr, 1, n_batch, output->data.f);
+ } else {
+ tensor_utils::ZeroVector(output->data.f, n_batch * 1);
+ }
+
+ if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
+ // Quantize input from float to int8.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_input;
+ tensor_utils::SymmetricQuantizeFloats(
+ input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset,
+ &unused_min, &unused_max, &scaling_factors_ptr[b]);
+ scaling_factors_ptr[b] *= weights->params.scale;
+ }
+
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ weights_ptr, /*m_rows=*/1, n_input, quantized_input_ptr_batch,
+ scaling_factors_ptr, n_batch, output->data.f, /*result_stride=*/1);
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* lookup = GetInput(context, node, kInputLookupTensor);
+ const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor);
+ const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ switch (weights->type) {
+ case kTfLiteFloat32: {
+ return EvalFloat(input, lookup, weights, bias, output);
+ }
+ case kTfLiteUInt8: {
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, /*index=*/kInputQuantized);
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, /*index=*/kScalingFactors);
+ return EvalHybrid(input, lookup, weights, bias, scaling_factors,
+ input_quantized, output);
+ }
+ default:
+ context->ReportError(context, "Type %d is not currently supported.",
+ weights->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+} // namespace sparse_output_fully_connected
+
+TfLiteRegistration* Register_SPARSE_OUTPUT_FULLY_CONNECTED() {
+ static TfLiteRegistration r = {sparse_output_fully_connected::Init,
+ sparse_output_fully_connected::Free,
+ sparse_output_fully_connected::Prepare,
+ sparse_output_fully_connected::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/sparse_output_fully_connected_test.cc b/tensorflow/contrib/lite/kernels/sparse_output_fully_connected_test.cc
new file mode 100644
index 0000000000..365986a5c1
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/sparse_output_fully_connected_test.cc
@@ -0,0 +1,158 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite sparse output fully connected op.
+#include <iomanip>
+#include <random>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+
+namespace tflite {
+
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_SPARSE_OUTPUT_FULLY_CONNECTED();
+
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BaseSparseOutputFullyConnectedOpModel : public SingleOpModel {
+ public:
+ BaseSparseOutputFullyConnectedOpModel(const TensorData& input,
+ const TensorData& weights,
+ const TensorData& output = {
+ TensorType_FLOAT32}) {
+ input_ = AddInput(input);
+ lookup_ = AddInput({TensorType_INT32, {1}});
+ weights_ = AddInput(weights);
+ int bias_size = GetShape(weights_)[0];
+ bias_ = AddInput({TensorType_FLOAT32, {bias_size}});
+ output_ = AddOutput(output);
+
+ // Create empty (required) options map.
+ flexbuffers::Builder fbb;
+ fbb.Map([&]() {});
+ fbb.Finish();
+
+ SetCustomOp("SPARSE_OUTPUT_FULLY_CONNECTED", fbb.GetBuffer(),
+ Register_SPARSE_OUTPUT_FULLY_CONNECTED);
+ BuildInterpreter({GetShape(input_), GetShape(lookup_), GetShape(weights_),
+ GetShape(bias_)});
+ }
+
+ void SetInput(const std::vector<float>& data) {
+ PopulateTensor(input_, data);
+ }
+
+ void SetLookup(const std::vector<int32>& f) { PopulateTensor(lookup_, f); }
+
+ void SetBias(const std::vector<float>& f) { PopulateTensor(bias_, f); }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ protected:
+ int input_;
+ int lookup_;
+ int weights_;
+ int bias_;
+ int output_;
+};
+
+class FloatSparseOutputFullyConnectedOpModel
+ : public BaseSparseOutputFullyConnectedOpModel {
+ public:
+ using BaseSparseOutputFullyConnectedOpModel::
+ BaseSparseOutputFullyConnectedOpModel;
+
+ void SetWeights(const std::vector<float>& f) { PopulateTensor(weights_, f); }
+};
+
+class HybridSparseOutputFullyConnectedOpModel
+ : public BaseSparseOutputFullyConnectedOpModel {
+ public:
+ using BaseSparseOutputFullyConnectedOpModel::
+ BaseSparseOutputFullyConnectedOpModel;
+
+ void SetWeights(const std::vector<float>& f) {
+ SymmetricQuantizeAndPopulate(weights_, f);
+ }
+};
+
+TEST(SparseOutputFullyConnectedOpTest, SimpleTestFloat) {
+ FloatSparseOutputFullyConnectedOpModel m({TensorType_FLOAT32, {1, 5}},
+ {TensorType_FLOAT32, {3, 5}},
+ {TensorType_FLOAT32, {}});
+
+ m.SetInput({-1.0, 0.0, 1.0, 2.0, 3.0});
+
+ m.SetLookup({2});
+
+ m.SetWeights({
+ -1.0, 0.0, 1.0, 2.0, 3.0, //
+ 0.0, 1.0, 2.0, 3.0, 4.0, //
+ 1.0, 2.0, 3.0, 4.0, 5.0, //
+ });
+
+ m.SetBias({1.0, 2.0, 3.0});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({28}));
+}
+
+TEST(SparseOutputFullyConnectedOpTest, SimpleTestHybrid) {
+ HybridSparseOutputFullyConnectedOpModel m({TensorType_FLOAT32, {1, 5}},
+ {TensorType_UINT8, {3, 5}},
+ {TensorType_FLOAT32, {}});
+
+ m.SetInput({-1.0, 0.0, 1.0, 2.0, 3.0});
+
+ m.SetLookup({2});
+
+ m.SetWeights({
+ -1.0, 0.0, 1.0, 2.0, 3.0, //
+ 0.0, 1.0, 2.0, 3.0, 4.0, //
+ 1.0, 2.0, 3.0, 4.0, 5.0, //
+ });
+
+ m.SetBias({1.0, 2.0, 3.0});
+
+ m.Invoke();
+
+ // We get 28.0552 instead of 28.
+ //
+ // Input -> -42, 0, 42, 85, 127 with scale factor of 127/3.
+ // Looked up weights -> 25, 51, 76, 102, 127 with scale factor of 127/5.
+ //
+ // (-42 * 25 + 0 * 51 + 42 * 76 + 85 * 102 + 127 * 127) * (3*5/127^2) + 3.0
+ // gives us the expected result.
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({28}, 0.0553)));
+}
+
+} // namespace
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
index 63817bd886..ec9cf38b83 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/lstm_eval.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
@@ -429,273 +430,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
-// The LSTM Op engine.
-TfLiteStatus EvalFloat(
- const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
- const TfLiteTensor* input_to_forget_weights,
- const TfLiteTensor* input_to_cell_weights,
- const TfLiteTensor* input_to_output_weights,
- const TfLiteTensor* recurrent_to_input_weights,
- const TfLiteTensor* recurrent_to_forget_weights,
- const TfLiteTensor* recurrent_to_cell_weights,
- const TfLiteTensor* recurrent_to_output_weights,
- const TfLiteTensor* cell_to_input_weights,
- const TfLiteTensor* cell_to_forget_weights,
- const TfLiteTensor* cell_to_output_weights,
- const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
- const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
- const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
- const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
- TfLiteTensor* activation_state, TfLiteTensor* cell_state,
- TfLiteTensor* output) {
- const int max_time = input->dims->data[0];
- const int n_batch = input->dims->data[1];
- const int n_input = input->dims->data[2];
- // n_cell and n_output will be the same size when there is no projection.
- const int n_cell = input_to_output_weights->dims->data[0];
- const int n_output = recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existence of only one to get the condition.
- const bool use_cifg = (input_to_input_weights == nullptr);
- const bool use_peephole = (cell_to_output_weights != nullptr);
-
- float* input_gate_scratch = nullptr;
- float* cell_scratch = nullptr;
- float* forget_gate_scratch = nullptr;
- float* output_gate_scratch = nullptr;
- if (use_cifg) {
- cell_scratch = scratch_buffer->data.f;
- forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- } else {
- input_gate_scratch = scratch_buffer->data.f;
- cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
- forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- const float* input_to_input_weights_ptr =
- (use_cifg) ? nullptr : input_to_input_weights->data.f;
- const float* recurrent_to_input_weights_ptr =
- (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
- const float* input_gate_bias_ptr =
- (use_cifg) ? nullptr : input_gate_bias->data.f;
- const float* cell_to_input_weights_ptr =
- (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
- const float* cell_to_forget_weights_ptr =
- (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
- const float* cell_to_output_weights_ptr =
- (use_peephole) ? cell_to_output_weights->data.f : nullptr;
- const float* projection_weights_ptr =
- (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
- const float* projection_bias_ptr =
- (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
-
- // Required tensors, pointers are non-null.
- const float* input_to_forget_weights_ptr = input_to_forget_weights->data.f;
- const float* input_to_cell_weights_ptr = input_to_cell_weights->data.f;
- const float* input_to_output_weights_ptr = input_to_output_weights->data.f;
- const float* recurrent_to_forget_weights_ptr =
- recurrent_to_forget_weights->data.f;
- const float* recurrent_to_cell_weights_ptr =
- recurrent_to_cell_weights->data.f;
- const float* recurrent_to_output_weights_ptr =
- recurrent_to_output_weights->data.f;
- const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
- const float* cell_bias_ptr = cell_bias->data.f;
- const float* output_gate_bias_ptr = output_gate_bias->data.f;
-
- float* activation_state_ptr = activation_state->data.f;
- float* cell_state_ptr = cell_state->data.f;
-
- // Feed the sequence into the LSTM step-by-step.
- for (int t = 0; t < max_time; t++) {
- const float* input_ptr_batch = input->data.f + t * n_batch * n_input;
- float* output_ptr_batch = output->data.f + t * n_batch * n_output;
-
- kernel_utils::LstmStep(
- input_ptr_batch, input_to_input_weights_ptr,
- input_to_forget_weights_ptr, input_to_cell_weights_ptr,
- input_to_output_weights_ptr, recurrent_to_input_weights_ptr,
- recurrent_to_forget_weights_ptr, recurrent_to_cell_weights_ptr,
- recurrent_to_output_weights_ptr, cell_to_input_weights_ptr,
- cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
- input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
- output_gate_bias_ptr, projection_weights_ptr, projection_bias_ptr,
- params, n_batch, n_cell, n_input, n_output, activation_state_ptr,
- cell_state_ptr, input_gate_scratch, forget_gate_scratch, cell_scratch,
- output_gate_scratch, output_ptr_batch);
- }
- return kTfLiteOk;
-}
-
-TfLiteStatus EvalHybrid(
- const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
- const TfLiteTensor* input_to_forget_weights,
- const TfLiteTensor* input_to_cell_weights,
- const TfLiteTensor* input_to_output_weights,
- const TfLiteTensor* recurrent_to_input_weights,
- const TfLiteTensor* recurrent_to_forget_weights,
- const TfLiteTensor* recurrent_to_cell_weights,
- const TfLiteTensor* recurrent_to_output_weights,
- const TfLiteTensor* cell_to_input_weights,
- const TfLiteTensor* cell_to_forget_weights,
- const TfLiteTensor* cell_to_output_weights,
- const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
- const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
- const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
- const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
- TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
- TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
- TfLiteTensor* activation_state_quantized,
- TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state,
- TfLiteTensor* cell_state, TfLiteTensor* output) {
- const int max_time = input->dims->data[0];
- const int n_batch = input->dims->data[1];
- const int n_input = input->dims->data[2];
- // n_cell and n_output will be the same size when there is no projection.
- const int n_cell = input_to_output_weights->dims->data[0];
- const int n_output = recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existence of only one to get the condition.
- const bool use_cifg = (input_to_input_weights == nullptr);
- const bool use_peephole = (cell_to_output_weights != nullptr);
-
- float* input_gate_scratch = nullptr;
- float* cell_scratch = nullptr;
- float* forget_gate_scratch = nullptr;
- float* output_gate_scratch = nullptr;
- if (use_cifg) {
- cell_scratch = scratch_buffer->data.f;
- forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- } else {
- input_gate_scratch = scratch_buffer->data.f;
- cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
- forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- int8_t* input_to_input_weights_ptr = nullptr;
- float input_to_input_weights_scale = 1.0f;
- int8_t* recurrent_to_input_weights_ptr = nullptr;
- float recurrent_to_input_weights_scale = 1.0f;
- float* input_gate_bias_ptr = nullptr;
- if (!use_cifg) {
- input_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
- recurrent_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
- input_gate_bias_ptr = input_gate_bias->data.f;
- input_to_input_weights_scale = input_to_input_weights->params.scale;
- recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
- }
-
- int8_t* cell_to_input_weights_ptr = nullptr;
- int8_t* cell_to_forget_weights_ptr = nullptr;
- int8_t* cell_to_output_weights_ptr = nullptr;
- float cell_to_input_weights_scale = 1.0f;
- float cell_to_forget_weights_scale = 1.0f;
- float cell_to_output_weights_scale = 1.0f;
- if (use_peephole) {
- if (!use_cifg) {
- cell_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
- cell_to_input_weights_scale = cell_to_input_weights->params.scale;
- }
- cell_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
- cell_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
- cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
- cell_to_output_weights_scale = cell_to_output_weights->params.scale;
- }
-
- const int8_t* projection_weights_ptr =
- (projection_weights == nullptr)
- ? nullptr
- : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
- float projection_weights_scale =
- (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
- const float* projection_bias_ptr =
- (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
-
- // Required tensors, pointers are non-null.
- const int8_t* input_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
- const float input_to_forget_weights_scale =
- input_to_forget_weights->params.scale;
- const int8_t* input_to_cell_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
- const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
- const int8_t* input_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
- const float input_to_output_weights_scale =
- input_to_output_weights->params.scale;
- const int8_t* recurrent_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
- const float recurrent_to_forget_weights_scale =
- recurrent_to_forget_weights->params.scale;
- const int8_t* recurrent_to_cell_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
- const float recurrent_to_cell_weights_scale =
- recurrent_to_cell_weights->params.scale;
- const int8_t* recurrent_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
- const float recurrent_to_output_weights_scale =
- recurrent_to_output_weights->params.scale;
- const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
- const float* cell_bias_ptr = cell_bias->data.f;
- const float* output_gate_bias_ptr = output_gate_bias->data.f;
-
- float* activation_state_ptr = activation_state->data.f;
- float* cell_state_ptr = cell_state->data.f;
-
- // Temporary storage for quantized values and scaling factors.
- int8_t* quantized_input_ptr =
- reinterpret_cast<int8_t*>(input_quantized->data.uint8);
- int8_t* quantized_activation_state_ptr =
- reinterpret_cast<int8_t*>(activation_state_quantized->data.uint8);
- int8_t* quantized_cell_state_ptr =
- reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
- float* scaling_factors_ptr = scaling_factors->data.f;
- float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
- float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
-
- // Feed the sequence into the LSTM step-by-step.
- for (int t = 0; t < max_time; t++) {
- const float* input_ptr_batch = input->data.f + t * n_batch * n_input;
- float* output_ptr_batch = output->data.f + t * n_batch * n_output;
-
- kernel_utils::LstmStep(
- input_ptr_batch, input_to_input_weights_ptr,
- input_to_input_weights_scale, input_to_forget_weights_ptr,
- input_to_forget_weights_scale, input_to_cell_weights_ptr,
- input_to_cell_weights_scale, input_to_output_weights_ptr,
- input_to_output_weights_scale, recurrent_to_input_weights_ptr,
- recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
- recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
- recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
- recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
- cell_to_input_weights_scale, cell_to_forget_weights_ptr,
- cell_to_forget_weights_scale, cell_to_output_weights_ptr,
- cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr,
- cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
- projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell,
- n_input, n_output, input_gate_scratch, forget_gate_scratch,
- cell_scratch, output_gate_scratch, scaling_factors_ptr,
- prod_scaling_factors_ptr, recovered_cell_weights_ptr,
- quantized_input_ptr, quantized_activation_state_ptr,
- quantized_cell_state_ptr, activation_state_ptr, cell_state_ptr,
- output_ptr_batch);
- }
- return kTfLiteOk;
-}
-
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
@@ -750,15 +484,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
switch (input_to_output_weights->type) {
case kTfLiteFloat32: {
- return EvalFloat(input, input_to_input_weights, input_to_forget_weights,
- input_to_cell_weights, input_to_output_weights,
- recurrent_to_input_weights, recurrent_to_forget_weights,
- recurrent_to_cell_weights, recurrent_to_output_weights,
- cell_to_input_weights, cell_to_forget_weights,
- cell_to_output_weights, input_gate_bias,
- forget_gate_bias, cell_bias, output_gate_bias,
- projection_weights, projection_bias, params,
- scratch_buffer, activation_state, cell_state, output);
+ return lstm_eval::EvalFloat(
+ input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights,
+ recurrent_to_input_weights, recurrent_to_forget_weights,
+ recurrent_to_cell_weights, recurrent_to_output_weights,
+ cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
+ /*aux_input=*/nullptr,
+ /*aux_input_to_input_weights=*/nullptr,
+ /*aux_input_to_forget_weights=*/nullptr,
+ /*aux_input_to_cell_weights=*/nullptr,
+ /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
+ projection_bias, params, /*forward_sequence=*/true,
+ /*output_offset=*/0, scratch_buffer, activation_state, cell_state,
+ output);
}
case kTfLiteUInt8: {
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
@@ -771,17 +511,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTemporary(context, node, /*index=*/5);
TfLiteTensor* recovered_cell_weights =
GetTemporary(context, node, /*index=*/6);
- return EvalHybrid(
+ return lstm_eval::EvalHybrid(
input, input_to_input_weights, input_to_forget_weights,
input_to_cell_weights, input_to_output_weights,
recurrent_to_input_weights, recurrent_to_forget_weights,
recurrent_to_cell_weights, recurrent_to_output_weights,
cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
- input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias,
- projection_weights, projection_bias, params, scratch_buffer,
- scaling_factors, prod_scaling_factors, recovered_cell_weights,
- input_quantized, activation_state_quantized, cell_state_quantized,
- activation_state, cell_state, output);
+ /*aux_input=*/nullptr,
+ /*aux_input_to_input_weights=*/nullptr,
+ /*aux_input_to_forget_weights=*/nullptr,
+ /*aux_input_to_cell_weights=*/nullptr,
+ /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
+ projection_bias, params, /*forward_sequence=*/true,
+ /*output_offset=*/0, scratch_buffer, scaling_factors,
+ prod_scaling_factors, recovered_cell_weights, input_quantized,
+ /*aux_input_quantized=*/nullptr, activation_state_quantized,
+ cell_state_quantized, activation_state, cell_state, output);
}
default:
context->ReportError(context, "Type %d is not currently supported.",
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index d50c345194..d7b109ac1a 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -27,9 +27,6 @@ limitations under the License.
#ifndef TFLITE_MCU
#include "tensorflow/contrib/lite/nnapi_delegate.h"
#endif
-#if defined(TFLITE_FLEX)
-#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
-#endif
#include "tensorflow/contrib/lite/version.h"
namespace tflite {
@@ -43,6 +40,25 @@ ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
const char* kEmptyTensorName = "";
+// Normally we'd use ABSL_HAVE_ATTRIBUTE_WEAK and ABSL_ATTRIBUTE_WEAK, but
+// we avoid the absl dependency for binary size reasons.
+#ifdef __has_attribute
+#define TFLITE_HAS_ATTRIBUTE(x) __has_attribute(x)
+#else
+#define TFLITE_HAS_ATTRIBUTE(x) 0
+#endif
+
+#if TFLITE_HAS_ATTRIBUTE(weak) || (defined(__GNUC__) && !defined(__clang__))
+// Using weak symbols for the flex delegate allows automatic injection of the
+// delegate simply by adding it as a dependency. See also the strong override in
+// lite/delegates/flex/delegate.cc.
+__attribute__((weak)) Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() {
+ return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
+}
+#else
+Interpreter::TfLiteDelegatePtr (*AcquireFlexDelegate)() = nullptr;
+#endif
+
#ifndef TFLITE_MCU
// Loads a model from `filename`. If `mmap_file` is true then use mmap,
// otherwise make a copy of the model in a buffer.
@@ -450,13 +466,14 @@ TfLiteStatus InterpreterBuilder::operator()(
}
(**interpreter).SetVariables(std::move(variables));
-#if defined(TFLITE_FLEX)
- if (auto delegate = FlexDelegate::Create()) {
- (**interpreter)
- .ModifyGraphWithDelegate(std::move(delegate),
- /*allow_dynamic_tensors=*/true);
+ // TODO(b/116667551): Only create the flex delegate if the model has flex ops.
+ if (AcquireFlexDelegate != nullptr) {
+ if (auto flex_delegate = AcquireFlexDelegate()) {
+ (**interpreter)
+ .ModifyGraphWithDelegate(std::move(flex_delegate),
+ /*allow_dynamic_tensors=*/true);
+ }
}
-#endif
return kTfLiteOk;
}
diff --git a/tensorflow/contrib/lite/model_flex_test.cc b/tensorflow/contrib/lite/model_flex_test.cc
new file mode 100644
index 0000000000..52e76bee49
--- /dev/null
+++ b/tensorflow/contrib/lite/model_flex_test.cc
@@ -0,0 +1,45 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/model.h"
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/testing/util.h"
+
+namespace tflite {
+
+// Ensures that a model with TensorFlow ops can be imported as long as the
+// appropriate delegate is linked into the client.
+TEST(FlexModel, WithFlexDelegate) {
+ auto model = FlatBufferModel::BuildFromFile(
+ "tensorflow/contrib/lite/testdata/multi_add_flex.bin");
+ ASSERT_TRUE(model);
+
+ std::unique_ptr<Interpreter> interpreter;
+ ASSERT_EQ(InterpreterBuilder(*model,
+ ops::builtin::BuiltinOpResolver{})(&interpreter),
+ kTfLiteOk);
+ ASSERT_TRUE(interpreter);
+
+ ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc
index ec7d46af7c..b969bea5dc 100644
--- a/tensorflow/contrib/lite/model_test.cc
+++ b/tensorflow/contrib/lite/model_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include <gtest/gtest.h>
#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/testing/util.h"
// Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object,
@@ -193,6 +194,27 @@ TEST(BasicFlatBufferModel, TestModelInInterpreter) {
}
}
+// Test that loading a model with TensorFlow ops fails when the flex delegate is
+// not linked into the target.
+TEST(FlexModel, FailureWithoutFlexDelegate) {
+ auto model = FlatBufferModel::BuildFromFile(
+ "tensorflow/contrib/lite/testdata/multi_add_flex.bin");
+ ASSERT_TRUE(model);
+
+ // Note that creation will succeed when using the BuiltinOpResolver, but
+ // unless the appropriate delegate is linked into the target or the client
+ // explicitly installs the delegate, execution will fail.
+ std::unique_ptr<Interpreter> interpreter;
+ ASSERT_EQ(InterpreterBuilder(*model,
+ ops::builtin::BuiltinOpResolver{})(&interpreter),
+ kTfLiteOk);
+ ASSERT_TRUE(interpreter);
+
+ // As the flex ops weren't resolved implicitly by the flex delegate, runtime
+ // allocation and execution will fail.
+ ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteError);
+}
+
// This tests on a flatbuffer that defines a shape of 2 to be a memory mapped
// buffer. But the buffer is provided to be only 1 element.
TEST(BasicFlatBufferModel, TestBrokenMmap) {
diff --git a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
index 687944023b..eccf4aefb6 100644
--- a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
+++ b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
@@ -180,6 +180,14 @@ enum {
};
/**
+ * Implicit padding algorithms.
+ */
+enum {
+ ANEURALNETWORKS_PADDING_SAME = 1,
+ ANEURALNETWORKS_PADDING_VALID = 2,
+};
+
+/**
* ANeuralNetworksMemory is an opaque type that represents memory.
*
* This type is used to represent shared memory, memory mapped files,
diff --git a/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc b/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc
index 67a5eecfa0..465c294962 100644
--- a/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc
+++ b/tensorflow/contrib/lite/profiling/profile_summarizer_test.cc
@@ -31,6 +31,8 @@ namespace profiling {
namespace {
+const char* kOpName = "SimpleOpEval";
+
#ifdef TFLITE_PROFILING_ENABLED
TfLiteStatus SimpleOpEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input1 = tflite::GetInput(context, node, /*index=*/0);
@@ -63,7 +65,7 @@ TfLiteRegistration* RegisterSimpleOpWithProfilingDetails() {
SimpleOpEval,
SimpleOpProfilingString,
tflite::BuiltinOperator_CUSTOM,
- "SimpleOpEval",
+ kOpName,
1};
return &registration;
}
@@ -89,7 +91,7 @@ void SimpleOpModel::Init(
inputs_[0] = AddInput({TensorType_INT32, {1}});
inputs_[1] = AddInput({TensorType_INT32, {1}});
output_ = AddOutput({TensorType_INT32, {}});
- SetCustomOp("SimpleAdd", {}, registration);
+ SetCustomOp(kOpName, {}, registration);
BuildInterpreter({GetShape(inputs_[0]), GetShape(inputs_[1])});
}
diff --git a/tensorflow/contrib/lite/profiling/profiler_test.cc b/tensorflow/contrib/lite/profiling/profiler_test.cc
index 0fba0450a0..cf56eed2a4 100644
--- a/tensorflow/contrib/lite/profiling/profiler_test.cc
+++ b/tensorflow/contrib/lite/profiling/profiler_test.cc
@@ -83,8 +83,8 @@ TEST(ProfilingTest, ProfilesAreCollected) {
EXPECT_EQ("SleepForQuarter", profile_events[4]->tag);
#ifndef ADDRESS_SANITIZER
- // ASAN build is sometimes very slow.
- const int eps_ms = 10;
+ // ASAN build is sometimes very slow. Set a large epsilon to avoid flakiness.
+ const int eps_ms = 50;
AssertDurationOfEventAroundMs(profile_events[0], /*expected_ms*/ 500, eps_ms);
AssertDurationOfEventAroundMs(profile_events[1], /*expected_ms*/ 250, eps_ms);
AssertDurationOfEventAroundMs(profile_events[2], /*expected_ms*/ 250, eps_ms);
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py
index 613a1530f7..1bf42d7551 100644
--- a/tensorflow/contrib/lite/python/convert.py
+++ b/tensorflow/contrib/lite/python/convert.py
@@ -155,7 +155,8 @@ def build_toco_convert_protos(input_tensors,
post_training_quantize=False,
dump_graphviz_dir=None,
dump_graphviz_video=False,
- converter_mode=ConverterMode.DEFAULT):
+ converter_mode=ConverterMode.DEFAULT,
+ allow_nonexistent_arrays=False):
"""Builds protocol buffers describing a conversion of a model using TOCO.
Typically this is to convert from TensorFlow GraphDef to TFLite, in which
@@ -212,6 +213,8 @@ def build_toco_convert_protos(input_tensors,
every graph transformation. (default False)
converter_mode: Experimental flag, subject to change. ConverterMode
indicating which converter to use. (default ConverterMode.DEFAULT)
+ allow_nonexistent_arrays: Allow specifying array names that don't exist
+ or are unused in the final graph. (default False)
Returns:
model_flags, toco_flags: two protocol buffers describing the conversion
@@ -261,6 +264,9 @@ def build_toco_convert_protos(input_tensors,
for output_tensor in output_tensors:
model.output_arrays.append(tensor_name(output_tensor))
+
+ model.allow_nonexistent_arrays = allow_nonexistent_arrays
+
return model, toco
diff --git a/tensorflow/contrib/lite/python/interpreter.py b/tensorflow/contrib/lite/python/interpreter.py
index 5700bf7892..6300552cbe 100644
--- a/tensorflow/contrib/lite/python/interpreter.py
+++ b/tensorflow/contrib/lite/python/interpreter.py
@@ -129,6 +129,23 @@ class Interpreter(object):
return details
+ def get_tensor_details(self):
+ """Gets tensor details for every tensor with valid tensor details.
+
+ Tensors where required information about the tensor is not found are not
+ added to the list. This includes temporary tensors without a name.
+
+ Returns:
+ A list of dictionaries containing tensor information.
+ """
+ tensor_details = []
+ for idx in range(self._interpreter.NumTensors()):
+ try:
+ tensor_details.append(self._get_tensor_details(idx))
+ except ValueError:
+ pass
+ return tensor_details
+
def get_input_details(self):
"""Gets model input details.
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
index 418f19a179..1e2384b6d2 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
@@ -277,13 +277,20 @@ PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) {
Py_RETURN_NONE;
}
+int InterpreterWrapper::NumTensors() const {
+ if (!interpreter_) {
+ return 0;
+ }
+ return interpreter_->tensors_size();
+}
+
std::string InterpreterWrapper::TensorName(int i) const {
if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
return "";
}
const TfLiteTensor* tensor = interpreter_->tensor(i);
- return tensor->name;
+ return tensor->name ? tensor->name : "";
}
PyObject* InterpreterWrapper::TensorType(int i) const {
@@ -291,6 +298,11 @@ PyObject* InterpreterWrapper::TensorType(int i) const {
TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
const TfLiteTensor* tensor = interpreter_->tensor(i);
+ if (tensor->type == kTfLiteNoType) {
+ PyErr_Format(PyExc_ValueError, "Tensor with no type found.");
+ return nullptr;
+ }
+
int code = TfLiteTypeToPyArrayType(tensor->type);
if (code == -1) {
PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code);
@@ -302,7 +314,12 @@ PyObject* InterpreterWrapper::TensorType(int i) const {
PyObject* InterpreterWrapper::TensorSize(int i) const {
TFLITE_PY_ENSURE_VALID_INTERPRETER();
TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
+
const TfLiteTensor* tensor = interpreter_->tensor(i);
+ if (tensor->dims == nullptr) {
+ PyErr_Format(PyExc_ValueError, "Tensor with no shape found.");
+ return nullptr;
+ }
PyObject* np_array =
PyArrayFromIntVector(tensor->dims->data, tensor->dims->size);
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
index f5ca81e62a..b98046fe8a 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
@@ -59,6 +59,7 @@ class InterpreterWrapper {
PyObject* OutputIndices() const;
PyObject* ResizeInputTensor(int i, PyObject* value);
+ int NumTensors() const;
std::string TensorName(int i) const;
PyObject* TensorType(int i) const;
PyObject* TensorSize(int i) const;
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 3da3188c3a..ff8430827c 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -248,6 +248,8 @@ union BuiltinOptions {
SquareOptions,
ZerosLikeOptions,
FillOptions,
+ BidirectionalSequenceLSTMOptions,
+ BidirectionalSequenceRNNOptions,
}
enum Padding : byte { SAME, VALID }
@@ -327,6 +329,7 @@ table SequenceRNNOptions {
table BidirectionalSequenceRNNOptions {
time_major:bool;
fused_activation_function:ActivationFunctionType;
+ merge_outputs: bool;
}
enum FullyConnectedOptionsWeightsFormat: byte {
@@ -391,6 +394,15 @@ table LSTMOptions {
kernel_type: LSTMKernelType = FULL;
}
+table BidirectionalSequenceLSTMOptions {
+ fused_activation_function:ActivationFunctionType;
+ cell_clip: float; // Optional, 0.0 means no clipping
+ proj_clip: float; // Optional, 0.0 means no clipping
+
+ // If true, store the outputs of both directions into the first output.
+ merge_outputs: bool;
+}
+
table ResizeBilinearOptions {
new_height: int (deprecated);
new_width: int (deprecated);
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 23ac8484de..f3cb113c9c 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -79,6 +79,9 @@ struct LocalResponseNormalizationOptionsT;
struct LSTMOptions;
struct LSTMOptionsT;
+struct BidirectionalSequenceLSTMOptions;
+struct BidirectionalSequenceLSTMOptionsT;
+
struct ResizeBilinearOptions;
struct ResizeBilinearOptionsT;
@@ -676,11 +679,13 @@ enum BuiltinOptions {
BuiltinOptions_SquareOptions = 66,
BuiltinOptions_ZerosLikeOptions = 67,
BuiltinOptions_FillOptions = 68,
+ BuiltinOptions_BidirectionalSequenceLSTMOptions = 69,
+ BuiltinOptions_BidirectionalSequenceRNNOptions = 70,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_FillOptions
+ BuiltinOptions_MAX = BuiltinOptions_BidirectionalSequenceRNNOptions
};
-inline const BuiltinOptions (&EnumValuesBuiltinOptions())[69] {
+inline const BuiltinOptions (&EnumValuesBuiltinOptions())[71] {
static const BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -750,7 +755,9 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[69] {
BuiltinOptions_FloorDivOptions,
BuiltinOptions_SquareOptions,
BuiltinOptions_ZerosLikeOptions,
- BuiltinOptions_FillOptions
+ BuiltinOptions_FillOptions,
+ BuiltinOptions_BidirectionalSequenceLSTMOptions,
+ BuiltinOptions_BidirectionalSequenceRNNOptions
};
return values;
}
@@ -826,6 +833,8 @@ inline const char * const *EnumNamesBuiltinOptions() {
"SquareOptions",
"ZerosLikeOptions",
"FillOptions",
+ "BidirectionalSequenceLSTMOptions",
+ "BidirectionalSequenceRNNOptions",
nullptr
};
return names;
@@ -1112,6 +1121,14 @@ template<> struct BuiltinOptionsTraits<FillOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_FillOptions;
};
+template<> struct BuiltinOptionsTraits<BidirectionalSequenceLSTMOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_BidirectionalSequenceLSTMOptions;
+};
+
+template<> struct BuiltinOptionsTraits<BidirectionalSequenceRNNOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_BidirectionalSequenceRNNOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -1687,6 +1704,22 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_FillOptions ?
reinterpret_cast<const FillOptionsT *>(value) : nullptr;
}
+ BidirectionalSequenceLSTMOptionsT *AsBidirectionalSequenceLSTMOptions() {
+ return type == BuiltinOptions_BidirectionalSequenceLSTMOptions ?
+ reinterpret_cast<BidirectionalSequenceLSTMOptionsT *>(value) : nullptr;
+ }
+ const BidirectionalSequenceLSTMOptionsT *AsBidirectionalSequenceLSTMOptions() const {
+ return type == BuiltinOptions_BidirectionalSequenceLSTMOptions ?
+ reinterpret_cast<const BidirectionalSequenceLSTMOptionsT *>(value) : nullptr;
+ }
+ BidirectionalSequenceRNNOptionsT *AsBidirectionalSequenceRNNOptions() {
+ return type == BuiltinOptions_BidirectionalSequenceRNNOptions ?
+ reinterpret_cast<BidirectionalSequenceRNNOptionsT *>(value) : nullptr;
+ }
+ const BidirectionalSequenceRNNOptionsT *AsBidirectionalSequenceRNNOptions() const {
+ return type == BuiltinOptions_BidirectionalSequenceRNNOptions ?
+ reinterpret_cast<const BidirectionalSequenceRNNOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -2834,9 +2867,11 @@ struct BidirectionalSequenceRNNOptionsT : public flatbuffers::NativeTable {
typedef BidirectionalSequenceRNNOptions TableType;
bool time_major;
ActivationFunctionType fused_activation_function;
+ bool merge_outputs;
BidirectionalSequenceRNNOptionsT()
: time_major(false),
- fused_activation_function(ActivationFunctionType_NONE) {
+ fused_activation_function(ActivationFunctionType_NONE),
+ merge_outputs(false) {
}
};
@@ -2844,7 +2879,8 @@ struct BidirectionalSequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuf
typedef BidirectionalSequenceRNNOptionsT NativeTableType;
enum {
VT_TIME_MAJOR = 4,
- VT_FUSED_ACTIVATION_FUNCTION = 6
+ VT_FUSED_ACTIVATION_FUNCTION = 6,
+ VT_MERGE_OUTPUTS = 8
};
bool time_major() const {
return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0;
@@ -2852,10 +2888,14 @@ struct BidirectionalSequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuf
ActivationFunctionType fused_activation_function() const {
return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
}
+ bool merge_outputs() const {
+ return GetField<uint8_t>(VT_MERGE_OUTPUTS, 0) != 0;
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
+ VerifyField<uint8_t>(verifier, VT_MERGE_OUTPUTS) &&
verifier.EndTable();
}
BidirectionalSequenceRNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -2872,6 +2912,9 @@ struct BidirectionalSequenceRNNOptionsBuilder {
void add_fused_activation_function(ActivationFunctionType fused_activation_function) {
fbb_.AddElement<int8_t>(BidirectionalSequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
}
+ void add_merge_outputs(bool merge_outputs) {
+ fbb_.AddElement<uint8_t>(BidirectionalSequenceRNNOptions::VT_MERGE_OUTPUTS, static_cast<uint8_t>(merge_outputs), 0);
+ }
explicit BidirectionalSequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -2887,8 +2930,10 @@ struct BidirectionalSequenceRNNOptionsBuilder {
inline flatbuffers::Offset<BidirectionalSequenceRNNOptions> CreateBidirectionalSequenceRNNOptions(
flatbuffers::FlatBufferBuilder &_fbb,
bool time_major = false,
- ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) {
+ ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE,
+ bool merge_outputs = false) {
BidirectionalSequenceRNNOptionsBuilder builder_(_fbb);
+ builder_.add_merge_outputs(merge_outputs);
builder_.add_fused_activation_function(fused_activation_function);
builder_.add_time_major(time_major);
return builder_.Finish();
@@ -3424,6 +3469,96 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(
flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct BidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable {
+ typedef BidirectionalSequenceLSTMOptions TableType;
+ ActivationFunctionType fused_activation_function;
+ float cell_clip;
+ float proj_clip;
+ bool merge_outputs;
+ BidirectionalSequenceLSTMOptionsT()
+ : fused_activation_function(ActivationFunctionType_NONE),
+ cell_clip(0.0f),
+ proj_clip(0.0f),
+ merge_outputs(false) {
+ }
+};
+
+struct BidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef BidirectionalSequenceLSTMOptionsT NativeTableType;
+ enum {
+ VT_FUSED_ACTIVATION_FUNCTION = 4,
+ VT_CELL_CLIP = 6,
+ VT_PROJ_CLIP = 8,
+ VT_MERGE_OUTPUTS = 10
+ };
+ ActivationFunctionType fused_activation_function() const {
+ return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
+ }
+ float cell_clip() const {
+ return GetField<float>(VT_CELL_CLIP, 0.0f);
+ }
+ float proj_clip() const {
+ return GetField<float>(VT_PROJ_CLIP, 0.0f);
+ }
+ bool merge_outputs() const {
+ return GetField<uint8_t>(VT_MERGE_OUTPUTS, 0) != 0;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
+ VerifyField<float>(verifier, VT_CELL_CLIP) &&
+ VerifyField<float>(verifier, VT_PROJ_CLIP) &&
+ VerifyField<uint8_t>(verifier, VT_MERGE_OUTPUTS) &&
+ verifier.EndTable();
+ }
+ BidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(BidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<BidirectionalSequenceLSTMOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct BidirectionalSequenceLSTMOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_fused_activation_function(ActivationFunctionType fused_activation_function) {
+ fbb_.AddElement<int8_t>(BidirectionalSequenceLSTMOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
+ }
+ void add_cell_clip(float cell_clip) {
+ fbb_.AddElement<float>(BidirectionalSequenceLSTMOptions::VT_CELL_CLIP, cell_clip, 0.0f);
+ }
+ void add_proj_clip(float proj_clip) {
+ fbb_.AddElement<float>(BidirectionalSequenceLSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f);
+ }
+ void add_merge_outputs(bool merge_outputs) {
+ fbb_.AddElement<uint8_t>(BidirectionalSequenceLSTMOptions::VT_MERGE_OUTPUTS, static_cast<uint8_t>(merge_outputs), 0);
+ }
+ explicit BidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ BidirectionalSequenceLSTMOptionsBuilder &operator=(const BidirectionalSequenceLSTMOptionsBuilder &);
+ flatbuffers::Offset<BidirectionalSequenceLSTMOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<BidirectionalSequenceLSTMOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<BidirectionalSequenceLSTMOptions> CreateBidirectionalSequenceLSTMOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE,
+ float cell_clip = 0.0f,
+ float proj_clip = 0.0f,
+ bool merge_outputs = false) {
+ BidirectionalSequenceLSTMOptionsBuilder builder_(_fbb);
+ builder_.add_proj_clip(proj_clip);
+ builder_.add_cell_clip(cell_clip);
+ builder_.add_merge_outputs(merge_outputs);
+ builder_.add_fused_activation_function(fused_activation_function);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<BidirectionalSequenceLSTMOptions> CreateBidirectionalSequenceLSTMOptions(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct ResizeBilinearOptionsT : public flatbuffers::NativeTable {
typedef ResizeBilinearOptions TableType;
bool align_corners;
@@ -6347,6 +6482,12 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const FillOptions *builtin_options_as_FillOptions() const {
return builtin_options_type() == BuiltinOptions_FillOptions ? static_cast<const FillOptions *>(builtin_options()) : nullptr;
}
+ const BidirectionalSequenceLSTMOptions *builtin_options_as_BidirectionalSequenceLSTMOptions() const {
+ return builtin_options_type() == BuiltinOptions_BidirectionalSequenceLSTMOptions ? static_cast<const BidirectionalSequenceLSTMOptions *>(builtin_options()) : nullptr;
+ }
+ const BidirectionalSequenceRNNOptions *builtin_options_as_BidirectionalSequenceRNNOptions() const {
+ return builtin_options_type() == BuiltinOptions_BidirectionalSequenceRNNOptions ? static_cast<const BidirectionalSequenceRNNOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -6650,6 +6791,14 @@ template<> inline const FillOptions *Operator::builtin_options_as<FillOptions>()
return builtin_options_as_FillOptions();
}
+template<> inline const BidirectionalSequenceLSTMOptions *Operator::builtin_options_as<BidirectionalSequenceLSTMOptions>() const {
+ return builtin_options_as_BidirectionalSequenceLSTMOptions();
+}
+
+template<> inline const BidirectionalSequenceRNNOptions *Operator::builtin_options_as<BidirectionalSequenceRNNOptions>() const {
+ return builtin_options_as_BidirectionalSequenceRNNOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -7407,6 +7556,7 @@ inline void BidirectionalSequenceRNNOptions::UnPackTo(BidirectionalSequenceRNNOp
(void)_resolver;
{ auto _e = time_major(); _o->time_major = _e; };
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; };
+ { auto _e = merge_outputs(); _o->merge_outputs = _e; };
}
inline flatbuffers::Offset<BidirectionalSequenceRNNOptions> BidirectionalSequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -7419,10 +7569,12 @@ inline flatbuffers::Offset<BidirectionalSequenceRNNOptions> CreateBidirectionalS
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const BidirectionalSequenceRNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _time_major = _o->time_major;
auto _fused_activation_function = _o->fused_activation_function;
+ auto _merge_outputs = _o->merge_outputs;
return tflite::CreateBidirectionalSequenceRNNOptions(
_fbb,
_time_major,
- _fused_activation_function);
+ _fused_activation_function,
+ _merge_outputs);
}
inline FullyConnectedOptionsT *FullyConnectedOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@@ -7657,6 +7809,41 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(flatbuffers::FlatBuffe
_kernel_type);
}
+inline BidirectionalSequenceLSTMOptionsT *BidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new BidirectionalSequenceLSTMOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void BidirectionalSequenceLSTMOptions::UnPackTo(BidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = fused_activation_function(); _o->fused_activation_function = _e; };
+ { auto _e = cell_clip(); _o->cell_clip = _e; };
+ { auto _e = proj_clip(); _o->proj_clip = _e; };
+ { auto _e = merge_outputs(); _o->merge_outputs = _e; };
+}
+
+inline flatbuffers::Offset<BidirectionalSequenceLSTMOptions> BidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateBidirectionalSequenceLSTMOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<BidirectionalSequenceLSTMOptions> CreateBidirectionalSequenceLSTMOptions(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const BidirectionalSequenceLSTMOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _fused_activation_function = _o->fused_activation_function;
+ auto _cell_clip = _o->cell_clip;
+ auto _proj_clip = _o->proj_clip;
+ auto _merge_outputs = _o->merge_outputs;
+ return tflite::CreateBidirectionalSequenceLSTMOptions(
+ _fbb,
+ _fused_activation_function,
+ _cell_clip,
+ _proj_clip,
+ _merge_outputs);
+}
+
inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new ResizeBilinearOptionsT();
UnPackTo(_o, _resolver);
@@ -9425,6 +9612,14 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const FillOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_BidirectionalSequenceLSTMOptions: {
+ auto ptr = reinterpret_cast<const BidirectionalSequenceLSTMOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_BidirectionalSequenceRNNOptions: {
+ auto ptr = reinterpret_cast<const BidirectionalSequenceRNNOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -9715,6 +9910,14 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const FillOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_BidirectionalSequenceLSTMOptions: {
+ auto ptr = reinterpret_cast<const BidirectionalSequenceLSTMOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
+ case BuiltinOptions_BidirectionalSequenceRNNOptions: {
+ auto ptr = reinterpret_cast<const BidirectionalSequenceRNNOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -9993,6 +10196,14 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const FillOptionsT *>(value);
return CreateFillOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_BidirectionalSequenceLSTMOptions: {
+ auto ptr = reinterpret_cast<const BidirectionalSequenceLSTMOptionsT *>(value);
+ return CreateBidirectionalSequenceLSTMOptions(_fbb, ptr, _rehasher).Union();
+ }
+ case BuiltinOptions_BidirectionalSequenceRNNOptions: {
+ auto ptr = reinterpret_cast<const BidirectionalSequenceRNNOptionsT *>(value);
+ return CreateBidirectionalSequenceRNNOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -10271,6 +10482,14 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new FillOptionsT(*reinterpret_cast<FillOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_BidirectionalSequenceLSTMOptions: {
+ value = new BidirectionalSequenceLSTMOptionsT(*reinterpret_cast<BidirectionalSequenceLSTMOptionsT *>(u.value));
+ break;
+ }
+ case BuiltinOptions_BidirectionalSequenceRNNOptions: {
+ value = new BidirectionalSequenceRNNOptionsT(*reinterpret_cast<BidirectionalSequenceRNNOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -10618,6 +10837,16 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_BidirectionalSequenceLSTMOptions: {
+ auto ptr = reinterpret_cast<BidirectionalSequenceLSTMOptionsT *>(value);
+ delete ptr;
+ break;
+ }
+ case BuiltinOptions_BidirectionalSequenceRNNOptions: {
+ auto ptr = reinterpret_cast<BidirectionalSequenceRNNOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
diff --git a/tensorflow/contrib/lite/testdata/multi_add_flex.bin b/tensorflow/contrib/lite/testdata/multi_add_flex.bin
new file mode 100644
index 0000000000..9aac2155fe
--- /dev/null
+++ b/tensorflow/contrib/lite/testdata/multi_add_flex.bin
Binary files differ
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 18036fac6f..3f2255c454 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -762,8 +762,11 @@ def make_constant_tests(zip_path):
dtype=parameters["dtype"],
name="input1",
shape=parameters["input_shape"])
- out = tf.constant(
+ constant = tf.constant(
create_tensor_data(parameters["dtype"], parameters["input_shape"]))
+ # This maximum node is here to avoid the situation where a graph output is
+ # a constant, which is an error in toco.
+ out = tf.maximum(dummy_input, constant)
return [dummy_input], [out]
def build_inputs(parameters, sess, inputs, outputs):
@@ -2848,7 +2851,14 @@ def make_zeros_like_tests(zip_path):
dtype=parameters["input_dtype"],
name="input",
shape=parameters["input_shape"])
- out = tf.zeros_like(input_tensor)
+ zeros = tf.zeros_like(input_tensor)
+ # This maximum node is so that toco can perform the constants-propagation
+ # through the above zeros_like, which it can't do if the output of the
+ # zeros_like as an output of the whole graphs (graph outputs can't be
+ # constants). If toco does not perform such constants-propagation then
+ # the resulting tflite graph retains the zeros_like as a Fill op, which
+ # is unsupported by TFLite, even as a custom op.
+ out = tf.maximum(zeros, input_tensor)
return [input_tensor], [out]
def build_inputs(parameters, sess, inputs, outputs):
diff --git a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py
index 5ca57d083d..72029ed03c 100644
--- a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py
+++ b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py
@@ -35,9 +35,9 @@ def _convert(converter, **kwargs):
"""Converts the model.
Args:
- converter: TocoConverter object.
+ converter: TFLiteConverter object.
**kwargs: Additional arguments to be passed into the converter. Supported
- flags are {"converter_mode", "post_training_quant"}.
+ flags are {"converter_mode", "post_training_quantize"}.
Returns:
The converted TFLite model in serialized format.
@@ -174,7 +174,7 @@ def compare_models_random_data(tflite_model, tf_eval_func, tolerance=5):
tflite_model: Serialized TensorFlow Lite model.
tf_eval_func: Lambda function that takes in input data and outputs the
results of the TensorFlow model ([np.ndarray data] : [np.ndarray result]).
- tolerance: Decimal place to check accuracy to.
+ tolerance: Decimal place to check accuracy to. (default 5)
"""
input_data = _generate_random_input_data(tflite_model)
tf_results = tf_eval_func(input_data)
@@ -183,6 +183,71 @@ def compare_models_random_data(tflite_model, tf_eval_func, tolerance=5):
np.testing.assert_almost_equal(tf_result, tflite_result, tolerance)
+def test_frozen_graph_quant(filename,
+ input_arrays,
+ output_arrays,
+ input_shapes=None,
+ **kwargs):
+ """Sanity check to validate post quantize flag alters the graph.
+
+ This test does not check correctness of the converted model. It converts the
+ TensorFlow frozen graph to TFLite with and without the post_training_quantized
+ flag. It ensures some tensors have different types between the float and
+ quantized models in the case of an all TFLite model or mix-and-match model.
+ It ensures tensor types do not change in the case of an all Flex model.
+
+ Args:
+ filename: Full filepath of file containing frozen GraphDef.
+ input_arrays: List of input tensors to freeze graph with.
+ output_arrays: List of output tensors to freeze graph with.
+ input_shapes: Dict of strings representing input tensor names to list of
+ integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
+ Automatically determined when input shapes is None (e.g., {"foo" : None}).
+ (default None)
+ **kwargs: Additional arguments to be passed into the converter.
+
+ Raises:
+ ValueError: post_training_quantize flag doesn't act as intended.
+ """
+ # Convert and load the float model.
+ converter = _lite.TFLiteConverter.from_frozen_graph(
+ filename, input_arrays, output_arrays, input_shapes)
+ tflite_model_float = _convert(converter, **kwargs)
+
+ interpreter_float = _lite.Interpreter(model_content=tflite_model_float)
+ interpreter_float.allocate_tensors()
+ float_tensors = interpreter_float.get_tensor_details()
+
+ # Convert and load the quantized model.
+ converter = _lite.TFLiteConverter.from_frozen_graph(filename, input_arrays,
+ output_arrays)
+ tflite_model_quant = _convert(
+ converter, post_training_quantize=True, **kwargs)
+
+ interpreter_quant = _lite.Interpreter(model_content=tflite_model_quant)
+ interpreter_quant.allocate_tensors()
+ quant_tensors = interpreter_quant.get_tensor_details()
+ quant_tensors_map = {
+ tensor_detail["name"]: tensor_detail for tensor_detail in quant_tensors
+ }
+
+ # Check if weights are of different types in the float and quantized models.
+ num_tensors_float = len(float_tensors)
+ num_tensors_same_dtypes = sum(
+ float_tensor["dtype"] == quant_tensors_map[float_tensor["name"]]["dtype"]
+ for float_tensor in float_tensors)
+ has_quant_tensor = num_tensors_float != num_tensors_same_dtypes
+
+ if ("converter_mode" in kwargs and
+ kwargs["converter_mode"] == _lite.ConverterMode.TOCO_FLEX_ALL):
+ if has_quant_tensor:
+ raise ValueError("--post_training_quantize flag unexpectedly altered the "
+ "full Flex mode graph.")
+ elif not has_quant_tensor:
+ raise ValueError("--post_training_quantize flag was unable to quantize the "
+ "graph as expected in TFLite and mix-and-match mode.")
+
+
def test_frozen_graph(filename,
input_arrays,
output_arrays,
@@ -203,8 +268,8 @@ def test_frozen_graph(filename,
(default None)
**kwargs: Additional arguments to be passed into the converter.
"""
- converter = _lite.TocoConverter.from_frozen_graph(filename, input_arrays,
- output_arrays, input_shapes)
+ converter = _lite.TFLiteConverter.from_frozen_graph(
+ filename, input_arrays, output_arrays, input_shapes)
tflite_model = _convert(converter, **kwargs)
tf_eval_func = evaluate_frozen_graph(filename, input_arrays, output_arrays)
@@ -224,8 +289,8 @@ def test_saved_model(directory, tag_set=None, signature_key=None, **kwargs):
signature_key: Key identifying SignatureDef containing inputs and outputs.
**kwargs: Additional arguments to be passed into the converter.
"""
- converter = _lite.TocoConverter.from_saved_model(directory, tag_set,
- signature_key)
+ converter = _lite.TFLiteConverter.from_saved_model(directory, tag_set,
+ signature_key)
tflite_model = _convert(converter, **kwargs)
tf_eval_func = evaluate_saved_model(directory, tag_set, signature_key)
@@ -242,7 +307,7 @@ def test_keras_model(filename, **kwargs):
filename: Full filepath of HDF5 file containing the tf.keras model.
**kwargs: Additional arguments to be passed into the converter.
"""
- converter = _lite.TocoConverter.from_keras_model_file(filename)
+ converter = _lite.TFLiteConverter.from_keras_model_file(filename)
tflite_model = _convert(converter, **kwargs)
tf_eval_func = evaluate_keras_model(filename)
diff --git a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py
index 1498f86c6f..e07202b1a6 100644
--- a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py
+++ b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import os
import tempfile
+import numpy as np
from tensorflow.contrib.lite.python import lite
from tensorflow.contrib.lite.testing.model_coverage import model_coverage_lib as model_coverage
@@ -66,6 +67,43 @@ class EvaluateFrozenGraph(test.TestCase):
model_coverage.test_frozen_graph(filename, ['inputA', 'inputB'],
['add', 'Mean'])
+ def _getQuantizedModel(self):
+ np.random.seed(0)
+ with session.Session().as_default() as sess:
+ # The tensor needs to have more than 1024 elements for quantize_weights to
+ # kick in. Thus, the [33, 33] shape.
+ in_tensor_1 = array_ops.placeholder(
+ shape=[33, 33], dtype=dtypes.float32, name='inputA')
+ in_tensor_2 = constant_op.constant(
+ np.random.uniform(low=-10., high=10., size=(33, 33)),
+ shape=[33, 33],
+ dtype=dtypes.float32,
+ name='inputB')
+ _ = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
+
+ filename = self._saveFrozenGraph(sess)
+ return filename
+
+ def testQuantized(self):
+ filename = self._getQuantizedModel()
+ model_coverage.test_frozen_graph_quant(filename, ['inputA', 'inputB'],
+ ['output'])
+
+ def testQuantizedInputShapes(self):
+ filename = self._getQuantizedModel()
+ model_coverage.test_frozen_graph_quant(
+ filename, ['inputA', 'inputB'], ['output'],
+ input_shapes={
+ 'inputA': [33, 33],
+ 'inputB': [33, 33],
+ })
+
+ def testQuantizedFlexAll(self):
+ filename = self._getQuantizedModel()
+ model_coverage.test_frozen_graph_quant(
+ filename, ['inputA', 'inputB'], ['output'],
+ converter_mode=lite.ConverterMode.TOCO_FLEX_ALL)
+
class EvaluateSavedModel(test.TestCase):
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
index 1bc366f555..fb299c31b7 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
@@ -97,15 +97,6 @@ const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) {
// to allow easily trying out quantization even if the graph
// lacks some minmax information.
if (array.buffer != nullptr) {
- LOG(WARNING)
- << "Constant array " << array_name
- << " lacks MinMax information. To make up for that, we will now compute"
- << " the MinMax from actual array elements. That will result in"
- << " quantization parameters that probably do not match whichever "
- "arithmetic"
- << " was used during training, and thus will probably be a cause of "
- "poor"
- << " inference accuracy.";
CHECK(array.buffer->type == ArrayDataType::kFloat);
const auto& data = array.GetBuffer<ArrayDataType::kFloat>().data;
// We always want [min, max] to contain 0.
@@ -120,6 +111,27 @@ const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) {
// to not be equal.
max = 1.f;
}
+ // No need to warn about accuracy if all array values are equal to either
+ // min or max:
+ // in that case, quantization is exact, and such arrays are not learned
+ // weights arrays for which fake-quantization would make sense, rather
+ // they tend to be hardcoded arrays of zeros or ones used in some graphs.
+ bool is_quantization_trivially_exact = true;
+ for (auto val : data) {
+ is_quantization_trivially_exact &= (val == min || val == max);
+ }
+ if (!is_quantization_trivially_exact) {
+ LOG(WARNING)
+ << "Constant array " << array_name
+ << " lacks MinMax information. To make up for that, we will now "
+ "compute"
+ << " the MinMax from actual array elements. That will result in"
+ << " quantization parameters that probably do not match whichever "
+ "arithmetic"
+ << " was used during training, and thus will probably be a cause of "
+ "poor"
+ << " inference accuracy.";
+ }
auto& minmax = array.GetOrCreateMinMax();
minmax.min = min;
minmax.max = max;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc
index 5b41c49bfa..eaa9d3bcda 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc
@@ -71,8 +71,10 @@ bool ReadArrayMinmaxAndNarrowRangeFromFakeQuant::Run(Model* model,
CHECK(fq_op->minmax);
CHECK_EQ(1, fq_op->inputs.size());
- return ApplyAttrsToArray(this, model, *fq_op, fq_op->inputs[0]) ||
- ApplyAttrsToArray(this, model, *fq_op, fq_op->outputs[0]);
+ bool changed = false;
+ changed |= ApplyAttrsToArray(this, model, *fq_op, fq_op->inputs[0]);
+ changed |= ApplyAttrsToArray(this, model, *fq_op, fq_op->outputs[0]);
+ return changed;
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
index fc49fbda59..d5983a1f12 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
@@ -29,20 +29,34 @@ namespace {
// array instead. from_array is assumed to be discardable, and consequently
// this only updates operator edges (since discardable arrays only
// appear there, and not e.g. in model flags).
-void RerouteEdges(const string& from_array, const string& to_array,
- Model* model) {
+void Reroute(const string& from, const string& to, Model* model) {
for (const auto& op : model->operators) {
for (auto& output : op->outputs) {
- if (output == from_array) {
- output = to_array;
+ if (output == from) {
+ output = to;
}
}
for (auto& input : op->inputs) {
- if (input == from_array) {
- input = to_array;
+ if (input == from) {
+ input = to;
}
}
}
+ const Array& from_array = model->GetArray(from);
+ Array& to_array = model->GetOrCreateArray(to);
+ // Preserve minmax information if to_array didn't already have any.
+ if (from_array.minmax && !to_array.minmax) {
+ to_array.GetOrCreateMinMax() = from_array.GetMinMax();
+ // If we're copying minmax info, then we should also be copying
+ // narrow_range, which affects how minmax info is to be interpreted.
+ to_array.narrow_range = from_array.narrow_range;
+ }
+ // Separately, also preserve final_data_type if to_array didn't already
+ // have any.
+ if (from_array.final_data_type != ArrayDataType::kNone &&
+ to_array.final_data_type == ArrayDataType::kNone) {
+ to_array.final_data_type = from_array.final_data_type;
+ }
}
} // namespace
@@ -90,14 +104,14 @@ bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
transformation->AddMessageF(
"Removing %s, keeping its non-constant input array %s and removing %s",
LogName(*passthru_op), main_input_name, output_name);
- RerouteEdges(output_name, main_input_name, model);
+ Reroute(output_name, main_input_name, model);
} else if (IsDiscardableArray(*model, main_input_name) &&
!IsConstantParameterArray(*model, main_input_name)) {
transformation->AddMessageF(
"Removing %s, keeping its output array %s and removing non-constant "
"input %s",
LogName(*passthru_op), output_name, main_input_name);
- RerouteEdges(main_input_name, output_name, model);
+ Reroute(main_input_name, output_name, model);
} else {
transformation->AddMessageF(
"Cannot remove %s, neither its main input nor its output may be "
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc
index 4bb1217828..b2b2ea151b 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc
@@ -60,6 +60,10 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
const auto& output_array_name = mul_op->outputs[0];
auto& output_array = model->GetArray(output_array_name);
+ if (!IsDiscardableArray(*model, output_array_name)) {
+ return false;
+ }
+
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes
return false;
@@ -139,14 +143,8 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
}
// Erase input arrays to the multiply if no longer used
- if (IsDiscardableArray(*model, mul_op->inputs[0]) &&
- CountOpsWithInput(*model, mul_op->inputs[0]) == 1) {
- model->EraseArray(mul_op->inputs[0]);
- }
- if (IsDiscardableArray(*model, mul_op->inputs[1]) &&
- CountOpsWithInput(*model, mul_op->inputs[1]) == 1) {
- model->EraseArray(mul_op->inputs[1]);
- }
+ DeleteArrayIfUsedOnce(mul_op->inputs[0], model);
+ DeleteArrayIfUsedOnce(mul_op->inputs[1], model);
// Erase the multiply operator.
model->operators.erase(mul_it);
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 5eaf6e27fc..133ef79a34 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -477,6 +477,30 @@ string CreateConstArray(Model* model, string const& name,
return array_name;
}
+// Retain TensorFlow NodeDef in Toco Operator.
+//
+// If an op is supported by Toco but not supported by TFLite, TFLite exporter
+// will use the retained NodeDef to populate a Flex op when Flex mode is
+// enabled.
+//
+// This can't be easily applied to all operations, because a TensorFlow node
+// may become multiple Toco operators. Thus we need to call this function in
+// operator conversion functions one by one whenever feasible.
+//
+// This may cause problems if a graph transformation rule changes parameters
+// of the node. When calling this function, please check if any existing
+// graph transformation rule will change an existing operator with the same
+// type.
+//
+// This provides a route to handle Toco-supported & TFLite-unsupported ops
+// in Flex mode. However it's not a solid solution. Eventually we should
+// get rid of this.
+// TODO(b/117327937): Implement all Toco-supported ops in TFLite, and remove
+// this function.
+void RetainTensorFlowNodeDef(const NodeDef& node, Operator* op) {
+ node.SerializeToString(&op->tensorflow_node_def);
+}
+
tensorflow::Status ConvertConstOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
@@ -990,6 +1014,10 @@ tensorflow::Status ConvertBatchMatMulOperator(
auto* batch_matmul = new BatchMatMulOperator;
batch_matmul->inputs = {node.input(0), node.input(1)};
batch_matmul->outputs = {node.name()};
+
+ // For Flex mode. Please read the comments of the function.
+ RetainTensorFlowNodeDef(node, batch_matmul);
+
model->operators.emplace_back(batch_matmul);
return tensorflow::Status::OK();
}
@@ -1081,7 +1109,10 @@ tensorflow::Status ConvertUnsupportedOperator(
auto* op = new TensorFlowUnsupportedOperator;
op->tensorflow_op = node.op();
- node.SerializeToString(&op->tensorflow_node_def);
+
+ // For Flex mode. Please read the comments of the function.
+ RetainTensorFlowNodeDef(node, op);
+
model->operators.emplace_back(op);
// Parse inputs.
@@ -1605,6 +1636,10 @@ tensorflow::Status ConvertRangeOperator(
op->inputs.push_back(node.input(1));
op->inputs.push_back(node.input(2));
op->outputs.push_back(node.name());
+
+ // For Flex mode. Please read the comments of the function.
+ RetainTensorFlowNodeDef(node, op);
+
model->operators.emplace_back(op);
return tensorflow::Status::OK();
}
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 6e207fdf54..61f1f095e9 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -376,6 +376,13 @@ struct Operator {
// looks unused.
bool unresolved_outputs = false;
+ // A serialized tensorflow::NodeDef string.
+ // The field is filled only when importing from TensorFlow.
+ // It's guaranteed to be filled for `TensorFlowUnsupportedOperator`.
+ // It's not guaranteed to be filled for other ops. Ops created by graph
+ // transformations won't have TensorFlow NodeDef.
+ string tensorflow_node_def;
+
protected:
// Constructor used by subclasses for specific OperatorType's.
explicit Operator(OperatorType t)
@@ -1535,8 +1542,6 @@ struct TensorFlowUnsupportedOperator : Operator {
// The original TF operation type. Used for diagnostic purposes.
string tensorflow_op;
- // A serialized tensorflow::NodeDef string.
- string tensorflow_node_def;
// A boolean indicating if the unsupported op should be treated as quantized.
bool quantized = false;
// A boolean indicating if the unsupported op output should allow float values
diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
index d34da63e43..b6a401aaf2 100644
--- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
@@ -394,12 +394,18 @@ void ReadModelFlagsFromCommandLineFlags(
}
}
- model_flags->set_allow_nonascii_arrays(
- parsed_model_flags.allow_nonascii_arrays.value());
- model_flags->set_allow_nonexistent_arrays(
- parsed_model_flags.allow_nonexistent_arrays.value());
- model_flags->set_change_concat_input_ranges(
- parsed_model_flags.change_concat_input_ranges.value());
+ if (!model_flags->has_allow_nonascii_arrays()) {
+ model_flags->set_allow_nonascii_arrays(
+ parsed_model_flags.allow_nonascii_arrays.value());
+ }
+ if (!model_flags->has_allow_nonexistent_arrays()) {
+ model_flags->set_allow_nonexistent_arrays(
+ parsed_model_flags.allow_nonexistent_arrays.value());
+ }
+ if (!model_flags->has_change_concat_input_ranges()) {
+ model_flags->set_change_concat_input_ranges(
+ parsed_model_flags.change_concat_input_ranges.value());
+ }
if (parsed_model_flags.arrays_extra_info_file.specified()) {
string arrays_extra_info_file_contents;
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc
index 0c9fac249c..3b34cd6285 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export.cc
@@ -47,29 +47,37 @@ using ::tflite::Tensor;
namespace {
-details::OperatorKey GetOperatorKey(
- const ::toco::Operator& op,
- const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
- bool allow_flex_ops) {
- string custom_code;
- if (op.type == OperatorType::kUnsupported) {
- const TensorFlowUnsupportedOperator& unsupported_op =
- static_cast<const TensorFlowUnsupportedOperator&>(op);
+// Check if a TensorFlow Op is a control flow op by its name.
+bool IsControlFlowOp(const string& tensorflow_op) {
+ // Technically this is equalivent to `::tensorflow::Node::IsControlFlow()`.
+ // It requires to construct a `::tensorflow::Graph` to use that helper
+ // function, so we simply hardcode the list of control flow ops here.
+ if (tensorflow_op == "Switch" || tensorflow_op == "RefSwitch" ||
+ tensorflow_op == "Merge" || tensorflow_op == "RefMerge" ||
+ tensorflow_op == "Enter" || tensorflow_op == "RefEnter" ||
+ tensorflow_op == "Exit" || tensorflow_op == "RefExit" ||
+ tensorflow_op == "NextIteration" || tensorflow_op == "RefNextIteration") {
+ return true;
+ }
+ // TODO(ycling): Also check how to handle Variable ops and Assign ops.
+ return false;
+}
- // TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way
- // to populate a regular custom op. We need to find a way to fix this.
- if (allow_flex_ops) {
- custom_code = string(::tflite::kFlexCustomCodePrefix) +
- unsupported_op.tensorflow_op;
- } else {
- custom_code = unsupported_op.tensorflow_op;
+// Map from operator name to TF Lite enum value, for all builtins.
+const std::map<string, BuiltinOperator>& GetBuiltinOpsMap() {
+ static std::map<string, BuiltinOperator>* builtin_ops = nullptr;
+ if (builtin_ops == nullptr) {
+ builtin_ops = new std::map<string, BuiltinOperator>();
+
+ for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) {
+ BuiltinOperator op = static_cast<BuiltinOperator>(i);
+ string name = EnumNameBuiltinOperator(op);
+ if (op != BuiltinOperator_CUSTOM && !name.empty()) {
+ (*builtin_ops)[name] = op;
+ }
}
}
- int version = 1;
- if (ops_by_type.count(op.type) != 0) {
- version = ops_by_type.at(op.type)->GetVersion(op);
- }
- return details::OperatorKey(op.type, custom_code, version);
+ return *builtin_ops;
}
void WriteModelToString(const flatbuffers::FlatBufferBuilder& builder,
@@ -83,6 +91,72 @@ void WriteModelToString(const flatbuffers::FlatBufferBuilder& builder,
namespace details {
+OperatorKey GetOperatorKey(
+ const ::toco::Operator& op,
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
+ bool allow_flex_ops) {
+ // Get the op name (by Toco definition).
+ string name = HelpfulOperatorTypeName(op);
+
+ bool is_builtin = false;
+ OperatorKey key;
+
+ const auto& builtin_ops = GetBuiltinOpsMap();
+ if (ops_by_type.count(op.type) != 0) {
+ key.version = ops_by_type.at(op.type)->GetVersion(op);
+ name = ops_by_type.at(op.type)->name();
+ is_builtin = (builtin_ops.count(name) > 0);
+ }
+
+ if (is_builtin) {
+ // For TFLite supported builtin ops, find out its BuiltinOperator enum used
+ // in FlatBuffer.
+ key.type = builtin_ops.at(name);
+ return key;
+ }
+
+ // The logic below is all for custom ops.
+ key.is_custom_op = true;
+ key.type = BuiltinOperator_CUSTOM;
+
+ if (op.type == OperatorType::kUnsupported) {
+ const TensorFlowUnsupportedOperator& unsupported_op =
+ static_cast<const TensorFlowUnsupportedOperator&>(op);
+ const auto tensorflow_op = unsupported_op.tensorflow_op;
+
+ // TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way
+ // to populate a regular custom op. We need to find a way to fix this.
+ if (allow_flex_ops) {
+ key.is_flex_op = true;
+ key.flex_tensorflow_op = tensorflow_op;
+ key.custom_code =
+ string(::tflite::kFlexCustomCodePrefix) + key.flex_tensorflow_op;
+ } else {
+ key.custom_code = tensorflow_op;
+ }
+ } else if (allow_flex_ops && !op.tensorflow_node_def.empty()) {
+ // For Toco-supported/TFLite-unsupported ops, if the TensorFlow NodeDef
+ // is retained in the Toco Operator, we produce a Flex op if Flex mode
+ // is enabled.
+ key.is_flex_op = true;
+ key.flex_tensorflow_op = name;
+ key.custom_code =
+ string(::tflite::kFlexCustomCodePrefix) + key.flex_tensorflow_op;
+ } else {
+ // If Flex is disabled or the original TensorFlow NodeDef isn't available,
+ // we produce a custom op. This gives developers a chance to implemenr
+ // custom ops.
+ key.custom_code = name;
+ }
+
+ if (key.is_flex_op) {
+ if (IsControlFlowOp(key.flex_tensorflow_op)) {
+ key.is_unsupported_flex_op = true;
+ }
+ }
+ return key;
+}
+
void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) {
// First find a list of unique array names.
std::set<string> names;
@@ -114,6 +188,7 @@ void LoadOperatorsMap(
++index;
}
}
+
} // namespace details
Offset<Vector<Offset<Tensor>>> ExportTensors(
@@ -199,7 +274,7 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
const Model& model,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
const details::OperatorsMap& operators_map, FlatBufferBuilder* builder,
- std::set<string>* error_summary, const ExportParams& params) {
+ const ExportParams& params) {
// Map from operator name to TF Lite enum value, for all builtins.
std::map<string, BuiltinOperator> builtin_ops;
for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) {
@@ -216,37 +291,16 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
for (const auto& op : model.operators) {
const details::OperatorKey operator_key =
- GetOperatorKey(*op, ops_by_type, params.allow_flex_ops);
+ details::GetOperatorKey(*op, ops_by_type, params.allow_flex_ops);
int op_index = operators_map.at(operator_key);
- int op_version = operator_key.version;
- string name = HelpfulOperatorTypeName(*op);
- bool is_builtin = false;
- if (ops_by_type.count(op->type) != 0) {
- name = ops_by_type.at(op->type)->name();
- is_builtin = (builtin_ops.count(name) > 0);
+ flatbuffers::Offset<flatbuffers::String> custom_code = 0;
+ if (!operator_key.custom_code.empty()) {
+ custom_code = builder->CreateString(operator_key.custom_code);
}
- if (is_builtin) {
- ordered_opcodes[op_index] =
- CreateOperatorCode(*builder, builtin_ops[name], 0, op_version);
- } else {
- // This could be a kUnsupported, in which case we should be
- // able to retrieve the original Tensorflow name from the OperatorKey, or
- // this could be a proper TOCO operator that is completely unknown to TF
- // Lite.
- if (!operator_key.custom_code.empty()) {
- name = operator_key.custom_code;
- }
- // Either way, this is an operator that is not supported by TF Lite,
- // so we output it as a custom op and add it to the error summary.
- if (error_summary) {
- error_summary->insert(name);
- }
- ordered_opcodes[op_index] =
- CreateOperatorCode(*builder, BuiltinOperator_CUSTOM,
- builder->CreateString(name), op_version);
- }
+ ordered_opcodes[op_index] = CreateOperatorCode(
+ *builder, operator_key.type, custom_code, operator_key.version);
}
std::vector<Offset<OperatorCode>> opcode_vector;
@@ -280,8 +334,9 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
outputs.push_back(tensors_map.at(output));
}
- int op_index = operators_map.at(
- GetOperatorKey(*op, ops_by_type, params.allow_flex_ops));
+ const auto key =
+ details::GetOperatorKey(*op, ops_by_type, params.allow_flex_ops);
+ int op_index = operators_map.at(key);
auto tflite_op_it = ops_by_type.find(op->type);
BaseOperator* tflite_op = tflite_op_it == ops_by_type.end()
@@ -306,6 +361,11 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
variable_tensor_indices->insert(variable_tensor_index);
}
}
+ } else if (key.is_flex_op && !op->tensorflow_node_def.empty()) {
+ auto fbb = WriteFlexOpOptions(op->tensorflow_node_def);
+ if (fbb) {
+ options = Options::Custom(builder->CreateVector(fbb->GetBuffer()));
+ }
}
// The only supported CustomOptionFormat is FLEXBUFFERS now.
op_vector.push_back(CreateOperator(
@@ -355,9 +415,8 @@ void Export(
Array empty_array;
buffers_to_write.push_back(&empty_array);
- std::set<string> error_summary;
- auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map,
- &builder, &error_summary, params);
+ auto op_codes =
+ ExportOperatorCodes(model, ops_by_type, operators_map, &builder, params);
for (const auto& op : model.operators) {
if (op->type == OperatorType::kFakeQuant) {
@@ -367,30 +426,66 @@ void Export(
"for --std_values and --mean_values.";
}
}
- if (!params.allow_custom_ops && !error_summary.empty()) {
- // Remove ExpandDims and ReorderAxes from unimplemented list unless they
- // compose the list. Both ops are removed during graph transformations.
- // However, if an op is unimplemented earlier in the model, the graph
- // transformation is unable to run because the output shape is not defined.
- // This causes unnecessary confusion during model conversion time.
- std::set<string> error_summary_final;
- for (const auto& op_type : error_summary) {
- if (op_type != "ReorderAxes" && op_type != "ExpandDims") {
- error_summary_final.insert(op_type);
+
+ std::set<string> custom_ops;
+ std::set<string> unsupported_flex_ops;
+ for (const auto& it : operators_map) {
+ const details::OperatorKey& key = it.first;
+ if (key.is_custom_op) {
+ custom_ops.insert(key.custom_code);
+ }
+ if (key.is_unsupported_flex_op) {
+ unsupported_flex_ops.insert(key.flex_tensorflow_op);
+ }
+ }
+
+ if (!custom_ops.empty()) {
+ if (!params.allow_custom_ops) {
+ // Remove ExpandDims and ReorderAxes from unimplemented list unless they
+ // compose the list. Both ops are removed during graph transformations.
+ // However, if an op is unimplemented earlier in the model, the graph
+ // transformation is unable to run because the output shape is not
+ // defined. This causes unnecessary confusion during model conversion
+ // time.
+ std::set<string> custom_ops_final;
+ for (const auto& op_type : custom_ops) {
+ if (op_type != "ReorderAxes" && op_type != "ExpandDims") {
+ custom_ops_final.insert(op_type);
+ }
+ }
+ if (custom_ops_final.empty()) {
+ custom_ops_final = custom_ops;
}
+
+ LOG(QFATAL)
+ << "Some of the operators in the model are not supported by "
+ "the standard TensorFlow Lite runtime. If you have a custom "
+ "implementation for them you can disable this error with "
+ "--allow_custom_ops, or by setting allow_custom_ops=True "
+ "when calling tf.contrib.lite.TFLiteConverter(). Here is a list "
+ "of operators for which you will need custom implementations: "
+ << absl::StrJoin(custom_ops_final, ", ") << ".";
}
- if (error_summary_final.empty()) {
- error_summary_final = error_summary;
+
+ std::set<string> unsupported_control_flow_ops;
+ // Check if unsupported ops contains control flow ops. It's impossible
+ // to implement these ops as custom ops at the moment.
+ for (const auto& op : custom_ops) {
+ if (IsControlFlowOp(op)) {
+ unsupported_control_flow_ops.insert(op);
+ }
+ }
+ if (!unsupported_control_flow_ops.empty()) {
+ LOG(QFATAL)
+ << "TensorFlow Lite currently doesn't support control flow ops: "
+ << absl::StrJoin(unsupported_control_flow_ops, ", ") << ".";
}
+ }
- LOG(QFATAL)
- << "Some of the operators in the model are not supported by "
- "the standard TensorFlow Lite runtime. If you have a custom "
- "implementation for them you can disable this error with "
- "--allow_custom_ops, or by setting allow_custom_ops=True "
- "when calling tf.contrib.lite.TFLiteConverter(). Here is a list "
- "of operators for which you will need custom implementations: "
- << absl::StrJoin(error_summary_final, ", ") << ".";
+ if (!unsupported_flex_ops.empty()) {
+ LOG(QFATAL) << "Some of the operators in the model are not supported by "
+ "TensorFlow Flex runtime: "
+ << absl::StrJoin(unsupported_flex_ops, ", ") << ".";
}
std::set<int32_t> variable_tensor_indices;
diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h
index 29d6de4049..c627f48086 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.h
+++ b/tensorflow/contrib/lite/toco/tflite/export.h
@@ -81,11 +81,25 @@ using TensorsMap = std::unordered_map<string, int>;
// Only when `type` is `kUnsupported`, `custom_code` is filled to
// identify which operation is used.
struct OperatorKey {
- OperatorKey(OperatorType type, const std::string& custom_code, int version)
+ OperatorKey() {}
+ OperatorKey(::tflite::BuiltinOperator type, const std::string& custom_code,
+ int version)
: type(type), custom_code(custom_code), version(version) {}
- const OperatorType type;
- const std::string custom_code;
- const int version;
+
+ // Only `type`, `custom_code` and `version` is used to compute hash and
+ // identity.
+ ::tflite::BuiltinOperator type = ::tflite::BuiltinOperator_CUSTOM;
+ std::string custom_code;
+ int version = 1;
+
+ // The fields below are not used to compute hash and identity.
+ // TODO(ycling): Consider to change these fields to accessor functions.
+ bool is_custom_op = false;
+ bool is_flex_op = false;
+ bool is_unsupported_flex_op = false;
+ // The original TensorFlow op name for the flex op. Filled only when
+ // `is_flex_op` is true.
+ std::string flex_tensorflow_op;
bool operator<(const OperatorKey& other) const {
if (type < other.type) return true;
@@ -114,6 +128,11 @@ struct OperatorKey {
};
};
+OperatorKey GetOperatorKey(
+ const ::toco::Operator& op,
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
+ bool allow_flex_ops);
+
// A maps from operator type to its final position in the TF Lite buffer.
using OperatorsMap = std::unordered_map<OperatorKey, int, OperatorKey::Hash>;
diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc
index 93882a91a7..eda1aa78a3 100644
--- a/tensorflow/contrib/lite/toco/tflite/export_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h"
#include "tensorflow/contrib/lite/toco/tflite/operator.h"
#include "tensorflow/contrib/lite/toco/tflite/types.h"
+#include "tensorflow/core/framework/node_def.pb.h"
namespace toco {
namespace tflite {
@@ -105,13 +106,15 @@ TEST_F(ExportTest, LoadOperatorsMap) {
details::OperatorsMap operators;
const auto ops_by_type = BuildOperatorByTypeMap();
- // TODO(ycling): Add a test for allow_flex_ops.
details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
- EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "", 1)]);
- EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "", 1)]);
- EXPECT_EQ(2, operators[details::OperatorKey(OperatorType::kSub, "", 1)]);
- EXPECT_EQ(3, operators[details::OperatorKey(OperatorType::kUnsupported,
+ EXPECT_EQ(
+ 0, operators[details::OperatorKey(::tflite::BuiltinOperator_ADD, "", 1)]);
+ EXPECT_EQ(1, operators[details::OperatorKey(::tflite::BuiltinOperator_CONV_2D,
+ "", 1)]);
+ EXPECT_EQ(2, operators[details::OperatorKey(::tflite::BuiltinOperator_CUSTOM,
"MyCrazyOp", 1)]);
+ EXPECT_EQ(
+ 3, operators[details::OperatorKey(::tflite::BuiltinOperator_SUB, "", 1)]);
}
TEST_F(ExportTest, Export) {
@@ -133,7 +136,7 @@ TEST_F(ExportTest, Export) {
}
EXPECT_THAT(names, ElementsAre("builtin:ADD", "builtin:CONV_2D",
- "builtin:SUB", "custom:MyCrazyOp"));
+ "custom:MyCrazyOp", "builtin:SUB"));
std::vector<uint32_t> indices;
auto operators = (*model->subgraphs())[0]->operators();
@@ -142,7 +145,7 @@ TEST_F(ExportTest, Export) {
indices.push_back(op->opcode_index());
}
- EXPECT_THAT(indices, ElementsAre(1, 0, 3, 2));
+ EXPECT_THAT(indices, ElementsAre(1, 0, 2, 3));
}
TEST_F(ExportTest, QuantizeWeights) {
@@ -257,7 +260,8 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV1) {
details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
EXPECT_EQ(1, operators.size());
- EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1)));
+ EXPECT_EQ(0, operators.at(details::OperatorKey(
+ ::tflite::BuiltinOperator_CONV_2D, "", 1)));
}
TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV2) {
@@ -268,7 +272,8 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV2) {
details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
EXPECT_EQ(1, operators.size());
- EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 2)));
+ EXPECT_EQ(0, operators.at(details::OperatorKey(
+ ::tflite::BuiltinOperator_CONV_2D, "", 2)));
}
TEST_F(VersionedOpExportTest, LoadOperatorsMapWithBothVersions) {
@@ -280,8 +285,10 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithBothVersions) {
details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
EXPECT_EQ(2, operators.size());
- EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1)));
- EXPECT_EQ(1, operators.at(details::OperatorKey(OperatorType::kConv, "", 2)));
+ EXPECT_EQ(0, operators.at(details::OperatorKey(
+ ::tflite::BuiltinOperator_CONV_2D, "", 1)));
+ EXPECT_EQ(1, operators.at(details::OperatorKey(
+ ::tflite::BuiltinOperator_CONV_2D, "", 2)));
}
TEST_F(VersionedOpExportTest, Export) {
@@ -313,6 +320,102 @@ TEST_F(VersionedOpExportTest, Export) {
EXPECT_EQ(1, (*operators)[1]->opcode_index());
}
+TEST(OperatorKeyTest, TestBuiltinOp) {
+ auto op = absl::make_unique<ConvOperator>();
+
+ const auto ops_by_type = BuildOperatorByTypeMap();
+ const auto key = details::GetOperatorKey(*op, ops_by_type, false);
+
+ EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CONV_2D);
+ EXPECT_EQ(key.custom_code, "");
+ EXPECT_EQ(key.version, 1);
+}
+
+TEST(OperatorKeyTest, TestCustomOp) {
+ auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
+ op->tensorflow_op = "MyCrazyCustomOp";
+
+ const auto ops_by_type = BuildOperatorByTypeMap();
+ const auto key = details::GetOperatorKey(*op, ops_by_type, false);
+
+ EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CUSTOM);
+ EXPECT_EQ(key.custom_code, "MyCrazyCustomOp");
+ EXPECT_EQ(key.version, 1);
+}
+
+TEST(OperatorKeyTest, TestFlexOp) {
+ auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
+ op->tensorflow_op = "BatchMatMul";
+
+ const auto ops_by_type = BuildOperatorByTypeMap();
+ {
+ const auto key = details::GetOperatorKey(*op, ops_by_type, false);
+ // It shouldn't be converted to Flex op if `allow_flex_op` is false.
+ EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CUSTOM);
+ EXPECT_EQ(key.custom_code, "BatchMatMul");
+ EXPECT_EQ(key.version, 1);
+ EXPECT_FALSE(key.is_flex_op);
+ }
+
+ {
+ // Verify that the custom op name is prefixed by "Flex" and `is_flex_op`
+ // is true.
+ const auto key = details::GetOperatorKey(*op, ops_by_type, true);
+ EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CUSTOM);
+ EXPECT_EQ(key.custom_code, "FlexBatchMatMul");
+ EXPECT_EQ(key.version, 1);
+ EXPECT_TRUE(key.is_flex_op);
+ }
+}
+
+TEST(OperatorKeyTest, TestFlexWithControlFlowOp) {
+ auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
+ op->tensorflow_op = "Merge";
+
+ const auto ops_by_type = BuildOperatorByTypeMap();
+ const auto key = details::GetOperatorKey(*op, ops_by_type, true);
+
+ EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CUSTOM);
+ EXPECT_EQ(key.custom_code, "FlexMerge");
+ EXPECT_EQ(key.version, 1);
+ EXPECT_TRUE(key.is_flex_op);
+ // The control flow ops should be marked as unsupported.
+ EXPECT_TRUE(key.is_unsupported_flex_op);
+}
+
+TEST(OperatorKeyTest, TestFlexWithPartiallySupportedOps) {
+ // Test Toco-supported/TFLite-unsupported operators.
+ // TODO(ycling): The test will be broken if Range is implemented in TFLite.
+ // Find a more robust way to test the fallback logic.
+ auto op = absl::make_unique<RangeOperator>();
+
+ const auto ops_by_type = BuildOperatorByTypeMap();
+
+ {
+ // If NodeDef isn't retained in the Toco op, a regular custom op
+ // will be exported.
+ const auto key = details::GetOperatorKey(*op, ops_by_type, true);
+ EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CUSTOM);
+ EXPECT_EQ(key.custom_code, "Range");
+ EXPECT_EQ(key.version, 1);
+ EXPECT_FALSE(key.is_flex_op);
+ }
+
+ ::tensorflow::NodeDef node_def;
+ node_def.set_name("Range");
+ node_def.set_op("Range");
+ node_def.SerializeToString(&op->tensorflow_node_def);
+
+ {
+ // If NodeDef is retained in the Toco op, a Flex op will be exported.
+ const auto key = details::GetOperatorKey(*op, ops_by_type, true);
+ EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CUSTOM);
+ EXPECT_EQ(key.custom_code, "FlexRange");
+ EXPECT_EQ(key.version, 1);
+ EXPECT_TRUE(key.is_flex_op);
+ }
+}
+
// TODO(ahentz): tests for tensors, inputs, outputs, opcodes and operators.
} // namespace
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 9addbb81e7..ed37535fe0 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -1157,6 +1157,25 @@ class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
int GetVersion(const Operator& op) const override { return 1; }
};
+std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
+ const string& tensorflow_node_def) {
+ auto fbb = absl::make_unique<flexbuffers::Builder>();
+
+ ::tensorflow::NodeDef node_def;
+ if (!node_def.ParseFromString(tensorflow_node_def)) {
+ LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
+ return {};
+ }
+
+ fbb->Vector([&]() {
+ fbb->String(node_def.op());
+ fbb->String(tensorflow_node_def);
+ });
+ fbb->Finish();
+ LOG(INFO) << "Writing flex op: " << node_def.op();
+ return std::unique_ptr<flexbuffers::Builder>(fbb.release());
+}
+
class TensorFlowUnsupported : public BaseOperator {
public:
TensorFlowUnsupported(const string& name, OperatorType type,
@@ -1192,6 +1211,9 @@ class TensorFlowUnsupported : public BaseOperator {
std::unique_ptr<flexbuffers::Builder> WriteOptions(
const TensorFlowUnsupportedOperator& op) const {
+ if (allow_flex_ops_) {
+ return WriteFlexOpOptions(op.tensorflow_node_def);
+ }
auto fbb = absl::make_unique<flexbuffers::Builder>();
::tensorflow::NodeDef node_def;
@@ -1200,16 +1222,6 @@ class TensorFlowUnsupported : public BaseOperator {
return std::unique_ptr<flexbuffers::Builder>();
}
- if (allow_flex_ops_) {
- fbb->Vector([&]() {
- fbb->String(node_def.op());
- fbb->String(op.tensorflow_node_def);
- });
- fbb->Finish();
- LOG(INFO) << "Writing flex op: " << node_def.op();
- return std::unique_ptr<flexbuffers::Builder>(fbb.release());
- }
-
bool has_valid_attr = false;
size_t map_start = fbb->StartMap();
for (const auto& pair : node_def.attr()) {
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h
index 13d9f6c49a..6e4e0a16d1 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.h
+++ b/tensorflow/contrib/lite/toco/tflite/operator.h
@@ -16,6 +16,7 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_
#include "flatbuffers/flatbuffers.h"
+#include "flatbuffers/flexbuffers.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
#include "tensorflow/contrib/lite/toco/model.h"
@@ -36,6 +37,11 @@ std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
bool allow_flex_ops = false);
+// Write the custom option FlexBuffer with a serialized TensorFlow NodeDef
+// for a Flex op.
+std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
+ const string& tensorflow_node_def);
+
// These are the flatbuffer types for custom and builtin options.
using CustomOptions = flatbuffers::Vector<uint8_t>;
using BuiltinOptions = void;
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 4a1ae35cb5..e3f27e9e2a 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -843,24 +843,43 @@ void CheckNonAsciiIOArrays(const ModelFlags& model_flags) {
}
void CheckNonExistentIOArrays(const Model& model) {
+ // "non-existent" is interpreted in the stronger sense of
+ // "not actually produced/consumed by an op".
+ // Rationale: we have to artificially fix up TensorFlow graphs by creating
+ // any array that it refers to, so just checking that arrays exist isn't
+ // sufficient. The real invariant here is whether arrays are produced/consumed
+ // by something.
if (model.flags.allow_nonexistent_arrays()) {
return;
}
+ static constexpr char general_comment[] =
+ "Is it a typo? To silence this message, pass this flag: "
+ "allow_nonexistent_arrays";
for (const auto& input_array : model.flags.input_arrays()) {
- CHECK(model.HasArray(input_array.name()))
- << "Input array not found: " << input_array.name();
+ QCHECK(GetOpWithInput(model, input_array.name()))
+ << "Specified input array \"" << input_array.name()
+ << "\" is not consumed by any op in this graph. " << general_comment;
}
for (const string& output_array : model.flags.output_arrays()) {
- CHECK(model.HasArray(output_array))
- << "Output array not found: " << output_array;
+ QCHECK(GetOpWithOutput(model, output_array))
+ << "Specified output array \"" << output_array
+ << "\" is not produced by any op in this graph. " << general_comment;
}
for (const auto& rnn_state : model.flags.rnn_states()) {
if (!rnn_state.discardable()) {
- CHECK(model.HasArray(rnn_state.state_array()));
- CHECK(model.HasArray(rnn_state.back_edge_source_array()));
+ // Check that all RNN states are consumed
+ QCHECK(GetOpWithInput(model, rnn_state.state_array()))
+ << "Specified RNN state \"" << rnn_state.state_array()
+ << "\" is not consumed by any op in this graph. " << general_comment;
+ // Check that all RNN back-edge source arrays are produced
+ QCHECK(GetOpWithOutput(model, rnn_state.back_edge_source_array()))
+ << "Specified RNN back-edge source array \""
+ << rnn_state.back_edge_source_array()
+ << "\" is not produced by any op in this graph. " << general_comment;
}
}
}
+
} // namespace
void CheckNoMissingArray(const Model& model) {
@@ -1597,6 +1616,7 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
input_array.GetOrCreateMinMax() = input_minmax;
}
}
+
// Creation of the RNN state arrays
for (const auto& rnn_state : model->flags.rnn_states()) {
CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(),
diff --git a/tensorflow/contrib/lite/tools/benchmark/BUILD b/tensorflow/contrib/lite/tools/benchmark/BUILD
index 502e181139..71bf61657e 100644
--- a/tensorflow/contrib/lite/tools/benchmark/BUILD
+++ b/tensorflow/contrib/lite/tools/benchmark/BUILD
@@ -40,7 +40,7 @@ cc_binary(
srcs = [
"benchmark_main.cc",
],
- copts = common_copts + ["-DTFLITE_FLEX"],
+ copts = common_copts,
linkopts = tflite_linkopts() + select({
"//tensorflow:android": [
"-pie", # Android 5.0 and later supports only PIE
@@ -49,8 +49,9 @@ cc_binary(
"//conditions:default": [],
}),
deps = [
- ":benchmark_tflite_model_plus_flex_lib",
+ ":benchmark_tflite_model_lib",
":logging",
+ "//tensorflow/contrib/lite/delegates/flex:delegate",
],
)
@@ -111,25 +112,6 @@ cc_library(
)
cc_library(
- name = "benchmark_tflite_model_plus_flex_lib",
- srcs = [
- "benchmark_tflite_model.cc",
- "logging.h",
- ],
- hdrs = ["benchmark_tflite_model.h"],
- copts = common_copts + ["-DTFLITE_FLEX"],
- deps = [
- ":benchmark_model_lib",
- ":logging",
- "//tensorflow/contrib/lite:framework",
- "//tensorflow/contrib/lite:string_util",
- "//tensorflow/contrib/lite/delegates/flex:delegate",
- "//tensorflow/contrib/lite/kernels:builtin_ops",
- "//tensorflow/contrib/lite/profiling:profile_summarizer",
- ],
-)
-
-cc_library(
name = "benchmark_params",
srcs = [
"benchmark_params.cc",
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
index 463d5993f4..2a3df7f289 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
@@ -23,9 +23,6 @@ limitations under the License.
#include <unordered_set>
#include <vector>
-#ifdef TFLITE_FLEX
-#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
-#endif // TFLITE_FLEX
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/op_resolver.h"
@@ -305,15 +302,6 @@ void BenchmarkTfLiteModel::Init() {
interpreter->UseNNAPI(use_nnapi);
-#ifdef TFLITE_FLEX
- TFLITE_LOG(INFO) << "Instantiating Flex Delegate";
- delegate_ = FlexDelegate::Create();
- if (delegate_) {
- interpreter->ModifyGraphWithDelegate(delegate_.get(),
- /*allow_dynamic_tensors=*/true);
- }
-#endif // TFLITE_FLEX
-
auto interpreter_inputs = interpreter->inputs();
if (!inputs.empty()) {
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
index b091e18a29..25a302b2aa 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
@@ -20,9 +20,6 @@ limitations under the License.
#include <string>
#include <vector>
-#ifdef TFLITE_FLEX
-#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
-#endif // TFLITE_FLEX
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/profiling/profile_summarizer.h"
#include "tensorflow/contrib/lite/tools/benchmark/benchmark_model.h"
@@ -73,9 +70,6 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
void PrepareInputsAndOutputs() override;
private:
-#ifdef TFLITE_FLEX
- std::unique_ptr<FlexDelegate> delegate_;
-#endif // TFLITE_FLEX
std::unique_ptr<tflite::FlatBufferModel> model;
std::unique_ptr<tflite::Interpreter> interpreter;
std::vector<InputLayerInfo> inputs;
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index 89b538d1ba..9e9345e875 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -23,8 +23,8 @@ import numpy as np
import six
from tensorflow.contrib import lookup
-from tensorflow.contrib.data.python.ops import counter
from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import counter
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md
index 15d95896d9..b313024e28 100644
--- a/tensorflow/contrib/model_pruning/README.md
+++ b/tensorflow/contrib/model_pruning/README.md
@@ -62,6 +62,7 @@ The pruning library allows for specification of the following hyper parameters:
| sparsity_function_begin_step | integer | 0 | The global step at this which the gradual sparsity function begins to take effect |
| sparsity_function_end_step | integer | 100 | The global step used as the end point for the gradual sparsity function |
| sparsity_function_exponent | float | 3.0 | exponent = 1 is linearly varying sparsity between initial and final. exponent > 1 varies more slowly towards the end than the beginning |
+| use_tpu | bool | False | Training using TPUs? |
The sparsity $$s_t$$ at global step $$t$$ is given by:
diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD
index 6a67c6295d..f4ac70eb1a 100644
--- a/tensorflow/contrib/opt/BUILD
+++ b/tensorflow/contrib/opt/BUILD
@@ -377,11 +377,6 @@ py_test(
size = "large",
srcs = ["python/training/shampoo_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "noasan", # b/116875897
- "nomsan",
- "notsan",
- ],
deps = [
":opt_py",
"//tensorflow/python:array_ops",
diff --git a/tensorflow/contrib/opt/python/training/shampoo.py b/tensorflow/contrib/opt/python/training/shampoo.py
index f161521b97..e542f46892 100644
--- a/tensorflow/contrib/opt/python/training/shampoo.py
+++ b/tensorflow/contrib/opt/python/training/shampoo.py
@@ -108,7 +108,8 @@ class ShampooOptimizer(optimizer.Optimizer):
precond_update_interval: We should update the preconditioners after
this many steps. Default = 1. Usually less than
svd_interval.
- epsilon: epsilon * I_n is added to each mat_gbar_j for stability
+ epsilon: epsilon * I_n is added to each mat_gbar_j for stability for
+ non-diagonal version of shampoo.
alpha: total power of the preconditioners.
use_iterative_root: should the optimizer use SVD (faster) or the
iterative root method (for TPU) for finding the
@@ -394,15 +395,20 @@ class ShampooOptimizer(optimizer.Optimizer):
assert self._mat_gbar_decay == 1.0
mat_g_updated = state_ops.scatter_add(mat_g, indices,
mat_gbar_weight_t * grad_outer)
- mat_h = math_ops.pow(
- array_ops.gather(mat_g_updated, indices) + self._epsilon,
- neg_alpha)
+ mat_g_updated_slice = array_ops.gather(mat_g_updated, indices)
+ mat_h = array_ops.where(
+ math_ops.greater(mat_g_updated_slice, 0),
+ math_ops.pow(mat_g_updated_slice, neg_alpha),
+ array_ops.zeros_like(mat_g_updated_slice))
else:
mat_g_updated = self._weighted_average(mat_g,
self._mat_gbar_decay,
mat_gbar_decay_t,
mat_gbar_weight_t * grad_outer)
- mat_h = math_ops.pow(mat_g_updated + self._epsilon, neg_alpha)
+ mat_h = array_ops.where(
+ math_ops.greater(mat_g_updated, 0),
+ math_ops.pow(mat_g_updated, neg_alpha),
+ array_ops.zeros_like(mat_g_updated))
# Need to do the transpose to ensure that the tensor becomes
# a d_{i+1} x ... x d_n x d_0 x ... d_i tensor as described above.
diff --git a/tensorflow/contrib/opt/python/training/shampoo_test.py b/tensorflow/contrib/opt/python/training/shampoo_test.py
index a2fd8fbd87..e88c8221a0 100644
--- a/tensorflow/contrib/opt/python/training/shampoo_test.py
+++ b/tensorflow/contrib/opt/python/training/shampoo_test.py
@@ -279,7 +279,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
# Update rule is var = var - lr * gg^{-0.5} * grad
# lr = 1
mat_g = (grad_np * grad_np)
- new_val_np = init_var_np - np.power(mat_g + RIDGE_EPSILON, -0.5) * grad_np
+ new_val_np = init_var_np - np.power(mat_g, -0.5) * grad_np
self.assertAllCloseAccordingToType(
new_val_np, new_val, atol=TOLERANCE, rtol=TOLERANCE)
@@ -288,7 +288,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
new_val = sess.run(var)
mat_g += (grad_np_2 * grad_np_2)
- new_val_np -= np.power(mat_g + RIDGE_EPSILON, -0.5) * grad_np_2
+ new_val_np -= np.power(mat_g, -0.5) * grad_np_2
self.assertAllCloseAccordingToType(
new_val_np, new_val, atol=TOLERANCE, rtol=TOLERANCE)
@@ -339,7 +339,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
mat_g1 = np.sum(
grad_np * grad_np, axis=1, keepdims=True) / grad_np.shape[0]
- mat_left = np.power(mat_g1 + RIDGE_EPSILON, -0.25)
+ mat_left = np.power(mat_g1, -0.25)
mat_g2 = np.dot(grad_np.transpose(), grad_np) / grad_np.shape[1]
mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
new_val_np = init_var_np - np.dot(grad_np * mat_left, mat_right)
@@ -353,7 +353,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
mat_g1 += np.sum(
grad_np_2 * grad_np_2, axis=1, keepdims=True) / grad_np_2.shape[0]
- mat_left = np.power(mat_g1 + RIDGE_EPSILON, -0.25)
+ mat_left = np.power(mat_g1, -0.25)
mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) / grad_np_2.shape[1]
mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
new_val_np -= np.dot(grad_np_2 * mat_left, mat_right)
diff --git a/tensorflow/contrib/optimizer_v2/BUILD b/tensorflow/contrib/optimizer_v2/BUILD
index 3ba3ee29ec..2cf445a85e 100644
--- a/tensorflow/contrib/optimizer_v2/BUILD
+++ b/tensorflow/contrib/optimizer_v2/BUILD
@@ -47,15 +47,8 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:distribute",
- "//tensorflow/python:framework",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:resource_variable_ops",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
+ "//tensorflow/python:util",
+ "//tensorflow/python/keras:optimizer_v2",
],
)
diff --git a/tensorflow/contrib/optimizer_v2/adadelta.py b/tensorflow/contrib/optimizer_v2/adadelta.py
index b206f9f61b..9d73bddd1c 100644
--- a/tensorflow/contrib/optimizer_v2/adadelta.py
+++ b/tensorflow/contrib/optimizer_v2/adadelta.py
@@ -18,17 +18,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.optimizer_v2 import optimizer_v2
-from tensorflow.python.training import training_ops
+from tensorflow.python.keras.optimizer_v2 import adadelta
+from tensorflow.python.util import deprecation
-class AdadeltaOptimizer(optimizer_v2.OptimizerV2):
+class AdadeltaOptimizer(adadelta.Adadelta):
"""Optimizer that implements the Adadelta algorithm.
See [M. D. Zeiler](http://arxiv.org/abs/1212.5701)
([pdf](http://arxiv.org/pdf/1212.5701v1.pdf))
"""
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self, learning_rate=0.001, rho=0.95, epsilon=1e-8,
use_locking=False, name="Adadelta"):
"""Construct a new Adadelta optimizer.
@@ -48,66 +52,5 @@ class AdadeltaOptimizer(optimizer_v2.OptimizerV2):
name: Optional name prefix for the operations created when applying
gradients. Defaults to "Adadelta".
"""
- super(AdadeltaOptimizer, self).__init__(use_locking, name)
- self._set_hyper("learning_rate", learning_rate)
- self._set_hyper("rho", rho)
- self._set_hyper("epsilon", epsilon)
-
- def _create_vars(self, var_list, state):
- for v in var_list:
- state.zeros_slot(v, "accum")
- state.zeros_slot(v, "accum_update")
-
- def _apply_dense(self, grad, var, state):
- accum = state.get_slot(var, "accum")
- accum_update = state.get_slot(var, "accum_update")
- return training_ops.apply_adadelta(
- var,
- accum,
- accum_update,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("rho", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
- grad,
- use_locking=self._use_locking)
-
- def _resource_apply_dense(self, grad, var, state):
- accum = state.get_slot(var, "accum")
- accum_update = state.get_slot(var, "accum_update")
- return training_ops.resource_apply_adadelta(
- var.handle,
- accum.handle,
- accum_update.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("rho", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
- grad,
- use_locking=self._use_locking)
-
- def _apply_sparse(self, grad, var, state):
- accum = state.get_slot(var, "accum")
- accum_update = state.get_slot(var, "accum_update")
- return training_ops.sparse_apply_adadelta(
- var,
- accum,
- accum_update,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("rho", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
- grad.values,
- grad.indices,
- use_locking=self._use_locking)
-
- def _resource_apply_sparse(self, grad, var, indices, state):
- accum = state.get_slot(var, "accum")
- accum_update = state.get_slot(var, "accum_update")
- return training_ops.resource_sparse_apply_adadelta(
- var.handle,
- accum.handle,
- accum_update.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("rho", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
- grad,
- indices,
- use_locking=self._use_locking)
+ super(AdadeltaOptimizer, self).__init__(
+ learning_rate=learning_rate, rho=rho, epsilon=epsilon, name=name)
diff --git a/tensorflow/contrib/optimizer_v2/adagrad.py b/tensorflow/contrib/optimizer_v2/adagrad.py
index dab1e02716..716361e29c 100644
--- a/tensorflow/contrib/optimizer_v2/adagrad.py
+++ b/tensorflow/contrib/optimizer_v2/adagrad.py
@@ -18,15 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.optimizer_v2 import optimizer_v2
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_array_ops
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.training import training_ops
+from tensorflow.python.keras.optimizer_v2 import adagrad
+from tensorflow.python.util import deprecation
-class AdagradOptimizer(optimizer_v2.OptimizerV2):
+class AdagradOptimizer(adagrad.Adagrad):
"""Optimizer that implements the Adagrad algorithm.
See this [paper](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
@@ -34,6 +30,10 @@ class AdagradOptimizer(optimizer_v2.OptimizerV2):
[intro](https://ppasupat.github.io/a9online/uploads/proximal_notes.pdf).
"""
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self, learning_rate, initial_accumulator_value=0.1,
use_locking=False, name="Adagrad"):
"""Construct a new Adagrad optimizer.
@@ -54,64 +54,7 @@ class AdagradOptimizer(optimizer_v2.OptimizerV2):
Raises:
ValueError: If the `initial_accumulator_value` is invalid.
"""
- if initial_accumulator_value <= 0.0:
- raise ValueError("initial_accumulator_value must be positive: %s" %
- initial_accumulator_value)
- super(AdagradOptimizer, self).__init__(use_locking, name)
- self._set_hyper("learning_rate", learning_rate)
-
- self._initial_accumulator_value = initial_accumulator_value
-
- def _create_vars(self, var_list, state):
- for v in var_list:
- dtype = v.dtype.base_dtype
- if v.get_shape().is_fully_defined():
- init = init_ops.constant_initializer(self._initial_accumulator_value,
- dtype=dtype)
- else:
- def init(v=v, dtype=dtype):
- # Use a Tensor instead of initializer if variable does not have
- # static shape.
- init_constant = gen_array_ops.fill(array_ops.shape(v),
- self._initial_accumulator_value)
- return math_ops.cast(init_constant, dtype)
- state.create_slot_with_initializer(v, init, v.get_shape(), dtype,
- "accumulator")
-
- def _apply_dense(self, grad, var, state):
- acc = state.get_slot(var, "accumulator")
- return training_ops.apply_adagrad(
- var,
- acc,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- use_locking=self._use_locking)
-
- def _resource_apply_dense(self, grad, var, state):
- acc = state.get_slot(var, "accumulator")
- return training_ops.resource_apply_adagrad(
- var.handle,
- acc.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- use_locking=self._use_locking)
-
- def _apply_sparse(self, grad, var, state):
- acc = state.get_slot(var, "accumulator")
- return training_ops.sparse_apply_adagrad(
- var,
- acc,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad.values,
- grad.indices,
- use_locking=self._use_locking)
-
- def _resource_apply_sparse(self, grad, var, indices, state):
- acc = state.get_slot(var, "accumulator")
- return training_ops.resource_sparse_apply_adagrad(
- var.handle,
- acc.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- indices,
- use_locking=self._use_locking)
+ super(AdagradOptimizer, self).__init__(
+ learning_rate=learning_rate,
+ initial_accumulator_value=initial_accumulator_value,
+ name=name)
diff --git a/tensorflow/contrib/optimizer_v2/adagrad_test.py b/tensorflow/contrib/optimizer_v2/adagrad_test.py
index debaaaeeba..320e41567f 100644
--- a/tensorflow/contrib/optimizer_v2/adagrad_test.py
+++ b/tensorflow/contrib/optimizer_v2/adagrad_test.py
@@ -68,9 +68,6 @@ class AdagradOptimizerTest(test.TestCase):
def testBasicResource(self):
self.doTestBasic(use_locking=False, use_resource=True)
- def testBasicLocked(self):
- self.doTestBasic(use_locking=True)
-
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.cached_session():
diff --git a/tensorflow/contrib/optimizer_v2/adam.py b/tensorflow/contrib/optimizer_v2/adam.py
index 04b1552b61..363e020757 100644
--- a/tensorflow/contrib/optimizer_v2/adam.py
+++ b/tensorflow/contrib/optimizer_v2/adam.py
@@ -18,22 +18,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.optimizer_v2 import optimizer_v2
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import state_ops
-from tensorflow.python.training import training_ops
+from tensorflow.python.keras.optimizer_v2 import adam
+from tensorflow.python.util import deprecation
-class AdamOptimizer(optimizer_v2.OptimizerV2):
+class AdamOptimizer(adam.Adam):
"""Optimizer that implements the Adam algorithm.
See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
"""
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
use_locking=False, name="Adam"):
"""Construct a new Adam optimizer.
@@ -87,111 +86,9 @@ class AdamOptimizer(optimizer_v2.OptimizerV2):
name: Optional name for the operations created when applying gradients.
Defaults to "Adam".
"""
- super(AdamOptimizer, self).__init__(use_locking, name)
-
- self._set_hyper("learning_rate", learning_rate)
- self._set_hyper("beta1", beta1)
- self._set_hyper("beta2", beta2)
- self._set_hyper("epsilon", epsilon)
-
- def _get_beta_accumulators(self, state=None):
- if state is None:
- state = self._get_per_graph_state()
- return (state.get_non_slot("beta1_power"),
- state.get_non_slot("beta2_power"))
-
- def _create_vars(self, var_list, state):
- # Non-slot variables end up on the same device(s).
- state.create_non_slot(initial_value=lambda: state.get_hyper("beta1"),
- name="beta1_power")
- state.create_non_slot(initial_value=lambda: state.get_hyper("beta2"),
- name="beta2_power")
-
- # Create slots for the first and second moments.
- for v in var_list:
- state.zeros_slot(v, "m")
- state.zeros_slot(v, "v")
-
- def _apply_dense(self, grad, var, state):
- m = state.get_slot(var, "m")
- v = state.get_slot(var, "v")
- beta1_power, beta2_power = self._get_beta_accumulators(state)
- return training_ops.apply_adam(
- var, m, v,
- math_ops.cast(beta1_power, var.dtype.base_dtype),
- math_ops.cast(beta2_power, var.dtype.base_dtype),
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("beta1", var.dtype.base_dtype),
- state.get_hyper("beta2", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
- grad, use_locking=self._use_locking).op
-
- def _resource_apply_dense(self, grad, var, state):
- m = state.get_slot(var, "m")
- v = state.get_slot(var, "v")
- beta1_power, beta2_power = self._get_beta_accumulators(state)
- return training_ops.resource_apply_adam(
- var.handle, m.handle, v.handle,
- math_ops.cast(beta1_power, grad.dtype.base_dtype),
- math_ops.cast(beta2_power, grad.dtype.base_dtype),
- state.get_hyper("learning_rate", grad.dtype.base_dtype),
- state.get_hyper("beta1", grad.dtype.base_dtype),
- state.get_hyper("beta2", grad.dtype.base_dtype),
- state.get_hyper("epsilon", grad.dtype.base_dtype),
- grad, use_locking=self._use_locking)
-
- def _apply_sparse_shared(self, grad, var, indices, scatter_add, state):
- beta1_power, beta2_power = self._get_beta_accumulators(state)
- beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
- beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
- lr_t = state.get_hyper("learning_rate", var.dtype.base_dtype)
- beta1_t = state.get_hyper("beta1", var.dtype.base_dtype)
- beta2_t = state.get_hyper("beta2", var.dtype.base_dtype)
- epsilon_t = state.get_hyper("epsilon", var.dtype.base_dtype)
- lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
- # m_t = beta1 * m + (1 - beta1) * g_t
- m = state.get_slot(var, "m")
- m_scaled_g_values = grad * (1 - beta1_t)
- m_t = state_ops.assign(m, m * beta1_t,
- use_locking=self._use_locking)
- with ops.control_dependencies([m_t]):
- m_t = scatter_add(m, indices, m_scaled_g_values)
- # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
- v = state.get_slot(var, "v")
- v_scaled_g_values = (grad * grad) * (1 - beta2_t)
- v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
- with ops.control_dependencies([v_t]):
- v_t = scatter_add(v, indices, v_scaled_g_values)
- v_sqrt = math_ops.sqrt(v_t)
- var_update = state_ops.assign_sub(var,
- lr * m_t / (v_sqrt + epsilon_t),
- use_locking=self._use_locking)
- return control_flow_ops.group(*[var_update, m_t, v_t])
-
- def _apply_sparse(self, grad, var, state):
- return self._apply_sparse_shared(
- grad.values, var, grad.indices,
- lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
- x, i, v, use_locking=self._use_locking),
- state)
-
- def _resource_scatter_add(self, x, i, v):
- with ops.control_dependencies(
- [resource_variable_ops.resource_scatter_add(
- x.handle, i, v)]):
- return x.value()
-
- def _resource_apply_sparse(self, grad, var, indices, state):
- return self._apply_sparse_shared(
- grad, var, indices, self._resource_scatter_add, state)
-
- def _finish(self, state):
- # Update the power accumulators.
- beta1_power, beta2_power = self._get_beta_accumulators(state)
- update_beta1 = beta1_power.assign(
- beta1_power * state.get_hyper("beta1"),
- use_locking=self._use_locking)
- update_beta2 = beta2_power.assign(
- beta2_power * state.get_hyper("beta2"),
- use_locking=self._use_locking)
- return control_flow_ops.group(update_beta1, update_beta2)
+ super(AdamOptimizer, self).__init__(
+ learning_rate=learning_rate,
+ beta_1=beta1,
+ beta_2=beta2,
+ epsilon=epsilon,
+ name=name)
diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
index e13b82d1d2..3c68ef995a 100644
--- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
+++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
@@ -130,8 +130,8 @@ class CheckpointingTests(test.TestCase):
# non-Layer dependency of the model
"model/_non_layer/a_variable",
# The optimizer creates two non-slot variables
- "optimizer/beta1_power",
- "optimizer/beta2_power",
+ "optimizer/beta_1_power",
+ "optimizer/beta_2_power",
# Slot variables
"model/_second/kernel/.OPTIMIZER_SLOT/optimizer/m",
"model/_second/kernel/.OPTIMIZER_SLOT/optimizer/v",
@@ -161,21 +161,20 @@ class CheckpointingTests(test.TestCase):
"my_model/dense/kernel",
named_variables["model/_named_dense/kernel" + suffix].full_name)
self.assertEqual(
- "beta1_power",
- named_variables["optimizer/beta1_power" + suffix].full_name)
+ "beta_1_power",
+ named_variables["optimizer/beta_1_power" + suffix].full_name)
self.assertEqual(
- "beta2_power",
- named_variables["optimizer/beta2_power" + suffix].full_name)
+ "beta_2_power",
+ named_variables["optimizer/beta_2_power" + suffix].full_name)
# Spot check the generated protocol buffers.
self.assertEqual("optimizer",
serialized_graph.nodes[0].children[1].local_name)
optimizer_node = serialized_graph.nodes[serialized_graph.nodes[0].children[
1].node_id]
- self.assertEqual("beta1_power",
- optimizer_node.children[0].local_name)
- self.assertEqual("beta1_power",
- serialized_graph.nodes[optimizer_node.children[0].node_id]
- .attributes[0].full_name)
+ self.assertEqual("beta_1_power", optimizer_node.children[0].local_name)
+ self.assertEqual(
+ "beta_1_power", serialized_graph.nodes[
+ optimizer_node.children[0].node_id].attributes[0].full_name)
self.assertEqual(
"my_model/dense/kernel",
serialized_graph.nodes[optimizer_node.slot_variables[0]
@@ -241,9 +240,10 @@ class CheckpointingTests(test.TestCase):
on_create_model = MyModel()
on_create_optimizer = adam.AdamOptimizer(
0.001,
- # Preserve beta1_power and beta2_power when appying gradients so we can
- # test that they've been restored correctly.
- beta1=1.0, beta2=1.0)
+ # Preserve beta_1_power and beta_2_power when appying gradients
+ # so we can test that they've been restored correctly.
+ beta1=1.0,
+ beta2=1.0)
on_create_root = util.Checkpoint(
optimizer=on_create_optimizer, model=on_create_model)
# Deferred restoration
@@ -263,9 +263,9 @@ class CheckpointingTests(test.TestCase):
dummy_var = resource_variable_ops.ResourceVariable([1.])
on_create_optimizer.minimize(loss=dummy_var.read_value)
status.assert_consumed()
- beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators()
- self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power))
- self.assertAllEqual(optimizer_variables[1], self.evaluate(beta2_power))
+ beta_1_power, beta_2_power = on_create_optimizer._get_beta_accumulators()
+ self.assertAllEqual(optimizer_variables[0], self.evaluate(beta_1_power))
+ self.assertAllEqual(optimizer_variables[1], self.evaluate(beta_2_power))
# TODO(allenl): Debug garbage created by this test in python3.
def testDeferredRestorationUsageEager(self):
@@ -477,7 +477,7 @@ class CheckpointingTests(test.TestCase):
no_slot_status.run_restore_ops()
self.assertEqual(12., self.evaluate(new_root.var))
new_root.optimizer = adam.AdamOptimizer(0.1)
- with self.assertRaisesRegexp(AssertionError, "beta1_power"):
+ with self.assertRaisesRegexp(AssertionError, "beta_1_power"):
slot_status.assert_consumed()
self.assertEqual(12., self.evaluate(new_root.var))
if context.executing_eagerly():
@@ -556,8 +556,8 @@ class CheckpointingTests(test.TestCase):
self.evaluate(first_variable.assign([1.]))
self.evaluate(optimizer.get_slot(
var=first_variable, name="m").assign([2.]))
- beta1_power, _ = optimizer._get_beta_accumulators()
- self.evaluate(beta1_power.assign(3.))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(3.))
# Save and load in a second graph
second_graph = ops.Graph()
@@ -571,29 +571,29 @@ class CheckpointingTests(test.TestCase):
self.evaluate(second_variable.assign([4.]))
self.evaluate(optimizer.get_slot(
var=second_variable, name="m").assign([5.]))
- beta1_power, _ = optimizer._get_beta_accumulators()
- self.evaluate(beta1_power.assign(6.))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(6.))
save_path = second_root_checkpointable.save(checkpoint_prefix)
self.evaluate(second_variable.assign([7.]))
self.evaluate(optimizer.get_slot(
var=second_variable, name="m").assign([8.]))
- beta1_power, _ = optimizer._get_beta_accumulators()
- self.assertAllEqual(6., self.evaluate(beta1_power))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.assertAllEqual(6., self.evaluate(beta_1_power))
status = second_root_checkpointable.restore(save_path)
status.assert_consumed().run_restore_ops()
self.assertAllEqual([4.], self.evaluate(second_variable))
self.assertAllEqual([5.], self.evaluate(optimizer.get_slot(
var=second_variable, name="m")))
- beta1_power, _ = optimizer._get_beta_accumulators()
- self.assertAllEqual(6., self.evaluate(beta1_power))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.assertAllEqual(6., self.evaluate(beta_1_power))
# Check that the first graph is unmolested
with first_graph.as_default(), first_session.as_default():
self.assertAllEqual([1.], self.evaluate(first_variable))
self.assertAllEqual([2.], self.evaluate(optimizer.get_slot(
var=first_variable, name="m")))
- beta1_power, _ = optimizer._get_beta_accumulators()
- self.assertAllEqual(3., self.evaluate(beta1_power))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.assertAllEqual(3., self.evaluate(beta_1_power))
class TemplateTests(test.TestCase):
@@ -659,8 +659,8 @@ class CheckpointCompatibilityTests(test.TestCase):
self.evaluate(model._named_dense.bias.assign([1.]))
self.evaluate(optimizer.get_slot(
var=model._named_dense.bias, name="m").assign([2.]))
- beta1_power, _ = optimizer._get_beta_accumulators()
- self.evaluate(beta1_power.assign(3.))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(3.))
return root_checkpointable
def _set_sentinels(self, root_checkpointable):
@@ -669,8 +669,8 @@ class CheckpointCompatibilityTests(test.TestCase):
root_checkpointable.optimizer.get_slot(
var=root_checkpointable.model._named_dense.bias, name="m")
.assign([102.]))
- beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
- self.evaluate(beta1_power.assign(103.))
+ beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(103.))
def _check_sentinels(self, root_checkpointable):
self.assertAllEqual(
@@ -678,8 +678,8 @@ class CheckpointCompatibilityTests(test.TestCase):
self.assertAllEqual([2.], self.evaluate(
root_checkpointable.optimizer.get_slot(
var=root_checkpointable.model._named_dense.bias, name="m")))
- beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
- self.assertAllEqual(3., self.evaluate(beta1_power))
+ beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
+ self.assertAllEqual(3., self.evaluate(beta_1_power))
def _write_name_based_checkpoint(self):
checkpoint_directory = self.get_temp_dir()
diff --git a/tensorflow/contrib/optimizer_v2/gradient_descent.py b/tensorflow/contrib/optimizer_v2/gradient_descent.py
index 945c8de559..8bdf408217 100644
--- a/tensorflow/contrib/optimizer_v2/gradient_descent.py
+++ b/tensorflow/contrib/optimizer_v2/gradient_descent.py
@@ -18,15 +18,17 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.optimizer_v2 import optimizer_v2
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.training import training_ops
+from tensorflow.python.keras.optimizer_v2 import sgd
+from tensorflow.python.util import deprecation
-class GradientDescentOptimizer(optimizer_v2.OptimizerV2):
+class GradientDescentOptimizer(sgd.SGD):
"""Optimizer that implements the gradient descent algorithm."""
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self, learning_rate, use_locking=False, name="GradientDescent"):
"""Construct a new gradient descent optimizer.
@@ -41,29 +43,5 @@ class GradientDescentOptimizer(optimizer_v2.OptimizerV2):
name: Optional name prefix for the operations created when applying
gradients. Defaults to "GradientDescent".
"""
- super(GradientDescentOptimizer, self).__init__(use_locking, name)
- self._set_hyper("learning_rate", learning_rate)
-
- def _apply_dense(self, grad, var, state):
- return training_ops.apply_gradient_descent(
- var,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- use_locking=self._use_locking).op
-
- def _resource_apply_dense(self, grad, handle, state):
- lr = state.get_hyper("learning_rate", grad.dtype.base_dtype)
- return training_ops.resource_apply_gradient_descent(
- handle.handle, lr, grad, use_locking=self._use_locking)
-
- def _resource_apply_sparse_duplicate_indices(
- self, grad, handle, indices, state):
- lr = state.get_hyper("learning_rate", grad.dtype.base_dtype)
- return resource_variable_ops.resource_scatter_add(
- handle.handle, indices, -grad * lr)
-
- def _apply_sparse_duplicate_indices(self, grad, var, state):
- delta = ops.IndexedSlices(
- grad.values * state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad.indices, grad.dense_shape)
- return var.scatter_sub(delta, use_locking=self._use_locking)
+ super(GradientDescentOptimizer, self).__init__(
+ learning_rate=learning_rate, name=name)
diff --git a/tensorflow/contrib/optimizer_v2/momentum.py b/tensorflow/contrib/optimizer_v2/momentum.py
index 0a5aadc2d1..0636f7e356 100644
--- a/tensorflow/contrib/optimizer_v2/momentum.py
+++ b/tensorflow/contrib/optimizer_v2/momentum.py
@@ -18,11 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.optimizer_v2 import optimizer_v2
-from tensorflow.python.training import training_ops
+from tensorflow.python.keras.optimizer_v2 import sgd
+from tensorflow.python.util import deprecation
-class MomentumOptimizer(optimizer_v2.OptimizerV2):
+class MomentumOptimizer(sgd.SGD):
"""Optimizer that implements the Momentum algorithm.
Computes (if `use_nesterov = False`):
@@ -39,6 +39,10 @@ class MomentumOptimizer(optimizer_v2.OptimizerV2):
when that part of the variable was used in the forward pass.
"""
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self, learning_rate, momentum,
use_locking=False, name="Momentum", use_nesterov=False):
"""Construct a new Momentum optimizer.
@@ -68,57 +72,8 @@ class MomentumOptimizer(optimizer_v2.OptimizerV2):
optimizer functions.
@end_compatibility
"""
- super(MomentumOptimizer, self).__init__(use_locking, name)
- self._set_hyper("learning_rate", learning_rate)
- self._set_hyper("momentum", momentum)
- self._use_nesterov = use_nesterov
-
- def _create_vars(self, var_list, state):
- for v in var_list:
- state.zeros_slot(v, "momentum")
-
- def _apply_dense(self, grad, var, state):
- mom = state.get_slot(var, "momentum")
- return training_ops.apply_momentum(
- var,
- mom,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- state.get_hyper("momentum", var.dtype.base_dtype),
- use_locking=self._use_locking,
- use_nesterov=self._use_nesterov).op
-
- def _resource_apply_dense(self, grad, var, state):
- mom = state.get_slot(var, "momentum")
- return training_ops.resource_apply_momentum(
- var.handle,
- mom.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- state.get_hyper("momentum", var.dtype.base_dtype),
- use_locking=self._use_locking,
- use_nesterov=self._use_nesterov)
-
- def _apply_sparse(self, grad, var, state):
- mom = state.get_slot(var, "momentum")
- return training_ops.sparse_apply_momentum(
- var,
- mom,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad.values,
- grad.indices,
- state.get_hyper("momentum", var.dtype.base_dtype),
- use_locking=self._use_locking,
- use_nesterov=self._use_nesterov).op
-
- def _resource_apply_sparse(self, grad, var, indices, state):
- mom = state.get_slot(var, "momentum")
- return training_ops.resource_sparse_apply_momentum(
- var.handle,
- mom.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- indices,
- state.get_hyper("momentum", var.dtype.base_dtype),
- use_locking=self._use_locking,
- use_nesterov=self._use_nesterov)
+ super(MomentumOptimizer, self).__init__(
+ learning_rate=learning_rate,
+ momentum=momentum,
+ name=name,
+ nesterov=use_nesterov)
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index 6af59dcfbf..9c98dd93b4 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -20,463 +20,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import abc
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.util import deprecation
-from tensorflow.python.eager import backprop
-from tensorflow.python.eager import context
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import gradients
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables
-from tensorflow.python.training import distribute as distribute_lib
-from tensorflow.python.training import distribution_strategy_context
-from tensorflow.python.training import optimizer as optimizer_v1
-from tensorflow.python.training import slot_creator
-from tensorflow.python.training.checkpointable import base as checkpointable
-from tensorflow.python.util import nest
-
-class _OptimizableVariable(object):
- """Interface for abstracting over variables in the optimizers."""
-
- @abc.abstractmethod
- def target(self):
- """Returns the optimization target for this variable."""
- raise NotImplementedError("Calling an abstract method.")
-
- @abc.abstractmethod
- def update_op(self, optimizer, g, *args):
- """Returns the update ops for updating the variable."""
- raise NotImplementedError("Calling an abstract method.")
-
-
-class _RefVariableProcessor(_OptimizableVariable):
- """Processor for Variable."""
-
- def __init__(self, v):
- self._v = v
-
- def target(self):
- return self._v._ref() # pylint: disable=protected-access
-
- def update_op(self, optimizer, g, *args):
- if isinstance(g, ops.Tensor):
- update_op = optimizer._apply_dense(g, self._v, *args) # pylint: disable=protected-access
- if self._v.constraint is not None:
- with ops.control_dependencies([update_op]):
- return self._v.assign(self._v.constraint(self._v))
- else:
- return update_op
- else:
- assert isinstance(g, ops.IndexedSlices), ("Gradient ", g, " is neither a "
- "tensor nor IndexedSlices.")
- if self._v.constraint is not None:
- raise RuntimeError(
- "Cannot use a constraint function on a sparse variable.")
- # pylint: disable=protected-access
- return optimizer._apply_sparse_duplicate_indices(g, self._v, *args)
-
-
-class _DenseReadResourceVariableProcessor(_OptimizableVariable):
- """Processor for dense ResourceVariables."""
-
- def __init__(self, v):
- self._v = v
-
- def target(self):
- return self._v
-
- def update_op(self, optimizer, g, *args):
- # pylint: disable=protected-access
- update_op = optimizer._resource_apply_dense(g, self._v.op.inputs[0], *args)
- if self._v.constraint is not None:
- with ops.control_dependencies([update_op]):
- return self._v.assign(self._v.constraint(self._v))
- else:
- return update_op
-
-
-class _DenseResourceVariableProcessor(_OptimizableVariable):
- """Processor for dense ResourceVariables."""
-
- def __init__(self, v):
- self._v = v
-
- def target(self):
- return self._v
-
- def update_op(self, optimizer, g, *args):
- # pylint: disable=protected-access
- if isinstance(g, ops.IndexedSlices):
- if self._v.constraint is not None:
- raise RuntimeError(
- "Cannot use a constraint function on a sparse variable.")
- return optimizer._resource_apply_sparse_duplicate_indices(
- g.values, self._v, g.indices, *args)
- update_op = optimizer._resource_apply_dense(g, self._v, *args)
- if self._v.constraint is not None:
- with ops.control_dependencies([update_op]):
- return self._v.assign(self._v.constraint(self._v))
- else:
- return update_op
-
-
-class _TensorProcessor(_OptimizableVariable):
- """Processor for ordinary Tensors.
-
- Even though a Tensor can't really be updated, sometimes it is useful to
- compute the gradients with respect to a Tensor using the optimizer. Updating
- the Tensor is, of course, unsupported.
- """
-
- def __init__(self, v):
- self._v = v
-
- def target(self):
- return self._v
-
- def update_op(self, optimizer, g, *args):
- raise NotImplementedError("Trying to update a Tensor ", self._v)
-
-
-def _get_processor(v):
- """The processor of v."""
- if context.executing_eagerly():
- if isinstance(v, ops.Tensor):
- return _TensorProcessor(v)
- else:
- return _DenseResourceVariableProcessor(v)
- if v.op.type == "VarHandleOp":
- return _DenseResourceVariableProcessor(v)
- if isinstance(v, variables.Variable):
- return _RefVariableProcessor(v)
- if isinstance(v, ops.Tensor):
- return _TensorProcessor(v)
- raise NotImplementedError("Trying to optimize unsupported type ", v)
-
-
-def _var_key_v2(var):
- """Key for representing a primary variable, for looking up slots."""
- # pylint: disable=protected-access
- if hasattr(var, "_distributed_container"):
- distributed_container = var._distributed_container()
- assert distributed_container is not None
- if context.executing_eagerly():
- return distributed_container._unique_id
- return distributed_container._shared_name
- if context.executing_eagerly():
- return var._unique_id
- return var.op.name
-
-
-def _resolve(value, name):
- if callable(value):
- value = value()
- return ops.convert_to_tensor(value, name=name)
-
-
-def _is_dynamic(value):
- """Returns true if __init__ arg `value` should be re-evaluated each step."""
- if callable(value): return True
- # Don't need to do anything special in graph mode, since dynamic values
- # will propagate correctly automatically.
- # TODO(josh11b): Add per-device caching across steps using variables for
- # truly static values once we add distributed support.
- if context.executing_eagerly() and isinstance(
- value, resource_variable_ops.ResourceVariable):
- return True
- return False
-
-
-class _OptimizerV2State(object):
- """Holds per-graph and per-step optimizer state.
-
- Use _init_with_static_hyper() to create the state for a graph, and then
- _copy_with_dynamic_hyper() to convert that to state for a particular step.
- The difference between the two is that the former only has hyper
- parameter values that are static and the latter also has values that
- can change every step (according to _is_dynamic()).
- """
-
- def __init__(self, op_name):
- self._op_name = op_name
-
- def _init_with_static_hyper(self, hyper):
- """Initialize a fresh state object from hyper dict."""
- # self._hyper contains a dict from name to a dict with the Tensor values.
- # This dict starts with a single item with key "None" with the hyper
- # parameter value converted to a Tensor. Other items have dtype keys
- # with that Tensor cast to that dtype.
- with ops.init_scope():
- self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)}
- for name, (dynamic, value) in sorted(hyper.items())
- if not dynamic}
- self._slots = {}
- self._non_slot_dict = {}
- # Extra state to help Optimizers implement Checkpointable. Holds information
- # about variables which will be restored as soon as they're created.
- self._deferred_dependencies = {} # Non-slot variables
- self._deferred_slot_restorations = {} # Slot variables
-
- def _copy_with_dynamic_hyper(self, hyper, distribution, non_slot_devices):
- """Create a new state object for a particular step."""
- ret = _OptimizerV2State(self._op_name)
- # pylint: disable=protected-access
- ret._slots = self._slots
- ret._non_slot_dict = self._non_slot_dict
- ret._deferred_dependencies = self._deferred_dependencies
- ret._deferred_slot_restorations = self._deferred_slot_restorations
- ret._hyper = {name: {None: _resolve(value, name)}
- for name, (dynamic, value) in sorted(hyper.items())
- if dynamic}
- ret._hyper.update(self._hyper)
- ret._non_slot_devices = non_slot_devices
- ret._distribution = distribution
- return ret
-
- def _variables(self):
- """Returns a list of all variables held by self."""
- optimizer_variables = list(self._non_slot_dict.values())
- for variable_dict in self._slots.values():
- for slot_for_variable in variable_dict.values():
- optimizer_variables.append(slot_for_variable)
- # Sort variables by name so that the return is deterministic.
- return sorted(optimizer_variables, key=lambda v: v.name)
-
- def _slot_dict(self, slot_name):
- """Returns a dict for caching slots created under the given name.
-
- Args:
- slot_name: Name for the slot.
-
- Returns:
- A dict that maps primary `Variable` objects to the slot created
- for that variable, under the given slot name.
- """
- named_slots = self._slots.get(slot_name, None)
- if named_slots is None:
- named_slots = {}
- self._slots[slot_name] = named_slots
- return named_slots
-
- def create_slot(self, var, val, slot_name, optional_op_name=None):
- """Find or create a slot for a variable.
-
- Args:
- var: A `Variable` object.
- val: A `Tensor`. The initial value of the slot.
- slot_name: Name for the slot.
- optional_op_name: Name to use when scoping the Variable that
- needs to be created for the slot.
-
- Returns:
- A `Variable` object.
- """
- named_slots = self._slot_dict(slot_name)
- var_key = _var_key_v2(var)
- if var_key not in named_slots:
- new_slot_variable = slot_creator.create_slot(
- var, val, optional_op_name or self._op_name)
- self._restore_slot_variable(
- slot_name=slot_name, variable=var,
- slot_variable=new_slot_variable)
- named_slots[var_key] = new_slot_variable
- return named_slots[var_key]
-
- def create_slot_with_initializer(self, var, initializer, shape, dtype,
- slot_name, optional_op_name=None):
- """Find or create a slot for a variable, using an Initializer.
-
- Args:
- var: A `Variable` object.
- initializer: An `Initializer`. The initial value of the slot.
- shape: Shape of the initial value of the slot.
- dtype: Type of the value of the slot.
- slot_name: Name for the slot.
- optional_op_name: Name to use when scoping the Variable that
- needs to be created for the slot.
-
- Returns:
- A `Variable` object.
- """
- named_slots = self._slot_dict(slot_name)
- var_key = _var_key_v2(var)
- if var_key not in named_slots:
- new_slot_variable = slot_creator.create_slot_with_initializer(
- var, initializer, shape, dtype, optional_op_name or self._op_name)
- self._restore_slot_variable(
- slot_name=slot_name, variable=var,
- slot_variable=new_slot_variable)
- named_slots[var_key] = new_slot_variable
- return named_slots[var_key]
-
- def zeros_slot(self, var, slot_name, optional_op_name=None):
- """Find or create a slot initialized with 0.0.
-
- Args:
- var: A `Variable` object.
- slot_name: Name for the slot.
- optional_op_name: Name to use when scoping the Variable that
- needs to be created for the slot.
-
- Returns:
- A `Variable` object.
- """
- named_slots = self._slot_dict(slot_name)
- var_key = _var_key_v2(var)
- if var_key not in named_slots:
- new_slot_variable = slot_creator.create_zeros_slot(
- var, optional_op_name or self._op_name)
- self._restore_slot_variable(
- slot_name=slot_name, variable=var,
- slot_variable=new_slot_variable)
- named_slots[var_key] = new_slot_variable
- return named_slots[var_key]
-
- def _create_or_restore_slot_variable(
- self, slot_variable_position, slot_name, variable,
- optional_op_name=None):
- """Restore a slot variable's value, possibly creating it.
-
- Called when a variable which has an associated slot variable is created or
- restored. When executing eagerly, we create the slot variable with a
- restoring initializer.
-
- No new variables are created when graph building. Instead,
- _restore_slot_variable catches these after normal creation and adds restore
- ops to the graph. This method is nonetheless important when graph building
- for the case when a slot variable has already been created but `variable`
- has just been added to a dependency graph (causing us to realize that the
- slot variable needs to be restored).
-
- Args:
- slot_variable_position: A `checkpointable._CheckpointPosition` object
- indicating the slot variable `Checkpointable` object to be restored.
- slot_name: The name of this `Optimizer`'s slot to restore into.
- variable: The variable object this slot is being created for.
- optional_op_name: Name to use when scoping the Variable that
- needs to be created for the slot.
- """
- slot_variable = self.get_slot(var=variable, name=slot_name)
- if (slot_variable is None and context.executing_eagerly() and
- slot_variable_position.is_simple_variable()
- # Defer slot variable creation if there is an active variable creator
- # scope. Generally we'd like to eagerly create/restore slot variables
- # when possible, but this may mean that scopes intended to catch
- # `variable` also catch its eagerly created slot variable
- # unintentionally (specifically make_template would add a dependency on
- # a slot variable if not for this case). Deferring is mostly harmless
- # (aside from double initialization), and makes variable creator scopes
- # behave the same way they do when graph building.
- and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access
- initializer = checkpointable.CheckpointInitialValue(
- checkpoint_position=slot_variable_position)
- slot_variable = self.create_slot(
- var=variable,
- val=initializer,
- slot_name=slot_name,
- optional_op_name=optional_op_name)
- # Optimizers do not have unconditional dependencies on their slot
- # variables (nor do any other objects). They are only saved if the
- # variables they were created for are also saved.
- if slot_variable is not None:
- # If we've either made this slot variable, or if we've pulled out an
- # existing slot variable, we should restore it.
- slot_variable_position.restore(slot_variable)
- else:
- # We didn't make the slot variable. Defer restoring until it gets created
- # normally. We keep a list rather than the one with the highest restore
- # UID in case slot variables have their own dependencies, in which case
- # those could differ between restores.
- variable_key = _var_key_v2(variable)
- self._deferred_slot_restorations.setdefault(
- slot_name, {}).setdefault(variable_key, []).append(
- slot_variable_position)
-
- def get_slot(self, var, name):
- """Return a slot named `name` created for `var` by the Optimizer.
-
- Some `Optimizer` subclasses use additional variables. For example
- `Momentum` and `Adagrad` use variables to accumulate updates. This method
- gives access to these `Variable` objects if for some reason you need them.
-
- Use `get_slot_names()` to get the list of slot names created by the
- `Optimizer`.
-
- Args:
- var: A variable passed to `minimize()` or `apply_gradients()`.
- name: A string.
-
- Returns:
- The `Variable` for the slot if it was created, `None` otherwise.
- """
- named_slots = self._slots.get(name, None)
- if not named_slots:
- return None
- return named_slots.get(_var_key_v2(var), None)
-
- def get_slot_names(self):
- """Return a list of the names of slots created by the `Optimizer`.
-
- See `get_slot()`.
-
- Returns:
- A list of strings.
- """
- return sorted(self._slots.keys())
-
- def create_non_slot(self, initial_value, name, colocate_with=None):
- """Add an extra variable, not associated with a slot."""
- v = self._non_slot_dict.get(name, None)
- if v is None:
- if colocate_with is None: colocate_with = self._non_slot_devices
- with self._distribution.colocate_vars_with(colocate_with):
- # TODO(josh11b): Use get_variable() except for the legacy Adam use case.
- v = variable_scope.variable(initial_value, name=name, trainable=False)
- self._non_slot_dict[name] = v
- deferred_dependencies_list = self._deferred_dependencies.pop(name, ())
- for checkpoint_position in sorted(
- deferred_dependencies_list,
- key=lambda restore: restore.checkpoint.restore_uid,
- reverse=True):
- checkpoint_position.restore(v)
- return v
-
- def _restore_slot_variable(self, slot_name, variable, slot_variable):
- """Restore a newly created slot variable's value."""
- variable_key = _var_key_v2(variable)
- deferred_restorations = self._deferred_slot_restorations.get(
- slot_name, {}).pop(variable_key, [])
- # Iterate over restores, highest restore UID first to minimize the number
- # of assignments.
- deferred_restorations.sort(key=lambda position: position.restore_uid,
- reverse=True)
- for checkpoint_position in deferred_restorations:
- checkpoint_position.restore(slot_variable)
-
- def get_non_slot(self, name):
- """Returns the non-slot variable identified by `name`."""
- return self._non_slot_dict.get(name, None)
-
- def get_hyper(self, name, dtype=None):
- """Returns the `name` hyper parameter, optionally cast to `dtype`."""
- dtype_dict = self._hyper[name]
- # Do we have the value cast to dtype already cached? This should always
- # succeed when dtype is None.
- if dtype in dtype_dict:
- return dtype_dict[dtype]
- # Not cached, cast to dtype and save the result in the cache.
- result = math_ops.cast(dtype_dict[None], dtype)
- dtype_dict[dtype] = result
- return result
-
-
-class OptimizerV2(optimizer_v1.Optimizer):
+class OptimizerV2(optimizer_v2.OptimizerV2):
"""Updated base class for optimizers.
This class defines the API to add Ops to train a model. You never use this
@@ -587,6 +135,10 @@ class OptimizerV2(optimizer_v1.Optimizer):
GATE_OP = 1
GATE_GRAPH = 2
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self, use_locking, name):
"""Create a new Optimizer.
@@ -607,749 +159,4 @@ class OptimizerV2(optimizer_v1.Optimizer):
RuntimeError: If _create_slots has been overridden instead of
_create_vars.
"""
- # Note: We intentionally don't call parent __init__.
-
- # Optimizer._create_slots was replaced by _create_vars in OptimizerV2.
- if (self.__class__._create_slots.__code__ is not # pylint: disable=protected-access
- OptimizerV2._create_slots.__code__):
- raise RuntimeError("Override _create_vars instead of _create_slots when "
- "descending from OptimizerV2 (class %s)" %
- self.__class__.__name__)
- if not name:
- raise ValueError("Must specify the optimizer name")
-
- self._use_locking = use_locking
- self._name = name
- # Map from graph_key to state for that graph. We use the graph_key
- # since it works in both eager and graph mode, and gives the outer
- # graph inside functions.
- tower_context = distribution_strategy_context.get_tower_context()
- if tower_context is None:
- # In a cross-tower context for a DistributionStrategy, which means
- # only one Optimizer will be created, not one per tower.
- self._per_graph_state = {}
- else:
- # We use get_tower_context().merge_call() to get a single dict
- # shared across all model replicas when running with a
- # DistributionStrategy.
- self._per_graph_state = tower_context.merge_call(lambda _: {})
-
- # Hyper parameters, and whether they should be re-evaluated every step.
- self._hyper = {}
-
- def _set_hyper(self, name, value):
- self._hyper[name] = (_is_dynamic(value), value)
-
- def minimize(self, loss, global_step=None, var_list=None,
- gate_gradients=GATE_OP, aggregation_method=None,
- colocate_gradients_with_ops=False, name=None,
- grad_loss=None, stop_gradients=None,
- scale_loss_by_num_towers=None):
- """Add operations to minimize `loss` by updating `var_list`.
-
- This method simply combines calls `compute_gradients()` and
- `apply_gradients()`. If you want to process the gradient before applying
- them call `compute_gradients()` and `apply_gradients()` explicitly instead
- of using this function.
-
- Args:
- loss: A `Tensor` containing the value to minimize.
- global_step: Optional `Variable` to increment by one after the
- variables have been updated.
- var_list: Optional list or tuple of `Variable` objects to update to
- minimize `loss`. Defaults to the list of variables collected in
- the graph under the key `GraphKeys.TRAINABLE_VARIABLES`.
- gate_gradients: How to gate the computation of gradients. Can be
- `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
- aggregation_method: Specifies the method used to combine gradient terms.
- Valid values are defined in the class `AggregationMethod`.
- colocate_gradients_with_ops: If True, try colocating gradients with
- the corresponding op.
- name: Optional name for the returned operation.
- grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
- stop_gradients: Optional. A Tensor or list of tensors not to differentiate
- through.
- scale_loss_by_num_towers: Optional boolean. If true, scale the loss
- down by the number of towers. By default, auto-detects whether this
- is needed.
-
- Returns:
- An Operation that updates the variables in `var_list`. If `global_step`
- was not `None`, that operation also increments `global_step`.
-
- Raises:
- ValueError: If some of the variables are not `Variable` objects.
-
- @compatibility(eager)
- When eager execution is enabled, `loss` should be a Python function that
- takes elements of `var_list` as arguments and computes the value to be
- minimized. If `var_list` is None, `loss` should take no arguments.
- Minimization (and gradient computation) is done with respect to the
- elements of `var_list` if not None, else with respect to any trainable
- variables created during the execution of the `loss` function.
- `gate_gradients`, `aggregation_method`, `colocate_gradients_with_ops` and
- `grad_loss` are ignored when eager execution is enabled.
- @end_compatibility
- """
- grads_and_vars = self.compute_gradients(
- loss, var_list=var_list, gate_gradients=gate_gradients,
- aggregation_method=aggregation_method,
- colocate_gradients_with_ops=colocate_gradients_with_ops,
- grad_loss=grad_loss, stop_gradients=stop_gradients,
- scale_loss_by_num_towers=scale_loss_by_num_towers)
-
- vars_with_grad = [v for g, v in grads_and_vars if g is not None]
- if not vars_with_grad:
- raise ValueError(
- "No gradients provided for any variable, check your graph for ops"
- " that do not support gradients, between variables %s and loss %s." %
- ([str(v) for _, v in grads_and_vars], loss))
-
- return self.apply_gradients(grads_and_vars, global_step=global_step,
- name=name)
-
- def compute_gradients(self, loss, var_list=None,
- gate_gradients=GATE_OP,
- aggregation_method=None,
- colocate_gradients_with_ops=False,
- grad_loss=None, stop_gradients=None,
- scale_loss_by_num_towers=None):
- """Compute gradients of `loss` for the variables in `var_list`.
-
- This is the first part of `minimize()`. It returns a list
- of (gradient, variable) pairs where "gradient" is the gradient
- for "variable". Note that "gradient" can be a `Tensor`, an
- `IndexedSlices`, or `None` if there is no gradient for the
- given variable.
-
- Args:
- loss: A Tensor containing the value to minimize or a callable taking
- no arguments which returns the value to minimize. When eager execution
- is enabled it must be a callable.
- var_list: Optional list or tuple of `tf.Variable` to update to minimize
- `loss`. Defaults to the list of variables collected in the graph
- under the key `GraphKeys.TRAINABLE_VARIABLES`.
- gate_gradients: How to gate the computation of gradients. Can be
- `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
- aggregation_method: Specifies the method used to combine gradient terms.
- Valid values are defined in the class `AggregationMethod`.
- colocate_gradients_with_ops: If True, try colocating gradients with
- the corresponding op.
- grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
- stop_gradients: Optional. A Tensor or list of tensors not to differentiate
- through.
- scale_loss_by_num_towers: Optional boolean. If true, scale the loss
- down by the number of towers. By default, auto-detects whether this
- is needed.
-
- Returns:
- A list of (gradient, variable) pairs. Variable is always present, but
- gradient can be `None`.
-
- Raises:
- TypeError: If `var_list` contains anything else than `Variable` objects.
- ValueError: If some arguments are invalid.
- RuntimeError: If called with eager execution enabled and `loss` is
- not callable.
-
- @compatibility(eager)
- When eager execution is enabled, `gate_gradients`, `aggregation_method`,
- and `colocate_gradients_with_ops` are ignored.
- @end_compatibility
- """
- # TODO(josh11b): Test that we handle weight decay in a reasonable way.
- if callable(loss):
- with backprop.GradientTape() as tape:
- if var_list is not None:
- tape.watch(var_list)
- loss_value = loss()
-
- # Scale loss for number of towers (callable-loss case). In this case,
- # we have to be careful to call distribute_lib.get_loss_reduction()
- # *after* loss() is evaluated, so we know what loss reduction it uses.
- if scale_loss_by_num_towers is None:
- scale_loss_by_num_towers = (
- distribute_lib.get_loss_reduction() ==
- variable_scope.VariableAggregation.MEAN)
- if scale_loss_by_num_towers:
- num_towers = distribution_strategy_context.get_distribution_strategy(
- ).num_towers
- if num_towers > 1:
- loss_value *= 1. / num_towers
-
- if var_list is None:
- var_list = tape.watched_variables()
- grads = tape.gradient(loss_value, var_list, grad_loss)
- return list(zip(grads, var_list))
- if context.executing_eagerly():
- raise RuntimeError(
- "`loss` passed to Optimizer.compute_gradients should "
- "be a function when eager execution is enabled.")
-
- # Scale loss for number of towers (non-callable-loss case).
- if scale_loss_by_num_towers is None:
- scale_loss_by_num_towers = (
- distribute_lib.get_loss_reduction() ==
- variable_scope.VariableAggregation.MEAN)
- if scale_loss_by_num_towers:
- num_towers = distribution_strategy_context.get_distribution_strategy(
- ).num_towers
- if num_towers > 1:
- loss *= 1. / num_towers
-
- if gate_gradients not in [optimizer_v1.Optimizer.GATE_NONE,
- optimizer_v1.Optimizer.GATE_OP,
- optimizer_v1.Optimizer.GATE_GRAPH]:
- raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
- "Optimizer.GATE_OP, Optimizer.GATE_GRAPH. Not %s" %
- gate_gradients)
- self._assert_valid_dtypes([loss])
- if grad_loss is not None:
- self._assert_valid_dtypes([grad_loss])
- if var_list is None:
- var_list = (
- variables.trainable_variables() +
- ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
- else:
- var_list = nest.flatten(var_list)
- # pylint: disable=protected-access
- var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS)
- # pylint: enable=protected-access
- processors = [_get_processor(v) for v in var_list]
- if not var_list:
- raise ValueError("No variables to optimize.")
- var_refs = [p.target() for p in processors]
- grads = gradients.gradients(
- loss, var_refs, grad_ys=grad_loss,
- gate_gradients=(gate_gradients == optimizer_v1.Optimizer.GATE_OP),
- aggregation_method=aggregation_method,
- colocate_gradients_with_ops=colocate_gradients_with_ops,
- stop_gradients=stop_gradients)
- if gate_gradients == optimizer_v1.Optimizer.GATE_GRAPH:
- grads = control_flow_ops.tuple(grads)
- grads_and_vars = list(zip(grads, var_list))
- self._assert_valid_dtypes(
- [v for g, v in grads_and_vars
- if g is not None and v.dtype != dtypes.resource])
- return grads_and_vars
-
- def apply_gradients(self, grads_and_vars, global_step=None, name=None):
- """Apply gradients to variables.
-
- This is the second part of `minimize()`. It returns an `Operation` that
- applies gradients.
-
- Args:
- grads_and_vars: List of (gradient, variable) pairs as returned by
- `compute_gradients()`.
- global_step: Optional `Variable` to increment by one after the
- variables have been updated.
- name: Optional name for the returned operation. Default to the
- name passed to the `Optimizer` constructor.
-
- Returns:
- An `Operation` that applies the specified gradients. If `global_step`
- was not None, that operation also increments `global_step`.
-
- Raises:
- TypeError: If `grads_and_vars` is malformed.
- ValueError: If none of the variables have gradients.
- """
- # This is a default implementation of apply_gradients() that can be shared
- # by most optimizers. It relies on the subclass implementing the following
- # methods: _create_vars(), _prepare(), _apply_dense(), and _apply_sparse().
-
- # Filter out variables with gradients of `None`.
- grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works.
- if not grads_and_vars:
- raise ValueError("No variables provided.")
- filtered = tuple((g, v) for (g, v) in grads_and_vars if g is not None)
- if not filtered:
- raise ValueError("No gradients provided for any variable: %s." %
- ([str(v) for _, v in grads_and_vars],))
- return distribution_strategy_context.get_tower_context().merge_call(
- self._distributed_apply, filtered, global_step=global_step, name=name)
-
- def _get_or_create_state(self, var_list=None):
- """Either looks up or creates `_OptimizerV2State`.
-
- If any variables are available, they should be passed via the `var_list`
- argument, and these will be used to determine the graph to create/retrieve
- state for. Otherwise the returned state is for the current default graph.
-
- Args:
- var_list: A list of variables to extract a graph from.
-
- Returns:
- An `_OptimizerV2State` object.
- """
- # Determine the graph_key from the current graph.
- eager_execution = context.executing_eagerly()
- if eager_execution or var_list is None:
- graph = ops.get_default_graph()
- else:
- graph = ops._get_graph_from_inputs(var_list) # pylint: disable=protected-access
- assert graph is not None
- graph_key = graph._graph_key # pylint: disable=protected-access
-
- # Get the per graph state by looking up the graph_key.
- if graph_key in self._per_graph_state:
- per_graph_state = self._per_graph_state[graph_key]
- else:
- per_graph_state = _OptimizerV2State(self._name)
- per_graph_state._init_with_static_hyper(self._hyper) # pylint: disable=protected-access
- self._per_graph_state[graph_key] = per_graph_state
- return per_graph_state
-
- def _distributed_apply(self, distribution, grads_and_vars, global_step, name):
- """`apply_gradients` for use with a `DistributionStrategy`."""
- reduced_grads = distribution.batch_reduce(
- variable_scope.VariableAggregation.SUM, grads_and_vars)
- var_list = [v for _, v in grads_and_vars]
- grads_and_vars = zip(reduced_grads, var_list)
-
- unwrapped_var_list = [x for v in var_list for x in distribution.unwrap(v)]
- eager_execution = context.executing_eagerly()
- if eager_execution:
- # Give a clear error in this case instead of "name not supported
- # for Eager Tensors" when we compute non_slot_devices.
- for v in unwrapped_var_list:
- if isinstance(v, ops.Tensor):
- raise NotImplementedError("Trying to update a Tensor ", v)
-
- with ops.name_scope(name, self._name) as name:
- per_graph_state = self._get_or_create_state(var_list=unwrapped_var_list)
- # Include the current value of any dynamic hyper parameters in `state`.
- non_slot_devices = distribution.non_slot_devices(var_list)
- state = per_graph_state._copy_with_dynamic_hyper( # pylint: disable=protected-access
- self._hyper, distribution, non_slot_devices)
-
- # Create any slot and non-slot variables we need in `state`.
- with ops.init_scope():
- self._create_vars(var_list, state)
-
- with ops.name_scope(name): # Re-enter name_scope created above
- # Give the child class a chance to do something before we start
- # applying gradients.
- self._prepare(state)
-
- def update(v, g):
- """Update variable `v` using gradient `g`."""
- assert v is not None
-
- # Convert the grad to Tensor or IndexedSlices if necessary, and
- # look up a processor for each variable's type.
- try:
- g = ops.convert_to_tensor_or_indexed_slices(g)
- except TypeError:
- raise TypeError(
- "Gradient must be convertible to a Tensor"
- " or IndexedSlices, or None: %s" % g)
- if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
- raise TypeError(
- "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
- processor = _get_processor(v)
-
- # We colocate all ops created in _apply_dense or _apply_sparse
- # on the same device as the variable.
- # TODO(apassos): figure out how to get the variable name here.
- scope_name = "" if eager_execution else v.op.name
- # device_policy is set because non-mirrored tensors will be read in
- # `update_op`.
- # TODO(josh11b): Make different state objects for each device to
- # avoid needing to set the device_policy.
- with ops.name_scope("update_" + scope_name), \
- context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
- return processor.update_op(self, g, state)
-
- # Use the processors to update the variables.
- update_ops = []
- for grad, var in grads_and_vars:
- update_ops.extend(distribution.unwrap(distribution.update(
- var, update, grad)))
-
- # Give the child class a chance to do something after applying
- # gradients
- def finish():
- # TODO(josh11b): Make different state objects for each device to
- # avoid needing to set the device_policy.
- with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
- return self._finish(state)
-
- update_ops = control_flow_ops.group(update_ops)
- with ops.control_dependencies([update_ops]):
- finish_updates = distribution.update_non_slot(non_slot_devices, finish)
- if finish_updates is None:
- finish_updates = update_ops
-
- # Update `global_step` (if any).
- if global_step is None:
- apply_updates = distribution.group(finish_updates, name=name)
- else:
- with ops.control_dependencies(distribution.unwrap(finish_updates)):
-
- def update_global_step(global_step):
- if isinstance(global_step, resource_variable_ops.ResourceVariable):
- return global_step.assign_add(
- ops.convert_to_tensor(1, dtype=global_step.dtype),
- read_value=False)
- else:
- return state_ops.assign_add(global_step, 1)
-
- apply_updates = distribution.group(
- distribution.update(global_step, update_global_step), name=name)
-
- # Add the training op to the TRAIN_OP graph collection in graph mode.
- if not eager_execution:
- if isinstance(apply_updates, ops.Tensor):
- apply_updates = apply_updates.op
- train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
- if apply_updates not in train_op:
- train_op.append(apply_updates)
-
- return apply_updates
-
- def get_slot(self, var, name):
- """Return a slot named `name` created for `var` by the Optimizer.
-
- Some `Optimizer` subclasses use additional variables. For example
- `Momentum` and `Adagrad` use variables to accumulate updates. This method
- gives access to these `Variable` objects if for some reason you need them.
-
- Use `get_slot_names()` to get the list of slot names created by the
- `Optimizer`.
-
- Args:
- var: A variable passed to `minimize()` or `apply_gradients()`.
- name: A string.
-
- Returns:
- The `Variable` for the slot if it was created, `None` otherwise.
- """
- state = self._get_state_for_var(var)
- return state.get_slot(var, name) if state is not None else None
-
- def get_slot_names(self):
- """Return a list of the names of slots created by the `Optimizer`.
-
- See `get_slot()`.
-
- Returns:
- A list of strings.
- """
- state = self._get_per_graph_state()
- return state.get_slot_names() if state is not None else []
-
- def variables(self):
- """A list of variables which encode the current state of `Optimizer`.
-
- Includes slot variables and additional global variables created by the
- optimizer in the current default graph.
-
- Returns:
- A list of variables.
- """
- state = self._get_per_graph_state()
- return state._variables() if state is not None else [] # pylint: disable=protected-access
-
- # --------------
- # Methods to be implemented by subclasses if they want to use the
- # inherited implementation of apply_gradients() or compute_gradients().
- # --------------
- def _create_vars(self, var_list, state):
- """Create all slots needed by the variables and any non-slot variables.
-
- Args:
- var_list: A list of `Variable` objects.
- state: An object with these methods:
- `create_slot(var, val, slot_name, optional_op_name)`,
- `create_slot_with_initializer(`
- `var, initializer, shape, dtype, slot_name, optional_op_name)`,
- `zeros_slot(var, slot_name, optional_op_name)`,
- `create_non_slot_variable(initial_value, name, colocate_with)`,
- `get_hyper(name)`
- """
- # No slots needed by default
- pass
-
- def _prepare(self, state):
- """Code to execute before applying gradients.
-
- Note that most uses of _prepare() in Optimizer have been subsumed
- by explicit support for hyper parameters in OptimizerV2
-
- Args:
- state: An object with a `get_hyper(name)` method.
-
- Returns:
- Return value will be ignored.
- """
- pass
-
- def _apply_dense(self, grad, var, state):
- """Add ops to apply dense gradients to `var`.
-
- Args:
- grad: A `Tensor`.
- var: A `Variable` object.
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- An `Operation`.
- """
- raise NotImplementedError()
-
- def _resource_apply_dense(self, grad, handle, state):
- """Add ops to apply dense gradients to the variable `handle`.
-
- Args:
- grad: a `Tensor` representing the gradient.
- handle: a `Tensor` of dtype `resource` which points to the variable
- to be updated.
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- An `Operation` which updates the value of the variable.
- """
- raise NotImplementedError()
-
- def _resource_apply_sparse_duplicate_indices(
- self, grad, handle, indices, state):
- """Add ops to apply sparse gradients to `handle`, with repeated indices.
-
- Optimizers which override this method must deal with repeated indices. See
- the docstring of `_apply_sparse_duplicate_indices` for details. By default
- the correct behavior, to sum non-unique indices and their associated
- gradients, is enforced by first pre-processing `grad` and `indices` and
- passing them on to `_resource_apply_sparse`. Optimizers which deal correctly
- with duplicate indices may instead override this method to avoid the
- overhead of summing.
-
- Args:
- grad: a `Tensor` representing the gradient for the affected indices.
- handle: a `Tensor` of dtype `resource` which points to the variable
- to be updated.
- indices: a `Tensor` of integral type representing the indices for
- which the gradient is nonzero. Indices may be repeated.
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- An `Operation` which updates the value of the variable.
- """
- # pylint: disable=protected-access
- summed_grad, unique_indices = optimizer_v1._deduplicate_indexed_slices(
- values=grad, indices=indices)
- # pylint: enable=protected-access
- return self._resource_apply_sparse(
- summed_grad, handle, unique_indices, state)
-
- def _resource_apply_sparse(self, grad, handle, indices, state):
- """Add ops to apply sparse gradients to the variable `handle`.
-
- Similar to `_apply_sparse`, the `indices` argument to this method has been
- de-duplicated. Optimizers which deal correctly with non-unique indices may
- instead override `_resource_apply_sparse_duplicate_indices` to avoid this
- overhead.
-
- Args:
- grad: a `Tensor` representing the gradient for the affected indices.
- handle: a `Tensor` of dtype `resource` which points to the variable
- to be updated.
- indices: a `Tensor` of integral type representing the indices for
- which the gradient is nonzero. Indices are unique.
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- An `Operation` which updates the value of the variable.
- """
- raise NotImplementedError()
-
- def _apply_sparse_duplicate_indices(self, grad, var, state):
- """Add ops to apply sparse gradients to `var`, with repeated sparse indices.
-
- Optimizers which override this method must deal with IndexedSlices objects
- such as the following:
-
- IndexedSlicesValue(values=[1, 1], indices=[0, 0], dense_shape=[1])
-
- The correct interpretation is:
-
- IndexedSlicesValue(values=[2], indices=[0], dense_shape=[1])
-
- Many optimizers deal incorrectly with repeated indices when updating based
- on sparse gradients (e.g. summing squares rather than squaring the sum, or
- applying momentum terms multiple times). Adding first is always the correct
- behavior, so this is enforced here by reconstructing the IndexedSlices to
- have only unique indices, then calling _apply_sparse.
-
- Optimizers which deal correctly with repeated indices may instead override
- this method to avoid the overhead of summing indices.
-
- Args:
- grad: `IndexedSlices`.
- var: A `Variable` object.
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- An `Operation`.
- """
- # pylint: disable=protected-access
- summed_values, unique_indices = optimizer_v1._deduplicate_indexed_slices(
- values=grad.values, indices=grad.indices)
- # pylint: enable=protected-access
- gradient_no_duplicate_indices = ops.IndexedSlices(
- indices=unique_indices,
- values=summed_values,
- dense_shape=grad.dense_shape)
- return self._apply_sparse(gradient_no_duplicate_indices, var, state)
-
- def _apply_sparse(self, grad, var, state):
- """Add ops to apply sparse gradients to `var`.
-
- The IndexedSlices object passed to `grad` in this function is by default
- pre-processed in `_apply_sparse_duplicate_indices` to remove duplicate
- indices (see its docstring for details). Optimizers which can tolerate or
- have correct special cases for duplicate sparse indices may override
- `_apply_sparse_duplicate_indices` instead of this function, avoiding that
- overhead.
-
- Args:
- grad: `IndexedSlices`, with no repeated indices.
- var: A `Variable` object.
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- An `Operation`.
- """
- raise NotImplementedError()
-
- def _finish(self, state):
- """Do what is needed to finish the update.
-
- This is called inside a scope colocated with any non-slot variables.
-
- Args:
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- The operation to apply updates, or None if no updates.
- """
- return None
-
- # --------------
- # Utility methods for subclasses.
- # --------------
- def _get_per_graph_state(self):
- # pylint: disable=protected-access
- return self._per_graph_state.get(ops.get_default_graph()._graph_key, None)
-
- def _get_state_for_var(self, var):
- # pylint: disable=protected-access
- return self._per_graph_state.get(var._graph_key, None)
-
- # --------------
- # Overridden methods from Checkpointable.
- # --------------
-
- def _track_checkpointable(self, *args, **kwargs):
- """Optimizers may not track dependencies. Raises an error."""
- raise NotImplementedError(
- "Optimizers may not have dependencies. File a feature request if this "
- "limitation bothers you.")
-
- @property
- def _checkpoint_dependencies(self):
- """From Checkpointable. Gather graph-specific non-slot variables to save."""
- current_graph_non_slot_variables = []
- state = self._get_per_graph_state()
- if state is not None:
- for name, variable_object in sorted(
- state._non_slot_dict.items(), # pylint: disable=protected-access
- # Avoid comparing variables
- key=lambda item: item[0]):
- current_graph_non_slot_variables.append(
- checkpointable.CheckpointableReference(
- name=name, ref=variable_object))
- # Note: ignores super(); Optimizers may not have any dependencies outside of
- # state objects.
- return current_graph_non_slot_variables
-
- def _lookup_dependency(self, name):
- """From Checkpointable. Find a non-slot variable in the current graph."""
- state = self._get_per_graph_state()
- if state is None:
- return None
- else:
- return state.get_non_slot(name)
-
- @property
- def _deferred_dependencies(self):
- """Lets Checkpointable know where non-slot variables are created.
-
- If necessary, creates a new state object for the current default graph.
- Checkpointable will then add entries to that state's deferred dependency
- dictionary. The state object will check that dictionary when creating
- non-slot variables, restoring their value if an entry is found.
-
- Returns:
- A dictionary which holds deferred dependencies for the current default
- graph.
- """
- state = self._get_or_create_state()
- return state._deferred_dependencies # pylint: disable=protected-access
-
- def _create_or_restore_slot_variable(
- self, slot_variable_position, slot_name, variable):
- """Checkpointable: Restore a slot variable's value, possibly creating it.
-
- Called when a variable which has an associated slot variable is created or
- restored.
-
- Args:
- slot_variable_position: A `checkpointable._CheckpointPosition` object
- indicating the slot variable `Checkpointable` object to be restored.
- slot_name: The name of this `Optimizer`'s slot to restore into.
- variable: The variable object this slot is being created for.
- """
- state = self._get_or_create_state(var_list=[variable])
- state._create_or_restore_slot_variable( # pylint: disable=protected-access
- slot_variable_position=slot_variable_position,
- slot_name=slot_name,
- variable=variable,
- optional_op_name=self._name)
-
- # --------------
- # Unsupported parent methods
- # --------------
- def _slot_dict(self, slot_name):
- raise NotImplementedError(
- "_slot_dict() method unsupported in OptimizerV2")
-
- def _get_or_make_slot(self, var, val, slot_name, op_name):
- raise NotImplementedError(
- "_get_or_make_slot() method unsupported in OptimizerV2")
-
- def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype,
- slot_name, op_name):
- raise NotImplementedError(
- "_get_or_make_slot_with_initializer() method unsupported in "
- "OptimizerV2")
-
- def _create_non_slot_variable(self, initial_value, name, colocate_with):
- raise NotImplementedError(
- "_create_non_slot_variable() method unsupported in OptimizerV2")
-
- def _get_non_slot_variable(self, name, graph=None):
- raise NotImplementedError(
- "_get_non_slot_variable() method unsupported in OptimizerV2")
-
- def _non_slot_variables(self):
- raise NotImplementedError(
- "_non_slot_variables() method unsupported in OptimizerV2")
+ super(OptimizerV2, self).__init__(name)
diff --git a/tensorflow/contrib/optimizer_v2/rmsprop.py b/tensorflow/contrib/optimizer_v2/rmsprop.py
index 3de53405ec..090e257ddc 100644
--- a/tensorflow/contrib/optimizer_v2/rmsprop.py
+++ b/tensorflow/contrib/optimizer_v2/rmsprop.py
@@ -41,19 +41,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.optimizer_v2 import optimizer_v2
-from tensorflow.python.ops import array_ops
+from tensorflow.python.keras.optimizer_v2 import rmsprop
+from tensorflow.python.util import deprecation
-from tensorflow.python.training import training_ops
-
-class RMSPropOptimizer(optimizer_v2.OptimizerV2):
+class RMSPropOptimizer(rmsprop.RMSProp):
"""Optimizer that implements the RMSProp algorithm.
See the
[paper](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf).
"""
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self,
learning_rate,
decay=0.9,
@@ -96,138 +98,10 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2):
name: Optional name prefix for the operations created when applying
gradients. Defaults to "RMSProp".
"""
- super(RMSPropOptimizer, self).__init__(use_locking, name)
- self._set_hyper("learning_rate", learning_rate)
- self._set_hyper("decay", decay)
- self._set_hyper("momentum", momentum)
- self._set_hyper("epsilon", epsilon)
-
- self._centered = centered
-
- def _create_vars(self, var_list, state):
- for v in var_list:
- init_rms = state.get_hyper(
- "epsilon", v.dtype.base_dtype) * array_ops.ones_like(v)
- state.create_slot_with_initializer(v, init_rms, v.get_shape(),
- v.dtype.base_dtype, "rms")
- if self._centered:
- state.zeros_slot(v, "mg")
- state.zeros_slot(v, "momentum")
-
- def _apply_dense(self, grad, var, state):
- rms = state.get_slot(var, "rms")
- mom = state.get_slot(var, "momentum")
- if self._centered:
- mg = state.get_slot(var, "mg")
- return training_ops.apply_centered_rms_prop(
- var,
- mg,
- rms,
- mom,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- # epsilon is now the rms initial value and is not added to the
- # denominator anymore, hence calling the kernel op with epsilon=0.
- 0,
- grad,
- use_locking=self._use_locking).op
- else:
- return training_ops.apply_rms_prop(
- var,
- rms,
- mom,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad,
- use_locking=self._use_locking).op
-
- def _resource_apply_dense(self, grad, var, state):
- rms = state.get_slot(var, "rms")
- mom = state.get_slot(var, "momentum")
- if self._centered:
- mg = state.get_slot(var, "mg")
- return training_ops.resource_apply_centered_rms_prop(
- var.handle,
- mg.handle,
- rms.handle,
- mom.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad,
- use_locking=self._use_locking)
- else:
- return training_ops.resource_apply_rms_prop(
- var.handle,
- rms.handle,
- mom.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad,
- use_locking=self._use_locking)
-
- def _apply_sparse(self, grad, var, state):
- rms = state.get_slot(var, "rms")
- mom = state.get_slot(var, "momentum")
- if self._centered:
- mg = state.get_slot(var, "mg")
- return training_ops.sparse_apply_centered_rms_prop(
- var,
- mg,
- rms,
- mom,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad.values,
- grad.indices,
- use_locking=self._use_locking)
- else:
- return training_ops.sparse_apply_rms_prop(
- var,
- rms,
- mom,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad.values,
- grad.indices,
- use_locking=self._use_locking)
-
- def _resource_apply_sparse(self, grad, var, indices, state):
- rms = state.get_slot(var, "rms")
- mom = state.get_slot(var, "momentum")
- if self._centered:
- mg = self.get_slot(var, "mg")
- return training_ops.resource_sparse_apply_centered_rms_prop(
- var.handle,
- mg.handle,
- rms.handle,
- mom.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad,
- indices,
- use_locking=self._use_locking)
- else:
- return training_ops.resource_sparse_apply_rms_prop(
- var.handle,
- rms.handle,
- mom.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad,
- indices,
- use_locking=self._use_locking)
+ super(RMSPropOptimizer, self).__init__(
+ learning_rate=learning_rate,
+ rho=decay,
+ momentum=momentum,
+ epsilon=epsilon,
+ centered=centered,
+ name=name)
diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD
index 23e3a25d71..94a2d9672d 100644
--- a/tensorflow/contrib/quantize/BUILD
+++ b/tensorflow/contrib/quantize/BUILD
@@ -138,7 +138,6 @@ py_library(
srcs = ["python/quant_ops.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/contrib/framework:framework_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops",
diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py
index 27069444a4..d9dc7fa62e 100644
--- a/tensorflow/contrib/quantize/python/quant_ops.py
+++ b/tensorflow/contrib/quantize/python/quant_ops.py
@@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.framework.python.ops import add_arg_scope
-from tensorflow.contrib.framework.python.ops import model_variable
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
@@ -29,7 +27,6 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.training import moving_averages
-@add_arg_scope
def FixedQuantize(inputs, init_min=-6.0, init_max=6.0, scope=None):
"""Adds a fake quantize layer with fixed quantization interval.
@@ -46,7 +43,21 @@ def FixedQuantize(inputs, init_min=-6.0, init_max=6.0, scope=None):
inputs, min=init_min, max=init_max)
-@add_arg_scope
+def _ModelVariable(name,
+ shape=None,
+ initializer=None,
+ collections=None,
+ trainable=None):
+ collections = list(collections or [])
+ collections += [ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES]
+ return variable_scope.get_variable(
+ name,
+ shape=shape,
+ initializer=initializer,
+ collections=collections,
+ trainable=trainable)
+
+
def LastValueQuantize(inputs,
per_channel=False,
init_min=-6.0,
@@ -93,13 +104,13 @@ def LastValueQuantize(inputs,
else:
min_max_shape = []
- min_var = model_variable(
+ min_var = _ModelVariable(
'min',
shape=min_max_shape,
initializer=init_ops.constant_initializer(init_min),
collections=[vars_collection],
trainable=False)
- max_var = model_variable(
+ max_var = _ModelVariable(
'max',
shape=min_max_shape,
initializer=init_ops.constant_initializer(init_max),
@@ -153,7 +164,6 @@ def LastValueQuantize(inputs,
narrow_range=narrow_range)
-@add_arg_scope
def MovingAvgQuantize(inputs,
per_channel=False,
init_min=-6.0,
@@ -202,13 +212,13 @@ def MovingAvgQuantize(inputs,
else:
min_max_shape = []
- min_var = model_variable(
+ min_var = _ModelVariable(
'min',
shape=min_max_shape,
initializer=init_ops.constant_initializer(init_min),
collections=[vars_collection],
trainable=False)
- max_var = model_variable(
+ max_var = _ModelVariable(
'max',
shape=min_max_shape,
initializer=init_ops.constant_initializer(init_max),
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index afb9de8370..5e63d33db8 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -461,8 +461,8 @@ class _LayerMatch(object):
return self._bias_add_op
-def _GetFollowingFakeQuantOp(tensor):
- """Returns the following FakeQuant op if it exists else None."""
+def _FollowedByFakeQuant(tensor):
+ """Returns True if the tensor is followed by a FakeQuant."""
fake_quant_ops = set([
'FakeQuantWithMinMaxVars', 'FakeQuantWithMinMaxArgs',
'FakeQuantWithMinMaxVarsPerChannel'
@@ -472,11 +472,11 @@ def _GetFollowingFakeQuantOp(tensor):
while consumers:
c = consumers.pop()
if c.type in fake_quant_ops:
- return c
+ return True
elif c.type in pass_through_ops:
for output in c.outputs:
consumers.extend(output.consumers())
- return None
+ return False
def _InsertQuantOp(context,
@@ -559,77 +559,44 @@ def _InsertQuantOp(context,
# Prevent ops from being quantized multiple times. Bypass ops can sometimes
# overlap between multiple matches, so we need to ensure that we don't
# add duplicate FakeQuant operations.
- fake_quant_op = _GetFollowingFakeQuantOp(inputs)
-
- # If we find that we are attempting to insert a fake quant op following
- # a fake quant, we skip inserting a fake quant op
-
- if fake_quant_op is None:
- if moving_avg:
- quant = (
- quant_ops.MovingAvgQuantize(
- inputs,
- init_min=init_min,
- init_max=init_max,
- ema_decay=ema_decay,
- is_training=is_training,
- num_bits=bits,
- narrow_range=narrow_range,
- vars_collection=vars_collection,
- name_prefix=name_prefix))
- else:
- quant = (
- quant_ops.LastValueQuantize(
- inputs,
- init_min=init_min,
- init_max=init_max,
- is_training=is_training,
- num_bits=bits,
- narrow_range=narrow_range,
- vars_collection=vars_collection,
- name_prefix=name_prefix))
-
- if quant_delay and quant_delay > 0:
- activate_quant = math_ops.greater_equal(
- common.CreateOrGetQuantizationStep(),
- quant_delay,
- name=name_prefix + '/activate_quant')
- quant = control_flow_ops.cond(
- activate_quant,
- lambda: quant,
- lambda: inputs,
- name=name_prefix + '/delayed_quant')
+ if _FollowedByFakeQuant(inputs):
+ return
+
+ if moving_avg:
+ quant = (
+ quant_ops.MovingAvgQuantize(
+ inputs,
+ init_min=init_min,
+ init_max=init_max,
+ ema_decay=ema_decay,
+ is_training=is_training,
+ num_bits=bits,
+ narrow_range=narrow_range,
+ vars_collection=vars_collection,
+ name_prefix=name_prefix))
else:
- # If a fake quant op is present already, make sure that
- # any downstream use of the tensor reroutes to the appropriate quantized
- # tensor. If there is no quant_delay, this is simply the output of the
- # fake quant op. If there is a quant delay, we reroute to the output
- # of the delayed quant operation, which inserts quantization only after
- # a specified quant_delay
-
- quant = fake_quant_op.outputs[0]
- if quant_delay and quant_delay > 0:
- name_prefix = '/'.join(quant.name.split('/')[:-1])
- quant = quant.graph.get_tensor_by_name(name_prefix +
- '/delayed_quant/Merge:0')
- pruned_consumer_set = set()
- for consumer in consumers:
- fake_quant_dest_op = _GetFollowingFakeQuantOp(consumer.outputs[0])
- if (fake_quant_dest_op is None or
- fake_quant_dest_op.name != fake_quant_op.name):
- pruned_consumer_set.add(consumer)
- consumers = pruned_consumer_set
-
- # If we have
- # input->pass_through->fake_quant
- # there is nothing to reroute.
- #
- # If we have
- # input-> pass_through->fake_quant
- # |-> consumer
- # Then we reroute such that:
- # input-> pass_through->fake_quant
- # |-> consumer
+ quant = (
+ quant_ops.LastValueQuantize(
+ inputs,
+ init_min=init_min,
+ init_max=init_max,
+ is_training=is_training,
+ num_bits=bits,
+ narrow_range=narrow_range,
+ vars_collection=vars_collection,
+ name_prefix=name_prefix))
+
+ if quant_delay and quant_delay > 0:
+ activate_quant = math_ops.greater_equal(
+ common.CreateOrGetQuantizationStep(),
+ quant_delay,
+ name=name_prefix + '/activate_quant')
+ quant = control_flow_ops.cond(
+ activate_quant,
+ lambda: quant,
+ lambda: inputs,
+ name=name_prefix + '/delayed_quant')
+
if consumers:
tensors_modified_count = common.RerouteTensor(
quant, inputs, can_modify=consumers)
diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py
index a9fc6c3c61..e80d2183a6 100644
--- a/tensorflow/contrib/quantize/python/quantize_graph_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py
@@ -27,7 +27,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import template
from tensorflow.python.platform import googletest
@@ -307,42 +306,6 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase):
# No ops should be inserted or removed.
self.assertEqual(op_names_before_rewrite, op_names_after_rewrite)
- def testWithSharedWeights(self):
-
- self._RunTestOverAllRewrites(self._TestWithSharedWeights)
- self._RunTestOverTrainingRewrites(self._TestRewriteWithSharedWeights)
-
- def _TestRewriteWithSharedWeights(self, rewrite_fn, quant_delay=1):
- self._TestWithSharedWeights(rewrite_fn, quant_delay)
-
- def _TestWithSharedWeights(self, rewrite_fn, quant_delay=None):
- with ops.Graph().as_default() as g:
- conv = template.make_template('shared_weights_conv', self._ConvLayer)
- conv()
- conv()
- if quant_delay is None:
- rewrite_fn()
- else:
- rewrite_fn(quant_delay=quant_delay)
-
- conv_ops = [op for op in g.get_operations() if op.type == 'Conv2D']
- weights_quants = [
- op for op in g.get_operations()
- if 'weights_quant' in op.name and op.type == 'FakeQuantWithMinMaxVars'
- ]
- # Check that the shared weights variable is not quantized multiple times
- self.assertTrue(len(weights_quants) == 1)
- weights_quant_tensor = weights_quants[0].outputs[0]
- if quant_delay:
- delayed_weights_quants = [
- op for op in g.get_operations()
- if 'weights_quant' in op.name and op.type == 'Merge'
- ]
- self.assertTrue(len(delayed_weights_quants) == 1)
- weights_quant_tensor = delayed_weights_quants[0].outputs[0]
- # Check that the Conv2D operations get the quantized weights
- self.assertTrue(all(weights_quant_tensor in op.inputs for op in conv_ops))
-
def _ConvLayer(
self, input_tensor=None, scope='test', pre_activation_bypass=False,
post_activation_bypass=False):
diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD
index 4e67d80558..1385a9ddc1 100644
--- a/tensorflow/contrib/rnn/BUILD
+++ b/tensorflow/contrib/rnn/BUILD
@@ -108,7 +108,7 @@ cuda_py_tests(
cuda_py_tests(
name = "core_rnn_cell_test",
- size = "small",
+ size = "medium",
srcs = ["python/kernel_tests/core_rnn_cell_test.py"],
additional_deps = [
":rnn_py",
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index 6689664fb9..aa1d7d2b01 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -29,6 +29,9 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
+from tensorflow.python.keras import initializers
+from tensorflow.python.keras import testing_utils
+from tensorflow.python.keras import utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
@@ -40,7 +43,9 @@ from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
+from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import test
+from tensorflow.python.training import training
from tensorflow.python.util import nest
@@ -1115,6 +1120,138 @@ class RNNCellTest(test.TestCase):
r"input size \(3\) must be divisible by number_of_groups \(2\)"):
gcell(glstm_input, gcell_zero_state)
+ def testCFNCell(self):
+ with self.cached_session() as sess:
+ with variable_scope.variable_scope("root"):
+ x = array_ops.zeros([1, 2])
+ m = array_ops.zeros([1, 2])
+ cell = contrib_rnn_cell.CFNCell(
+ units=2,
+ kernel_initializer=initializers.Constant(0.5))
+ g, _ = cell(x, m)
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([g], {
+ x.name: np.array([[1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
+ # Smoke test
+ self.assertAllClose(res[0], [[0.17188203, 0.17188203]])
+ with variable_scope.variable_scope("other"):
+ # Test CFN with input_size != num_units.
+ x = array_ops.zeros([1, 3])
+ m = array_ops.zeros([1, 2])
+ cell = contrib_rnn_cell.CFNCell(
+ units=2,
+ kernel_initializer=initializers.Constant(0.5))
+ g, _ = cell(x, m)
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([g], {
+ x.name: np.array([[1., 1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
+ # Smoke test
+ self.assertAllClose(res[0], [[0.15535763, 0.15535763]])
+
+ def testCFNCellEndToEnd(self):
+ with self.cached_session() as sess:
+ input_shape = 10
+ output_shape = 5
+ timestep = 4
+ batch = 100
+ (x_train, y_train), _ = testing_utils.get_test_data(
+ train_samples=batch,
+ test_samples=0,
+ input_shape=(timestep, input_shape),
+ num_classes=output_shape)
+ y_train = utils.to_categorical(y_train)
+ cell = contrib_rnn_cell.CFNCell(output_shape)
+
+ inputs = array_ops.placeholder(
+ dtypes.float32, shape=(None, timestep, input_shape))
+ predict = array_ops.placeholder(
+ dtypes.float32, shape=(None, output_shape))
+
+ outputs, state = rnn.dynamic_rnn(
+ cell, inputs, dtype=dtypes.float32)
+ self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape])
+ self.assertEqual(state.shape.as_list(), [None, output_shape])
+ loss = losses.softmax_cross_entropy(predict, state)
+ train_op = training.GradientDescentOptimizer(0.001).minimize(loss)
+
+ sess.run([variables.global_variables_initializer()])
+ _, outputs, state = sess.run(
+ [train_op, outputs, state], {inputs: x_train, predict: y_train})
+
+ self.assertEqual(len(outputs), batch)
+ self.assertEqual(len(state), batch)
+
+ def testMinimalRNNCell(self):
+ with self.cached_session() as sess:
+ with variable_scope.variable_scope(
+ "root"):
+ x = array_ops.zeros([1, 2])
+ m = array_ops.zeros([1, 2])
+ cell = contrib_rnn_cell.MinimalRNNCell(
+ units=2,
+ kernel_initializer=initializers.Constant(0.5))
+ g, _ = cell(x, m)
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([g], {
+ x.name: np.array([[1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
+ # Smoke test
+ self.assertAllClose(res[0], [[0.18899589, 0.18899589]])
+ with variable_scope.variable_scope(
+ "other"):
+ # Test MinimalRNN with input_size != num_units.
+ x = array_ops.zeros([1, 3])
+ m = array_ops.zeros([1, 2])
+ cell = contrib_rnn_cell.MinimalRNNCell(
+ units=2,
+ kernel_initializer=initializers.Constant(0.5))
+ g, _ = cell(x, m)
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([g], {
+ x.name: np.array([[1., 1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
+ # Smoke test
+ self.assertAllClose(res[0], [[0.19554167, 0.19554167]])
+
+ def testMinimalRNNCellEndToEnd(self):
+ with self.cached_session() as sess:
+ input_shape = 10
+ output_shape = 5
+ timestep = 4
+ batch = 100
+ (x_train, y_train), _ = testing_utils.get_test_data(
+ train_samples=batch,
+ test_samples=0,
+ input_shape=(timestep, input_shape),
+ num_classes=output_shape)
+ y_train = utils.to_categorical(y_train)
+ cell = contrib_rnn_cell.MinimalRNNCell(output_shape)
+
+ inputs = array_ops.placeholder(
+ dtypes.float32, shape=(None, timestep, input_shape))
+ predict = array_ops.placeholder(
+ dtypes.float32, shape=(None, output_shape))
+
+ outputs, state = rnn.dynamic_rnn(
+ cell, inputs, dtype=dtypes.float32)
+ self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape])
+ self.assertEqual(state.shape.as_list(), [None, output_shape])
+ loss = losses.softmax_cross_entropy(predict, state)
+ train_op = training.GradientDescentOptimizer(0.001).minimize(loss)
+
+ sess.run([variables.global_variables_initializer()])
+ _, outputs, state = sess.run(
+ [train_op, outputs, state], {inputs: x_train, predict: y_train})
+
+ self.assertEqual(len(outputs), batch)
+ self.assertEqual(len(state), batch)
+
class LayerNormBasicLSTMCellTest(test.TestCase):
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index 06c481672c..78cea8feb4 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -28,6 +28,8 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras import activations
+from tensorflow.python.keras import initializers
from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
@@ -3394,3 +3396,246 @@ class IndyLSTMCell(rnn_cell_impl.LayerRNNCell):
new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
return new_h, new_state
+
+
+class MinimalRNNCell(rnn_cell_impl.LayerRNNCell):
+ """MinimalRNN cell.
+
+ The implementation is based on:
+
+ https://arxiv.org/pdf/1806.05394v2.pdf
+
+ Minmin Chen, Jeffrey Pennington, Samuel S. Schoenholz.
+ "Dynamical Isometry and a Mean Field Theory of RNNs: Gating Enables Signal
+ Propagation in Recurrent Neural Networks." ICML, 2018.
+
+ A MinimalRNN cell first projects the input to the hidden space. The new
+ hidden state is then calcuated as a weighted sum of the projected input and
+ the previous hidden state, using a single update gate.
+ """
+
+ def __init__(self,
+ units,
+ activation="tanh",
+ kernel_initializer="glorot_uniform",
+ bias_initializer="ones",
+ name=None,
+ dtype=None,
+ **kwargs):
+ """Initialize the parameters for a MinimalRNN cell.
+
+ Args:
+ units: int, The number of units in the MinimalRNN cell.
+ activation: Nonlinearity to use in the feedforward network. Default:
+ `tanh`.
+ kernel_initializer: The initializer to use for the weight in the update
+ gate and feedforward network. Default: `glorot_uniform`.
+ bias_initializer: The initializer to use for the bias in the update
+ gate. Default: `ones`.
+ name: String, the name of the cell.
+ dtype: Default dtype of the cell.
+ **kwargs: Dict, keyword named properties for common cell attributes.
+ """
+ super(MinimalRNNCell, self).__init__(name=name, dtype=dtype, **kwargs)
+
+ # Inputs must be 2-dimensional.
+ self.input_spec = base_layer.InputSpec(ndim=2)
+
+ self.units = units
+ self.activation = activations.get(activation)
+ self.kernel_initializer = initializers.get(kernel_initializer)
+ self.bias_initializer = initializers.get(bias_initializer)
+
+ @property
+ def state_size(self):
+ return self.units
+
+ @property
+ def output_size(self):
+ return self.units
+
+ def build(self, inputs_shape):
+ if inputs_shape[-1] is None:
+ raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
+ % str(inputs_shape))
+
+ input_size = inputs_shape[-1]
+ # pylint: disable=protected-access
+ # self._kernel contains W_x, W, V
+ self.kernel = self.add_weight(
+ name=rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ shape=[input_size + 2 * self.units, self.units],
+ initializer=self.kernel_initializer)
+ self.bias = self.add_weight(
+ name=rnn_cell_impl._BIAS_VARIABLE_NAME,
+ shape=[self.units],
+ initializer=self.bias_initializer)
+ # pylint: enable=protected-access
+
+ self.built = True
+
+ def call(self, inputs, state):
+ """Run one step of MinimalRNN.
+
+ Args:
+ inputs: input Tensor, must be 2-D, `[batch, input_size]`.
+ state: state Tensor, must be 2-D, `[batch, state_size]`.
+
+ Returns:
+ A tuple containing:
+
+ - Output: A `2-D` tensor with shape `[batch_size, state_size]`.
+ - New state: A `2-D` tensor with shape `[batch_size, state_size]`.
+
+ Raises:
+ ValueError: If input size cannot be inferred from inputs via
+ static shape inference.
+ """
+ input_size = inputs.get_shape()[1]
+ if input_size.value is None:
+ raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
+
+ feedforward_weight, gate_weight = array_ops.split(
+ value=self.kernel,
+ num_or_size_splits=[input_size.value, 2 * self.units],
+ axis=0)
+
+ feedforward = math_ops.matmul(inputs, feedforward_weight)
+ feedforward = self.activation(feedforward)
+
+ gate_inputs = math_ops.matmul(
+ array_ops.concat([feedforward, state], 1), gate_weight)
+ gate_inputs = nn_ops.bias_add(gate_inputs, self.bias)
+ u = math_ops.sigmoid(gate_inputs)
+
+ new_h = u * state + (1 - u) * feedforward
+ return new_h, new_h
+
+
+class CFNCell(rnn_cell_impl.LayerRNNCell):
+ """Chaos Free Network cell.
+
+ The implementation is based on:
+
+ https://openreview.net/pdf?id=S1dIzvclg
+
+ Thomas Laurent, James von Brecht.
+ "A recurrent neural network without chaos." ICLR, 2017.
+
+ A CFN cell first projects the input to the hidden space. The hidden state
+ goes through a contractive mapping. The new hidden state is then calcuated
+ as a linear combination of the projected input and the contracted previous
+ hidden state, using decoupled input and forget gates.
+ """
+
+ def __init__(self,
+ units,
+ activation="tanh",
+ kernel_initializer="glorot_uniform",
+ bias_initializer="ones",
+ name=None,
+ dtype=None,
+ **kwargs):
+ """Initialize the parameters for a CFN cell.
+
+ Args:
+ units: int, The number of units in the CFN cell.
+ activation: Nonlinearity to use. Default: `tanh`.
+ kernel_initializer: Initializer for the `kernel` weights
+ matrix. Default: `glorot_uniform`.
+ bias_initializer: The initializer to use for the bias in the
+ gates. Default: `ones`.
+ name: String, the name of the cell.
+ dtype: Default dtype of the cell.
+ **kwargs: Dict, keyword named properties for common cell attributes.
+ """
+ super(CFNCell, self).__init__(name=name, dtype=dtype, **kwargs)
+
+ # Inputs must be 2-dimensional.
+ self.input_spec = base_layer.InputSpec(ndim=2)
+
+ self.units = units
+ self.activation = activations.get(activation)
+ self.kernel_initializer = initializers.get(kernel_initializer)
+ self.bias_initializer = initializers.get(bias_initializer)
+
+ @property
+ def state_size(self):
+ return self.units
+
+ @property
+ def output_size(self):
+ return self.units
+
+ def build(self, inputs_shape):
+ if inputs_shape[-1] is None:
+ raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
+ % str(inputs_shape))
+
+ input_size = inputs_shape[-1]
+ # pylint: disable=protected-access
+ # `self.kernel` contains V_{\theta}, V_{\eta}, W.
+ # `self.recurrent_kernel` contains U_{\theta}, U_{\eta}.
+ # `self.bias` contains b_{\theta}, b_{\eta}.
+ self.kernel = self.add_weight(
+ shape=[input_size, 3 * self.units],
+ name=rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ initializer=self.kernel_initializer)
+ self.recurrent_kernel = self.add_weight(
+ shape=[self.units, 2 * self.units],
+ name="recurrent_%s" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ initializer=self.kernel_initializer)
+ self.bias = self.add_weight(
+ shape=[2 * self.units],
+ name=rnn_cell_impl._BIAS_VARIABLE_NAME,
+ initializer=self.bias_initializer)
+ # pylint: enable=protected-access
+
+ self.built = True
+
+ def call(self, inputs, state):
+ """Run one step of CFN.
+
+ Args:
+ inputs: input Tensor, must be 2-D, `[batch, input_size]`.
+ state: state Tensor, must be 2-D, `[batch, state_size]`.
+
+ Returns:
+ A tuple containing:
+
+ - Output: A `2-D` tensor with shape `[batch_size, state_size]`.
+ - New state: A `2-D` tensor with shape `[batch_size, state_size]`.
+
+ Raises:
+ ValueError: If input size cannot be inferred from inputs via
+ static shape inference.
+ """
+ input_size = inputs.get_shape()[-1]
+ if input_size.value is None:
+ raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
+
+ # The variable names u, v, w, b are consistent with the notations in the
+ # original paper.
+ v, w = array_ops.split(
+ value=self.kernel,
+ num_or_size_splits=[2 * self.units, self.units],
+ axis=1)
+ u = self.recurrent_kernel
+ b = self.bias
+
+ gates = math_ops.matmul(state, u) + math_ops.matmul(inputs, v)
+ gates = nn_ops.bias_add(gates, b)
+ gates = math_ops.sigmoid(gates)
+ theta, eta = array_ops.split(value=gates,
+ num_or_size_splits=2,
+ axis=1)
+
+ proj_input = math_ops.matmul(inputs, w)
+
+ # The input gate is (1 - eta), which is different from the original paper.
+ # This is for the propose of initialization. With the default
+ # bias_initializer `ones`, the input gate is initialized to a small number.
+ new_h = theta * self.activation(state) + (1 - eta) * self.activation(
+ proj_input)
+
+ return new_h, new_h
diff --git a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py
index 360e7dbe75..7743f5b4a7 100644
--- a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py
+++ b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py
@@ -109,6 +109,42 @@ class SparsemaxLossTest(test.TestCase):
np_loss, tf_loss_out, half_atol=1e-2, half_rtol=5e-3)
self.assertShapeEqual(np_loss, tf_loss_op)
+ def _test_sparsemax_loss_of_nan(self, dtype, random, use_gpu):
+ """check sparsemax-loss transfers nan"""
+ q = np.asarray([[0, 0, 1], [0, 0, 1], [0, 0, 1]])
+ z_nan = np.asarray([[0, np.nan, 0], [0, np.nan, np.nan],
+ [np.nan, np.nan, np.nan]]).astype(dtype)
+
+ _, tf_loss_nan = self._tf_sparsemax_loss(z_nan, q, dtype, use_gpu)
+ self.assertAllCloseAccordingToType([np.nan, np.nan, np.nan], tf_loss_nan)
+
+ def _test_sparsemax_loss_of_inf(self, dtype, random, use_gpu):
+ """check sparsemax-loss is infinity safe"""
+ q = np.asarray([[0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1]])
+ z_neg = np.asarray([
+ [0, -np.inf, 0],
+ [0, -np.inf, -np.inf],
+ [-np.inf, -np.inf, 0],
+ [-np.inf, -np.inf, -np.inf],
+ ]).astype(dtype)
+ z_pos = np.asarray([[0, np.inf, 0], [0, np.inf,
+ np.inf], [np.inf, np.inf, 0],
+ [np.inf, np.inf, np.inf]]).astype(dtype)
+ z_mix = np.asarray([[0, np.inf, 0], [0, np.inf, -np.inf],
+ [-np.inf, np.inf, 0], [-np.inf, np.inf,
+ -np.inf]]).astype(dtype)
+
+ _, tf_loss_neg = self._tf_sparsemax_loss(z_neg, q, dtype, use_gpu)
+ self.assertAllCloseAccordingToType([0.25, np.inf, 0, np.nan], tf_loss_neg)
+
+ _, tf_loss_pos = self._tf_sparsemax_loss(z_pos, q, dtype, use_gpu)
+ self.assertAllCloseAccordingToType([np.nan, np.nan, np.nan, np.nan],
+ tf_loss_pos)
+
+ _, tf_loss_mix = self._tf_sparsemax_loss(z_mix, q, dtype, use_gpu)
+ self.assertAllCloseAccordingToType([np.nan, np.nan, np.nan, np.nan],
+ tf_loss_mix)
+
def _test_constant_add(self, dtype, random, use_gpu):
"""check sparsemax-loss proposition 3"""
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
@@ -198,6 +234,10 @@ class SparsemaxLossTest(test.TestCase):
self._test_sparsemax_loss_against_numpy(dtype, random, use_gpu=False)
+ self._test_sparsemax_loss_of_nan(dtype, random, use_gpu=False)
+
+ self._test_sparsemax_loss_of_inf(dtype, random, use_gpu=False)
+
self._test_constant_add(dtype, random, use_gpu=False)
self._test_sparsemax_loss_positive(dtype, random, use_gpu=False)
diff --git a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py
index 259e62bd86..c95b9da1e4 100644
--- a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py
+++ b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py
@@ -87,6 +87,46 @@ class SparsemaxTest(test.TestCase):
p_sparemax, tf_sparsemax_out, half_atol=5e-3)
self.assertShapeEqual(p_sparemax, tf_sparsemax_op)
+ def _test_sparsemax_of_nan(self, dtype, random, use_gpu):
+ """check sparsemax transfers nan"""
+ z_nan = np.asarray([
+ [0, np.nan, 0],
+ [0, np.nan, np.nan],
+ [np.nan, np.nan, np.nan],
+ ]).astype(dtype)
+
+ _, tf_sparsemax_nan = self._tf_sparsemax(z_nan, dtype, use_gpu)
+ self.assertAllCloseAccordingToType(
+ [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan],
+ [np.nan, np.nan, np.nan]], tf_sparsemax_nan)
+
+ def _test_sparsemax_of_inf(self, dtype, random, use_gpu):
+ """check sparsemax is infinity safe"""
+ z_neg = np.asarray([
+ [0, -np.inf, 0],
+ [0, -np.inf, -np.inf],
+ [-np.inf, -np.inf, -np.inf],
+ ]).astype(dtype)
+ z_pos = np.asarray([[0, np.inf, 0], [0, np.inf, np.inf],
+ [np.inf, np.inf, np.inf]]).astype(dtype)
+ z_mix = np.asarray([[0, np.inf, 0], [0, np.inf, -np.inf],
+ [-np.inf, np.inf, -np.inf]]).astype(dtype)
+
+ _, tf_sparsemax_neg = self._tf_sparsemax(z_neg, dtype, use_gpu)
+ self.assertAllCloseAccordingToType(
+ [[0.5, 0, 0.5], [1, 0, 0], [np.nan, np.nan, np.nan]], tf_sparsemax_neg)
+
+ _, tf_sparsemax_pos = self._tf_sparsemax(z_pos, dtype, use_gpu)
+ self.assertAllCloseAccordingToType(
+ [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan],
+ [np.nan, np.nan, np.nan]], tf_sparsemax_pos)
+
+ _, tf_sparsemax_mix = self._tf_sparsemax(z_mix, dtype, use_gpu)
+ self.assertAllCloseAccordingToType(
+ [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan],
+ [np.nan, np.nan, np.nan]], tf_sparsemax_mix)
+
+
def _test_sparsemax_of_zero(self, dtype, random, use_gpu):
"""check sparsemax proposition 1, part 1"""
z = np.zeros((1, 10))
@@ -97,7 +137,7 @@ class SparsemaxTest(test.TestCase):
self.assertAllCloseAccordingToType(p_sparemax, tf_sparsemax_out)
self.assertShapeEqual(p_sparemax, tf_sparsemax_op)
- def _test_sparsemax_of_inf(self, dtype, random, use_gpu):
+ def _test_sparsemax_of_to_inf(self, dtype, random, use_gpu):
"""check sparsemax proposition 1, part 2"""
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
@@ -210,10 +250,14 @@ class SparsemaxTest(test.TestCase):
self._test_sparsemax_against_numpy(dtype, random, use_gpu=False)
- self._test_sparsemax_of_zero(dtype, random, use_gpu=False)
+ self._test_sparsemax_of_nan(dtype, random, use_gpu=False)
self._test_sparsemax_of_inf(dtype, random, use_gpu=False)
+ self._test_sparsemax_of_zero(dtype, random, use_gpu=False)
+
+ self._test_sparsemax_of_to_inf(dtype, random, use_gpu=False)
+
self._test_constant_add(dtype, random, use_gpu=False)
self._test_permutation(dtype, random, use_gpu=False)
diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py
index e617af2ff1..f79c93f347 100644
--- a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py
+++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py
@@ -49,7 +49,14 @@ def sparsemax(logits, name=None):
obs = array_ops.shape(logits)[0]
dims = array_ops.shape(logits)[1]
- z = logits - math_ops.reduce_mean(logits, axis=1)[:, array_ops.newaxis]
+ # In the paper, they call the logits z.
+ # The mean(logits) can be substracted from logits to make the algorithm
+ # more numerically stable. the instability in this algorithm comes mostly
+ # from the z_cumsum. Substacting the mean will cause z_cumsum to be close
+ # to zero. However, in practise the numerical instability issues are very
+ # minor and substacting the mean causes extra issues with inf and nan
+ # input.
+ z = logits
# sort z
z_sorted, _ = nn.top_k(z, k=dims)
@@ -64,10 +71,24 @@ def sparsemax(logits, name=None):
k_z = math_ops.reduce_sum(math_ops.cast(z_check, dtypes.int32), axis=1)
# calculate tau(z)
- indices = array_ops.stack([math_ops.range(0, obs), k_z - 1], axis=1)
+ # If there are inf values or all values are -inf, the k_z will be zero,
+ # this is mathematically invalid and will also cause the gather_nd to fail.
+ # Prevent this issue for now by setting k_z = 1 if k_z = 0, this is then
+ # fixed later (see p_safe) by returning p = nan. This results in the same
+ # behavior as softmax.
+ k_z_safe = math_ops.maximum(k_z, 1)
+ indices = array_ops.stack([math_ops.range(0, obs), k_z_safe - 1], axis=1)
tau_sum = array_ops.gather_nd(z_cumsum, indices)
tau_z = (tau_sum - 1) / math_ops.cast(k_z, logits.dtype)
# calculate p
- return math_ops.maximum(
+ p = math_ops.maximum(
math_ops.cast(0, logits.dtype), z - tau_z[:, array_ops.newaxis])
+ # If k_z = 0 or if z = nan, then the input is invalid
+ p_safe = array_ops.where(
+ math_ops.logical_or(
+ math_ops.equal(k_z, 0), math_ops.is_nan(z_cumsum[:, -1])),
+ array_ops.fill([obs, dims], math_ops.cast(float("nan"), logits.dtype)),
+ p)
+
+ return p_safe
diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py
index 582d1e6136..c0438f16bc 100644
--- a/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py
+++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py
@@ -47,14 +47,30 @@ def sparsemax_loss(logits, sparsemax, labels, name=None):
sparsemax = ops.convert_to_tensor(sparsemax, name="sparsemax")
labels = ops.convert_to_tensor(labels, name="labels")
- shifted_logits = logits - \
- math_ops.reduce_mean(logits, axis=1)[:, array_ops.newaxis]
+ # In the paper, they call the logits z.
+ # A constant can be substracted from logits to make the algorithm
+ # more numerically stable in theory. However, there are really no major
+ # source numerical instability in this algorithm.
+ z = logits
# sum over support
- support = math_ops.cast(sparsemax > 0, sparsemax.dtype)
- sum_s = support * sparsemax * (shifted_logits - 0.5 * sparsemax)
+ # Use a conditional where instead of a multiplication to support z = -inf.
+ # If z = -inf, and there is no support (sparsemax = 0), a multiplication
+ # would cause 0 * -inf = nan, which is not correct in this case.
+ sum_s = array_ops.where(
+ math_ops.logical_or(sparsemax > 0, math_ops.is_nan(sparsemax)),
+ sparsemax * (z - 0.5 * sparsemax), array_ops.zeros_like(sparsemax))
# - z_k + ||q||^2
- q_part = labels * (0.5 * labels - shifted_logits)
+ q_part = labels * (0.5 * labels - z)
+ # Fix the case where labels = 0 and z = -inf, where q_part would
+ # otherwise be 0 * -inf = nan. But since the lables = 0, no cost for
+ # z = -inf should be consideredself.
+ # The code below also coveres the case where z = inf. Howeverm in this
+ # caose the sparsemax will be nan, which means the sum_s will also be nan,
+ # therefor this case doesn't need addtional special treatment.
+ q_part_safe = array_ops.where(
+ math_ops.logical_and(math_ops.equal(labels, 0), math_ops.is_inf(z)),
+ array_ops.zeros_like(z), q_part)
- return math_ops.reduce_sum(sum_s + q_part, axis=1)
+ return math_ops.reduce_sum(sum_s + q_part_safe, axis=1)
diff --git a/tensorflow/contrib/stateless/BUILD b/tensorflow/contrib/stateless/BUILD
index dcbef2881d..e9ddec8889 100644
--- a/tensorflow/contrib/stateless/BUILD
+++ b/tensorflow/contrib/stateless/BUILD
@@ -9,19 +9,16 @@ exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
-tf_gen_op_wrapper_py(
- name = "stateless_random_ops",
- out = "gen_stateless_random_ops.py", # cmake chokes without this
- deps = ["//tensorflow/core:stateless_random_ops_op_lib"],
-)
-
py_library(
name = "stateless",
- srcs = ["__init__.py"],
+ srcs = [
+ "__init__.py",
+ "python/stateless_ops.py",
+ ],
srcs_version = "PY2AND3",
deps = [
- ":stateless_random_ops",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:stateless_random_ops_gen",
"//tensorflow/python:util",
],
)
diff --git a/tensorflow/contrib/stateless/__init__.py b/tensorflow/contrib/stateless/__init__.py
index 0cca40f071..30d0a7ab6a 100644
--- a/tensorflow/contrib/stateless/__init__.py
+++ b/tensorflow/contrib/stateless/__init__.py
@@ -33,14 +33,8 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import
-from tensorflow.contrib.stateless.gen_stateless_random_ops import *
+from tensorflow.contrib.stateless.python.stateless_ops import *
-from tensorflow.python.framework import ops
from tensorflow.python.util.all_util import remove_undocumented
-ops.NotDifferentiable("StatelessMultinomial")
-ops.NotDifferentiable("StatelessRandomNormal")
-ops.NotDifferentiable("StatelessRandomUniform")
-ops.NotDifferentiable("StatelessTruncatedNormal")
-
remove_undocumented(__name__)
diff --git a/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py b/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py
index d724a5c014..ec5a13b7c6 100644
--- a/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py
+++ b/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
+
import numpy as np
from tensorflow.contrib import stateless
from tensorflow.python.framework import constant_op
@@ -27,10 +29,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
-CASES = [(stateless.stateless_random_uniform, random_ops.random_uniform),
- (stateless.stateless_random_normal, random_ops.random_normal),
- (stateless.stateless_truncated_normal, random_ops.truncated_normal)]
-
def invert_philox(key, value):
"""Invert the Philox bijection."""
@@ -51,90 +49,30 @@ def invert_philox(key, value):
class StatelessOpsTest(test.TestCase):
- def testMatchStateful(self):
+ def _test_match(self, cases):
# Stateless ops should be the same as stateful ops on the first call
# after seed scrambling.
+ cases = tuple(cases)
key = 0x3ec8f720, 0x02461e29
for seed in (7, 17), (11, 5), (2, 3):
preseed = invert_philox(key, (seed[0], 0, seed[1], 0)).astype(np.uint64)
preseed = preseed[::2] | preseed[1::2] << 32
random_seed.set_random_seed(seed[0])
with self.test_session(use_gpu=True):
- for stateless_op, stateful_op in CASES:
- for shape in (), (3,), (2, 5):
- stateful = stateful_op(shape, seed=seed[1])
- pure = stateless_op(shape, seed=preseed)
- self.assertAllEqual(stateful.eval(), pure.eval())
+ for stateless_op, stateful_op in cases:
+ stateful = stateful_op(seed=seed[1])
+ pure = stateless_op(seed=preseed)
+ self.assertAllEqual(stateful.eval(), pure.eval())
- def testDeterminism(self):
+ def _test_determinism(self, cases):
# Stateless values should be equal iff the seeds are equal (roughly)
+ cases = tuple(cases)
with self.test_session(use_gpu=True):
for seed_type in [dtypes.int32, dtypes.int64]:
seed_t = array_ops.placeholder(seed_type, shape=[2])
seeds = [(x, y) for x in range(5) for y in range(5)] * 3
- for stateless_op, _ in CASES:
- for shape in (), (3,), (2, 5):
- pure = stateless_op(shape, seed=seed_t)
- values = [(seed, pure.eval(feed_dict={seed_t: seed}))
- for seed in seeds]
- for s0, v0 in values:
- for s1, v1 in values:
- self.assertEqual(s0 == s1, np.all(v0 == v1))
-
- def testShapeType(self):
- with self.test_session(use_gpu=True):
- for shape_dtype in [dtypes.int32, dtypes.int64]:
- seed_t = array_ops.placeholder(dtypes.int64, shape=[2])
- seeds = [(x, y) for x in range(5) for y in range(5)] * 3
- for stateless_op, _ in CASES:
- for shape in (), (3,), (2, 5):
- pure = stateless_op(constant_op.constant(shape, dtype=shape_dtype),
- seed=seed_t)
- values = [(seed, pure.eval(feed_dict={seed_t: seed}))
- for seed in seeds]
- for s0, v0 in values:
- for s1, v1 in values:
- self.assertEqual(s0 == s1, np.all(v0 == v1))
-
- def testMatchStatefulMultinomial(self):
- # Stateless ops should be the same as stateful ops on the first call
- # after seed scrambling.
- key = 0x3ec8f720, 0x02461e29
- num_samples = 4
- for logits_dtype in np.float16, np.float32, np.float64:
- for output_dtype in dtypes.int32, dtypes.int64:
- for seed in (7, 17), (11, 5), (2, 3):
- preseed = invert_philox(key,
- (seed[0], 0, seed[1], 0)).astype(np.uint64)
- preseed = preseed[::2] | preseed[1::2] << 32
- random_seed.set_random_seed(seed[0])
- with self.test_session(use_gpu=True):
- for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
- [0.25, 0.75]]):
- logits_t = constant_op.constant(logits, dtype=logits_dtype)
- stateful = random_ops.multinomial(
- logits_t,
- num_samples,
- seed=seed[1],
- output_dtype=output_dtype)
- pure = stateless.stateless_multinomial(
- logits_t,
- num_samples,
- seed=preseed,
- output_dtype=output_dtype)
- self.assertAllEqual(stateful.eval(), pure.eval())
-
- def testDeterminismMultinomial(self):
- # Stateless values should be equal iff the seeds are equal (roughly)
- num_samples = 10
- with self.test_session(use_gpu=True):
- for seed_type in [dtypes.int32, dtypes.int64]:
- seed_t = array_ops.placeholder(seed_type, shape=[2])
- seeds = [(x, y) for x in range(5) for y in range(5)] * 3
- for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
- [0.25, 0.75]]):
- pure = stateless.stateless_multinomial(
- logits, num_samples, seed=seed_t)
+ for stateless_op, _ in cases:
+ pure = stateless_op(seed=seed_t)
values = [
(seed, pure.eval(feed_dict={seed_t: seed})) for seed in seeds
]
@@ -142,6 +80,74 @@ class StatelessOpsTest(test.TestCase):
for s1, v1 in values:
self.assertEqual(s0 == s1, np.all(v0 == v1))
+ def _float_cases(self, shape_dtypes=(None,)):
+ float_cases = (
+ # Uniform distribution, with and without range
+ (stateless.stateless_random_uniform, random_ops.random_uniform, {}),
+ (stateless.stateless_random_uniform, random_ops.random_uniform,
+ dict(minval=2.2, maxval=7.1)),
+ # Normal distribution, with and without mean+stddev
+ (stateless.stateless_random_normal, random_ops.random_normal, {}),
+ (stateless.stateless_random_normal, random_ops.random_normal,
+ dict(mean=2, stddev=3)),
+ # Truncated normal distribution, with and without mean+stddev
+ (stateless.stateless_truncated_normal, random_ops.truncated_normal, {}),
+ (stateless.stateless_truncated_normal, random_ops.truncated_normal,
+ dict(mean=3, stddev=4)),
+ )
+ for dtype in dtypes.float16, dtypes.float32, dtypes.float64:
+ for shape_dtype in shape_dtypes:
+ for shape in (), (3,), (2, 5):
+ if shape_dtype is not None:
+ shape = constant_op.constant(shape, dtype=shape_dtype)
+ for stateless_op, stateful_op, kwds in float_cases:
+ kwds = dict(shape=shape, dtype=dtype, **kwds)
+ yield (functools.partial(stateless_op, **kwds),
+ functools.partial(stateful_op, **kwds))
+
+ def _int_cases(self, shape_dtypes=(None,)):
+ for shape_dtype in shape_dtypes:
+ for shape in (), (3,), (2, 5):
+ if shape_dtype is not None:
+ shape = constant_op.constant(shape, dtype=shape_dtype)
+ for dtype in dtypes.int32, dtypes.int64:
+ kwds = dict(minval=2, maxval=11111, dtype=dtype, shape=shape)
+ yield (functools.partial(stateless.stateless_random_uniform, **kwds),
+ functools.partial(random_ops.random_uniform, **kwds))
+
+ def _multinomial_cases(self):
+ num_samples = 10
+ for logits_dtype in np.float16, np.float32, np.float64:
+ for output_dtype in dtypes.int32, dtypes.int64:
+ for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
+ [0.25, 0.75]]):
+ kwds = dict(
+ logits=constant_op.constant(logits, dtype=logits_dtype),
+ num_samples=num_samples,
+ output_dtype=output_dtype)
+ yield (functools.partial(stateless.stateless_multinomial, **kwds),
+ functools.partial(random_ops.multinomial, **kwds))
+
+ def testMatchFloat(self):
+ self._test_match(self._float_cases())
+
+ def testMatchInt(self):
+ self._test_match(self._int_cases())
+
+ def testMatchMultinomial(self):
+ self._test_match(self._multinomial_cases())
+
+ def testDeterminismFloat(self):
+ self._test_determinism(
+ self._float_cases(shape_dtypes=(dtypes.int32, dtypes.int64)))
+
+ def testDeterminismInt(self):
+ self._test_determinism(
+ self._int_cases(shape_dtypes=(dtypes.int32, dtypes.int64)))
+
+ def testDeterminismMultinomial(self):
+ self._test_determinism(self._multinomial_cases())
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/stateless/python/stateless_ops.py b/tensorflow/contrib/stateless/python/stateless_ops.py
new file mode 100644
index 0000000000..1449825c83
--- /dev/null
+++ b/tensorflow/contrib/stateless/python/stateless_ops.py
@@ -0,0 +1,214 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Stateless random ops which take seed as a tensor input."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import gen_stateless_random_ops
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import math_ops
+
+ops.NotDifferentiable("StatelessMultinomial")
+ops.NotDifferentiable("StatelessRandomNormal")
+ops.NotDifferentiable("StatelessRandomUniform")
+ops.NotDifferentiable("StatelessRandomUniformInt")
+ops.NotDifferentiable("StatelessTruncatedNormal")
+
+
+def stateless_random_uniform(shape,
+ seed,
+ minval=0,
+ maxval=None,
+ dtype=dtypes.float32,
+ name=None):
+ """Outputs deterministic pseudorandom values from a uniform distribution.
+
+ This is a stateless version of `tf.random_uniform`: if run twice with the
+ same seeds, it will produce the same pseudorandom numbers. The output is
+ consistent across multiple runs on the same hardware (and between CPU
+ and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
+ hardware.
+
+ The generated values follow a uniform distribution in the range
+ `[minval, maxval)`. The lower bound `minval` is included in the range, while
+ the upper bound `maxval` is excluded.
+
+ For floats, the default range is `[0, 1)`. For ints, at least `maxval` must
+ be specified explicitly.
+
+ In the integer case, the random integers are slightly biased unless
+ `maxval - minval` is an exact power of two. The bias is small for values of
+ `maxval - minval` significantly smaller than the range of the output (either
+ `2**32` or `2**64`).
+
+ Args:
+ shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
+ seed: A shape [2] integer Tensor of seeds to the random number generator.
+ minval: A 0-D Tensor or Python value of type `dtype`. The lower bound on the
+ range of random values to generate. Defaults to 0.
+ maxval: A 0-D Tensor or Python value of type `dtype`. The upper bound on the
+ range of random values to generate. Defaults to 1 if `dtype` is floating
+ point.
+ dtype: The type of the output: `float16`, `float32`, `float64`, `int32`, or
+ `int64`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tensor of the specified shape filled with random uniform values.
+
+ Raises:
+ ValueError: If `dtype` is integral and `maxval` is not specified.
+ """
+ dtype = dtypes.as_dtype(dtype)
+ if dtype not in (dtypes.float16, dtypes.bfloat16, dtypes.float32,
+ dtypes.float64, dtypes.int32, dtypes.int64):
+ raise ValueError("Invalid dtype %r" % dtype)
+ if maxval is None:
+ if dtype.is_integer:
+ raise ValueError("Must specify maxval for integer dtype %r" % dtype)
+ maxval = 1
+ with ops.name_scope(name, "stateless_random_uniform",
+ [shape, seed, minval, maxval]) as name:
+ shape = random_ops._ShapeTensor(shape) # pylint: disable=protected-access
+ minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
+ maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
+ if dtype.is_integer:
+ return gen_stateless_random_ops.stateless_random_uniform_int(
+ shape, seed=seed, minval=minval, maxval=maxval, name=name)
+ else:
+ rnd = gen_stateless_random_ops.stateless_random_uniform(
+ shape, seed=seed, dtype=dtype)
+ return math_ops.add(rnd * (maxval - minval), minval, name=name)
+
+
+def stateless_random_normal(shape,
+ seed,
+ mean=0.0,
+ stddev=1.0,
+ dtype=dtypes.float32,
+ name=None):
+ """Outputs deterministic pseudorandom values from a normal distribution.
+
+ This is a stateless version of `tf.random_normal`: if run twice with the
+ same seeds, it will produce the same pseudorandom numbers. The output is
+ consistent across multiple runs on the same hardware (and between CPU
+ and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
+ hardware.
+
+ Args:
+ shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
+ seed: A shape [2] integer Tensor of seeds to the random number generator.
+ mean: A 0-D Tensor or Python value of type `dtype`. The mean of the normal
+ distribution.
+ stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
+ of the normal distribution.
+ dtype: The type of the output.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tensor of the specified shape filled with random normal values.
+ """
+ with ops.name_scope(name, "stateless_random_normal",
+ [shape, seed, mean, stddev]) as name:
+ shape = random_ops._ShapeTensor(shape) # pylint: disable=protected-access
+ mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
+ stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
+ rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype)
+ return math_ops.add(rnd * stddev, mean, name=name)
+
+
+def stateless_truncated_normal(shape,
+ seed,
+ mean=0.0,
+ stddev=1.0,
+ dtype=dtypes.float32,
+ name=None):
+ """Outputs deterministic pseudorandom values, truncated normally distributed.
+
+ This is a stateless version of `tf.truncated_normal`: if run twice with the
+ same seeds, it will produce the same pseudorandom numbers. The output is
+ consistent across multiple runs on the same hardware (and between CPU
+ and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
+ hardware.
+
+ The generated values follow a normal distribution with specified mean and
+ standard deviation, except that values whose magnitude is more than 2 standard
+ deviations from the mean are dropped and re-picked.
+
+ Args:
+ shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
+ seed: A shape [2] integer Tensor of seeds to the random number generator.
+ mean: A 0-D Tensor or Python value of type `dtype`. The mean of the
+ truncated normal distribution.
+ stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
+ of the normal distribution, before truncation.
+ dtype: The type of the output.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tensor of the specified shape filled with random truncated normal values.
+ """
+ with ops.name_scope(name, "stateless_truncated_normal",
+ [shape, seed, mean, stddev]) as name:
+ shape = random_ops._ShapeTensor(shape) # pylint: disable=protected-access
+ mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
+ stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
+ rnd = gen_stateless_random_ops.stateless_truncated_normal(
+ shape, seed, dtype)
+ return math_ops.add(rnd * stddev, mean, name=name)
+
+
+def stateless_multinomial(logits,
+ num_samples,
+ seed,
+ output_dtype=dtypes.int64,
+ name=None):
+ """Draws deterministic pseudorandom samples from a multinomial distribution.
+
+ This is a stateless version of `tf.multinomial`: if run twice with the
+ same seeds, it will produce the same pseudorandom numbers. The output is
+ consistent across multiple runs on the same hardware (and between CPU
+ and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
+ hardware.
+
+ Example:
+
+ ```python
+ # samples has shape [1, 5], where each value is either 0 or 1 with equal
+ # probability.
+ samples = tf.contrib.stateless.stateless_multinomial(
+ tf.log([[10., 10.]]), 5, seed=[7, 17])
+ ```
+
+ Args:
+ logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice
+ `[i, :]` represents the unnormalized log-probabilities for all classes.
+ num_samples: 0-D. Number of independent samples to draw for each row slice.
+ seed: A shape [2] integer Tensor of seeds to the random number generator.
+ name: Optional name for the operation.
+ output_dtype: integer type to use for the output. Defaults to int64.
+
+ Returns:
+ The drawn samples of shape `[batch_size, num_samples]`.
+ """
+ with ops.name_scope(name, "stateless_multinomial", [logits, seed]):
+ logits = ops.convert_to_tensor(logits, name="logits")
+ return gen_stateless_random_ops.stateless_multinomial(
+ logits, num_samples, seed, output_dtype=output_dtype)
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index 9e8979bce4..5c16fcb760 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -455,7 +455,6 @@ cuda_py_tests(
"test/multi_connection_neighbor_engine_test.py",
"test/neighboring_engine_test.py",
"test/rank_two_test.py",
- "test/unary_test.py",
"test/vgg_block_nchw_test.py",
"test/vgg_block_test.py",
],
@@ -471,6 +470,25 @@ cuda_py_tests(
],
)
+cuda_py_tests(
+ name = "tf_trt_integration_test_no_oss",
+ srcs = [
+ "test/unary_test.py",
+ ],
+ additional_deps = [
+ ":tf_trt_integration_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ ],
+ tags = [
+ "no_cuda_on_cpu_tap",
+ "no_oss", # TODO(b/117274186): re-enable in OSS after crash fixed
+ "no_pip", # TODO(b/117274186): re-enable in OSS after crash fixed
+ "no_windows",
+ "nomac",
+ ],
+)
+
cc_library(
name = "utils",
srcs = ["convert/utils.cc"],
diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD
index cb1f707028..c230919168 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/BUILD
+++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD
@@ -159,12 +159,7 @@ py_test(
],
shard_count = 4,
srcs_version = "PY2AND3",
- tags = [
- "no_pip_gpu", # b/63391119
- "noasan", # b/116875897
- "nomsan",
- "notsan",
- ],
+ tags = ["no_pip_gpu"], # b/63391119
deps = [
":estimators",
":feature_keys",
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index 0c4bdab191..10ed1c2891 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -135,6 +135,9 @@ tf_gen_op_wrapper_py(
name = "tpu_ops",
hidden = [
"SendTPUEmbeddingGradients",
+ "EnqueueTPUEmbeddingIntegerBatch",
+ "EnqueueTPUEmbeddingSparseBatch",
+ "EnqueueTPUEmbeddingSparseTensorBatch",
],
deps = [
":cross_replica_ops_op_lib",
diff --git a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
index 6b0730b40c..0ef29bdf73 100644
--- a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
+++ b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
@@ -46,7 +46,7 @@ namespace tensorflow {
// 5. TPUEmbeddingActivations, when used with appropriate Python libraries,
// enables the automatic differentiation of models that use embeddings.
// 6. TPUEmbeddingSendGradients takes a list of Tensors (of the same shapes
-// as those returned by TPUEmbeddingReceivActivations) containing gradients
+// as those returned by TPUEmbeddingReceiveActivations) containing gradients
// to use in updating the embedding tables.
// 7. Before saving a checkpoint, use the TPUEmbeddingRetrieve Op to update
// the Graph's embedding table Variables from the updated tables in the
@@ -104,9 +104,18 @@ Status RegisterPerTableLoadOpsForAlgorithmBody(
}
}
{
+ auto* table_id_attr = op_def->add_attr();
+ table_id_attr->set_name("table_id");
+ table_id_attr->set_type("int");
+ table_id_attr->set_has_minimum(true);
+ table_id_attr->set_minimum(-1);
+ table_id_attr->mutable_default_value()->set_i(-1);
+ }
+ {
auto* table_name_attr = op_def->add_attr();
table_name_attr->set_name("table_name");
table_name_attr->set_type("string");
+ table_name_attr->mutable_default_value()->set_s("");
}
{
auto* num_shards_attr = op_def->add_attr();
@@ -138,9 +147,11 @@ parameters that are loaded from a checkpoint before a training loop is
executed.
%s
table_name: Name of this table; must match a name in the
- EmbeddingLayerConfiguration proto.
+ TPUEmbeddingConfiguration proto (overrides table_id).
num_shards: Number of shards into which the embedding tables are divided.
shard_id: Identifier of shard for this operation.
+table_id: Index of this table in the EmbeddingLayerConfiguration proto
+ (deprecated).
)doc",
parameter_descriptions.c_str()));
op_def->set_is_commutative(false);
@@ -149,10 +160,14 @@ shard_id: Identifier of shard for this operation.
auto shape_inference_function =
[state_variable_specs,
is_debug_op](shape_inference::InferenceContext* c) -> Status {
+ int table_id;
+ TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
string table_name;
TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
- if (table_name.empty()) {
- return errors::InvalidArgument("table_name attribute must be set");
+ // Exactly one must be non-default.
+ if ((table_id >= 0) == (!table_name.empty())) {
+ return errors::InvalidArgument(
+ "exactly one of table_id or table_name must be non-default");
}
int num_shards;
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
@@ -226,9 +241,18 @@ Status RegisterPerTableRetrieveOpsForAlgorithmBody(
}
}
{
+ auto* table_id_attr = op_def->add_attr();
+ table_id_attr->set_name("table_id");
+ table_id_attr->set_type("int");
+ table_id_attr->set_has_minimum(true);
+ table_id_attr->set_minimum(-1);
+ table_id_attr->mutable_default_value()->set_i(-1);
+ }
+ {
auto* table_name_attr = op_def->add_attr();
table_name_attr->set_name("table_name");
table_name_attr->set_type("string");
+ table_name_attr->mutable_default_value()->set_s("");
}
{
auto* num_shards_attr = op_def->add_attr();
@@ -259,9 +283,11 @@ the correct embedding table configuration. For example, this op is
used to retrieve updated parameters before saving a checkpoint.
%s
table_name: Name of this table; must match a name in the
- EmbeddingLayerConfiguration proto.
+ TPUEmbeddingConfiguration proto (overrides table_id).
num_shards: Number of shards into which the embedding tables are divided.
shard_id: Identifier of shard for this operation.
+table_id: Index of this table in the EmbeddingLayerConfiguration proto
+ (deprecated).
)doc",
parameter_descriptions.c_str()));
op_def->set_is_commutative(false);
@@ -270,10 +296,14 @@ shard_id: Identifier of shard for this operation.
auto shape_inference_function =
[state_variable_specs,
is_debug_op](shape_inference::InferenceContext* c) -> Status {
+ int table_id;
+ TF_RETURN_IF_ERROR(c->GetAttr("table_id", &table_id));
string table_name;
TF_RETURN_IF_ERROR(c->GetAttr("table_name", &table_name));
- if (table_name.empty()) {
- return errors::InvalidArgument("table_name must be non-empty");
+ // Exactly one must be non-default.
+ if ((table_id >= 0) == (!table_name.empty())) {
+ return errors::InvalidArgument(
+ "exactly one of table_id or table_name must be non-default");
}
int num_shards;
TF_RETURN_IF_ERROR(c->GetAttr("num_shards", &num_shards));
@@ -305,7 +335,6 @@ void RegisterPerTableLoadAndRetrieveOps() {
tpu::GradientAccumulationSupport grad_accum_support;
TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support));
if (grad_accum_support == tpu::GradientAccumulationSupport::kSupported) {
- // TODO(gkurian): Condition this on being used internally within Google.
OpRegistry::Global()->Register(
[alg](OpRegistrationData* op_reg_data) -> Status {
return RegisterPerTableLoadOpsForAlgorithmBody(alg, true,
@@ -323,7 +352,6 @@ void RegisterPerTableLoadAndRetrieveOps() {
tpu::GradientAccumulationSupport grad_accum_support;
TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support));
if (grad_accum_support == tpu::GradientAccumulationSupport::kSupported) {
- // TODO(gkurian): Condition this on being used internally within Google.
OpRegistry::Global()->Register(
[alg](OpRegistrationData* op_reg_data) -> Status {
return RegisterPerTableRetrieveOpsForAlgorithmBody(alg, true,
@@ -336,7 +364,7 @@ void RegisterPerTableLoadAndRetrieveOps() {
} // namespace
REGISTER_OP("RecvTPUEmbeddingActivations")
- .Output("outputs: num_outputs * float")
+ .Output("outputs: num_outputs * float32")
.Attr("num_outputs: int >= 1")
.Attr("config: string")
.SetIsStateful()
@@ -446,7 +474,8 @@ config: Serialized TPUEmbeddingConfiguration proto.
REGISTER_OP("EnqueueTPUEmbeddingIntegerBatch")
.Input("batch: N * int32")
- .Attr("N: int")
+ .Input("mode_override: string")
+ .Attr("N: int >= 1")
.Attr("device_ordinal: int = -1")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
@@ -455,6 +484,10 @@ An op that enqueues a list of input batch tensors to TPUEmbedding.
batch: A list of 1D tensors, one for each embedding table, containing the
indices into the tables.
+mode_override: A string input that overrides the mode specified in the
+ TPUEmbeddingConfiguration. Supported values are {'unspecified', 'inference',
+ 'training', 'backward_pass_only'}. When set to 'unspecified', the mode set
+ in TPUEmbeddingConfiguration is used, otherwise mode_override is used.
device_ordinal: The TPU device to use. Should be >= 0 and less than the number
of TPU cores in the task on which the node is placed.
)doc");
@@ -463,7 +496,8 @@ REGISTER_OP("EnqueueTPUEmbeddingSparseBatch")
.Input("sample_indices: N * int32")
.Input("embedding_indices: N * int32")
.Input("aggregation_weights: N * float32")
- .Attr("N: int")
+ .Input("mode_override: string")
+ .Attr("N: int >= 1")
.Attr("device_ordinal: int = -1")
.Attr("combiners: list(string) = []")
.SetIsStateful()
@@ -493,14 +527,18 @@ The tensors at corresponding positions in the three input lists
must have the same shape, i.e. rank 1 with dim_size() equal to the total
number of lookups into the table described by the corresponding table_id.
-sample_indices: A list of Rank 1 Tensors specifying the training example and
+sample_indices: A list of rank 1 Tensors specifying the training example and
feature to which the corresponding embedding_indices and aggregation_weights
values belong. sample_indices[i] must equal b * nf + f, where nf is the
number of features from the corresponding table, f is in [0, nf), and
b is in [0, batch size).
-embedding_indices: A list of Rank 1 Tensors, indices into the embedding tables.
-aggregation_weights: A list of Rank 1 Tensors containing per sample -- i.e. per
+embedding_indices: A list of rank 1 Tensors, indices into the embedding tables.
+aggregation_weights: A list of rank 1 Tensors containing per sample -- i.e. per
(training example, feature) -- aggregation weights.
+mode_override: A string input that overrides the mode specified in the
+ TPUEmbeddingConfiguration. Supported values are {'unspecified', 'inference',
+ 'training', 'backward_pass_only'}. When set to 'unspecified', the mode set
+ in TPUEmbeddingConfiguration is used, otherwise mode_override is used.
device_ordinal: The TPU device to use. Should be >= 0 and less than the number
of TPU cores in the task on which the node is placed.
combiners: A list of string scalars, one for each embedding table that specify
@@ -515,7 +553,8 @@ REGISTER_OP("EnqueueTPUEmbeddingSparseTensorBatch")
.Input("sample_indices: N * int32")
.Input("embedding_indices: N * int32")
.Input("aggregation_weights: N * float32")
- .Attr("N: int")
+ .Input("mode_override: string")
+ .Attr("N: int >= 1")
.Attr("device_ordinal: int = -1")
.Attr("combiners: list(string) = []")
.Attr("table_ids: list(int)")
@@ -525,7 +564,7 @@ REGISTER_OP("EnqueueTPUEmbeddingSparseTensorBatch")
This Op eases the porting of code that uses tf.nn.embedding_lookup_sparse().
sample_indices[i], embedding_indices[i] and aggregation_weights[i] correspond
-to ith feature. table_ids[i] indicates which embedding table to look up ith
+to the ith feature. table_ids[i] indicates which embedding table to look up ith
feature.
The tensors at corresponding positions in the three input lists (sample_indices,
@@ -533,12 +572,18 @@ embedding_indices and aggregation_weights) must have the same shape, i.e. rank 1
with dim_size() equal to the total number of lookups into the table described by
the corresponding feature.
-sample_indices: A list of Rank 1 Tensors, corresponds to sp_ids.indices[:,0] in
+sample_indices: A list of rank 1 Tensors specifying the training example to
+ which the corresponding embedding_indices and aggregation_weights values
+ belong. It corresponds to sp_ids.indices[:,0] in embedding_lookup_sparse().
+embedding_indices: A list of rank 1 Tensors, indices into the embedding tables.
+ It corresponds to sp_ids.values in embedding_lookup_sparse().
+aggregation_weights: A list of rank 1 Tensors containing per training example
+ aggregation weights. It corresponds to sp_weights.values in
embedding_lookup_sparse().
-embedding_indices: A list of Rank 1 Tensors, corresponds to sp_ids.values
- in embedding_lookup_sparse().
-aggregation_weights: A list of Rank 1 Tensors, corresponds to sp_weights.values
- in embedding_lookup_sparse().
+mode_override: A string input that overrides the mode specified in the
+ TPUEmbeddingConfiguration. Supported values are {'unspecified', 'inference',
+ 'training', 'backward_pass_only'}. When set to 'unspecified', the mode set
+ in TPUEmbeddingConfiguration is used, otherwise mode_override is used.
device_ordinal: The TPU device to use. Should be >= 0 and less than the number
of TPU cores in the task on which the node is placed.
combiners: A list of string scalars, one for each embedding table that specify
@@ -547,8 +592,11 @@ combiners: A list of string scalars, one for each embedding table that specify
the sum of the weights be 0 for 'mean' or the sum of the squared weights be
0 for 'sqrtn'. If combiners isn't passed, the default is to use 'sum' for
all tables.
-table_ids: A list of int. table_ids[i] indicates which embedding table to look
- up ith feature in the list.
+table_ids: A list of integers specifying the identifier of the embedding table
+ (offset of TableDescriptor in the TPUEmbeddingConfiguration) to lookup the
+ corresponding input. The ith input is looked up using table_ids[i]. The size
+ of the table_ids list must be equal to that of sample_indices,
+ embedding_indices and aggregation_weights.
)doc");
} // namespace tensorflow
diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
index 8e6e9aa0cd..1c5ea2d997 100644
--- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
+++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
@@ -237,7 +237,8 @@ void StartMonitoring(const tensorflow::string& service_addr, int duration_ms,
MonitorResponse response;
TF_QCHECK_OK(FromGrpcStatus(stub->Monitor(&context, request, &response)));
- std::cout << "Xprof Monitoring Results (Sample " << query + 1 << "):\n\n"
+ std::cout << "Cloud TPU Monitoring Results (Sample " << query + 1
+ << "):\n\n"
<< response.data() << std::flush;
}
}
diff --git a/tensorflow/contrib/tpu/profiler/pip_package/setup.py b/tensorflow/contrib/tpu/profiler/pip_package/setup.py
index 2415c46718..f27ae38e04 100644
--- a/tensorflow/contrib/tpu/profiler/pip_package/setup.py
+++ b/tensorflow/contrib/tpu/profiler/pip_package/setup.py
@@ -20,7 +20,7 @@ from __future__ import print_function
from setuptools import setup
-_VERSION = '1.11.0'
+_VERSION = '1.12.0'
CONSOLE_SCRIPTS = [
'capture_tpu_profile=cloud_tpu_profiler.main:run_main',
diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
index f88dc51636..1e66801efd 100644
--- a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
+++ b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
@@ -168,6 +168,12 @@ message RunEnvironmentResult {
optional HostIndependentJobInfoResult host_independent_job_info = 5;
// Host-dependent job information.
repeated HostDependentJobInfoResult host_dependent_job_info = 6;
+ // The number of replicas, corresponds to input parallelism.
+ // If there is no model parallelism, replica_count = tpu_core_count
+ optional int32 replica_count = 7;
+ // The number of cores used for a single replica, e.g. model parallelism.
+ // If there is no model parallelism, then num_cores_per_replica = 1
+ optional int32 num_cores_per_replica = 8;
}
// The types of host operations that are tracked.
diff --git a/tensorflow/contrib/tpu/profiler/version.h b/tensorflow/contrib/tpu/profiler/version.h
index 90d34b5ef1..4b6d1b2b07 100644
--- a/tensorflow/contrib/tpu/profiler/version.h
+++ b/tensorflow/contrib/tpu/profiler/version.h
@@ -16,6 +16,6 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_
#define TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_
-#define TPU_PROFILER_VERSION "1.11.0"
+#define TPU_PROFILER_VERSION "1.12.0"
#endif // TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_
diff --git a/tensorflow/contrib/tpu/proto/optimization_parameters.proto b/tensorflow/contrib/tpu/proto/optimization_parameters.proto
index a43f45554f..8529b48c15 100644
--- a/tensorflow/contrib/tpu/proto/optimization_parameters.proto
+++ b/tensorflow/contrib/tpu/proto/optimization_parameters.proto
@@ -62,7 +62,10 @@ message FtrlParameters {
// (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/AdamOptimizer). If
// use_non_lazy_adam is enabled, use_gradient_accumulation is also required in
// order to get correct results; a warning will be printed otherwise (which may
-// change to an error in the future).
+// change to an error in the future). If use_max_with_epsilon is set, the Adam
+// variable update formula will be changed from m / (sqrt(v) + epsilon) to
+// m / max(sqrt(v), abs(epsilon)); this option improves the performance of TPU
+// training and is not expected to harm model quality.
message AdamParameters {
float beta1 = 3;
float beta2 = 4;
@@ -70,6 +73,7 @@ message AdamParameters {
float initial_m = 6;
float initial_v = 7;
bool use_non_lazy_adam = 8;
+ bool use_max_with_epsilon = 9;
}
// https://www.tensorflow.org/api_docs/python/tf/train/MomentumOptimizer
diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
index e2e4acadab..968adccf2b 100644
--- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py
+++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
@@ -227,6 +227,154 @@ if platform.system() != "Windows":
inputs=inputs, learning_rates=learning_rates, config=config, name=name)
+ send_tpu_embedding_gradients.__doc__ = (
+ gen_tpu_ops._send_tpu_embedding_gradients.__doc__)
+
+ # pylint: disable=protected-access
+ def enqueue_tpu_embedding_integer_batch(batch,
+ device_ordinal,
+ mode_override=None,
+ name=None):
+ """A placeholder op for enqueueing embedding IDs to the TPU.
+
+ Args:
+ batch: A list of 1D tensors, one for each embedding table, containing the
+ indices into the tables.
+ device_ordinal: The TPU device to use. Should be >= 0 and less than the
+ number of TPU cores in the task on which the node is placed.
+ mode_override: A string input that overrides the mode specified in the
+ TPUEmbeddingConfiguration. Supported values are {'unspecified',
+ 'inference', 'training', 'backward_pass_only'}. When set to
+ 'unspecified', the mode set in TPUEmbeddingConfiguration is used,
+ otherwise mode_override is used (optional).
+ name: A name for the operation (optional).
+
+ Returns:
+ An EnqueueTPUEmbeddingIntegerBatch operation.
+ """
+ if mode_override is None:
+ mode_override = "unspecified"
+ return gen_tpu_ops._enqueue_tpu_embedding_integer_batch(
+ batch=batch,
+ device_ordinal=device_ordinal,
+ mode_override=mode_override,
+ name=name)
+
+ enqueue_tpu_embedding_integer_batch.__doc__ = (
+ gen_tpu_ops._enqueue_tpu_embedding_integer_batch.__doc__)
+
+ # pylint: disable=protected-access
+ def enqueue_tpu_embedding_sparse_batch(sample_indices,
+ embedding_indices,
+ aggregation_weights,
+ device_ordinal,
+ combiners=None,
+ mode_override=None,
+ name=None):
+ """A placeholder op for enqueueing embedding IDs to the TPU.
+
+ Args:
+ sample_indices: A list of rank 1 Tensors specifying the training example
+ and feature to which the corresponding embedding_indices and
+ aggregation_weights values belong. sample_indices[i] must equal b * nf +
+ f, where nf is the number of features from the corresponding table, f is
+ in [0, nf), and b is in [0, batch size).
+ embedding_indices: A list of rank 1 Tensors, indices into the embedding
+ tables.
+ aggregation_weights: A list of rank 1 Tensors containing per sample --
+ i.e. per (training example, feature) -- aggregation weights.
+ device_ordinal: The TPU device to use. Should be >= 0 and less than the
+ number of TPU cores in the task on which the node is placed.
+ combiners: A list of string scalars, one for each embedding table that
+ specify how to normalize the embedding activations after weighted
+ summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
+ invalid to have the sum of the weights be 0 for 'mean' or the sum of the
+ squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
+ is to use 'sum' for all tables (optional).
+ mode_override: A string input that overrides the mode specified in the
+ TPUEmbeddingConfiguration. Supported values are {'unspecified',
+ 'inference', 'training', 'backward_pass_only'}. When set to
+ 'unspecified', the mode set in TPUEmbeddingConfiguration is used,
+ otherwise mode_override is used (optional).
+ name: A name for the operation (optional).
+
+ Returns:
+ An EnqueueTPUEmbeddingSparseBatch operation.
+ """
+ if mode_override is None:
+ mode_override = "unspecified"
+ return gen_tpu_ops._enqueue_tpu_embedding_sparse_batch(
+ sample_indices=sample_indices,
+ embedding_indices=embedding_indices,
+ aggregation_weights=aggregation_weights,
+ device_ordinal=device_ordinal,
+ combiners=combiners,
+ mode_override=mode_override,
+ name=name)
+
+ enqueue_tpu_embedding_sparse_batch.__doc__ = (
+ gen_tpu_ops._enqueue_tpu_embedding_sparse_batch.__doc__)
+
+ # pylint: disable=protected-access
+ def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
+ embedding_indices,
+ aggregation_weights,
+ table_ids,
+ device_ordinal,
+ combiners=None,
+ mode_override=None,
+ name=None):
+ """A placeholder op for enqueueing embedding IDs to the TPU.
+
+ Args:
+ sample_indices: A list of rank 1 Tensors specifying the training example
+ to which the corresponding embedding_indices and aggregation_weights
+ values
+ belong. It corresponds to sp_ids.indices[:,0] in
+ embedding_lookup_sparse().
+ embedding_indices: A list of rank 1 Tensors, indices into the embedding
+ tables. It corresponds to sp_ids.values in embedding_lookup_sparse().
+ aggregation_weights: A list of rank 1 Tensors containing per training
+ example aggregation weights. It corresponds to sp_weights.values in
+ embedding_lookup_sparse().
+ table_ids: A list of integers specifying the identifier of the embedding
+ table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
+ lookup the corresponding input. The ith input is looked up using
+ table_ids[i]. The size of the table_ids list must be equal to that of
+ sample_indices, embedding_indices and aggregation_weights.
+ device_ordinal: The TPU device to use. Should be >= 0 and less than the
+ number of TPU cores in the task on which the node is placed.
+ combiners: A list of string scalars, one for each embedding table that
+ specify how to normalize the embedding activations after weighted
+ summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
+ invalid to have the sum of the weights be 0 for 'mean' or the sum of the
+ squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
+ is to use 'sum' for all tables (optional).
+ mode_override: A string input that overrides the mode specified in the
+ TPUEmbeddingConfiguration. Supported values are {'unspecified',
+ 'inference', 'training', 'backward_pass_only'}. When set to
+ 'unspecified', the mode set in TPUEmbeddingConfiguration is used,
+ otherwise mode_override is used (optional).
+ name: A name for the operation (optional).
+
+ Returns:
+ An EnqueueTPUEmbeddingSparseTensorBatch operation.
+ """
+ if mode_override is None:
+ mode_override = "unspecified"
+ return gen_tpu_ops._enqueue_tpu_embedding_sparse_tensor_batch(
+ sample_indices=sample_indices,
+ embedding_indices=embedding_indices,
+ aggregation_weights=aggregation_weights,
+ table_ids=table_ids,
+ device_ordinal=device_ordinal,
+ combiners=combiners,
+ mode_override=mode_override,
+ name=name)
+
+ enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = (
+ gen_tpu_ops._enqueue_tpu_embedding_sparse_tensor_batch.__doc__)
+
else:
# We have already built the appropriate libraries into the binary via CMake
# if we have built contrib, so we don't need this
diff --git a/tensorflow/contrib/tpu/python/tpu/datasets.py b/tensorflow/contrib/tpu/python/tpu/datasets.py
index d879170b68..c694e9c1bc 100644
--- a/tensorflow/contrib/tpu/python/tpu/datasets.py
+++ b/tensorflow/contrib/tpu/python/tpu/datasets.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import readers
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index 696656e840..af183b3232 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -46,6 +46,7 @@ from __future__ import print_function
import abc
import collections
+import contextlib
import re
import sys
import time
@@ -94,21 +95,56 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
+# TODO(b/114775106): temporary shim to optionally initialize the TPU
+# This increases the odds our session is initialized, but shouldn't be needed.
+def _maybe_initialize_tpu(session):
+ """Initialize the TPU if it has not already been initialized."""
+ try:
+
+ def test_op():
+ return constant_op.constant(1) + constant_op.constant(1)
+
+ session.run(tpu.rewrite(test_op))
+ except errors.FailedPreconditionError as _:
+ session.run(tpu.initialize_system())
+
+
+@contextlib.contextmanager
+def _tpu_session_context():
+ """Initialize the TPU and cleans cache entries for bad sessions."""
+ try:
+ _maybe_initialize_tpu(K.get_session())
+ yield
+ except (errors.FailedPreconditionError, errors.AbortedError) as e:
+ K.clear_session()
+ raise Exception("""
+An error occurred connecting or initializing your TPU.
+
+The session has been reset. re-run keras_to_tpu_model to create a new session.
+""" + e)
+
+
def setup_tpu_session(cluster_resolver):
"""Construct or return a `tf.Session` connected to the given cluster."""
master = cluster_resolver.master()
# Use the existing session if we're already connected to this TPU
- if (K.get_session()._target == master and
- getattr(K.get_session(), '_tpu_initialized', None)):
- return
+ # N.B K.get_session() is a non-trivial operation, and may fail if the remote
+ # session has been reset.
+ try:
+ default_session = K.get_session()
+ if (default_session._target == master and
+ getattr(default_session, '_tpu_initialized', None)):
+ return
+ except errors.AbortedError as _:
+ # We lost the remote session and need to re-initialize.
+ logging.warning('Lost remote session: creating a new session.')
cluster_spec = cluster_resolver.cluster_spec()
config = config_pb2.ConfigProto(isolate_session_state=True)
if cluster_spec:
config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
- logging.info('Initialize')
tpu_session = tf_session.Session(target=master, config=config)
tpu_session.run(tpu.initialize_system())
tpu_session._tpu_initialized = True
@@ -1391,97 +1427,74 @@ class KerasTPUModel(models.Model):
raise EnvironmentError('KerasTPUModel currently does not support eager '
'mode.')
- assert not self._numpy_to_infeed_manager_list # Ensure empty.
-
- infeed_managers = [] # Managers to clean up at the end of the fit call.
- if isinstance(x, dataset_ops.Dataset):
- # TODO(b/111413240): Support taking a tf.data.Dataset directly.
- raise ValueError(
- 'Taking a Dataset directly is not yet supported. Please '
- 'wrap your dataset construction code in a function and '
- 'pass that to fit instead. For examples, see: '
- 'https://github.com/tensorflow/tpu/tree/master/models/experimental'
- '/keras')
- if callable(x):
- with ops.device('/job:%s/device:CPU:0' %
- self._tpu_assignment.worker_name):
- dataset = x()
- if steps_per_epoch is None:
- raise ValueError('When using tf.data as input to a model, you '
- 'should specify the steps_per_epoch argument.')
- if y is not None:
- raise ValueError('When using tf.data as input to a model, y must be '
- 'None')
- infeed_manager = TPUDatasetInfeedManager(
- dataset, self._tpu_assignment, model_fn_lib.ModeKeys.TRAIN)
+ with _tpu_session_context():
+ assert not self._numpy_to_infeed_manager_list # Ensure empty.
+
+ infeed_managers = [] # Managers to clean up at the end of the fit call.
+ if isinstance(x, dataset_ops.Dataset):
+ # TODO(b/111413240): Support taking a tf.data.Dataset directly.
+ raise ValueError(
+ 'Taking a Dataset directly is not yet supported. Please '
+ 'wrap your dataset construction code in a function and '
+ 'pass that to fit instead. For examples, see: '
+ 'https://github.com/tensorflow/tpu/tree/master/models/experimental'
+ '/keras')
+ if callable(x):
+ with ops.device(
+ '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
+ dataset = x()
+ if steps_per_epoch is None:
+ raise ValueError('When using tf.data as input to a model, you '
+ 'should specify the steps_per_epoch argument.')
+ if y is not None:
+ raise ValueError('When using tf.data as input to a model, y must '
+ 'be None')
+ infeed_manager = TPUDatasetInfeedManager(
+ dataset, self._tpu_assignment, model_fn_lib.ModeKeys.TRAIN)
+ # Use dummy numpy inputs for the rest of Keras' shape checking. We
+ # intercept them when building the model.
+ x = infeed_manager.dummy_x
+ y = infeed_manager.dummy_y
+ infeed_managers.append((x, infeed_manager))
+
+ if isinstance(validation_data, dataset_ops.Dataset):
+ # TODO(b/111413240): Support taking a tf.data.Dataset directly.
+ raise ValueError(
+ 'Taking a Dataset directly is not yet supported. Please '
+ 'wrap your dataset construction code in a function and '
+ 'pass that to fit instead. For examples, see: '
+ 'https://github.com/tensorflow/tpu/tree/master/models/experimental'
+ '/keras')
+ if callable(validation_data):
+ dataset = validation_data()
+ if validation_steps is None:
+ raise ValueError('When using tf.data as validation for a model, you '
+ 'should specify the validation_steps argument.')
+ infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
+ model_fn_lib.ModeKeys.EVAL)
# Use dummy numpy inputs for the rest of Keras' shape checking. We
# intercept them when building the model.
- x = infeed_manager.dummy_x
- y = infeed_manager.dummy_y
- infeed_managers.append((x, infeed_manager))
+ val_x = infeed_manager.dummy_x
+ val_y = infeed_manager.dummy_y
+ infeed_managers.append((val_x, infeed_manager))
+ validation_data = (val_x, val_y)
- if isinstance(validation_data, dataset_ops.Dataset):
- # TODO(b/111413240): Support taking a tf.data.Dataset directly.
- raise ValueError(
- 'Taking a Dataset directly is not yet supported. Please '
- 'wrap your dataset construction code in a function and '
- 'pass that to fit instead. For examples, see: '
- 'https://github.com/tensorflow/tpu/tree/master/models/experimental'
- '/keras')
- if callable(validation_data):
- dataset = validation_data()
- if validation_steps is None:
- raise ValueError('When using tf.data as validation for a model, you '
- 'should specify the validation_steps argument.')
- infeed_manager = TPUDatasetInfeedManager(
- dataset, self._tpu_assignment, model_fn_lib.ModeKeys.EVAL)
- # Use dummy numpy inputs for the rest of Keras' shape checking. We
- # intercept them when building the model.
- val_x = infeed_manager.dummy_x
- val_y = infeed_manager.dummy_y
- infeed_managers.append((val_x, infeed_manager))
- validation_data = (val_x, val_y)
-
- self._numpy_to_infeed_manager_list = infeed_managers
- try:
- if not kwargs.get('_pipeline', True):
- logging.info('Running non-pipelined training loop (`_pipeline=%s`).',
- kwargs['_pipeline'])
- kwargs.pop('_pipeline')
- return super(KerasTPUModel, self).fit(
- x,
- y,
- batch_size,
- epochs,
- verbose,
- callbacks,
- validation_split,
- validation_data,
- shuffle,
- class_weight,
- sample_weight,
- initial_epoch,
- steps_per_epoch,
- validation_steps,
- **kwargs)
- return self._pipeline_fit(
- x,
- y,
- batch_size,
- epochs,
- verbose,
- callbacks,
- validation_split,
- validation_data,
- shuffle,
- class_weight,
- sample_weight,
- initial_epoch,
- steps_per_epoch,
- validation_steps,
- **kwargs)
- finally:
- self._numpy_to_infeed_manager_list = []
+ self._numpy_to_infeed_manager_list = infeed_managers
+ try:
+ if not kwargs.get('_pipeline', True):
+ logging.info('Running non-pipelined training loop (`_pipeline=%s`).',
+ kwargs['_pipeline'])
+ kwargs.pop('_pipeline')
+ return super(KerasTPUModel, self).fit(
+ x, y, batch_size, epochs, verbose, callbacks, validation_split,
+ validation_data, shuffle, class_weight, sample_weight,
+ initial_epoch, steps_per_epoch, validation_steps, **kwargs)
+ return self._pipeline_fit(x, y, batch_size, epochs, verbose, callbacks,
+ validation_split, validation_data, shuffle,
+ class_weight, sample_weight, initial_epoch,
+ steps_per_epoch, validation_steps, **kwargs)
+ finally:
+ self._numpy_to_infeed_manager_list = []
def evaluate(self,
x=None,
@@ -1492,37 +1505,38 @@ class KerasTPUModel(models.Model):
steps=None):
assert not self._numpy_to_infeed_manager_list # Ensure empty.
- infeed_managers = [] # Managers to clean up at the end of the fit call.
- if isinstance(x, dataset_ops.Dataset):
- # TODO(b/111413240): Support taking a tf.data.Dataset directly.
- raise ValueError(
- 'Taking a Dataset directly is not yet supported. Please '
- 'wrap your dataset construction code in a function and '
- 'pass that to fit instead. For examples, see: '
- 'https://github.com/tensorflow/tpu/tree/master/models/experimental'
- '/keras')
- if callable(x):
- dataset = x()
- if steps is None:
- raise ValueError('When using tf.data as input to a model, you '
- 'should specify the steps argument.')
- if y is not None:
- raise ValueError('When using tf.data as input to a model, y must be '
- 'None')
- infeed_manager = TPUDatasetInfeedManager(
- dataset, self._tpu_assignment, model_fn_lib.ModeKeys.EVAL)
- # Use dummy numpy inputs for the rest of Keras' shape checking. We
- # intercept them when building the model.
- x = infeed_manager.dummy_x
- y = infeed_manager.dummy_y
- infeed_managers.append((x, infeed_manager))
-
- self._numpy_to_infeed_manager_list = infeed_managers
- try:
- return super(KerasTPUModel, self).evaluate(x, y, batch_size, verbose,
- sample_weight, steps)
- finally:
- self._numpy_to_infeed_manager_list = []
+ with _tpu_session_context():
+ infeed_managers = [] # Managers to clean up at the end of the fit call.
+ if isinstance(x, dataset_ops.Dataset):
+ # TODO(b/111413240): Support taking a tf.data.Dataset directly.
+ raise ValueError(
+ 'Taking a Dataset directly is not yet supported. Please '
+ 'wrap your dataset construction code in a function and '
+ 'pass that to fit instead. For examples, see: '
+ 'https://github.com/tensorflow/tpu/tree/master/models/experimental'
+ '/keras')
+ if callable(x):
+ dataset = x()
+ if steps is None:
+ raise ValueError('When using tf.data as input to a model, you '
+ 'should specify the steps argument.')
+ if y is not None:
+ raise ValueError('When using tf.data as input to a model, y must be '
+ 'None')
+ infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
+ model_fn_lib.ModeKeys.EVAL)
+ # Use dummy numpy inputs for the rest of Keras' shape checking. We
+ # intercept them when building the model.
+ x = infeed_manager.dummy_x
+ y = infeed_manager.dummy_y
+ infeed_managers.append((x, infeed_manager))
+
+ self._numpy_to_infeed_manager_list = infeed_managers
+ try:
+ return super(KerasTPUModel, self).evaluate(x, y, batch_size, verbose,
+ sample_weight, steps)
+ finally:
+ self._numpy_to_infeed_manager_list = []
def _pipeline_fit(self, x, y, batch_size, epochs, verbose, callbacks,
validation_split, validation_data, shuffle, class_weight,
@@ -1910,6 +1924,24 @@ class KerasTPUModel(models.Model):
return val_x, val_y, val_sample_weights
+ def predict(self,
+ x,
+ batch_size=None,
+ verbose=0,
+ steps=None,
+ max_queue_size=10,
+ workers=1,
+ use_multiprocessing=False):
+ with _tpu_session_context():
+ return super(KerasTPUModel, self).predict(
+ x,
+ batch_size=batch_size,
+ verbose=verbose,
+ steps=steps,
+ max_queue_size=max_queue_size,
+ workers=workers,
+ use_multiprocessing=use_multiprocessing)
+
@property
def optimizer(self):
if self._tpu_model:
@@ -1966,6 +1998,9 @@ class KerasTPUModel(models.Model):
logging.info('Setting weights on TPU model.')
cloned_model.set_weights(weights)
+ if self._tpu_model.optimizer is None:
+ # tpu_model may not be compiled, e.g., loading weights and then predict.
+ return
for k, v in six.iteritems(cpu_optimizer_config):
opt_var = getattr(self._tpu_model.optimizer, k)
if isinstance(opt_var, variables.Variable):
@@ -2020,6 +2055,10 @@ class KerasTPUModel(models.Model):
self._cpu_model.set_weights(weights)
self._tpu_weights_initialized = False
+ def load_weights(self, filepath, by_name=False):
+ self._cpu_model.load_weights(filepath, by_name)
+ self._tpu_weights_initialized = False
+
# pylint: disable=bad-continuation
def _validate_shapes(model):
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
index 7cfb6c38fa..da6bdf67d6 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
@@ -154,6 +154,20 @@ class TPUContext(object):
# as far as model is replicated to all cores in the system.
return self._internal_ctx.device_for_replica(replica_id)
+ @property
+ def tpu_host_placement_function(self):
+ """Returns the TPU host place function.
+
+ The place function takes host_id as the input and returns the TF device
+ for the correspoding host.
+ """
+
+ def _placement_function(host_id):
+ """Return the host device given host_id."""
+ return self._internal_ctx.tpu_host_placement_function(host_id=host_id)
+
+ return _placement_function
+
class _InternalTPUContext(object):
"""A context holds immutable states of TPU computation.
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 3aa5b6efa1..8d15c857f8 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -177,14 +177,29 @@ def _create_or_get_iterations_per_loop():
use_resource=True)
-def _sync_variables_ops():
- # Gets the variables back from TPU nodes. This means the variables updated
- # by TPU will now be *synced* to host memory.
- return [
- array_ops.check_numerics(v.read_value(),
- 'Gradient for %s is NaN' % v.name).op
- for v in variables.trainable_variables()
- ]
+def _sync_variables_ops(ctx):
+ """Create varriables synchronization ops.
+
+ Gets the variables back from TPU nodes. This means the variables updated
+ by TPU will now be *synced* to host memory.
+ In BROADCAST mode, we skip this sync since the variables are ususally too
+ big to transmit via RPC.
+
+ Args:
+ ctx: A `_InternalTPUContext` instance with mode.
+
+ Returns:
+ A list of sync ops.
+ """
+
+ if not ctx.is_input_broadcast_with_iterators():
+ return [
+ array_ops.check_numerics(v.read_value(),
+ 'Gradient for %s is NaN' % v.name).op
+ for v in variables.trainable_variables()
+ ]
+ else:
+ return [control_flow_ops.no_op()]
def _increase_eval_step_op(iterations_per_loop):
@@ -2567,7 +2582,7 @@ class TPUEstimator(estimator_lib.Estimator):
summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss)
with ops.control_dependencies([loss]):
- update_ops = _sync_variables_ops()
+ update_ops = _sync_variables_ops(ctx)
# Validate the TPU training graph to catch basic errors
_validate_tpu_training_graph()
@@ -2600,7 +2615,7 @@ class TPUEstimator(estimator_lib.Estimator):
# After TPU evaluation computation is done (the mean_loss tensor),
# reads all variables back from TPU and updates the eval step
# counter properly
- internal_ops_to_run = _sync_variables_ops()
+ internal_ops_to_run = _sync_variables_ops(ctx)
internal_ops_to_run.append(
_increase_eval_step_op(iterations_per_loop_var))
with ops.control_dependencies(internal_ops_to_run):
@@ -2645,7 +2660,7 @@ class TPUEstimator(estimator_lib.Estimator):
scaffold, prediction_hooks) = _predict_on_tpu_system(
ctx, model_fn_wrapper, dequeue_fn)
with ops.control_dependencies([dummy_predict_op]):
- internal_ops_to_run = _sync_variables_ops()
+ internal_ops_to_run = _sync_variables_ops(ctx)
with ops.control_dependencies(internal_ops_to_run):
dummy_predict_op = control_flow_ops.no_op()
diff --git a/tensorflow/contrib/tpu/tpu_estimator.md b/tensorflow/contrib/tpu/tpu_estimator.md
index 639e708169..b6514e19dc 100644
--- a/tensorflow/contrib/tpu/tpu_estimator.md
+++ b/tensorflow/contrib/tpu/tpu_estimator.md
@@ -87,7 +87,7 @@ handle training:
label = tf.cast(features["label"], tf.int32)
return image, label
- dataset = tf.contrib.data.TFRecordDataset(
+ dataset = tf.data.TFRecordDataset(
filename, buffer_size=FLAGS.dataset_reader_buffer_size)
dataset = dataset.map(parser).cache().repeat().batch(batch_size)
images, labels = dataset.make_one_shot_iterator().get_next()
diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD
index b565ebd073..00295f57f6 100644
--- a/tensorflow/contrib/training/BUILD
+++ b/tensorflow/contrib/training/BUILD
@@ -295,7 +295,6 @@ py_test(
tags = ["notsan"],
deps = [
":training_py",
- "//tensorflow/contrib/data/python/kernel_tests/serialization:dataset_serialization_test_base",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:gradients",
@@ -305,6 +304,7 @@ py_test(
"//tensorflow/python:training",
"//tensorflow/python:variables",
"//tensorflow/python/data",
+ "//tensorflow/python/data/experimental/kernel_tests/serialization:dataset_serialization_test_base",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py
index d9b0511a98..c1657fec7b 100644
--- a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py
+++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py
@@ -19,8 +19,8 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.contrib.training.python.training import tensor_queue_dataset as tqd
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors