aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--configure.py21
-rw-r--r--tensorflow/c/eager/c_api_test_util.cc2
-rw-r--r--tensorflow/cc/BUILD1
-rw-r--r--tensorflow/cc/gradients/nn_grad.cc27
-rw-r--r--tensorflow/cc/gradients/nn_grad_test.cc27
-rw-r--r--tensorflow/compiler/jit/deadness_analysis.cc77
-rw-r--r--tensorflow/compiler/jit/xla_device_context.cc15
-rw-r--r--tensorflow/compiler/jit/xla_device_context.h3
-rw-r--r--tensorflow/compiler/tests/BUILD29
-rw-r--r--tensorflow/compiler/tests/nullary_ops_test.py43
-rw-r--r--tensorflow/compiler/tests/permute_test.py80
-rw-r--r--tensorflow/compiler/tests/randomized_tests.cc2
-rw-r--r--tensorflow/compiler/tests/sort_ops_test.py18
-rw-r--r--tensorflow/compiler/tests/tensor_list_ops_test.py96
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc157
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/binary_ops.cc10
-rw-r--r--tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc63
-rw-r--r--tensorflow/compiler/tf2xla/kernels/const_op.cc12
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cwise_ops.cc57
-rw-r--r--tensorflow/compiler/tf2xla/kernels/cwise_ops.h3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/permute_op.cc98
-rw-r--r--tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc21
-rw-r--r--tensorflow/compiler/tf2xla/kernels/scan_ops.cc3
-rw-r--r--tensorflow/compiler/tf2xla/kernels/sort_ops.cc17
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc226
-rw-r--r--tensorflow/compiler/tf2xla/lib/BUILD16
-rw-r--r--tensorflow/compiler/tf2xla/lib/broadcast.cc93
-rw-r--r--tensorflow/compiler/tf2xla/lib/broadcast.h32
-rw-r--r--tensorflow/compiler/tf2xla/lib/scatter.cc213
-rw-r--r--tensorflow/compiler/tf2xla/lib/scatter.h6
-rw-r--r--tensorflow/compiler/tf2xla/ops/xla_ops.cc25
-rw-r--r--tensorflow/compiler/tf2xla/python/xla.py18
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc11
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc40
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h5
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc18
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h6
-rw-r--r--tensorflow/compiler/xla/literal.cc38
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.cc5
-rw-r--r--tensorflow/compiler/xla/python/local_computation_builder.h2
-rw-r--r--tensorflow/compiler/xla/python/xla_client.py25
-rw-r--r--tensorflow/compiler/xla/service/BUILD79
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc6
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc13
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc27
-rw-r--r--tensorflow/compiler/xla/service/fusion_queue.h53
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.cc2
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.h7
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD12
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc40
-rw-r--r--tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h11
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc26
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc11
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc23
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_map.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_map.h14
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc52
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h13
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler.cc39
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.h8
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc15
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc487
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h25
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier_test.cc67
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc1
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.h28
-rw-r--r--tensorflow/compiler/xla/service/interpreter/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/interpreter/compiler.cc2
-rw-r--r--tensorflow/compiler/xla/service/map_inliner.cc (renamed from tensorflow/compiler/xla/service/inliner.cc)51
-rw-r--r--tensorflow/compiler/xla/service/map_inliner.h (renamed from tensorflow/compiler/xla/service/inliner.h)22
-rw-r--r--tensorflow/compiler/xla/service/map_inliner_test.cc (renamed from tensorflow/compiler/xla/service/inliner_test.cc)46
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.h16
-rw-r--r--tensorflow/compiler/xla/shape_util.cc7
-rw-r--r--tensorflow/compiler/xla/tests/convolution_test.cc9
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.cc14
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.h8
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc12
-rw-r--r--tensorflow/contrib/BUILD1
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py1
-rw-r--r--tensorflow/contrib/cmake/CMakeLists.txt14
-rw-r--r--tensorflow/contrib/cmake/README.md345
-rw-r--r--tensorflow/contrib/distribute/python/BUILD18
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py16
-rw-r--r--tensorflow/contrib/distribute/python/examples/keras_mnist.py2
-rw-r--r--tensorflow/contrib/distribute/python/minimize_loss_test.py5
-rw-r--r--tensorflow/contrib/distribute/python/moving_averages_test.py141
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy_test.py4
-rw-r--r--tensorflow/contrib/distribute/python/values.py4
-rw-r--r--tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb2
-rw-r--r--tensorflow/contrib/estimator/python/estimator/boosted_trees.py6
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py2
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py2
-rw-r--r--tensorflow/contrib/estimator/python/estimator/rnn.py54
-rw-r--r--tensorflow/contrib/feature_column/BUILD21
-rw-r--r--tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py72
-rw-r--r--tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py280
-rw-r--r--tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py912
-rw-r--r--tensorflow/contrib/gdr/gdr_memory_manager.cc156
-rw-r--r--tensorflow/contrib/lite/BUILD26
-rw-r--r--tensorflow/contrib/lite/c/builtin_op_data.h7
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc15
-rw-r--r--tensorflow/contrib/lite/delegates/flex/BUILD4
-rw-r--r--tensorflow/contrib/lite/delegates/flex/delegate.cc9
-rw-r--r--tensorflow/contrib/lite/experimental/micro/BUILD76
-rw-r--r--tensorflow/contrib/lite/experimental/micro/README.md114
-rw-r--r--tensorflow/contrib/lite/experimental/micro/compatibility.h32
-rw-r--r--tensorflow/contrib/lite/experimental/micro/examples/micro_speech/BUILD31
-rw-r--r--tensorflow/contrib/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc55
-rw-r--r--tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc1672
-rw-r--r--tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h27
-rw-r--r--tensorflow/contrib/lite/experimental/micro/kernels/BUILD107
-rw-r--r--tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.cc43
-rw-r--r--tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h34
-rw-r--r--tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv.cc208
-rw-r--r--tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test.cc406
-rw-r--r--tensorflow/contrib/lite/experimental/micro/kernels/fully_connected.cc184
-rw-r--r--tensorflow/contrib/lite/experimental/micro/kernels/fully_connected_test.cc643
-rw-r--r--tensorflow/contrib/lite/experimental/micro/kernels/softmax.cc213
-rw-r--r--tensorflow/contrib/lite/experimental/micro/kernels/softmax_test.cc220
-rw-r--r--tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h170
-rw-r--r--tensorflow/contrib/lite/experimental/micro/micro_error_reporter.cc78
-rw-r--r--tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h34
-rw-r--r--tensorflow/contrib/lite/experimental/micro/micro_error_reporter_test.cc (renamed from tensorflow/compiler/xla/service/gpu/gpu_options.h)28
-rw-r--r--tensorflow/contrib/lite/experimental/micro/micro_interpreter.cc310
-rw-r--r--tensorflow/contrib/lite/experimental/micro/micro_interpreter.h71
-rw-r--r--tensorflow/contrib/lite/experimental/micro/micro_interpreter_test.cc197
-rw-r--r--tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.cc80
-rw-r--r--tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h46
-rw-r--r--tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver_test.cc83
-rw-r--r--tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.cc149
-rw-r--r--tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h51
-rw-r--r--tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator_test.cc144
-rw-r--r--tensorflow/contrib/lite/experimental/micro/testing/BUILD17
-rw-r--r--tensorflow/contrib/lite/experimental/micro/testing/Dockerfile.bluepill21
-rw-r--r--tensorflow/contrib/lite/experimental/micro/testing/bluepill.resc36
-rw-r--r--tensorflow/contrib/lite/experimental/micro/testing/micro_test.bzl67
-rw-r--r--tensorflow/contrib/lite/experimental/micro/testing/micro_test.h138
-rwxr-xr-xtensorflow/contrib/lite/experimental/micro/testing/test_bluepill_binary.sh54
-rwxr-xr-xtensorflow/contrib/lite/experimental/micro/testing/test_linux_binary.sh39
-rw-r--r--tensorflow/contrib/lite/experimental/micro/tools/make/Makefile166
-rwxr-xr-xtensorflow/contrib/lite/experimental/micro/tools/make/download_dependencies.sh73
-rw-r--r--tensorflow/contrib/lite/experimental/micro/tools/make/targets/bluepill_makefile.inc65
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/android_build.md18
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/index.md18
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md18
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md18
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md18
-rw-r--r--tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md18
-rw-r--r--tensorflow/contrib/lite/interpreter.h15
-rw-r--r--tensorflow/contrib/lite/interpreter_test.cc6
-rw-r--r--tensorflow/contrib/lite/java/BUILD95
-rw-r--r--tensorflow/contrib/lite/java/aar_with_jni.bzl5
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java20
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java46
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java14
-rw-r--r--tensorflow/contrib/lite/kernels/BUILD32
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc333
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc1
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons.cc16
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons_test.cc11
-rw-r--r--tensorflow/contrib/lite/kernels/internal/BUILD16
-rw-r--r--tensorflow/contrib/lite/kernels/internal/compatibility.h23
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.cc598
-rw-r--r--tensorflow/contrib/lite/kernels/internal/kernel_utils.h184
-rw-r--r--tensorflow/contrib/lite/kernels/internal/legacy_types.h (renamed from tensorflow/compiler/xla/service/gpu/gpu_options.cc)18
-rw-r--r--tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc2
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h7
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h5
-rw-r--r--tensorflow/contrib/lite/kernels/internal/types.h6
-rw-r--r--tensorflow/contrib/lite/kernels/lstm.cc300
-rw-r--r--tensorflow/contrib/lite/kernels/lstm_eval.cc912
-rw-r--r--tensorflow/contrib/lite/kernels/lstm_eval.h79
-rw-r--r--tensorflow/contrib/lite/kernels/sparse_output_fully_connected.cc235
-rw-r--r--tensorflow/contrib/lite/kernels/sparse_output_fully_connected_test.cc158
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc320
-rw-r--r--tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc11
-rw-r--r--tensorflow/contrib/lite/model.cc35
-rw-r--r--tensorflow/contrib/lite/model_flex_test.cc45
-rw-r--r--tensorflow/contrib/lite/model_test.cc22
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs8
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h162
-rw-r--r--tensorflow/contrib/lite/testdata/multi_add_flex.binbin0 -> 1052 bytes
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc7
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc7
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc7
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_pack.cc8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc7
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc7
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc7
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc101
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD13
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc140
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc37
-rw-r--r--tensorflow/contrib/lite/toco/model.h9
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.cc189
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export.h19
-rw-r--r--tensorflow/contrib/lite/toco/tflite/export_test.cc111
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc32
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.h6
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/BUILD24
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc12
-rw-r--r--tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h6
-rw-r--r--tensorflow/contrib/opt/python/training/shampoo.py16
-rw-r--r--tensorflow/contrib/opt/python/training/shampoo_test.py8
-rw-r--r--tensorflow/contrib/optimizer_v2/BUILD11
-rw-r--r--tensorflow/contrib/optimizer_v2/adadelta.py75
-rw-r--r--tensorflow/contrib/optimizer_v2/adagrad.py79
-rw-r--r--tensorflow/contrib/optimizer_v2/adagrad_test.py3
-rw-r--r--tensorflow/contrib/optimizer_v2/adam.py129
-rw-r--r--tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py68
-rw-r--r--tensorflow/contrib/optimizer_v2/gradient_descent.py40
-rw-r--r--tensorflow/contrib/optimizer_v2/momentum.py69
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2.py1205
-rw-r--r--tensorflow/contrib/optimizer_v2/rmsprop.py154
-rw-r--r--tensorflow/contrib/optimizer_v2/rmsprop_test.py7
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py65
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py129
-rw-r--r--tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py40
-rw-r--r--tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py48
-rw-r--r--tensorflow/contrib/sparsemax/python/ops/sparsemax.py27
-rw-r--r--tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py28
-rw-r--r--tensorflow/contrib/stateless/BUILD5
-rw-r--r--tensorflow/contrib/stateless/__init__.py9
-rw-r--r--tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py154
-rw-r--r--tensorflow/contrib/stateless/python/stateless_ops.py214
-rw-r--r--tensorflow/contrib/tensorrt/BUILD20
-rw-r--r--tensorflow/contrib/tpu/BUILD1
-rw-r--r--tensorflow/contrib/tpu/profiler/tf_op_stats.proto6
-rw-r--r--tensorflow/contrib/tpu/proto/optimization_parameters.proto6
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py7
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py37
-rw-r--r--tensorflow/core/BUILD1
-rw-r--r--tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt24
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt23
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt46
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Substr.pbtxt10
-rw-r--r--tensorflow/core/api_def/python_api/api_defTensorListPushBackBatch.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_EmptyTensorList.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Substr.pbtxt8
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListConcatLists.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListElementShape.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListFromTensor.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListGather.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListGetItem.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListLength.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListPopBack.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListPushBack.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListReserve.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListScatter.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListSetItem.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListStack.pbtxt4
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local.cc1
-rw-r--r--tensorflow/core/common_runtime/constant_folding.cc28
-rw-r--r--tensorflow/core/common_runtime/eager/attr_builder.cc16
-rw-r--r--tensorflow/core/common_runtime/eager/attr_builder.h6
-rw-r--r--tensorflow/core/common_runtime/eager/context.cc4
-rw-r--r--tensorflow/core/common_runtime/eager/context.h2
-rw-r--r--tensorflow/core/common_runtime/eager/execute.cc67
-rw-r--r--tensorflow/core/common_runtime/eval_const_tensor.cc18
-rw-r--r--tensorflow/core/common_runtime/executor.cc4
-rw-r--r--tensorflow/core/common_runtime/lower_if_op.cc9
-rw-r--r--tensorflow/core/common_runtime/shape_refiner.cc5
-rw-r--r--tensorflow/core/framework/shape_inference.cc9
-rw-r--r--tensorflow/core/framework/shape_inference.h9
-rw-r--r--tensorflow/core/graph/graph.cc13
-rw-r--r--tensorflow/core/graph/graph.h5
-rw-r--r--tensorflow/core/graph/node_builder.cc8
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc6
-rw-r--r--tensorflow/core/grappler/op_types.cc22
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc47
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc130
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD1
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization.cc3
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/BUILD10
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc16
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc16
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h19
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h15
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc11
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h44
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc283
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc254
-rw-r--r--tensorflow/core/grappler/optimizers/graph_optimizer.h21
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc75
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.h15
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer_test.cc62
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc162
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h4
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc76
-rw-r--r--tensorflow/core/kernels/BUILD7
-rw-r--r--tensorflow/core/kernels/data/BUILD14
-rw-r--r--tensorflow/core/kernels/data/concatenate_dataset_op.cc8
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.cc47
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.h20
-rw-r--r--tensorflow/core/kernels/data/dataset_utils_test.cc46
-rw-r--r--tensorflow/core/kernels/data/filter_dataset_op.cc162
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc187
-rw-r--r--tensorflow/core/kernels/data/map_dataset_op.cc62
-rw-r--r--tensorflow/core/kernels/data/map_defun_op.cc68
-rw-r--r--tensorflow/core/kernels/data/parallel_map_dataset_op.cc79
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.cc17
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.h2
-rw-r--r--tensorflow/core/kernels/data/parse_example_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/dequantize_op.cc2
-rw-r--r--tensorflow/core/kernels/partitioned_function_ops.cc7
-rw-r--r--tensorflow/core/kernels/random_op.cc34
-rw-r--r--tensorflow/core/kernels/relu_op.cc153
-rw-r--r--tensorflow/core/kernels/relu_op.h61
-rw-r--r--tensorflow/core/kernels/relu_op_functor.h30
-rw-r--r--tensorflow/core/kernels/relu_op_gpu.cu.cc18
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc6
-rw-r--r--tensorflow/core/kernels/stateless_random_ops.cc155
-rw-r--r--tensorflow/core/kernels/string_util.cc4
-rw-r--r--tensorflow/core/kernels/string_util.h44
-rw-r--r--tensorflow/core/kernels/substr_op.cc162
-rw-r--r--tensorflow/core/kernels/substr_op_test.cc100
-rw-r--r--tensorflow/core/kernels/unique_op.cc15
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt380
-rw-r--r--tensorflow/core/ops/dataset_ops.cc11
-rw-r--r--tensorflow/core/ops/math_ops.cc19
-rw-r--r--tensorflow/core/ops/math_ops_test.cc12
-rw-r--r--tensorflow/core/ops/nn_ops.cc15
-rw-r--r--tensorflow/core/ops/ops.pbtxt153
-rw-r--r--tensorflow/core/ops/resource_variable_ops.cc3
-rw-r--r--tensorflow/core/ops/stateless_random_grad.cc23
-rw-r--r--tensorflow/core/ops/stateless_random_ops.cc53
-rw-r--r--tensorflow/core/ops/string_ops.cc1
-rw-r--r--tensorflow/core/protobuf/rewriter_config.proto4
-rw-r--r--tensorflow/docs_src/performance/xla/operation_semantics.md2426
-rw-r--r--tensorflow/go/op/wrappers.go852
-rw-r--r--tensorflow/python/BUILD17
-rw-r--r--tensorflow/python/autograph/core/converter_testing.py8
-rw-r--r--tensorflow/python/autograph/operators/py_builtins.py1
-rw-r--r--tensorflow/python/autograph/pyct/inspect_utils.py38
-rw-r--r--tensorflow/python/autograph/pyct/inspect_utils_test.py19
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/live_values.py4
-rw-r--r--tensorflow/python/compat/compat.py2
-rw-r--r--tensorflow/python/data/experimental/benchmarks/BUILD25
-rw-r--r--tensorflow/python/data/experimental/benchmarks/map_benchmark.py (renamed from tensorflow/python/data/experimental/kernel_tests/map_dataset_op_test.py)114
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/BUILD545
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py686
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py322
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/bucketing_test.py824
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/copy_to_device_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/prefetching_ops_test.py)417
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/counter_test.py51
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/csv_dataset_op_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/dataset_serialization_test_base.py692
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/dense_to_sparse_batch_test.py124
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/enumerate_dataset_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/range_dataset_op_test.py)26
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/function_buffering_resource_test.py247
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/group_by_reducer_test.py199
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/group_by_window_test.py367
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/ignore_errors_test.py115
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py239
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/make_csv_dataset_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test.py)425
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/make_tf_record_dataset_test.py243
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py368
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py12
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/override_threadpool_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/threadpool_dataset_ops_test.py)6
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/parallel_interleave_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/interleave_dataset_op_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/parse_example_dataset_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/parsing_ops_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py234
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py6
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/resample_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/restructured_dataset_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/dataset_constructor_op_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/scan_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/scan_dataset_op_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/BUILD22
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/checkpoint_input_pipeline_hook_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/iterator_ops_test.py)0
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization_integration_test.py85
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/shuffle_and_repeat_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/shuffle_dataset_op_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/sql_dataset_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test.py)6
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/sql_dataset_test_base.py (renamed from tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test_base.py)3
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/tf_record_writer_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/writer_ops_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/unbatch_test.py300
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/unique_test.py (renamed from tensorflow/python/data/experimental/kernel_tests/unique_dataset_op_test.py)4
-rw-r--r--tensorflow/python/data/experimental/ops/map_defun.py8
-rw-r--r--tensorflow/python/data/kernel_tests/filter_dataset_op_test.py2
-rw-r--r--tensorflow/python/data/kernel_tests/map_dataset_op_test.py126
-rw-r--r--tensorflow/python/data/kernel_tests/test_base.py29
-rw-r--r--tensorflow/python/eager/BUILD1
-rw-r--r--tensorflow/python/eager/benchmarks_test.py4
-rw-r--r--tensorflow/python/eager/core_test.py28
-rw-r--r--tensorflow/python/eager/function.py196
-rw-r--r--tensorflow/python/eager/function_test.py99
-rw-r--r--tensorflow/python/eager/imperative_grad.py5
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc56
-rw-r--r--tensorflow/python/feature_column/feature_column.py53
-rw-r--r--tensorflow/python/framework/importer.py3
-rw-r--r--tensorflow/python/framework/op_def_library.py3
-rw-r--r--tensorflow/python/framework/ops.py28
-rw-r--r--tensorflow/python/framework/test_util.py15
-rwxr-xr-xtensorflow/python/keras/BUILD155
-rw-r--r--tensorflow/python/keras/activations.py5
-rw-r--r--tensorflow/python/keras/activations_test.py10
-rw-r--r--tensorflow/python/keras/backend.py84
-rw-r--r--tensorflow/python/keras/backend_test.py44
-rw-r--r--tensorflow/python/keras/callbacks.py4
-rw-r--r--tensorflow/python/keras/engine/input_layer.py1
-rw-r--r--tensorflow/python/keras/engine/network.py19
-rw-r--r--tensorflow/python/keras/engine/topology_test.py31
-rw-r--r--tensorflow/python/keras/engine/training.py6
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py4
-rw-r--r--tensorflow/python/keras/engine/training_test.py4
-rw-r--r--tensorflow/python/keras/layers/convolutional.py177
-rw-r--r--tensorflow/python/keras/layers/convolutional_test.py31
-rw-r--r--tensorflow/python/keras/layers/pooling.py185
-rw-r--r--tensorflow/python/keras/layers/pooling_test.py30
-rw-r--r--tensorflow/python/keras/layers/wrappers.py3
-rw-r--r--tensorflow/python/keras/metrics.py14
-rw-r--r--tensorflow/python/keras/metrics_test.py26
-rw-r--r--tensorflow/python/keras/optimizer_v2/adadelta.py116
-rw-r--r--tensorflow/python/keras/optimizer_v2/adadelta_test.py166
-rw-r--r--tensorflow/python/keras/optimizer_v2/adagrad.py119
-rw-r--r--tensorflow/python/keras/optimizer_v2/adagrad_test.py276
-rw-r--r--tensorflow/python/keras/optimizer_v2/adam.py203
-rw-r--r--tensorflow/python/keras/optimizer_v2/adam_test.py333
-rw-r--r--tensorflow/python/keras/optimizer_v2/checkpointable_utils_test.py761
-rw-r--r--tensorflow/python/keras/optimizer_v2/optimizer_v2.py1349
-rw-r--r--tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py277
-rw-r--r--tensorflow/python/keras/optimizer_v2/rmsprop.py239
-rw-r--r--tensorflow/python/keras/optimizer_v2/rmsprop_test.py444
-rw-r--r--tensorflow/python/keras/optimizer_v2/sgd.py170
-rw-r--r--tensorflow/python/keras/optimizer_v2/sgd_test.py759
-rw-r--r--tensorflow/python/keras/testing_utils.py5
-rw-r--r--tensorflow/python/keras/utils/conv_utils.py45
-rw-r--r--tensorflow/python/keras/utils/multi_gpu_utils.py17
-rw-r--r--tensorflow/python/keras/utils/multi_gpu_utils_test.py26
-rw-r--r--tensorflow/python/keras/utils/np_utils.py5
-rw-r--r--tensorflow/python/kernel_tests/BUILD4
-rw-r--r--tensorflow/python/kernel_tests/batch_gather_op_test.py13
-rw-r--r--tensorflow/python/kernel_tests/benchmark_test.py2
-rw-r--r--tensorflow/python/kernel_tests/bincount_op_test.py18
-rw-r--r--tensorflow/python/kernel_tests/cholesky_op_test.py7
-rw-r--r--tensorflow/python/kernel_tests/cond_v2_test.py50
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py14
-rw-r--r--tensorflow/python/kernel_tests/depthwise_conv_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/determinant_op_test.py9
-rw-r--r--tensorflow/python/kernel_tests/logging_ops_test.py13
-rw-r--r--tensorflow/python/kernel_tests/matrix_band_part_op_test.py5
-rw-r--r--tensorflow/python/kernel_tests/matrix_exponential_op_test.py5
-rw-r--r--tensorflow/python/kernel_tests/matrix_inverse_op_test.py5
-rw-r--r--tensorflow/python/kernel_tests/matrix_logarithm_op_test.py3
-rw-r--r--tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py5
-rw-r--r--tensorflow/python/kernel_tests/matrix_solve_op_test.py5
-rw-r--r--tensorflow/python/kernel_tests/relu_op_test.py120
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py9
-rw-r--r--tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py3
-rw-r--r--tensorflow/python/kernel_tests/substr_op_test.py503
-rw-r--r--tensorflow/python/kernel_tests/where_op_test.py5
-rw-r--r--tensorflow/python/ops/array_ops.py23
-rw-r--r--tensorflow/python/ops/cond_v2_impl.py48
-rw-r--r--tensorflow/python/ops/control_flow_ops.py3
-rw-r--r--tensorflow/python/ops/control_flow_ops_benchmark.py122
-rw-r--r--tensorflow/python/ops/custom_gradient.py44
-rw-r--r--tensorflow/python/ops/gradients.py1
-rw-r--r--tensorflow/python/ops/gradients_impl.py97
-rw-r--r--tensorflow/python/ops/gradients_test.py34
-rw-r--r--tensorflow/python/ops/image_ops_test.py62
-rw-r--r--tensorflow/python/ops/nn_grad.py15
-rw-r--r--tensorflow/python/ops/nn_ops.py3
-rw-r--r--tensorflow/python/ops/parsing_ops.py13
-rw-r--r--tensorflow/python/ops/string_ops.py16
-rw-r--r--tensorflow/python/ops/while_v2.py3
-rw-r--r--tensorflow/python/platform/benchmark.py14
-rwxr-xr-xtensorflow/python/pywrap_tfe.i1
-rw-r--r--tensorflow/python/training/monitored_session.py8
-rw-r--r--tensorflow/python/training/moving_averages.py49
-rw-r--r--tensorflow/python/util/protobuf/compare.py18
-rw-r--r--tensorflow/python/util/util.cc8
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.-unconnected-gradients.pbtxt12
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.activations.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.backend.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.test.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-unconnected-gradients.pbtxt12
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.activations.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.test.pbtxt4
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.android5
-rwxr-xr-xtensorflow/tools/ci_build/ci_parameterized_build.sh29
-rwxr-xr-xtensorflow/tools/ci_build/install/install_auditwheel.sh4
-rwxr-xr-xtensorflow/tools/ci_build/install/install_pip_packages.sh4
-rw-r--r--tensorflow/tools/pip_package/setup.py17
-rwxr-xr-xtensorflow/workspace.bzl47
-rw-r--r--third_party/gpus/cuda_configure.bzl1979
-rw-r--r--third_party/jpeg/BUILD2
-rw-r--r--third_party/jpeg/BUILD.bazel (renamed from third_party/jpeg/jpeg.BUILD)11
-rw-r--r--third_party/jpeg/BUILD.system (renamed from third_party/systemlibs/jpeg.BUILD)0
-rw-r--r--third_party/jpeg/jpeg_helpers.BUILD.bazel1
-rw-r--r--third_party/jpeg/workspace.bzl16
-rw-r--r--third_party/nasm/BUILD1
-rw-r--r--third_party/nasm/BUILD.bazel (renamed from third_party/nasm.BUILD)12
-rw-r--r--third_party/nasm/BUILD.system (renamed from third_party/systemlibs/nasm.BUILD)0
-rw-r--r--third_party/nasm/workspace.bzl17
-rw-r--r--third_party/nccl/LICENSE231
-rw-r--r--third_party/nccl/archive.BUILD179
-rw-r--r--third_party/nccl/build_defs.bzl.tpl210
-rw-r--r--third_party/nccl/nccl_archive.BUILD68
-rw-r--r--third_party/nccl/nccl_configure.bzl214
550 files changed, 29244 insertions, 14656 deletions
diff --git a/configure.py b/configure.py
index a88fdb3555..89dc79b6b6 100644
--- a/configure.py
+++ b/configure.py
@@ -35,7 +35,6 @@ except ImportError:
_DEFAULT_CUDA_VERSION = '9.0'
_DEFAULT_CUDNN_VERSION = '7'
-_DEFAULT_NCCL_VERSION = '2.2'
_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,7.0'
_DEFAULT_CUDA_PATH = '/usr/local/cuda'
_DEFAULT_CUDA_PATH_LINUX = '/opt/cuda'
@@ -384,7 +383,9 @@ def set_build_var(environ_cp,
var = str(int(get_var(environ_cp, var_name, query_item, enabled_by_default)))
environ_cp[var_name] = var
if var == '1':
- write_to_bazelrc('build --define %s=true' % option_name)
+ write_to_bazelrc(
+ 'build:%s --define %s=true' % (bazel_config_name, option_name))
+ write_to_bazelrc('build --config=%s' % bazel_config_name)
elif bazel_config_name is not None:
# TODO(mikecase): Migrate all users of configure.py to use --config Bazel
# options and not to set build configs through environment variables.
@@ -1109,18 +1110,17 @@ def set_tf_nccl_install_path(environ_cp):
raise ValueError('Currently NCCL is only supported on Linux platforms.')
ask_nccl_version = (
- 'Please specify the NCCL version you want to use. If NCCL %s is not '
- 'installed, then you can use version 1.3 that can be fetched '
- 'automatically but it may have worse performance with multiple GPUs. '
- '[Default is %s]: ') % (_DEFAULT_NCCL_VERSION, _DEFAULT_NCCL_VERSION)
+ 'Please specify the locally installed NCCL version you want to use. '
+ '[Default is to use https://github.com/nvidia/nccl]: ')
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
tf_nccl_version = get_from_env_or_user_or_default(
- environ_cp, 'TF_NCCL_VERSION', ask_nccl_version, _DEFAULT_NCCL_VERSION)
- tf_nccl_version = reformat_version_sequence(str(tf_nccl_version), 1)
+ environ_cp, 'TF_NCCL_VERSION', ask_nccl_version, '')
- if tf_nccl_version == '1':
- break # No need to get install path, NCCL 1 is a GitHub repo.
+ if not tf_nccl_version:
+ break # No need to get install path, building the open source code.
+
+ tf_nccl_version = reformat_version_sequence(str(tf_nccl_version), 1)
# Look with ldconfig first if we can find the library in paths
# like /usr/lib/x86_64-linux-gnu and the header file in the corresponding
@@ -1232,7 +1232,6 @@ def set_tf_nccl_install_path(environ_cp):
environ_cp['TF_NCCL_VERSION'] = tf_nccl_version
write_action_env_to_bazelrc('TF_NCCL_VERSION', tf_nccl_version)
-
def get_native_cuda_compute_capabilities(environ_cp):
"""Get native cuda compute capabilities.
diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc
index 5607c9dcb0..008f088c2d 100644
--- a/tensorflow/c/eager/c_api_test_util.cc
+++ b/tensorflow/c/eager/c_api_test_util.cc
@@ -99,8 +99,6 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
TFE_OpAddInput(op, b, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteStatus(status);
- TFE_OpSetAttrBool(op, "transpose_a", 0);
- TFE_OpSetAttrBool(op, "transpose_b", 0);
TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
return op;
diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD
index b587e63227..9d2208d84d 100644
--- a/tensorflow/cc/BUILD
+++ b/tensorflow/cc/BUILD
@@ -411,6 +411,7 @@ tf_cc_test(
srcs = ["gradients/nn_grad_test.cc"],
deps = [
":cc_ops",
+ ":cc_ops_internal",
":grad_op_registry",
":grad_testutil",
":gradient_checker",
diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc
index 588e96cb19..2a32a2ed6f 100644
--- a/tensorflow/cc/gradients/nn_grad.cc
+++ b/tensorflow/cc/gradients/nn_grad.cc
@@ -143,6 +143,33 @@ Status Relu6GradHelper(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("Relu6", Relu6GradHelper);
+Status LeakyReluGradHelper(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ float alpha;
+ TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha));
+ internal::LeakyReluGrad::Attrs attrs;
+ auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(0),
+ attrs.Alpha(alpha));
+ grad_outputs->push_back(dx);
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("LeakyRelu", LeakyReluGradHelper);
+
+Status LeakyReluGradGradHelper(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ float alpha;
+ TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "alpha", &alpha));
+ internal::LeakyReluGrad::Attrs attrs;
+ auto dx = internal::LeakyReluGrad(scope, grad_inputs[0], op.input(1),
+ attrs.Alpha(alpha));
+ grad_outputs->push_back(dx);
+ grad_outputs->push_back(NoGradient());
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("LeakyReluGrad", LeakyReluGradGradHelper);
+
Status EluGradHelper(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc
index aa72cf7ba2..f5a09e09dc 100644
--- a/tensorflow/cc/gradients/nn_grad_test.cc
+++ b/tensorflow/cc/gradients/nn_grad_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/cc/framework/gradient_checker.h"
#include "tensorflow/cc/framework/testutil.h"
#include "tensorflow/cc/gradients/grad_testutil.h"
+#include "tensorflow/cc/ops/nn_ops_internal.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -160,6 +161,32 @@ TEST_F(NNGradTest, Relu6Grad) {
RunTest(x, x_init_value, y, shape);
}
+TEST_F(NNGradTest, LeakyReluGrad) {
+ TensorShape shape({5, 2});
+ auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
+ auto y = ops::internal::LeakyRelu(scope_, x);
+ // Avoid input values where Leaky ReLU gradient is not well defined (around
+ // zero).
+ Tensor x_init_value = test::AsTensor<float>(
+ {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f},
+ {5, 2});
+ RunTest(x, x_init_value, y, shape);
+}
+
+TEST_F(NNGradTest, LeakyReluGradGrad) {
+ TensorShape shape({5, 2});
+ auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
+ // Avoid input values where Leaky ReLU gradient is not well defined (around
+ // zero).
+ Tensor x_init_value = test::AsTensor<float>(
+ {2.3f, 1.9f, 1.5f, 1.1f, 0.7f, 0.3f, -0.1f, -0.5f, -0.9f, -1.3f}, {5, 2});
+ Tensor features = test::AsTensor<float>(
+ {-0.9f, -0.7f, -0.5f, -0.3f, -0.1f, 0.1f, 0.3f, 0.5f, 0.7f, 0.9f},
+ {5, 2});
+ auto y = ops::internal::LeakyReluGrad(scope_, x, features);
+ RunTest(x, x_init_value, y, shape);
+}
+
TEST_F(NNGradTest, EluGrad) {
TensorShape shape({5, 2});
auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc
index e0b9932d80..b7ae7fbeb3 100644
--- a/tensorflow/compiler/jit/deadness_analysis.cc
+++ b/tensorflow/compiler/jit/deadness_analysis.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/jit/deadness_analysis_internal.h"
+#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/hash/hash.h"
@@ -579,7 +580,8 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
private:
enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly };
- std::vector<Predicate*> GetIncomingPreds(Node* n, EdgeKind edge_kind);
+ Status GetInputPreds(Node* n, EdgeKind edge_kind,
+ std::vector<Predicate*>* result);
// Sets the predicate for output `output_idx` of `n` to `pred`. Sets the i'th
// bit of `should_revisit` if `pred` is different from the current predicate
@@ -625,9 +627,10 @@ TensorId InputEdgeToTensorId(const Edge* e) {
return TensorId(e->src()->name(), e->src_output());
}
-std::vector<Predicate*> DeadnessAnalysisImpl::GetIncomingPreds(
- Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind) {
- std::vector<Predicate*> incoming_preds;
+Status DeadnessAnalysisImpl::GetInputPreds(
+ Node* n, DeadnessAnalysisImpl::EdgeKind edge_kind,
+ std::vector<Predicate*>* result) {
+ result->clear();
for (const Edge* in_edge : n->in_edges()) {
bool should_process =
edge_kind == EdgeKind::kDataAndControl ||
@@ -636,17 +639,27 @@ std::vector<Predicate*> DeadnessAnalysisImpl::GetIncomingPreds(
if (should_process) {
auto it = predicate_map_.find(InputEdgeToTensorId(in_edge));
- CHECK(it != predicate_map_.end()) << n->name();
- incoming_preds.push_back(it->second);
+ if (it == predicate_map_.end()) {
+ GraphCycles graph_cycles;
+ TF_RETURN_IF_ERROR(CreateCycleDetectionGraph(&graph_, &graph_cycles));
+
+ // If we didn't return with an error above then the graph is probably
+ // fine and we have a bug in deadness analysis.
+ return errors::Internal("Could not find input ", in_edge->DebugString(),
+ " to ", n->name(),
+ " when visiting the graph in post-order. Most "
+ "likely indicates a bug in deadness analysis.");
+ }
+ result->push_back(it->second);
}
}
- return incoming_preds;
+ return Status::OK();
}
Status DeadnessAnalysisImpl::HandleSwitch(Node* n,
std::vector<bool>* should_revisit) {
- std::vector<Predicate*> input_preds =
- GetIncomingPreds(n, EdgeKind::kDataAndControl);
+ std::vector<Predicate*> input_preds;
+ TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds));
const Edge* pred_edge;
TF_RETURN_IF_ERROR(n->input_edge(1, &pred_edge));
Predicate* true_switch = predicate_factory_.MakeSymbolPredicate(
@@ -675,17 +688,31 @@ Status DeadnessAnalysisImpl::HandleSwitch(Node* n,
}
namespace {
-const Edge* FindUniqueBackedge(Node* merge) {
+Status CreateMultipleNextIterationInputsError(Node* merge) {
+ std::vector<string> backedges;
+ for (const Edge* backedge : merge->in_edges()) {
+ if (backedge->src()->IsNextIteration()) {
+ backedges.push_back(absl::StrCat(" ", SummarizeNode(*backedge->src())));
+ }
+ }
+ return errors::InvalidArgument(
+ "Multiple NextIteration inputs to merge node ", SummarizeNode(*merge),
+ ": \n", absl::StrJoin(backedges, "\n"),
+ "\nMerge nodes can have at most one incoming NextIteration edge.");
+}
+
+Status FindUniqueBackedge(Node* merge, const Edge** result) {
+ *result = nullptr;
CHECK(merge->IsMerge());
- const Edge* result = nullptr;
for (const Edge* e : merge->in_edges()) {
if (e->src()->IsNextIteration()) {
- CHECK_EQ(result, nullptr)
- << "Multiple backedges to " << merge->DebugString();
- result = e;
+ if (*result != nullptr) {
+ return CreateMultipleNextIterationInputsError(merge);
+ }
+ *result = e;
}
}
- return result;
+ return Status::OK();
}
// If `backedge_predicate` is equal to `symbolic_predicate` & Step where Step
@@ -764,9 +791,12 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n,
return Status::OK();
}
+ std::vector<Predicate*> input_preds;
+ TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataOnly, &input_preds));
+
// We're visiting this merge for the first time and it is a acyclic merge.
- Predicate* input_data_pred = predicate_factory_.MakeOrPredicate(
- GetIncomingPreds(n, EdgeKind::kDataOnly));
+ Predicate* input_data_pred =
+ predicate_factory_.MakeOrPredicate(input_preds);
SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred,
should_revisit);
return Status::OK();
@@ -777,7 +807,9 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n,
// of an unvisited backedge. Try to pattern match the predicate expression
// for that backedge (which should be visited now) into an and recurrence
// for the merge node.
- if (const Edge* unique_backedge = FindUniqueBackedge(n)) {
+ const Edge* unique_backedge;
+ TF_RETURN_IF_ERROR(FindUniqueBackedge(n, &unique_backedge));
+ if (unique_backedge) {
if (Predicate* step = DeduceStepPredicate(
&predicate_factory_, it->second,
predicate_map_[InputEdgeToTensorId(unique_backedge)])) {
@@ -808,8 +840,8 @@ Status DeadnessAnalysisImpl::HandleRecv(Node* n,
std::vector<bool>* should_revisit) {
// In addition to being alive or dead based on the inputs, a _Recv can also
// acquire a dead signal from a _Send.
- std::vector<Predicate*> input_preds =
- GetIncomingPreds(n, EdgeKind::kDataAndControl);
+ std::vector<Predicate*> input_preds;
+ TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds));
input_preds.push_back(predicate_factory_.MakeSymbolPredicate(
TensorId(n->name(), 0), /*must_be_true=*/false));
SetPredicate(n, {0, Graph::kControlSlot},
@@ -821,8 +853,9 @@ Status DeadnessAnalysisImpl::HandleRecv(Node* n,
Status DeadnessAnalysisImpl::HandleGeneric(Node* n,
std::vector<bool>* should_revisit) {
// Generally nodes are alive iff all their inputs are alive.
- Predicate* pred = predicate_factory_.MakeAndPredicate(
- GetIncomingPreds(n, EdgeKind::kDataAndControl));
+ std::vector<Predicate*> input_preds;
+ TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataAndControl, &input_preds));
+ Predicate* pred = predicate_factory_.MakeAndPredicate(input_preds);
for (int output_idx = 0; output_idx < n->num_outputs(); output_idx++) {
SetPredicate(n, output_idx, pred, should_revisit);
}
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index e083652978..af83c792e5 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -75,9 +75,8 @@ XlaTransferManager::XlaTransferManager(
}
}
-Status XlaTransferManager::TransferLiteralToDevice(const Tensor& host_tensor,
- Tensor* device_tensor,
- bool buffer_is_fresh) const {
+Status XlaTransferManager::TransferLiteralToDevice(
+ const Tensor& host_tensor, Tensor* device_tensor) const {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(),
host_tensor.shape(), &xla_shape));
@@ -98,11 +97,8 @@ Status XlaTransferManager::TransferLiteralToDevice(const Tensor& host_tensor,
// synchronized.
host_to_device_stream_->ThenWaitFor(stream_.get());
}
- xla::TransferManager::TransferToDeviceHint hint =
- buffer_is_fresh ? xla::TransferManager::kBufferUndefined
- : xla::TransferManager::kNoHint;
TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync(
- host_to_device_stream_.get(), *literal, shaped_buffer, hint));
+ host_to_device_stream_.get(), *literal, shaped_buffer));
if (UseMultipleStreams()) {
auto event = std::make_shared<se::Event>(stream_->parent());
TF_RET_CHECK(event->Init()) << "Event failed to initialize!";
@@ -169,7 +165,6 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
return;
}
TensorShape shape = shape_or_status.ValueOrDie();
- bool buffer_is_fresh = false;
if (!xla_tensor->has_shaped_buffer()) {
Status s =
xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
@@ -178,7 +173,6 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
done(s);
return;
}
- buffer_is_fresh = true;
}
Status status;
@@ -189,8 +183,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
"Tensor::CopyFrom failed when copying from CPU to XLA device"));
return;
}
- status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor,
- buffer_is_fresh);
+ status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
} else {
se::DeviceMemoryBase dev_dst_ptr =
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h
index a4c0c296fc..df82421294 100644
--- a/tensorflow/compiler/jit/xla_device_context.h
+++ b/tensorflow/compiler/jit/xla_device_context.h
@@ -67,8 +67,7 @@ class XlaTransferManager {
private:
Status TransferLiteralToDevice(const Tensor& host_tensor,
- Tensor* device_tensor,
- bool buffer_is_fresh) const;
+ Tensor* device_tensor) const;
void TransferLiteralFromDevice(Tensor* host_tensor,
const Tensor& device_tensor,
const StatusCallback& done) const;
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 822fedf121..ba2401ed26 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -895,6 +895,22 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "tensor_list_ops_test",
+ size = "small",
+ srcs = ["tensor_list_ops_test.py"],
+ # TensorList ops are not implemented in the on-demand compilation model yet.
+ disabled_backends = "cpu_ondemand",
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:list_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python/eager:function",
+ ],
+)
+
+tf_xla_py_test(
name = "ternary_ops_test",
size = "small",
srcs = ["ternary_ops_test.py"],
@@ -1029,6 +1045,19 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "permute_test",
+ size = "small",
+ srcs = ["permute_test.py"],
+ deps = [
+ "//tensorflow/compiler/tests:xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:nn_ops",
+ ],
+)
+
+tf_xla_py_test(
name = "xla_device_test",
size = "small",
srcs = ["xla_device_test.py"],
diff --git a/tensorflow/compiler/tests/nullary_ops_test.py b/tensorflow/compiler/tests/nullary_ops_test.py
index f985c5d2d9..38cb2f83ef 100644
--- a/tensorflow/compiler/tests/nullary_ops_test.py
+++ b/tensorflow/compiler/tests/nullary_ops_test.py
@@ -43,18 +43,37 @@ class NullaryOpsTest(xla_test.XLATestCase):
output.run()
def testConstants(self):
- constants = [
- np.float32(42),
- np.array([], dtype=np.float32),
- np.array([1, 2], dtype=np.float32),
- np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32),
- np.array([[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]],
- dtype=np.float32),
- np.array([[[]], [[]]], dtype=np.float32),
- np.array([[[[1]]]], dtype=np.float32),
- ]
- for c in constants:
- self._testNullary(lambda c=c: constant_op.constant(c), expected=c)
+ for dtype in self.numeric_types:
+ constants = [
+ dtype(42),
+ np.array([], dtype=dtype),
+ np.array([1, 2], dtype=dtype),
+ np.array([7, 7, 7, 7, 7], dtype=dtype),
+ np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype),
+ np.array([[[1, 2], [3, 4], [5, 6]], [[10, 20], [30, 40], [50, 60]]],
+ dtype=dtype),
+ np.array([[[]], [[]]], dtype=dtype),
+ np.array([[[[1]]]], dtype=dtype),
+ ]
+ for c in constants:
+ self._testNullary(lambda c=c: constant_op.constant(c), expected=c)
+
+ def testComplexConstants(self):
+ for dtype in self.complex_types:
+ constants = [
+ dtype(42 + 3j),
+ np.array([], dtype=dtype),
+ np.ones([50], dtype=dtype) * (3 + 4j),
+ np.array([1j, 2 + 1j], dtype=dtype),
+ np.array([[1, 2j, 7j], [4, 5, 6]], dtype=dtype),
+ np.array([[[1, 2], [3, 4 + 6j], [5, 6]],
+ [[10 + 7j, 20], [30, 40], [50, 60]]],
+ dtype=dtype),
+ np.array([[[]], [[]]], dtype=dtype),
+ np.array([[[[1 + 3j]]]], dtype=dtype),
+ ]
+ for c in constants:
+ self._testNullary(lambda c=c: constant_op.constant(c), expected=c)
if __name__ == "__main__":
diff --git a/tensorflow/compiler/tests/permute_test.py b/tensorflow/compiler/tests/permute_test.py
new file mode 100644
index 0000000000..dbb9274df4
--- /dev/null
+++ b/tensorflow/compiler/tests/permute_test.py
@@ -0,0 +1,80 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the DataFormatVecPermute operator."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import nn_ops
+from tensorflow.python.platform import test
+
+
+class XlaPermuteOpTest(xla_test.XLATestCase):
+
+ def _runPermuteAndCompare(self, x, src_format, dst_format, expected):
+ with self.cached_session() as session:
+ with self.test_scope():
+ placeholder = array_ops.placeholder(dtypes.as_dtype(x.dtype), x.shape)
+ param = {placeholder: x}
+ output = nn_ops.data_format_vec_permute(
+ placeholder, src_format=src_format, dst_format=dst_format)
+ result = session.run(output, param)
+ self.assertAllEqual(result, expected)
+
+ def testNHWCToNCHW(self):
+ x = np.array([7, 4, 9, 3], dtype=np.int32)
+ self._runPermuteAndCompare(x, "NHWC", "NCHW", [7, 3, 4, 9])
+
+ def testNCHWToNHWC(self):
+ x = np.array([7, 4, 9, 3], dtype=np.int32)
+ self._runPermuteAndCompare(x, "NCHW", "NHWC", [7, 9, 3, 4])
+
+ def testNHWCToHWNC(self):
+ x = np.array([7, 4, 9, 3], dtype=np.int32)
+ self._runPermuteAndCompare(x, "NHWC", "HWNC", [4, 9, 7, 3])
+
+ def testHWNCToNHWC(self):
+ x = np.array([7, 4, 9, 3], dtype=np.int32)
+ self._runPermuteAndCompare(x, "HWNC", "NHWC", [9, 7, 4, 3])
+
+ def testNHWCToNCHW2D(self):
+ x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32)
+ self._runPermuteAndCompare(x, "NHWC", "NCHW",
+ [[7, 4], [5, 1], [9, 3], [4, 5]])
+
+ def testNHWCToHWNC2D(self):
+ x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32)
+ self._runPermuteAndCompare(x, "NHWC", "HWNC",
+ [[9, 3], [4, 5], [7, 4], [5, 1]])
+
+ def testHWNCToNHWC2D(self):
+ x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32)
+ self._runPermuteAndCompare(x, "HWNC", "NHWC",
+ [[4, 5], [7, 4], [9, 3], [5, 1]])
+
+ def testNCHWToNHWC2D(self):
+ x = np.array([[7, 4], [9, 3], [4, 5], [5, 1]], dtype=np.int32)
+ self._runPermuteAndCompare(x, "NCHW", "NHWC",
+ [[7, 4], [4, 5], [5, 1], [9, 3]])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc
index 7a96f4c25c..dc119fb0f8 100644
--- a/tensorflow/compiler/tests/randomized_tests.cc
+++ b/tensorflow/compiler/tests/randomized_tests.cc
@@ -1820,7 +1820,7 @@ TEST_F(OpTest, Diag) {
do {
dims = RandomDims(1);
size = TensorShape(dims).num_elements();
- } while (size * size < tf_xla_max_tensor_size);
+ } while (size * size > tf_xla_max_tensor_size);
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("Diag").RandomInput(type, dims).Attr("T", type));
});
diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py
index dbf4beb693..57f0ab7a9e 100644
--- a/tensorflow/compiler/tests/sort_ops_test.py
+++ b/tensorflow/compiler/tests/sort_ops_test.py
@@ -48,13 +48,29 @@ class XlaSortOpTest(xla_test.XLATestCase):
self.assertAllClose(v, result, rtol=1e-3)
def testSort(self):
- supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32])
+ supported_types = set(
+ [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
for dtype in supported_types.intersection(self.numeric_types):
x = np.arange(101, dtype=dtype)
np.random.shuffle(x)
self._assertOpOutputMatchesExpected(
xla.sort, [x], expected=[np.arange(101, dtype=dtype)])
+ def testKeyValueSort(self):
+ supported_types = set(
+ [dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
+ for key_type in supported_types.intersection(self.numeric_types):
+ for value_type in supported_types.intersection(self.numeric_types):
+ x = np.arange(101, dtype=key_type)
+ np.random.shuffle(x)
+ y = (-x).astype(value_type)
+ self._assertOpOutputMatchesExpected(
+ xla.key_value_sort, [x, y],
+ expected=[
+ np.arange(101, dtype=key_type),
+ -np.arange(101, dtype=value_type)
+ ])
+
def testTopK(self):
supported_types = set(
[dtypes.bfloat16.as_numpy_dtype, np.float32, np.int32, np.uint32])
diff --git a/tensorflow/compiler/tests/tensor_list_ops_test.py b/tensorflow/compiler/tests/tensor_list_ops_test.py
new file mode 100644
index 0000000000..5c079d595c
--- /dev/null
+++ b/tensorflow/compiler/tests/tensor_list_ops_test.py
@@ -0,0 +1,96 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for ops which manipulate lists of tensors via bridge."""
+
+# pylint: disable=g-bad-name
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import numpy as np
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import list_ops
+from tensorflow.python.platform import test
+
+
+def scalar_shape():
+ return ops.convert_to_tensor([], dtype=dtypes.int32)
+
+
+class ListOpsTest(xla_test.XLATestCase):
+
+ def testElementShape(self):
+ with self.cached_session() as sess, self.test_scope():
+ dim = array_ops.placeholder(dtypes.int32)
+ l = list_ops.tensor_list_reserve(
+ element_shape=(dim, 15), num_elements=20,
+ element_dtype=dtypes.float32)
+ e32 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int32)
+ e64 = list_ops.tensor_list_element_shape(l, shape_type=dtypes.int64)
+ self.assertAllEqual(sess.run(e32, {dim: 10}), (10, 15))
+ self.assertAllEqual(sess.run(e64, {dim: 7}), (7, 15))
+
+ def testPushPop(self):
+ with self.cached_session() as sess, self.test_scope():
+ num = array_ops.placeholder(dtypes.int32)
+ l = list_ops.tensor_list_reserve(
+ element_shape=(7, 15), num_elements=num, element_dtype=dtypes.float32)
+ l = list_ops.tensor_list_push_back(
+ l, constant_op.constant(1.0, shape=(7, 15)))
+ l = list_ops.tensor_list_push_back(
+ l, constant_op.constant(2.0, shape=(7, 15)))
+ l, e2 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
+ _, e1 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
+ self.assertAllEqual(sess.run(e2, {num: 10}), 2.0 * np.ones((7, 15)))
+ self.assertAllEqual(sess.run(e1, {num: 10}), 1.0 * np.ones((7, 15)))
+
+ def testPushPopSeparateLists(self):
+ with self.cached_session() as sess, self.test_scope():
+ num = array_ops.placeholder(dtypes.int32)
+ l = list_ops.tensor_list_reserve(
+ element_shape=scalar_shape(),
+ num_elements=num,
+ element_dtype=dtypes.float32)
+ l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
+ l2 = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
+ l3 = list_ops.tensor_list_push_back(l, constant_op.constant(3.0))
+ _, e11 = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
+ l2, e21 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32)
+ l2, e22 = list_ops.tensor_list_pop_back(l2, element_dtype=dtypes.float32)
+ l3, e31 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32)
+ l3, e32 = list_ops.tensor_list_pop_back(l3, element_dtype=dtypes.float32)
+ result = sess.run([e11, [e21, e22], [e31, e32]], {num: 20})
+ self.assertEqual(result, [1.0, [2.0, 1.0], [3.0, 1.0]])
+
+ def testEmptyTensorList(self):
+ dim = 7
+ with self.cached_session() as sess, self.test_scope():
+ p = array_ops.placeholder(dtypes.int32)
+ l = list_ops.empty_tensor_list(
+ element_shape=(p, 15), element_dtype=dtypes.float32)
+ l = list_ops.tensor_list_push_back(
+ l, constant_op.constant(1.0, shape=(dim, 15)))
+ _, e = list_ops.tensor_list_pop_back(l, element_dtype=dtypes.float32)
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "Use TensorListReserve instead"):
+ self.assertEqual(sess.run(e, {p: dim}), 1.0 * np.ones((dim, 15)))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index 36c6f5d316..0362682bd6 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -79,7 +79,10 @@ Status FunctionalizeControlFlowForFunction(
const string& func_name, const string& new_func_name,
const protobuf::Map<string, tensorflow::AttrValue>& attrs,
FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr,
- std::map<string, string>* canonicalized_name_to_new_name) {
+ std::map<string, absl::optional<string>>* canonicalized_name_to_new_name,
+ bool* modified) {
+ *modified = false;
+
// Convert the function to Graph.
FunctionLibraryRuntime::Handle handle;
TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle));
@@ -91,44 +94,20 @@ Status FunctionalizeControlFlowForFunction(
}
});
const FunctionBody* body = flr->GetFunctionBody(handle);
+ Graph* g = body->graph;
- // Call graph optimizer. The most important optimization we need is constant
- // folding, which will replace ops like Shape/BroadcastGradientArgs with
- // constant shape input. Without this optimization, those ops might become
- // dynamic input for then/else body function and XLA will complain that input
- // is not compile time constant. We enable function inlining as well, because
- // otherwise we won't be able to infer shape for any node depending on
- // function call nodes.
- if (VLOG_IS_ON(4)) {
- dump_graph::DumpGraphToFile(
- absl::StrCat("functionalize_control_flow_before_opt_", func_name),
- *body->graph, fld);
- }
- // Optimizer accepts std::unique_ptr<Graph>* as input and might change
- // underlying pointer, thus we create a new Graph and copy from body->graph.
- std::unique_ptr<Graph> optimized_graph(new Graph(fld));
- CopyGraph(*body->graph, optimized_graph.get());
- OptimizerOptions opts;
- opts.set_opt_level(OptimizerOptions::L0);
- opts.set_do_function_inlining(true);
- opts.set_do_constant_folding(true);
- GraphOptimizer optimizer(opts);
- auto cf_consider_fn = [](const Node* n) {
- // Skip SymbolicGradient op when doing constant folding.
- // Enabling SymbolicGradient op in constant folding requires
- // flr->device() to be non-null, and here we have not constructed
- // proper Device object yet (it will be constructed in XlaCompiler).
- return n->type_string() != FunctionLibraryDefinition::kGradientOp;
- };
- optimizer.Optimize(flr, flr->env(),
- /*device=*/nullptr, &optimized_graph,
- /*shape_map=*/nullptr, /*cse_consider_fn=*/nullptr,
- cf_consider_fn);
- if (VLOG_IS_ON(4)) {
- dump_graph::DumpGraphToFile(
- absl::StrCat("functionalize_control_flow_after_opt_", func_name),
- *optimized_graph, fld);
+ // Check if the graph has Switch or Merge node.
+ bool has_switch_or_merge = false;
+ for (Node* n : body->graph->nodes()) {
+ if (n->type_string() == "Switch" || n->type_string() == "Merge") {
+ has_switch_or_merge = true;
+ break;
+ }
}
+ // We cannot return here directly if the graph has no Switch/Merge.
+ // It might contain function call nodes, or If/While nodes with Switch/Merge
+ // in function body. We still need to rewrite those functions and modify
+ // corresponding nodes.
// If any node has associated functions, functionalize them first.
// Gather nodes with associated functions first, because rewriting those nodes
@@ -136,7 +115,7 @@ Status FunctionalizeControlFlowForFunction(
// it.
std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>>
nodes_to_associated_functions;
- for (auto* n : optimized_graph->nodes()) {
+ for (auto* n : g->nodes()) {
auto associated_functions = GetAssociatedFunctions(*n, flr);
if (!associated_functions.empty()) {
nodes_to_associated_functions.push_back({n, associated_functions});
@@ -151,10 +130,15 @@ Status FunctionalizeControlFlowForFunction(
Canonicalize(name, AttrSlice(&associated_function.attrs()));
auto iter = canonicalized_name_to_new_name->find(canonicalized_name);
string new_name;
+ bool function_modified;
if (iter != canonicalized_name_to_new_name->end()) {
- // If we already functionalized this function, skip functionalization
- // but still rewrite the node.
- new_name = iter->second;
+ // If we already processed this function, check if it was rewritten. If
+ // the function was rewritten, the entry will be non-empty. Otherwise
+ // the entry will be empty.
+ function_modified = iter->second.has_value();
+ if (function_modified) {
+ new_name = iter->second.value();
+ }
} else {
if (associated_function.type() ==
AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient) {
@@ -166,42 +150,62 @@ Status FunctionalizeControlFlowForFunction(
}
TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
name, new_name, associated_function.attrs(), fld, flr,
- canonicalized_name_to_new_name));
- (*canonicalized_name_to_new_name)[canonicalized_name] = new_name;
+ canonicalized_name_to_new_name, &function_modified));
+ if (function_modified) {
+ // If the function was rewritten, add an non-empty entry. So later we
+ // know we have processed this function, and it was rewritten into
+ // another function.
+ (*canonicalized_name_to_new_name)[canonicalized_name] = new_name;
+ } else {
+ // If the function was not rewritten, add an empty entry. So later
+ // we know we have processed this function, and it does not need to be
+ // rewritten.
+ (*canonicalized_name_to_new_name)[canonicalized_name] = absl::nullopt;
+ }
+ }
+ if (function_modified) {
+ *modified = true;
+
+ // Notice that if "n" is a function call, RewriteAssociatedFunction()
+ // will delete it and create a new node instead, making "n" an invalid
+ // pointer. That's fine because in that case, associated_functions will
+ // only have one member and the loop will only run once.
+ TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
+ g, n, fld, associated_function, new_name));
}
- // Notice that if "n" is a function call, RewriteAssociatedFunction() will
- // delete it and create a new node instead, making "n" an invalid pointer.
- // That's fine because in that case, associated_functions will only have
- // one member and the loop will only run once.
- TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
- optimized_graph.get(), n, fld, associated_function, new_name));
}
}
- // Functionalize the function body.
- if (VLOG_IS_ON(4)) {
- dump_graph::DumpGraphToFile(
- absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
- *optimized_graph, fld);
- }
- TF_RETURN_IF_ERROR(FunctionalizeControlFlow(optimized_graph.get(), fld));
- if (VLOG_IS_ON(4)) {
- dump_graph::DumpGraphToFile(
- absl::StrCat("functionalize_control_flow_after_fdef_", func_name),
- *optimized_graph, fld);
+ if (has_switch_or_merge) {
+ *modified = true;
+
+ // Functionalize the function body.
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_control_flow_before_fdef_", func_name),
+ *g, fld);
+ }
+ TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld));
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g,
+ fld);
+ }
}
- FunctionDef functionalized_fdef;
- TF_RETURN_IF_ERROR(GraphToFunctionDef(*optimized_graph, new_func_name,
- &functionalized_fdef));
- // Add rewritten FunctionDef into library.
- if (func_name == new_func_name) {
- VLOG(2) << "Replacing function " << func_name;
+ if (*modified) {
+ // Add rewritten FunctionDef into library.
+ FunctionDef functionalized_fdef;
TF_RETURN_IF_ERROR(
- fld->ReplaceFunction(new_func_name, functionalized_fdef));
- } else {
- VLOG(2) << "Adding function " << new_func_name;
- TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef));
+ GraphToFunctionDef(*g, new_func_name, &functionalized_fdef));
+ if (func_name == new_func_name) {
+ VLOG(2) << "Replacing function " << func_name;
+ TF_RETURN_IF_ERROR(
+ fld->ReplaceFunction(new_func_name, functionalized_fdef));
+ } else {
+ VLOG(2) << "Adding function " << new_func_name;
+ TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef));
+ }
}
return ret_status;
@@ -227,7 +231,7 @@ Status FunctionalizeControlFlowPass::Run(
{"TPUCompile", "function"},
{"XlaLaunch", "function"},
};
- std::map<string, string> canonicalized_name_to_new_name;
+ std::map<string, absl::optional<string>> canonicalized_name_to_new_name;
for (Node* n : graph->nodes()) {
auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string());
if (it == kNodeTypeToFunctionAttrMapping->end()) {
@@ -242,12 +246,15 @@ Status FunctionalizeControlFlowPass::Run(
<< ". Corresponding function: " << func.name();
string new_func_name = options.flib_def->UniqueFunctionName(
absl::StrCat(func.name(), "_f15n_"));
+ bool modified;
TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
func.name(), new_func_name, func.attr(), options.flib_def, flr,
- &canonicalized_name_to_new_name));
- n->ClearAttr(func_attr);
- func.set_name(new_func_name);
- n->AddAttr(func_attr, func);
+ &canonicalized_name_to_new_name, &modified));
+ if (modified) {
+ n->ClearAttr(func_attr);
+ func.set_name(new_func_name);
+ n->AddAttr(func_attr, func);
+ }
}
}
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 3e823254d3..224e5ea123 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -62,6 +62,7 @@ tf_kernel_library(
"one_hot_op.cc",
"pack_op.cc",
"pad_op.cc",
+ "permute_op.cc",
"pooling_ops.cc",
"qr_op.cc",
"quantize_and_dequantize_op.cc",
@@ -94,6 +95,7 @@ tf_kernel_library(
"stateless_random_ops.cc",
"strided_slice_op.cc",
"tensor_array_ops.cc",
+ "tensor_list_ops.cc",
"tile_ops.cc",
"topk_op.cc",
"training_ops.cc",
@@ -119,6 +121,7 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/lib:batch_dot",
+ "//tensorflow/compiler/tf2xla/lib:broadcast",
"//tensorflow/compiler/tf2xla/lib:cholesky",
"//tensorflow/compiler/tf2xla/lib:qr",
"//tensorflow/compiler/tf2xla/lib:random",
@@ -157,6 +160,7 @@ tf_kernel_library(
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:conv_ops",
"//tensorflow/core/kernels:cwise_op",
+ "//tensorflow/core/kernels:list_kernels",
"//tensorflow/core/kernels:no_op",
"//tensorflow/core/kernels:ops_util",
"//tensorflow/core/kernels:pooling_ops",
diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
index a988d3c33e..47e517a657 100644
--- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc
@@ -64,7 +64,7 @@ XLA_MAKE_BINARY(Complex, xla::Complex(lhs, rhs, extend_dimensions));
// }
static xla::XlaOp DivNoNanImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
xla::XlaOp y, const BCast& broadcast_helper) {
- std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
auto zero = XlaHelpers::Zero(b, dtype);
auto y_equals_0 = xla::Eq(y, zero);
auto zeros = xla::ZerosLike(x);
@@ -84,7 +84,7 @@ XLA_MAKE_BINARY(DivNoNan,
// }
static xla::XlaOp FloorDivImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
xla::XlaOp y, const BCast& broadcast_helper) {
- std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
if (DataTypeIsUnsigned(dtype)) {
return xla::Div(x, y);
}
@@ -105,7 +105,7 @@ XLA_MAKE_BINARY(FloorDiv,
static xla::XlaOp XlogyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
xla::XlaOp y, const BCast& broadcast_helper) {
- std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
auto zero = XlaHelpers::Zero(b, dtype);
auto is_zero = xla::Eq(x, zero);
return xla::Select(is_zero, zero, xla::Mul(x, xla::Log(y)));
@@ -114,7 +114,7 @@ XLA_MAKE_BINARY(Xlogy, XlogyImpl(b, input_type(0), lhs, rhs, broadcast_helper));
static xla::XlaOp XdivyImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
xla::XlaOp y, const BCast& broadcast_helper) {
- std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
auto zero = XlaHelpers::Zero(b, dtype);
auto is_zero = xla::Eq(x, zero);
return xla::Select(is_zero, zero, xla::Div(x, y));
@@ -126,7 +126,7 @@ XLA_MAKE_BINARY(Xdivy, XdivyImpl(b, input_type(0), lhs, rhs, broadcast_helper));
// return (x < T(0)) == (y < T(0)) ? trunc_mod : std::fmod(trunc_mod + y, y);
static xla::XlaOp FloorModImpl(xla::XlaBuilder* b, DataType dtype, xla::XlaOp x,
xla::XlaOp y, const BCast& broadcast_helper) {
- std::tie(x, y) = XlaBinaryOp::Broadcast(b, x, y, broadcast_helper);
+ std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
auto zero = XlaHelpers::Zero(b, dtype);
auto same_sign = xla::Eq(xla::Lt(x, zero), xla::Lt(y, zero));
auto trunc_mod = xla::Rem(x, y);
diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
index 696c1c39be..9bb11fb67e 100644
--- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc
@@ -13,16 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "absl/algorithm/container.h"
-#include "tensorflow/compiler/tf2xla/shape_util.h"
-#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
-#include "tensorflow/compiler/xla/client/lib/constants.h"
-#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
namespace {
@@ -37,59 +32,9 @@ class BroadcastToOp : public XlaOpKernel {
TensorShape output_shape;
OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape));
- OP_REQUIRES(context, input_shape.dims() <= output_shape.dims(),
- errors::InvalidArgument(
- "Input rank (", input_shape.dims(),
- ") must be less than or equal to the output rank (",
- output_shape.dims(), ")"));
-
- auto input_dims = input_shape.dim_sizes();
- auto output_dims = output_shape.dim_sizes();
-
- // Broadcasting is done right-to-left on right-aligned dimensions; reverse
- // the two vectors so elements to be broadcast are aligned.
- absl::c_reverse(input_dims);
- absl::c_reverse(output_dims);
-
- std::vector<int64> broadcast_dims;
- std::vector<int64> broadcast_shape;
- for (int i = 0; i < output_shape.dims(); ++i) {
- if (i < input_shape.dims()) {
- OP_REQUIRES(
- context,
- (output_dims[i] == 0 && input_dims[i] == 0) ||
- (input_dims[i] != 0 && output_dims[i] % input_dims[i] == 0),
- errors::InvalidArgument("invalid shape to broadcast from ",
- input_shape.DebugString(), " to ",
- output_shape.DebugString()));
-
- broadcast_dims.push_back(broadcast_shape.size());
- if (output_dims[i] == input_dims[i]) {
- broadcast_shape.push_back(output_dims[i]);
- } else if (output_dims[i] != input_dims[i]) {
- // Add dimensions [I, O/I], which we will later flatten to just
- // [O]. We must do this in two phases since XLA broadcasting does not
- // support tiling.
- broadcast_shape.push_back(input_dims[i]);
- broadcast_shape.push_back(output_dims[i] / input_dims[i]);
- }
- } else {
- broadcast_shape.push_back(output_dims[i]);
- }
- }
- absl::c_reverse(broadcast_dims);
- int broadcast_shape_size = broadcast_shape.size();
- for (int64& broadcast_dim : broadcast_dims) {
- broadcast_dim = broadcast_shape_size - broadcast_dim - 1;
- }
- absl::c_reverse(broadcast_shape);
- xla::XlaOp output = xla::Reshape(
- xla::BroadcastInDim(context->Input(0),
- xla::ShapeUtil::MakeShape(
- context->input_xla_type(0), broadcast_shape),
- broadcast_dims),
- output_shape.dim_sizes());
- context->SetOutput(0, output);
+ auto output = BroadcastTo(context->Input(0), output_shape.dim_sizes());
+ OP_REQUIRES_OK(context, output.status());
+ context->SetOutput(0, output.ValueOrDie());
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc
index da8cf3fc6f..2628ef8e24 100644
--- a/tensorflow/compiler/tf2xla/kernels/const_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
namespace tensorflow {
namespace {
@@ -76,6 +77,17 @@ class ConstOp : public XlaOpKernel {
return;
}
break;
+ case DT_COMPLEX64:
+ if (proto_.scomplex_val_size() == 2) {
+ ctx->SetOutput(
+ 0,
+ xla::Broadcast(xla::ConstantR0<xla::complex64>(
+ b, xla::complex64(proto_.scomplex_val(0),
+ proto_.scomplex_val(1))),
+ shape.dim_sizes()));
+ return;
+ }
+ break;
case DT_INT32:
if (proto_.int_val_size() == 1) {
ctx->SetOutput(
diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
index ef1015552d..234f7b4a01 100644
--- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
+#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@@ -39,7 +40,8 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) {
// compute valid broadcast shapes, but rely below on XLA to
// automatically perform the broadcast assuming its valid shapes are
// a superset of TensorFlow's valid shapes.
- BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape));
+ BCast bcast(BCast::FromShape(lhs_shape), BCast::FromShape(rhs_shape),
+ /*fewer_dims_optimization=*/false);
if (!bcast.IsValid()) {
ctx->SetStatus(errors::InvalidArgument("Incompatible shapes: ",
lhs_shape.DebugString(), " vs. ",
@@ -86,51 +88,18 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) {
}
/* static */ std::pair<xla::XlaOp, xla::XlaOp> XlaBinaryOp::Broadcast(
- xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs,
- const BCast& broadcast_helper) {
- // Manually construct the broadcasting since MapN does not do
- // automatic broadcasting. The bcast helper ensures that
- // lhs.reshape(bcast.x_reshape()).broadcast(bcast.x_bcast()) and
- // rhs.reshape(bcast.y_reshape()).broadcast(bcast.y_bcast()) have
- // the same shape, so can be operated on by MapN.
-
- // First reshape the inputs, which should be a metadata-only
- // operation since we are flattening the dimensions in order.
- auto lhs_shaped = xla::Reshape(lhs, broadcast_helper.x_reshape());
- auto rhs_shaped = xla::Reshape(rhs, broadcast_helper.y_reshape());
-
- // Next broadcast the necessary input dimensions. We rely on the
- // XLA optimizer to be smart about the fact that we are asking
- // it to broadcast size 1 on some of these dimensions, to avoid
- // adding complexity to this code.
- auto lhs_broadcast = xla::Broadcast(lhs_shaped, broadcast_helper.x_bcast());
- int lhs_size = broadcast_helper.x_bcast().size();
- auto rhs_broadcast = xla::Broadcast(rhs_shaped, broadcast_helper.y_bcast());
- int rhs_size = broadcast_helper.y_bcast().size();
-
- // Now reshape them to the correct output shape. After the
- // broadcast each side is twice as wide as it should be, since the
- // broadcast dimensions were prepended to the shape. Reshape
- // flattening each original dimension with the prepended broadcast
- // dimension. E.g. if we started out with lhs_shaped with shape
- // [5,2,3] and x_bcast was [2,1,7] then lhs_broadcast would have
- // shape [2,1,7,5,2,3] and we want to reshape it to [10,2,21].
- std::vector<int64> lhs_reorder;
- for (int i = 0; i < lhs_size; ++i) {
- lhs_reorder.push_back(i);
- lhs_reorder.push_back(i + lhs_size);
+ xla::XlaOp lhs, xla::XlaOp rhs, const BCast& broadcast_helper) {
+ auto lhs_output = BroadcastTo(lhs, broadcast_helper.output_shape());
+ if (!lhs_output.ok()) {
+ xla::XlaOp error = lhs.builder()->ReportError(lhs_output.status());
+ return {error, error};
}
- auto lhs_output =
- xla::Reshape(lhs_broadcast, lhs_reorder, broadcast_helper.output_shape());
- std::vector<int64> rhs_reorder;
- for (int i = 0; i < rhs_size; ++i) {
- rhs_reorder.push_back(i);
- rhs_reorder.push_back(i + rhs_size);
+ auto rhs_output = BroadcastTo(rhs, broadcast_helper.output_shape());
+ if (!rhs_output.ok()) {
+ xla::XlaOp error = rhs.builder()->ReportError(rhs_output.status());
+ return {error, error};
}
- auto rhs_output =
- xla::Reshape(rhs_broadcast, rhs_reorder, broadcast_helper.output_shape());
-
- return {lhs_output, rhs_output};
+ return {lhs_output.ValueOrDie(), rhs_output.ValueOrDie()};
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
index 6653944a91..516ead4bfe 100644
--- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
+++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.h
@@ -67,8 +67,7 @@ class XlaBinaryOp : public XlaOpKernel {
// 'broadcast_helper', yielding arguments 'lhs' and 'rhs' that have the same
// shape.
static std::pair<xla::XlaOp, xla::XlaOp> Broadcast(
- xla::XlaBuilder* builder, const xla::XlaOp& lhs, const xla::XlaOp& rhs,
- const BCast& broadcast_helper);
+ xla::XlaOp lhs, xla::XlaOp rhs, const BCast& broadcast_helper);
};
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/permute_op.cc b/tensorflow/compiler/tf2xla/kernels/permute_op.cc
new file mode 100644
index 0000000000..0764e5503d
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/permute_op.cc
@@ -0,0 +1,98 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+namespace tensorflow {
+namespace {
+
+class DataFormatVecPermuteOp : public XlaOpKernel {
+ public:
+ explicit DataFormatVecPermuteOp(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("src_format", &src_format_));
+ OP_REQUIRES(
+ ctx, src_format_.size() == 4,
+ errors::InvalidArgument("Data format should have 4 characters"));
+ TensorFormat data_format;
+ OP_REQUIRES(ctx, FormatFromString(src_format_, &data_format),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dst_format", &dst_format_));
+ OP_REQUIRES(
+ ctx, dst_format_.size() == 4,
+ errors::InvalidArgument("Data format should have 4 characters"));
+ OP_REQUIRES(ctx, FormatFromString(dst_format_, &data_format),
+ errors::InvalidArgument("Invalid data format"));
+ }
+ void Compile(XlaOpKernelContext* ctx) override {
+ auto builder = ctx->builder();
+ const TensorShape input_tensor_shape = ctx->InputShape(0);
+ int input_rank = input_tensor_shape.dims();
+ OP_REQUIRES(ctx, input_rank == 1 || input_rank == 2,
+ errors::InvalidArgument(
+ "Input must be a vector or matrix, but got shape ",
+ input_tensor_shape.DebugString()));
+ OP_REQUIRES(
+ ctx, input_tensor_shape.dim_size(0) == 4,
+ errors::InvalidArgument(
+ "First dimension of input must be of size 4, but got shape ",
+ input_tensor_shape.DebugString()));
+ if (input_rank == 2) {
+ OP_REQUIRES(
+ ctx, input_tensor_shape.dim_size(1) == 2,
+ errors::InvalidArgument(
+ "Second dimension of 2D input must be of size 2, but got shape ",
+ input_tensor_shape.DebugString()));
+ }
+ std::vector<int32> dst_indices(4, 0);
+ for (int i = 0; i < 4; ++i) {
+ for (int j = 0; j < 4; ++j) {
+ if (src_format_[i] == dst_format_[j]) {
+ dst_indices[i] = j;
+ break;
+ }
+ }
+ }
+ auto keys = xla::ConstantR1(builder, absl::Span<const int32>(dst_indices));
+ if (input_rank == 2) {
+ keys = xla::BroadcastInDim(
+ keys, xla::ShapeUtil::MakeShape(xla::S32, {4, 2}), {0});
+ }
+ auto sorted = xla::Sort(keys, ctx->Input(0), 0);
+ auto output = xla::GetTupleElement(sorted, 1);
+ ctx->SetOutput(0, output);
+ }
+
+ private:
+ string src_format_;
+ string dst_format_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(DataFormatVecPermuteOp);
+};
+
+// TODO(b/115384656): Support DT_INT64.
+REGISTER_XLA_OP(Name("DataFormatVecPermute").TypeConstraint("T", DT_INT32),
+ DataFormatVecPermuteOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
index 8102faad28..8eee5b1299 100644
--- a/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/reduce_window_op.cc
@@ -40,10 +40,16 @@ class ReduceWindowOp : public XlaOpKernel {
std::vector<int64> window_dimensions;
std::vector<int64> window_strides;
+ std::vector<int64> base_dilations;
+ std::vector<int64> window_dilations;
OP_REQUIRES_OK(context, context->ConstantInputAsIntVector(
"window_dimensions", &window_dimensions));
OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("window_strides",
&window_strides));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector("base_dilations",
+ &base_dilations));
+ OP_REQUIRES_OK(context, context->ConstantInputAsIntVector(
+ "window_dilations", &window_dilations));
const int rank = input_shape.dims();
OP_REQUIRES(context, rank == window_dimensions.size(),
@@ -56,6 +62,16 @@ class ReduceWindowOp : public XlaOpKernel {
"The size of window_strides must be equal to the input "
"rank (",
window_strides.size(), " vs. ", rank, ")"));
+ OP_REQUIRES(context, rank == base_dilations.size(),
+ errors::InvalidArgument(
+ "The size of base_dilations must be equal to the input "
+ "rank (",
+ base_dilations.size(), " vs. ", rank, ")"));
+ OP_REQUIRES(context, rank == window_dilations.size(),
+ errors::InvalidArgument(
+ "The size of window_dilations must be equal to the input "
+ "rank (",
+ window_dilations.size(), " vs. ", rank, ")"));
// Build the reducer function.
XlaCompiler::Argument reducer_arg;
@@ -102,7 +118,8 @@ class ReduceWindowOp : public XlaOpKernel {
xla::XlaOp output = xla::ReduceWindowWithGeneralPadding(
context->Input(0), context->Input(1), *reducer.computation,
- window_dimensions, window_strides, padding);
+ window_dimensions, window_strides, base_dilations, window_dilations,
+ padding);
context->SetOutput(0, output);
}
@@ -115,6 +132,8 @@ class ReduceWindowOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("XlaReduceWindow")
.CompileTimeConstInput("window_dimensions")
.CompileTimeConstInput("window_strides")
+ .CompileTimeConstInput("base_dilations")
+ .CompileTimeConstInput("window_dilations")
.CompileTimeConstInput("padding"),
ReduceWindowOp);
diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
index ab094d7dd1..57afd608de 100644
--- a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
@@ -104,7 +104,8 @@ class ScanOp : public XlaOpKernel {
}
auto output = xla::ReduceWindowWithGeneralPadding(
XlaHelpers::ConvertElementType(builder, ctx->Input(0), dtype), init,
- *reducer, window_dims, window_strides, padding);
+ *reducer, window_dims, window_strides,
+ /*base_dilations=*/{}, /*window_dilations=*/{}, padding);
output =
XlaHelpers::ConvertElementType(builder, output, ctx->input_type(0));
diff --git a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
index aaeeae01cc..45f03d8c21 100644
--- a/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/sort_ops.cc
@@ -25,11 +25,26 @@ class XlaSortOp : public XlaOpKernel {
explicit XlaSortOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* context) override {
- context->SetOutput(0, xla::Sort(context->Input(0)));
+ context->SetOutput(0, xla::Sort(context->Input("input")));
}
};
REGISTER_XLA_OP(Name("XlaSort"), XlaSortOp);
+class XlaKeyValueSortOp : public XlaOpKernel {
+ public:
+ explicit XlaKeyValueSortOp(OpKernelConstruction* context)
+ : XlaOpKernel(context) {}
+
+ void Compile(XlaOpKernelContext* context) override {
+ xla::XlaOp result =
+ xla::Sort(context->Input("keys"), context->Input("values"));
+ context->SetOutput(0, xla::GetTupleElement(result, 0));
+ context->SetOutput(1, xla::GetTupleElement(result, 1));
+ }
+};
+
+REGISTER_XLA_OP(Name("XlaKeyValueSort"), XlaKeyValueSortOp);
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc
new file mode 100644
index 0000000000..74d4fcc425
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc
@@ -0,0 +1,226 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// XLA TensorList operators.
+
+#include <limits>
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/concat_lib.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace {
+
+Status GetTensorListShape(xla::XlaBuilder* builder, xla::XlaOp op,
+ TensorShape* tensor_list_shape) {
+ auto shape_or_status = builder->GetShape(op);
+ if (!shape_or_status.ok()) {
+ return shape_or_status.status();
+ }
+ xla::Shape shape = shape_or_status.ValueOrDie();
+ TF_RET_CHECK(xla::ShapeUtil::IsTuple(shape));
+ return XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(shape, 0),
+ tensor_list_shape);
+}
+
+class TensorListReserveOp : public XlaOpKernel {
+ public:
+ explicit TensorListReserveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ TensorShape element_shape;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &element_shape));
+ int64 num_elements;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_elements));
+
+ TensorShape tensor_shape;
+ tensor_shape.AddDim(num_elements);
+ tensor_shape.AppendShape(element_shape);
+
+ xla::XlaBuilder* b = ctx->builder();
+ ctx->SetOutput(0, xla::Tuple(b, {xla::Broadcast(XlaHelpers::Zero(b, dtype_),
+ tensor_shape.dim_sizes()),
+ xla::ConstantR0<int32>(b, 0)}));
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorListReserveOp);
+};
+
+REGISTER_XLA_OP(Name("TensorListReserve")
+ .CompileTimeConstInput("element_shape")
+ .CompileTimeConstInput("num_elements"),
+ TensorListReserveOp);
+
+class EmptyTensorListOp : public XlaOpKernel {
+ public:
+ explicit EmptyTensorListOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ ctx->CtxFailure(
+ errors::InvalidArgument("XLA compilation requires a fixed tensor list "
+ "size. Use TensorListReserve instead."));
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(EmptyTensorListOp);
+};
+
+REGISTER_XLA_OP(Name("EmptyTensorList"), EmptyTensorListOp);
+
+class TensorListElementShapeOp : public XlaOpKernel {
+ public:
+ explicit TensorListElementShapeOp(OpKernelConstruction* ctx)
+ : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("shape_type", &shape_type_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::XlaBuilder* b = ctx->builder();
+ TensorShape shape;
+ OP_REQUIRES_OK(ctx, GetTensorListShape(b, ctx->Input(0), &shape));
+ shape.RemoveDim(0);
+
+ switch (shape_type_) {
+ case DT_INT64:
+ ctx->SetOutput(0, xla::ConstantR1<int64>(b, shape.dim_sizes()));
+ break;
+ case DT_INT32: {
+ std::vector<int32> size;
+ for (int64 s : shape.dim_sizes()) {
+ size.push_back(s);
+ }
+ ctx->SetOutput(0, xla::ConstantR1<int32>(b, size));
+ break;
+ }
+ default:
+ ctx->CtxFailure(
+ errors::InvalidArgument("Unsupported shape type requested"));
+ return;
+ }
+ }
+
+ private:
+ DataType shape_type_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorListElementShapeOp);
+};
+
+REGISTER_XLA_OP(Name("TensorListElementShape"), TensorListElementShapeOp);
+
+class TensorListPushBackOp : public XlaOpKernel {
+ public:
+ explicit TensorListPushBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp list = ctx->Input(0);
+ TensorShape elem_shape = ctx->InputShape(1);
+
+ xla::XlaOp ta = xla::GetTupleElement(list, 0);
+ xla::XlaOp index = xla::GetTupleElement(list, 1);
+ xla::XlaOp value = ctx->Input(1);
+
+ // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
+ auto start_indices =
+ xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0),
+ xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
+
+ TensorShape slice_shape = elem_shape;
+ slice_shape.InsertDim(0, 1LL);
+ auto update = xla::Reshape(value, slice_shape.dim_sizes());
+
+ // TODO(phawkins): We don't check the index is in bounds --- there is no
+ // error mechanism in XLA.
+ ctx->SetOutput(
+ 0, xla::Tuple(b, {xla::DynamicUpdateSlice(ta, update, start_indices),
+ index + xla::ConstantR0<int32>(b, 1)}));
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorListPushBackOp);
+};
+
+REGISTER_XLA_OP(Name("TensorListPushBack"), TensorListPushBackOp);
+
+class TensorListPopBackOp : public XlaOpKernel {
+ public:
+ explicit TensorListPopBackOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("element_dtype", &dtype_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ xla::XlaBuilder* b = ctx->builder();
+ xla::XlaOp state = ctx->Input(0);
+
+ TensorShape shape;
+ OP_REQUIRES_OK(ctx, GetTensorListShape(b, state, &shape));
+
+ xla::XlaOp ta = xla::GetTupleElement(state, 0);
+ xla::XlaOp index = xla::GetTupleElement(state, 1);
+
+ index = index - xla::ConstantR0<int32>(b, 1);
+
+ // start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
+ auto start_indices =
+ xla::Pad(xla::Reshape(index, {1}), xla::ConstantR0<int32>(b, 0),
+ xla::MakeEdgePaddingConfig({{0, shape.dims() - 1}}));
+
+ auto slice_shape = shape.dim_sizes();
+ slice_shape[0] = 1LL;
+
+ // TODO(phawkins): We don't check the index is in bounds --- there is no
+ // error mechanism in XLA.
+ xla::XlaOp read = xla::DynamicSlice(ta, start_indices, slice_shape);
+ // Remove the leading '1' dimension.
+ std::vector<int64> value_shape(slice_shape.begin() + 1, slice_shape.end());
+
+ ctx->SetOutput(0, xla::Tuple(b, {ta, index}));
+ ctx->SetOutput(1, xla::Reshape(read, value_shape));
+ }
+
+ private:
+ DataType dtype_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TensorListPopBackOp);
+};
+
+REGISTER_XLA_OP(Name("TensorListPopBack"), TensorListPopBackOp);
+
+} // anonymous namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD
index 8597e7f139..1ce3930fd1 100644
--- a/tensorflow/compiler/tf2xla/lib/BUILD
+++ b/tensorflow/compiler/tf2xla/lib/BUILD
@@ -32,6 +32,22 @@ cc_library(
)
cc_library(
+ name = "broadcast",
+ srcs = ["broadcast.cc"],
+ hdrs = ["broadcast.h"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+cc_library(
name = "cholesky",
srcs = ["cholesky.cc"],
hdrs = ["cholesky.h"],
diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.cc b/tensorflow/compiler/tf2xla/lib/broadcast.cc
new file mode 100644
index 0000000000..3e402ef855
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/lib/broadcast.cc
@@ -0,0 +1,93 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/lib/broadcast.h"
+
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/util.h"
+
+namespace tensorflow {
+
+xla::StatusOr<xla::XlaOp> BroadcastTo(xla::XlaOp input,
+ absl::Span<int64 const> output_dims) {
+ xla::XlaBuilder* builder = input.builder();
+ TF_ASSIGN_OR_RETURN(xla::Shape input_shape, builder->GetShape(input));
+ absl::Span<int64 const> input_dims =
+ xla::AsInt64Slice(input_shape.dimensions());
+
+ if (input_dims == output_dims) {
+ return input;
+ }
+
+ if (input_dims.size() > output_dims.size()) {
+ return errors::InvalidArgument(
+ "Input shape (", xla::ShapeUtil::HumanString(input_shape),
+ ") must have rank less than or equal to the output shape [",
+ absl::StrJoin(output_dims, ","), "]");
+ }
+
+ std::vector<int64> broadcast_dims;
+ std::vector<int64> broadcast_shape;
+ auto input_it = input_dims.rbegin();
+ for (auto output_it = output_dims.rbegin(); output_it != output_dims.rend();
+ ++output_it) {
+ if (input_it != input_dims.rend()) {
+ if (!(*output_it == 0 && *input_it == 0) &&
+ !(*input_it != 0 && *output_it % *input_it == 0)) {
+ return errors::InvalidArgument("Invalid shape broadcast from ",
+ xla::ShapeUtil::HumanString(input_shape),
+ " to [", absl::StrJoin(output_dims, ","),
+ "]");
+ }
+
+ broadcast_dims.push_back(broadcast_shape.size());
+ if (*output_it == *input_it) {
+ broadcast_shape.push_back(*output_it);
+ } else if (*output_it != *input_it) {
+ // Add dimensions [I, O/I], which we will later flatten to just
+ // [O]. We must do this in two phases since XLA broadcasting does not
+ // support tiling.
+ broadcast_shape.push_back(*input_it);
+ broadcast_shape.push_back(*output_it / *input_it);
+ }
+ ++input_it;
+ } else {
+ broadcast_shape.push_back(*output_it);
+ }
+ }
+ TF_RET_CHECK(input_it == input_dims.rend());
+
+ absl::c_reverse(broadcast_dims);
+ int broadcast_shape_size = broadcast_shape.size();
+ for (int64& broadcast_dim : broadcast_dims) {
+ broadcast_dim = broadcast_shape_size - broadcast_dim - 1;
+ }
+ absl::c_reverse(broadcast_shape);
+ xla::XlaOp output = xla::BroadcastInDim(
+ input,
+ xla::ShapeUtil::MakeShape(input_shape.element_type(), broadcast_shape),
+ broadcast_dims);
+ if (broadcast_shape != output_dims) {
+ output = xla::Reshape(output, output_dims);
+ }
+ return output;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/broadcast.h b/tensorflow/compiler/tf2xla/lib/broadcast.h
new file mode 100644
index 0000000000..591e696f06
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/lib/broadcast.h
@@ -0,0 +1,32 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_
+#define TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_
+
+#include "absl/types/span.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/statusor.h"
+
+namespace tensorflow {
+
+// Broadcasts 'input' up to shape 'output_dims', using TensorFlow broadcasting
+// rules. Supports broadcasting a dimension of size x to size x*y, i.e., tiling.
+xla::StatusOr<xla::XlaOp> BroadcastTo(xla::XlaOp input,
+ absl::Span<int64 const> output_dims);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_BROADCAST_H_
diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc
index 38dfde165d..2b1c2ced92 100644
--- a/tensorflow/compiler/tf2xla/lib/scatter.cc
+++ b/tensorflow/compiler/tf2xla/lib/scatter.cc
@@ -38,12 +38,10 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
combiner,
xla::XlaBuilder* builder) {
TF_ASSIGN_OR_RETURN(xla::Shape buffer_shape, builder->GetShape(buffer));
- TF_RETURN_IF_ERROR(builder->GetShape(updates).status());
+ TF_ASSIGN_OR_RETURN(xla::Shape updates_shape, builder->GetShape(updates));
TF_ASSIGN_OR_RETURN(xla::Shape indices_shape, builder->GetShape(indices));
absl::Span<const int64> indices_dims =
xla::AsInt64Slice(indices_shape.dimensions());
- absl::Span<const int64> buffer_dims =
- xla::AsInt64Slice(buffer_shape.dimensions());
// If the indices are N-dimensional, the minor dimension of indices contains
// the indices to update. Otherwise the indices are all scalars.
@@ -81,104 +79,129 @@ xla::StatusOr<xla::XlaOp> XlaScatter(
}
}
- // Shape of the non-indexed dimensions of the buffer.
- std::vector<int64> buffer_shape_post_axes(
- buffer_dims.begin() + num_index_dims, buffer_dims.end());
-
- // Flatten the major dimensions of indices and updates into a single dimension
- // for ease of iteration.
- std::vector<int64> flat_indices_shape({num_indices});
- if (indices_are_vectors) {
- flat_indices_shape.push_back(num_index_dims);
+ // Example of a 1-D scatter that updates two [3,1] tensors in a tensor of
+ // shape [3,3]:
+ // NOTE: ***This case will not be generated by any of the tf.scatter ops.***
+ //
+ // operand = s32[3,3] parameter(0)
+ // indices = s32[2] parameter(1)
+ // updates = s32[3,2] parameter(2)
+ // scatter = s32[3,3] scatter(operand, indices, updates),
+ // to_apply=update_computation,
+ // update_window_dims={0},
+ // inserted_window_dims={1},
+ // scatter_dims_to_operand_dims={1},
+ // index_vector_dim=1
+ //
+ //
+ // Example of a 1-D scatter that updates two [1,3] tensors in a tensor of
+ // shape [3,3]:
+ //
+ // operand = s32[3,3] parameter(0)
+ // indices = s32[2] parameter(1)
+ // updates = s32[2,3] parameter(2)
+ // scatter = s32[3,3] scatter(operand, indices, updates),
+ // to_apply=update_computation,
+ // update_window_dims={1},
+ // inserted_window_dims={0},
+ // scatter_dims_to_operand_dims={0},
+ // index_vector_dim=1
+ //
+ //
+ // Example of an N-D scatter updating slices of shape [1,1,2] in a tensor of
+ // shape [3,3,2]
+ //
+ // operand = s32[3,3,2] parameter(0)
+ // indices = s32[2,2] parameter(1)
+ // updates = s32[2,2] parameter(2)
+ // scatter = s32[3,3,2] scatter(operand, indices, updates),
+ // to_apply=update_computation,
+ // update_window_dims={1},
+ // inserted_window_dims={0,1},
+ // scatter_dims_to_operand_dims={0,1},
+ // index_vector_dim=1
+ //
+ //
+ // Example of a scatter updating slices of shape [] in a tensor of shape [1,1]
+ //
+ // operand = s32[1,1] parameter(0)
+ // indices = s32[1] parameter(1)
+ // updates = s32[1] parameter(2)
+ // scatter = s32[1,1] scatter(operand, indices, updates),
+ // to_apply=update_computation,
+ // update_window_dims={},
+ // inserted_window_dims={0,1},
+ // scatter_dims_to_operand_dims={0},
+ // index_vector_dim=1
+ // Note that updates operand would be broadcasted into [1] in this case.
+ //
+
+ xla::ScatterDimensionNumbers dim_numbers;
+ dim_numbers.set_index_vector_dim(indices_are_vectors
+ ? indices_shape.dimensions_size() - 1
+ : indices_shape.dimensions_size());
+
+ int64 updates_rank = xla::ShapeUtil::Rank(updates_shape);
+ int64 buffer_rank = xla::ShapeUtil::Rank(buffer_shape);
+ int64 num_window_dims_in_updates = buffer_rank - num_index_dims;
+
+ // If the rank of `updates` is 0 and does not match the expected rank of
+ // updates, broadcast `updates` to the expected shape of updates.
+ auto new_updates = updates;
+ std::vector<int64> expected_updates_dims(indices_dims.begin(),
+ indices_dims.end());
+ for (int64 dim = num_index_dims; dim < buffer_rank; ++dim) {
+ expected_updates_dims.push_back(buffer_shape.dimensions(dim));
+ }
+ int64 expected_updates_rank = expected_updates_dims.size();
+ if (updates_rank == 0 && expected_updates_rank != 0) {
+ new_updates = xla::Broadcast(updates, expected_updates_dims);
+ TF_ASSIGN_OR_RETURN(updates_shape, builder->GetShape(new_updates));
+ updates_rank = xla::ShapeUtil::Rank(updates_shape);
}
- std::vector<int64> flat_updates_shape({num_indices});
- flat_updates_shape.insert(flat_updates_shape.end(),
- buffer_shape_post_axes.begin(),
- buffer_shape_post_axes.end());
-
- // Construct the initial values of the loop-carried Tensors.
- auto flat_indices = xla::Reshape(indices, flat_indices_shape);
- auto flat_updates = xla::Reshape(updates, flat_updates_shape);
- auto init = {flat_indices, flat_updates, buffer};
-
- // Constructs the loop body. The implementation of scatter is essentially:
- // for i in range(num_indices):
- // index = dynamic-slice(indices, i)
- // update = dynamic-slice(updates, i)
- // buffer = dynamic-update-slice(buffer, update, index)
- auto body_fn = [&](xla::XlaOp i, absl::Span<const xla::XlaOp> loop_vars,
- xla::XlaBuilder* body_builder) {
- auto indices = loop_vars[0];
- auto updates = loop_vars[1];
- auto buffer = loop_vars[2];
-
- auto zero_index = xla::ConstantLiteral(
- body_builder, xla::LiteralUtil::Zero(indices_shape.element_type()));
-
- // Slice the i-th index from the indices array.
- xla::XlaOp index;
- auto indices_offset = xla::Reshape(i, {1});
- if (indices_are_vectors) {
- indices_offset = xla::Pad(indices_offset, zero_index,
- xla::MakeEdgePaddingConfig({{0, 1}}));
-
- index = xla::DynamicSlice(indices, indices_offset, {1, num_index_dims});
- index = xla::Collapse(index, {0, 1});
- } else {
- index = xla::DynamicSlice(indices, indices_offset, {1});
+ if (updates_rank > 0) {
+ for (int64 i = (updates_rank - num_window_dims_in_updates);
+ i < updates_rank; ++i) {
+ dim_numbers.add_update_window_dims(i);
}
+ }
- // Discard updates with negative indices, since some users expect this.
- auto index_in_range = xla::ReduceAll(
- xla::Le(zero_index, index), xla::ConstantR0<bool>(body_builder, true),
- xla::CreateScalarAndComputation(xla::PRED, body_builder));
-
- // Make the index in bounds to prevent implementation defined behavior.
- index = xla::Max(index, zero_index);
- index = xla::Pad(
- index, zero_index,
- xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}}));
-
- // Slice the i-th index from the updates array.
- auto updates_offset = xla::Reshape(i, {1});
- updates_offset = xla::Pad(
- updates_offset, zero_index,
- xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}}));
- std::vector<int64> flat_updates_slice_shape({1});
- flat_updates_slice_shape.insert(flat_updates_slice_shape.end(),
- buffer_shape_post_axes.begin(),
- buffer_shape_post_axes.end());
- auto update =
- xla::DynamicSlice(updates, updates_offset, flat_updates_slice_shape);
-
- // Unflatten the major (iteration) dimensions of the slice to their
- // original shape.
- std::vector<int64> updates_slice_shape(num_index_dims, 1);
- updates_slice_shape.insert(updates_slice_shape.end(),
- buffer_shape_post_axes.begin(),
- buffer_shape_post_axes.end());
- update = xla::Reshape(update, updates_slice_shape);
-
- // Apply the update to the buffer. If there is a combiner, use it to merge
- // the current values with the update.
- auto current_value = xla::DynamicSlice(buffer, index, updates_slice_shape);
+ for (int64 i = 0; i < num_index_dims; ++i) {
+ dim_numbers.add_inserted_window_dims(i);
+ dim_numbers.add_scatter_dims_to_operand_dims(i);
+ }
+
+ // Build the combiner computation.
+ xla::XlaComputation combiner_computation;
+ {
+ xla::XlaBuilder cb("scatter-combiner");
+ auto xla_scalar_shape =
+ xla::ShapeUtil::MakeShape(buffer_shape.element_type(), {});
+ auto p0 = xla::Parameter(&cb, 0, xla_scalar_shape, "p0");
+ auto p1 = xla::Parameter(&cb, 1, xla_scalar_shape, "p1");
if (combiner) {
- update = combiner(current_value, update, body_builder);
+ combiner(p0, p1, &cb);
}
- // Use the current value instead of the update if the index is out of
- // bounds.
- update = xla::Select(index_in_range, update, current_value);
- // Apply the update.
- buffer = xla::DynamicUpdateSlice(buffer, update, index);
-
- return std::vector<xla::XlaOp>{indices, updates, buffer};
- };
-
- TF_ASSIGN_OR_RETURN(auto outputs,
- XlaForEachIndex(num_indices, indices_shape.element_type(),
- body_fn, init, "scatter", builder));
- return outputs[2];
+ combiner_computation = cb.Build().ConsumeValueOrDie();
+ }
+
+ VLOG(3) << "Scatter op:";
+ VLOG(3) << " Input: " << xla::ShapeUtil::HumanString(buffer_shape);
+ VLOG(3) << " Indices: " << xla::ShapeUtil::HumanString(indices_shape);
+ VLOG(3) << " Updates: " << xla::ShapeUtil::HumanString(updates_shape);
+ VLOG(3) << " Scatter Dimension Numbers: ";
+ VLOG(3) << " index_vector_dim: " << dim_numbers.index_vector_dim();
+ VLOG(3) << " update_window_dims: ["
+ << absl::StrJoin(dim_numbers.update_window_dims(), ",") << "]";
+ VLOG(3) << " inserted_window_dims: ["
+ << absl::StrJoin(dim_numbers.inserted_window_dims(), ",") << "]";
+ VLOG(3) << " scatter_dims_to_operand_dims: ["
+ << absl::StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ",")
+ << "]";
+
+ return xla::Scatter(buffer, indices, new_updates, combiner_computation,
+ dim_numbers);
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/scatter.h b/tensorflow/compiler/tf2xla/lib/scatter.h
index 13a5f1b850..4cf478c4b9 100644
--- a/tensorflow/compiler/tf2xla/lib/scatter.h
+++ b/tensorflow/compiler/tf2xla/lib/scatter.h
@@ -34,7 +34,11 @@ namespace tensorflow {
// Otherwise, `indices_are_vectors`, then indices are multidimensional and the
// minor dimension of `indices` represents a vector of indices.
//
-// If any indices are negative, the corresponding update is discarded.
+// If `updates` is a scalar, then it will be broadcasted into the expected shape
+// of updates.
+//
+// If any part of the update region is out-of-bounds, the corresponding update
+// is discarded.
//
// If a `combiner` is provided, updates are combined with the existing values in
// the buffer using the combiner function. Otherwise, the updates replace the
diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
index 733eeed3c6..bd2c0a5ee8 100644
--- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc
+++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc
@@ -283,6 +283,8 @@ REGISTER_OP("XlaReduceWindow")
.Input("init_value: T")
.Input("window_dimensions: Tindices")
.Input("window_strides: Tindices")
+ .Input("base_dilations: Tindices")
+ .Input("window_dilations: Tindices")
.Input("padding: Tindices")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
@@ -354,12 +356,33 @@ Wraps the XLA Sort operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#sort
.
-Sorts a tensor. Currently only rank 1 sorts in ascending order are supported.
+Sorts a tensor. Currently only sorts in ascending order are supported.
input: A `Tensor` of type T.
output: A `Tensor` of type T.
)doc");
+REGISTER_OP("XlaKeyValueSort")
+ .Input("keys: K")
+ .Input("values: V")
+ .Output("sorted_keys: K")
+ .Output("sorted_values: V")
+ .Attr("K: realnumbertype")
+ .Attr("V: type")
+ .SetShapeFn(shape_inference::UnchangedShape)
+ .Doc(R"doc(
+Wraps the XLA Sort operator, documented at
+ https://www.tensorflow.org/performance/xla/operation_semantics#sort
+.
+
+Sorts a tensor. Currently only sorts in ascending order are supported.
+
+keys: A `Tensor` of type K.
+values: A `Tensor` of type V.
+sorted_keys: A `Tensor` of type K.
+sorted_values: A `Tensor` of type V.
+)doc");
+
// TODO(b/37549631) setting the While Op to always be stateful is too
// conservative.
REGISTER_OP("XlaWhile")
diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py
index 27dd18a9bb..5e86b5d8ec 100644
--- a/tensorflow/compiler/tf2xla/python/xla.py
+++ b/tensorflow/compiler/tf2xla/python/xla.py
@@ -212,9 +212,9 @@ bitcast_convert_type = array_ops.bitcast
def broadcast(x, dims, name=None):
x = ops.convert_to_tensor(x)
- shape = array_ops.concat(
- [constant_op.constant(dims),
- array_ops.shape(x)], axis=0)
+ shape = array_ops.concat([constant_op.constant(dims),
+ array_ops.shape(x)],
+ axis=0)
return array_ops.broadcast_to(x, shape, name=name)
@@ -320,6 +320,8 @@ def reduce_window(operand,
reducer,
window_dimensions,
window_strides=None,
+ base_dilations=None,
+ window_dilations=None,
padding=None,
name=None):
"""Wraps the XLA ReduceWindow operator.
@@ -332,22 +334,27 @@ def reduce_window(operand,
init: a scalar tensor representing the initial value for the reduction
reducer: a reduction function that combines a pair of scalars.
window_dimensions: shape of the window, as a list of integers
- window_strides: inter-window strides, as a list of integers. Optional;
- if omitted, defaults to strides of 1.
+ window_strides: inter-window strides, as a list of integers. Optional; if
+ omitted, defaults to strides of 1.
padding: padding to apply to 'operand'. List of (low, high) pairs of
integers that specify the padding to apply before and after each
dimension. Optional; if omitted, defaults to no padding.
name: the operator name, or None.
+
Returns:
A tensor that represents the output of the reduce_window operator.
"""
window_strides = window_strides or [1] * len(window_dimensions)
+ base_dilations = base_dilations or [1] * len(window_dimensions)
+ window_dilations = window_dilations or [1] * len(window_dimensions)
padding = padding or [(0, 0)] * len(window_dimensions)
return gen_xla_ops.xla_reduce_window(
input=operand,
init_value=init,
window_dimensions=window_dimensions,
window_strides=window_strides,
+ base_dilations=base_dilations,
+ window_dilations=window_dilations,
padding=padding,
computation=reducer,
name=name)
@@ -377,4 +384,5 @@ def slice(x, start_dims, limit_dims, strides):
sort = gen_xla_ops.xla_sort
+key_value_sort = gen_xla_ops.xla_key_value_sort
while_loop = gen_xla_ops.xla_while
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index d5094e8ec5..b2c57e8880 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -194,6 +194,17 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
std::unique_ptr<Graph> graph = GetGraph(fbody);
+ // Clear the "_kernel" attribute if it is set to "host". This is used to
+ // indicate that a computation should happen on the host instead of the
+ // accelerator, but doesn't make sense in XLA.
+ const char* const kKernelAttr = "_kernel";
+ for (Node* n : graph->nodes()) {
+ string value;
+ if (GetNodeAttrSimple(n->attrs(), kKernelAttr, &value) && value == "host") {
+ n->ClearAttr(kKernelAttr);
+ }
+ }
+
// _Arg and _Retval nodes don't exist in the stored subgraph for the function;
// they are added by the function body looked up. Therefore, they don't have
// core assignments here.
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 2a9eaeee14..dd3498ef7a 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -455,23 +455,43 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
return Status::OK();
}
+Status XlaOpKernelContext::allocate_output(int index, const xla::Shape& shape,
+ Tensor** output) {
+ // The step's default allocator is the dummy XlaCompilationAllocator which
+ // simply allocates a metadata buffer to hold the expression to which it
+ // corresponds.
+ if (expected_output_dtype(index) == DT_VARIANT) {
+ // tensor_data() is not supported for variant Tensor (i.e.,
+ // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the
+ // XlaExpression inside the Tensor's tensor_data() does not work for
+ // variant. Instead construct a uint8 tensor and store the expression in its
+ // value.
+ // TODO(jpienaar): This should be refactored to stop masquerading
+ // XlaExpressions as Tensors.
+ *output = new Tensor();
+ TensorShape tensor_shape;
+ TF_RETURN_IF_ERROR(
+ context_->allocate_temp(DT_UINT8, tensor_shape, *output));
+ context_->set_output(index, **output);
+ } else {
+ TensorShape tensor_shape;
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &tensor_shape));
+ TF_RETURN_IF_ERROR(context_->allocate_output(index, tensor_shape, output));
+ }
+ return Status::OK();
+}
+
void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
// Makes the host Tensor that will refer to the expression.
Tensor* output = nullptr;
- auto shape = builder()->GetShape(handle);
- if (!shape.ok()) {
- SetStatus(shape.status());
+ auto shape_or = builder()->GetShape(handle);
+ if (!shape_or.ok()) {
+ SetStatus(shape_or.status());
return;
}
- // The step's default allocator is the dummy XlaCompilationAllocator which
- // simply allocates a metadata buffer to hold the expression to which it
- // corresponds.
- TensorShape tensor_shape;
- OP_REQUIRES_OK(context_,
- XLAShapeToTensorShape(shape.ValueOrDie(), &tensor_shape));
OP_REQUIRES_OK(context_,
- context_->allocate_output(index, tensor_shape, &output));
+ allocate_output(index, shape_or.ValueOrDie(), &output));
// The expression is stored in the tensor's data buffer. Fill in the
// fields now.
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index a3a0d10cc0..aa00a45496 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -255,6 +255,11 @@ class XlaOpKernelContext {
// Returns the tensor of input `name`.
const Tensor& GetInputTensorByName(absl::string_view name);
+ // Wraps OpKernelContext's allocate_output method while providing special
+ // behavior for DT_VARIANT: a variant is treated as DT_UINT8 scalar as the
+ // type to allow mapping for variant to more generic types.
+ Status allocate_output(int index, const xla::Shape& shape, Tensor** output);
+
OpKernelContext* const context_;
};
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index e0ec91dba1..6b31831010 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -208,6 +208,9 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle,
case HloOpcode::kWhile:
// TODO(b/32495713): We aren't checking the condition and body
// computations themselves.
+ case HloOpcode::kScatter:
+ // TODO(b/32495713): We aren't checking the embedded computation in
+ // Scatter.
case HloOpcode::kSend:
case HloOpcode::kRecv:
case HloOpcode::kParameter:
@@ -1786,9 +1789,9 @@ XlaOp XlaBuilder::ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
std::vector<std::pair<int64, int64>> padding_values =
MakePadding(AsInt64Slice(operand_shape.dimensions()), window_dimensions,
window_strides, padding);
- return ReduceWindowWithGeneralPadding(operand, init_value, computation,
- window_dimensions, window_strides,
- padding_values);
+ return ReduceWindowWithGeneralPadding(
+ operand, init_value, computation, window_dimensions, window_strides,
+ /*base_dilations=*/{}, /*window_dilations=*/{}, padding_values);
});
}
@@ -1797,6 +1800,8 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
@@ -1807,7 +1812,8 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
computation.GetProgramShape());
TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
MakeWindow(window_dimensions, window_strides, padding,
- /*lhs_dilation=*/{}, /*rhs_dilation=*/{}));
+ /*lhs_dilation=*/base_dilations,
+ /*rhs_dilation=*/window_dilations));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferReduceWindowShape(operand_shape, init_shape,
@@ -2797,10 +2803,12 @@ XlaOp ReduceWindowWithGeneralPadding(
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding) {
return operand.builder()->ReduceWindowWithGeneralPadding(
operand, init_value, computation, window_dimensions, window_strides,
- padding);
+ base_dilations, window_dilations, padding);
}
XlaOp CrossReplicaSum(const XlaOp& operand,
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index cd0d5ca5d3..2e14e47a35 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -671,6 +671,8 @@ class XlaBuilder {
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding);
// Returns the sum of the operand value within each subgroup of replicas. All
@@ -1245,6 +1247,8 @@ class XlaBuilder {
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding);
friend XlaOp CrossReplicaSum(const XlaOp& operand,
absl::Span<const ReplicaGroup> replica_groups);
@@ -1818,6 +1822,8 @@ XlaOp ReduceWindowWithGeneralPadding(
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding);
// Returns the sum of the operand value within each subgroup of replicas. All
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
index deeb140b8f..656ce720a1 100644
--- a/tensorflow/compiler/xla/literal.cc
+++ b/tensorflow/compiler/xla/literal.cc
@@ -727,16 +727,34 @@ Literal LiteralBase::Slice(absl::Span<const int64> start_indices,
ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions,
LayoutUtil::MinorToMajor(shape()));
switch (result_shape.element_type()) {
- case F32:
- return SliceInternal<float>(result_shape, start_indices);
+ case PRED:
+ return SliceInternal<bool>(result_shape, start_indices);
+ case U8:
+ return SliceInternal<uint8>(result_shape, start_indices);
+ case U16:
+ return SliceInternal<uint16>(result_shape, start_indices);
+ case U32:
+ return SliceInternal<uint32>(result_shape, start_indices);
+ case U64:
+ return SliceInternal<uint64>(result_shape, start_indices);
+ case S8:
+ return SliceInternal<int8>(result_shape, start_indices);
+ case S16:
+ return SliceInternal<int16>(result_shape, start_indices);
+ case S32:
+ return SliceInternal<int32>(result_shape, start_indices);
+ case S64:
+ return SliceInternal<int64>(result_shape, start_indices);
+ case F16:
+ return SliceInternal<half>(result_shape, start_indices);
case BF16:
return SliceInternal<bfloat16>(result_shape, start_indices);
+ case F32:
+ return SliceInternal<float>(result_shape, start_indices);
+ case F64:
+ return SliceInternal<double>(result_shape, start_indices);
case C64:
return SliceInternal<complex64>(result_shape, start_indices);
- case S32:
- return SliceInternal<int32>(result_shape, start_indices);
- case U32:
- return SliceInternal<uint32>(result_shape, start_indices);
default:
LOG(FATAL) << "not yet implemented: "
<< PrimitiveType_Name(result_shape.element_type());
@@ -1927,11 +1945,11 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
}
} break;
case TUPLE:
- LOG(FATAL) << "Should not be called on tuple shapes: "
- << ShapeUtil::HumanString(subshape());
- break;
+ return InvalidArgument("Should not be called on tuple shapes: %s",
+ ShapeUtil::HumanString(subshape()));
default:
- LOG(FATAL) << "Unhandled primitive type " << subshape().element_type();
+ return InvalidArgument("Is called on unsupported shape: %s",
+ ShapeUtil::HumanString(subshape()));
}
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index cd5fd33029..ffa336f304 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -532,10 +532,13 @@ LocalOp LocalComputationBuilder::ReduceWindowWithGeneralPadding(
const LocalComputation& local_computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding) {
return xla::ReduceWindowWithGeneralPadding(
operand.op(), init_value.op(), local_computation.computation(),
- window_dimensions, window_strides, padding);
+ window_dimensions, window_strides, base_dilations, window_dilations,
+ padding);
}
LocalOp LocalComputationBuilder::RngNormal(const LocalOp& mu,
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index 2166bb6721..43332e0abd 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -278,6 +278,8 @@ class LocalComputationBuilder {
const LocalComputation& local_computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
+ absl::Span<const int64> base_dilations,
+ absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64> > padding);
LocalOp RngNormal(const LocalOp& mu, const LocalOp& sigma,
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index bb303c5678..f8197488fb 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -995,7 +995,30 @@ class ComputationBuilder(object):
window_strides)
return self._client.ReduceWindowWithGeneralPadding(
operand, init_value, computation_to_apply.c_local_computation,
- window_dimensions, window_strides, pads)
+ window_dimensions, window_strides, (), (), pads)
+
+ def ReduceWindowWithGeneralPadding(
+ self, operand, init_value, computation_to_apply, window_dimensions,
+ window_strides, base_dilations, window_dilations, padding):
+ """Enqueues a windowed reduction operation onto the computation.
+
+ Args:
+ operand: reduction operand (LocalOp).
+ init_value: reduction initial value (LocalOp).
+ computation_to_apply: a binary reduction function (Computation).
+ window_dimensions: dimensions of window (sequence of integers).
+ window_strides: strides for window (sequence of integers).
+ base_dilations: dilations for the base (sequence of integers).
+ window_dilations: dilations for window (sequence of integers).
+ padding: length-N array-like of pairs of integers of (low, high) padding.
+
+ Returns:
+ A LocalOp representing the added ReduceWindow op.
+ """
+ return self._client.ReduceWindowWithGeneralPadding(
+ operand, init_value, computation_to_apply.c_local_computation,
+ window_dimensions, window_strides, base_dilations, window_dilations,
+ padding)
def RngNormal(self, mu, sigma, dims):
"""Enqueues an RngNormal operation onto the computation.
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index f329a27e14..2b292ed053 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -1324,10 +1324,19 @@ cc_library(
)
cc_library(
+ name = "fusion_queue",
+ hdrs = ["fusion_queue.h"],
+ deps = [
+ ":hlo",
+ ],
+)
+
+cc_library(
name = "instruction_fusion",
srcs = ["instruction_fusion.cc"],
hdrs = ["instruction_fusion.h"],
deps = [
+ ":fusion_queue",
":hlo",
":hlo_pass",
"//tensorflow/compiler/xla:util",
@@ -1833,42 +1842,6 @@ tf_cc_test(
)
cc_library(
- name = "inliner",
- srcs = ["inliner.cc"],
- hdrs = ["inliner.h"],
- deps = [
- ":hlo",
- ":hlo_pass",
- ":hlo_query",
- "//tensorflow/compiler/xla:status_macros",
- "//tensorflow/compiler/xla:types",
- "//tensorflow/core:lib",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-tf_cc_test(
- name = "inliner_test",
- srcs = ["inliner_test.cc"],
- deps = [
- ":cpu_plugin",
- ":hlo",
- ":hlo_matchers",
- ":inliner",
- "//tensorflow/compiler/xla:literal",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:test",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/tests:hlo_test_base",
- "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
- "//tensorflow/compiler/xla/tests:literal_test_util",
- "//tensorflow/compiler/xla/tests:xla_internal_test_main",
- "@com_google_absl//absl/memory",
- ],
-)
-
-cc_library(
name = "computation_placer",
srcs = ["computation_placer.cc"],
hdrs = ["computation_placer.h"],
@@ -2477,6 +2450,7 @@ tf_cc_test(
":hlo",
":hlo_parser",
":hlo_verifier",
+ ":layout_assignment",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
@@ -3483,6 +3457,39 @@ cc_library(
deps = ["//tensorflow/core:lib"],
)
+cc_library(
+ name = "map_inliner",
+ srcs = ["map_inliner.cc"],
+ hdrs = ["map_inliner.h"],
+ deps = [
+ ":hlo",
+ ":hlo_pass",
+ ":hlo_query",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/core:lib",
+ "@com_google_absl//absl/types:span",
+ ],
+)
+
+tf_cc_test(
+ name = "map_inliner_test",
+ srcs = ["map_inliner_test.cc"],
+ deps = [
+ ":hlo",
+ ":hlo_matchers",
+ ":map_inliner",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_verified_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
+ "@com_google_absl//absl/memory",
+ ],
+)
+
tf_cc_test(
name = "hlo_casting_utils_test",
srcs = ["hlo_casting_utils_test.cc"],
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 75dae7a714..86d9dbea90 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -2057,6 +2057,12 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow(
return Status::OK();
}
+ // Bail on dilation.
+ if (window_util::HasDilation(window)) {
+ VLOG(10) << "Not folding pad into reduce-window as there is dilation.";
+ return Status::OK();
+ }
+
VLOG(10) << "Considering folding Pad: " << pad->ToString()
<< "\ninto reduce-window: " << reduce_window->ToString()
<< (convert != nullptr
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index ae4c6e962d..58abb330a6 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -94,6 +94,7 @@ cc_library(
":target_machine_features",
"@com_google_absl//absl/types:span",
"//tensorflow/compiler/tf2xla:cpu_function_runtime",
+ "//tensorflow/compiler/xla/service:map_inliner",
"//tensorflow/compiler/xla/service:scatter_expander",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:protobuf_util",
@@ -127,7 +128,6 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:indexed_array_analysis",
- "//tensorflow/compiler/xla/service:inliner",
"//tensorflow/compiler/xla/service:llvm_compiler",
"//tensorflow/compiler/xla/service:reduce_precision_insertion",
"//tensorflow/compiler/xla/service:reshape_mover",
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index afc94f2185..68c715a086 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -86,8 +86,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
-#include "tensorflow/compiler/xla/service/inliner.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/service/map_inliner.h"
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/scatter_expander.h"
@@ -249,7 +249,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
&pipeline, module->config().debug_options(),
ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION);
- pipeline.AddPass<Inliner>();
+ pipeline.AddPass<MapInliner>();
// TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner
// pass.
@@ -327,8 +327,13 @@ Status CpuCompiler::RunHloPassesAfterLayoutAssn(
{
auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
"simplification after layout assignement");
- pass.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
- /*allow_mixed_precision=*/false);
+ // TODO(b/117156505): When the bug is fixed, the CPU backend should not
+ // produce layout changing elementwise operations. We will then pass
+ // LayoutAssignment::InstructionCanChangeLayout to the HLO verifier to
+ // enable stricter verification.
+ pass.AddInvariantChecker<HloVerifier>(
+ /*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false);
pass.AddPass<HloPassFix<AlgebraicSimplifier>>(
/*is_layout_sensitive=*/true,
[](const Shape&, const Shape&) { return true; },
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index a70abb117a..b2abdb39a5 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -688,8 +688,25 @@ StatusOr<llvm::Value*> IrEmitter::EmitTargetElementLoopBodyForReduceWindow(
for (size_t i = 0; i < index.size(); ++i) {
llvm::Value* strided_index =
NSWMul(index[i], b_.getInt64(window.dimensions(i).stride()));
- input_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]),
- b_.getInt64(window.dimensions(i).padding_low()));
+ input_index[i] = NSWSub(
+ NSWAdd(strided_index,
+ NSWMul(window_index[i],
+ b_.getInt64(window.dimensions(i).window_dilation()))),
+ b_.getInt64(window.dimensions(i).padding_low()));
+
+ // We need to verify that we are not in the dilated base area.
+ llvm::Value* dilation_condition = ICmpEQ(
+ SRem(input_index[i], b_.getInt64(window.dimensions(i).base_dilation())),
+ b_.getInt64(0));
+ if (in_bounds_condition == nullptr) {
+ in_bounds_condition = dilation_condition;
+ } else {
+ in_bounds_condition = And(in_bounds_condition, dilation_condition);
+ }
+
+ // Apply base dilation to the index.
+ input_index[i] =
+ SDiv(input_index[i], b_.getInt64(window.dimensions(i).base_dilation()));
// We need to check if 0 <= input_index[i] < bound, as otherwise we are in
// the padding so that we can skip the computation. That is equivalent to
@@ -728,12 +745,6 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) {
/*operands=*/{reduce_window->operand(0)},
/*supported_types=*/{F32, BF16, S32, F16}));
- // TODO(b/31410564): Implement dilation for reduce-window.
- if (window_util::HasDilation(reduce_window->window())) {
- return Unimplemented(
- "Dilation for ReduceWindow is not implemented on CPU.");
- }
-
// Pseudo code for reduce window:
//
// for (coordinates O in the output)
diff --git a/tensorflow/compiler/xla/service/fusion_queue.h b/tensorflow/compiler/xla/service/fusion_queue.h
new file mode 100644
index 0000000000..1208a7dda8
--- /dev/null
+++ b/tensorflow/compiler/xla/service/fusion_queue.h
@@ -0,0 +1,53 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_QUEUE_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_QUEUE_H_
+
+#include <utility>
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+
+namespace xla {
+
+// A queue interface that allows implementations to choose fusion candidates in
+// custom order.
+class FusionQueue {
+ public:
+ FusionQueue() = default;
+ virtual ~FusionQueue() = default;
+
+ // Dequeues the next fusion candidates: a consumer and the list of producers
+ // as operand indices.
+ virtual std::pair<HloInstruction*, std::vector<int64>>
+ DequeueNextInstructionAndOperandsToFuseInOrder() = 0;
+
+ // A callback passed to the queue implementation right before the producer is
+ // fused into the consumer.
+ virtual void PreFusion(HloInstruction* producer, HloInstruction* consumer) {}
+
+ // A callback passed to the queue implementation right after the fusion is
+ // created. Note that original_producer could have been destroyed.
+ virtual void OnFusingInstruction(HloInstruction* fusion,
+ HloInstruction* original_producer,
+ HloInstruction* original_consumer) {}
+
+ // A callback passed to the queue implementation to notify the removal of an
+ // instruction.
+ virtual void RemoveInstruction(HloInstruction* instruction) = 0;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_QUEUE_H_
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
index f92fde7f46..bec02e14f9 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
@@ -98,7 +98,7 @@ Status GenericTransferManager::TransferLiteralFromDeviceInternal(
Status GenericTransferManager::TransferLiteralToDeviceAsync(
se::Stream* stream, const LiteralSlice& literal,
- const ShapedBuffer& device_buffer, TransferToDeviceHint /*hint*/) {
+ const ShapedBuffer& device_buffer) {
const Shape& shape = literal.shape();
VLOG(2) << "transferring literal shape to device: "
<< ShapeUtil::HumanString(shape)
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h
index b1cba82b9f..86c8b1c145 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h
@@ -45,10 +45,9 @@ class GenericTransferManager : public TransferManager {
MutableBorrowingLiteral literal,
std::function<void(Status)> done) override;
- Status TransferLiteralToDeviceAsync(se::Stream* stream,
- const LiteralSlice& literal,
- const ShapedBuffer& device_buffer,
- TransferToDeviceHint hint) override;
+ Status TransferLiteralToDeviceAsync(
+ se::Stream* stream, const LiteralSlice& literal,
+ const ShapedBuffer& device_buffer) override;
Status TransferLiteralToInfeed(se::StreamExecutor* executor,
const LiteralSlice& literal) override;
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 522e9f5948..350fd32537 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -404,6 +404,7 @@ cc_library(
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
+ "@com_google_absl//absl/time",
"@com_google_absl//absl/types:optional",
],
)
@@ -780,7 +781,6 @@ cc_library(
srcs = ["gpu_layout_assignment.cc"],
hdrs = ["gpu_layout_assignment.h"],
deps = [
- ":gpu_options",
":ir_emission_utils",
":stream_executor_util",
"//tensorflow/compiler/xla:shape_util",
@@ -882,16 +882,6 @@ cc_library(
)
cc_library(
- name = "gpu_options",
- srcs = ["gpu_options.cc"],
- hdrs = ["gpu_options.h"],
- deps = [
- "//tensorflow/compiler/xla/service:hlo_module_config",
- "//tensorflow/core:lib_internal",
- ],
-)
-
-cc_library(
name = "stream_executor_util",
srcs = ["stream_executor_util.cc"],
hdrs = ["stream_executor_util.h"],
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
index 7125673887..590c0a7d54 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc
@@ -145,7 +145,7 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) {
// cache misses and doing extra work. Overall, caching doesn't seem worth the
// trouble, but we may want to revisit this if we ever find a model where
// caching would speed up compilation a lot.
-StatusOr<std::tuple<int64, bool, int64>>
+StatusOr<CudnnConvolutionAlgorithmPicker::AutotuneResult>
CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
HloCustomCallInstruction* instr) {
// TODO(timshen): for now only check fp16. It can be expanded to other types,
@@ -316,9 +316,10 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
<< AlgorithmToString(best_result.algorithm()) << ", takes "
<< best_result.elapsed_time_in_ms() << "ms, and uses "
<< best_result_bytes_used << "B of scratch memory.";
- return std::make_tuple(best_result.algorithm().algo_id(),
- best_result.algorithm().tensor_ops_enabled(),
- best_result_bytes_used);
+ return AutotuneResult{best_result.algorithm().algo_id(),
+ best_result.algorithm().tensor_ops_enabled(),
+ best_result_bytes_used,
+ absl::Milliseconds(best_result.elapsed_time_in_ms())};
}
return InternalError(
@@ -331,37 +332,30 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
HloInstruction* instr) {
CHECK(IsCustomCallToDnnConvolution(*instr));
- StatusOr<std::tuple<int64, bool, int64>> alg_scratch_and_tc =
+ StatusOr<AutotuneResult> best_algo_or =
PickBestAlgorithm(Cast<HloCustomCallInstruction>(instr));
-
- if (!alg_scratch_and_tc.ok()) {
- LOG(ERROR) << alg_scratch_and_tc.status();
+ if (!best_algo_or.ok()) {
+ LOG(ERROR) << best_algo_or.status();
return false;
}
- int64 algorithm;
- bool tensor_ops_enabled;
- int64 scratch_bytes;
-
- std::tie(algorithm, tensor_ops_enabled, scratch_bytes) =
- alg_scratch_and_tc.ConsumeValueOrDie();
-
- VLOG(1) << "Setting cudnn conv to use algorithm " << algorithm << " and "
- << NumBytesToString(scratch_bytes)
+ auto best_algo = std::move(best_algo_or).ValueOrDie();
+ VLOG(1) << "Setting cudnn conv to use algorithm " << best_algo.algorithm
+ << " and " << NumBytesToString(best_algo.scratch_bytes)
<< " of scratch memory: " << instr->ToString()
- << " tensor_ops_enabled: " << tensor_ops_enabled;
+ << " tensor_ops_enabled: " << best_algo.tensor_ops_enabled;
// Replace instr with a new CustomCall which has the correct algorithm, and
// whose output shape has the appropriate amount of scratch memory.
HloComputation* computation = instr->parent();
- Shape new_call_shape =
- ShapeUtil::MakeTupleShape({instr->shape().tuple_shapes(0),
- ShapeUtil::MakeShape(U8, {scratch_bytes})});
+ Shape new_call_shape = ShapeUtil::MakeTupleShape(
+ {instr->shape().tuple_shapes(0),
+ ShapeUtil::MakeShape(U8, {best_algo.scratch_bytes})});
TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
instr->backend_config<CudnnConvBackendConfig>());
- backend_config.set_algorithm(algorithm);
- backend_config.set_tensor_ops_enabled(tensor_ops_enabled);
+ backend_config.set_algorithm(best_algo.algorithm);
+ backend_config.set_tensor_ops_enabled(best_algo.tensor_ops_enabled);
HloInstruction* new_call = computation->AddInstruction(
instr->CloneWithNewOperands(new_call_shape, instr->operands()));
diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
index aeda2fc7f8..136c32210a 100644
--- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
+++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_
+#include "absl/time/time.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
@@ -47,10 +48,16 @@ class CudnnConvolutionAlgorithmPicker : public HloModulePass {
StatusOr<bool> Run(HloModule* module) override;
private:
+ struct AutotuneResult {
+ int64 algorithm;
+ bool tensor_ops_enabled;
+ int64 scratch_bytes;
+ absl::Duration runtime;
+ };
+
StatusOr<bool> RunOnComputation(HloComputation* computation);
StatusOr<bool> RunOnInstruction(HloInstruction* instr);
- StatusOr<std::tuple<int64, bool, int64>> PickBestAlgorithm(
- HloCustomCallInstruction* instr);
+ StatusOr<AutotuneResult> PickBestAlgorithm(HloCustomCallInstruction* instr);
se::StreamExecutor* stream_exec_; // never null
DeviceMemoryAllocator* allocator_; // may be null
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index c1aaa4bf04..6dcdaf1cfe 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -358,13 +358,6 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
const HloInstruction* operand = hlo->operand(0);
const Window& window = hlo->window();
- // TODO(b/31410564): Implement dilation for reduce-window.
- if (window_util::HasDilation(window)) {
- return Unimplemented(
- "Dilation for reduce-window not implemented on GPU. "
- "See b/31410564.");
- }
-
PrimitiveType operand_element_type = operand->shape().element_type();
llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
@@ -397,9 +390,24 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
for (size_t i = 0; i < index.size(); ++i) {
llvm::Value* stridden_index = NSWMul(
index[i], index_typed_const(window.dimensions(i).stride()));
+ input_index[i] = NSWSub(
+ NSWAdd(stridden_index,
+ NSWMul(window_index[i],
+ index_typed_const(
+ window.dimensions(i).window_dilation()))),
+ index_typed_const(window.dimensions(i).padding_low()));
+
+ // We need to verify that we are not in the dilated base area.
+ llvm::Value* dilation_condition = ICmpEQ(
+ SRem(input_index[i],
+ index_typed_const(window.dimensions(i).base_dilation())),
+ index_typed_const(0));
+ in_bounds = And(in_bounds, dilation_condition);
+
+ // Apply base dilation to the index.
input_index[i] =
- NSWSub(NSWAdd(stridden_index, window_index[i]),
- index_typed_const(window.dimensions(i).padding_low()));
+ SDiv(input_index[i],
+ index_typed_const(window.dimensions(i).base_dilation()));
// We must check whether 0 ≤ input_index[i] < bound, as otherwise
// we are in the pad and so can skip the computation. This
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
index 74352f26aa..1ffe855750 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
@@ -18,7 +18,6 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/service/gpu/gpu_options.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
@@ -125,14 +124,8 @@ Status GpuLayoutAssignment::AddBackendConstraintsToDnnConvCustomCall(
DataLayout input;
FilterLayout filter;
DataLayout output;
- if (ConvUseLayoutHeuristic(instr->GetModule()->config())) {
- std::tie(input, filter, output) =
- HeuristicLayoutAssignment(instr, stream_executor_);
- } else {
- input = DataLayout::kBatchDepthYX;
- filter = FilterLayout::kOutputInputYX;
- output = DataLayout::kBatchDepthYX;
- }
+ std::tie(input, filter, output) =
+ HeuristicLayoutAssignment(instr, stream_executor_);
TF_ASSIGN_OR_RETURN(
std::tie(*input_shape->mutable_layout(),
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index b4ae2e42c7..ac6c2c5565 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -239,8 +239,10 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
{
HloPassPipeline pipeline("post-layout_assignment");
- pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
- /*allow_mixed_precision=*/false);
+ pipeline.AddInvariantChecker<HloVerifier>(
+ /*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false,
+ LayoutAssignment::InstructionCanChangeLayout);
// The LayoutAssignment pass may leave behind kCopy instructions which are
// duplicate or NOPs, so remove them with algebraic simplification and CSE.
@@ -286,8 +288,10 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
{
HloPassFix<HloPassPipeline> fusion("fusion");
- fusion.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
- /*allow_mixed_precision=*/false);
+ fusion.AddInvariantChecker<HloVerifier>(
+ /*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false,
+ LayoutAssignment::InstructionCanChangeLayout);
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/false);
fusion.AddPass<GpuInstructionFusion>(/*may_duplicate=*/true);
fusion.AddPass<FusionMerger>();
@@ -299,7 +303,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
HloPassPipeline reduce_pipeline("reduce-precision");
reduce_pipeline.AddInvariantChecker<HloVerifier>(
- /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false);
+ /*is_layout_sensitive=*/true, /*allow_mixed_precision=*/false,
+ LayoutAssignment::InstructionCanChangeLayout);
ReducePrecisionInsertion::AddPasses(
&reduce_pipeline, hlo_module->config().debug_options(),
ReducePrecisionInsertion::PassTiming::AFTER_FUSION);
@@ -325,8 +330,10 @@ Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) {
// (b/27180329). Therefore, in that case, we set the output to be a copy of
// the parameter.
HloPassPipeline pipeline("GPU-ir-emit-prepare");
- pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/true,
- /*allow_mixed_precision=*/false);
+ pipeline.AddInvariantChecker<HloVerifier>(
+ /*layout_sensitive=*/true,
+ /*allow_mixed_precision=*/false,
+ LayoutAssignment::InstructionCanChangeLayout);
// Copy insertion should be performed immediately before IR emission to avoid
// inserting unnecessary copies (later pass adds an instruction which
@@ -401,7 +408,7 @@ void WarnIfBadPtxasVersion(const string& ptxas_path) {
"prefers >= 9.2.88). Compilation of XLA kernels below will likely "
"fail.\n\nYou do not need to update CUDA; cherry-picking the ptxas "
"binary is sufficient.";
- } else if ((vmaj < 9 || vmin < 2 || vdot < 88)) {
+ } else if (std::make_tuple(vmaj, vmin, vdot) < std::make_tuple(9, 2, 88)) {
LOG(WARNING)
<< "*** WARNING *** You are using ptxas " << vmaj << "." << vmin << "."
<< vdot
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc
index 6ca1255ede..c6d02f9f67 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_map.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc
@@ -42,18 +42,19 @@ namespace xla {
return std::move(domain_map);
}
-bool HloDomainMap::InSameDomain(HloInstruction* instruction1,
- HloInstruction* instruction2) const {
+bool HloDomainMap::InSameDomain(const HloInstruction* instruction1,
+ const HloInstruction* instruction2) const {
int64 domain_id1 = GetDomainId(instruction1);
int64 domain_id2 = GetDomainId(instruction2);
return domain_id1 >= 0 && domain_id1 == domain_id2;
}
-int64 HloDomainMap::GetDomainId(HloInstruction* instruction) const {
+int64 HloDomainMap::GetDomainId(const HloInstruction* instruction) const {
return FindOrDefault(instruction_to_domain_, instruction, -1);
}
-int64 HloDomainMap::GetDomainMetadataId(HloInstruction* instruction) const {
+int64 HloDomainMap::GetDomainMetadataId(
+ const HloInstruction* instruction) const {
return FindOrDie(domain_metadata_id_, instruction);
}
@@ -200,7 +201,8 @@ StatusOr<std::unique_ptr<DomainMetadata::Domain>> HloDomainMap::CreateDomain(
return std::move(domain);
}
-bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const {
+bool HloDomainMap::IsDomainInstruction(
+ const HloInstruction* instruction) const {
if (instruction->opcode() != HloOpcode::kDomain) {
return false;
}
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h
index c8d581b746..bce7d1aa7c 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_map.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.h
@@ -58,21 +58,21 @@ class HloDomainMap {
}
// Checks whether two instructions are within the same domain.
- bool InSameDomain(HloInstruction* instruction1,
- HloInstruction* instruction2) const;
+ bool InSameDomain(const HloInstruction* instruction1,
+ const HloInstruction* instruction2) const;
// Checks whether instruction is a kDomain instruction of the kind we are
// currently processing.
- bool IsDomainInstruction(HloInstruction* instruction) const;
+ bool IsDomainInstruction(const HloInstruction* instruction) const;
// Retrieves the domain identifier of the instruction, or -1 in case
// instruction is not found within any domain.
- int64 GetDomainId(HloInstruction* instruction) const;
+ int64 GetDomainId(const HloInstruction* instruction) const;
// Returns the unique id of the domain metadata for the domain the given
// instruction belongs to. The given instruction must not be a kDomain
// instruction since each domain instruction is associated with 2 domains.
- int64 GetDomainMetadataId(HloInstruction* instruction) const;
+ int64 GetDomainMetadataId(const HloInstruction* instruction) const;
private:
// Map used for representing instruction ordering, i.e.
@@ -119,8 +119,8 @@ class HloDomainMap {
string domain_kind_;
std::vector<std::unique_ptr<DomainMetadata::Domain>> instruction_domains_;
- absl::flat_hash_map<HloInstruction*, int64> instruction_to_domain_;
- absl::flat_hash_map<HloInstruction*, int64> domain_metadata_id_;
+ absl::flat_hash_map<const HloInstruction*, int64> instruction_to_domain_;
+ absl::flat_hash_map<const HloInstruction*, int64> domain_metadata_id_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index cee11a8a21..608a42bb60 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -1463,6 +1463,58 @@ TEST_P(HloEvaluatorTest, ReduceWindowMax) {
EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
}
+TEST_P(HloEvaluatorTest, ReduceWindowMaxWindowDilation) {
+ HloComputation::Builder b(TestName());
+
+ // arg:
+ // f32[3,3] {
+ // { 1, 2, 3 },
+ // { 5, 6, 7 },
+ // { 9, 10, 11 },
+ // }
+ auto arg_array = absl::make_unique<Array2D<float>>(3, 3);
+ arg_array->FillUnique(1.0f);
+ auto arg_literal = LiteralUtil::CreateR2FromArray2D<float>(*arg_array);
+
+ HloInstruction* arg_instruction =
+ b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal)));
+
+ auto init_value = b.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.f)));
+
+ HloComputation::Builder max_computation("max");
+ Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
+ auto param_lhs = max_computation.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "lhs"));
+ auto param_rhs = max_computation.AddInstruction(
+ HloInstruction::CreateParameter(1, scalar_shape, "rhs"));
+ max_computation.AddInstruction(HloInstruction::CreateBinary(
+ scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs));
+ auto max_func = module().AddEmbeddedComputation(max_computation.Build());
+
+ Window window;
+ WindowDimension dim;
+ dim.set_size(2);
+ dim.set_stride(1);
+ dim.set_padding_low(0);
+ dim.set_padding_high(0);
+ dim.set_window_dilation(2);
+ dim.set_base_dilation(1);
+ *window.add_dimensions() = dim;
+ *window.add_dimensions() = dim;
+
+ Shape shape = ShapeUtil::MakeShape(F32, {1, 1});
+ b.AddInstruction(HloInstruction::CreateReduceWindow(
+ shape, arg_instruction, init_value, window, max_func));
+
+ module().AddEntryComputation(b.Build());
+
+ Literal result = Evaluate();
+
+ auto expected = LiteralUtil::CreateR2<float>({{11}});
+ EXPECT_TRUE(LiteralTestUtil::Equal(expected, result));
+}
+
TEST_P(HloEvaluatorTest, ReduceWindowAdd) {
HloComputation::Builder b(TestName());
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index b2d12c94b8..a450dc6ff5 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -2613,8 +2613,17 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
std::vector<int64> base_index(rank);
bool out_of_bound = false;
for (int64 i = 0; i < rank; ++i) {
- base_index[i] = window_count_index[i] * window.dimensions(i).stride() +
- window_index[i] - window.dimensions(i).padding_low();
+ base_index[i] =
+ window_count_index[i] * window.dimensions(i).stride() +
+ window_index[i] * window.dimensions(i).window_dilation() -
+ window.dimensions(i).padding_low();
+ // We are not in the base area if the dilation placed us out of bounds.
+ if (base_index[i] % window.dimensions(i).base_dilation() != 0) {
+ out_of_bound = true;
+ break;
+ }
+ // Apply the dilation to the base area.
+ base_index[i] /= window.dimensions(i).base_dilation();
if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) {
out_of_bound = true;
break;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index fb91adc302..2f6db7cd7c 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -465,8 +465,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
break;
}
case HloOpcode::kIota:
- TF_RET_CHECK(proto.dimensions_size() <= 1)
- << "Iota instruction should have at most 1 dimension but sees "
+ TF_RET_CHECK(proto.dimensions_size() == 1)
+ << "Iota instruction should have 1 dimension but sees "
<< proto.dimensions_size();
instruction = CreateIota(proto.shape(), proto.dimensions(0));
break;
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 68d0979f5c..152d8eacdb 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -643,14 +643,6 @@ HloTransposeInstruction::HloTransposeInstruction(
absl::Span<const int64> dimensions)
: HloInstruction(HloOpcode::kTranspose, shape),
dimensions_(dimensions.begin(), dimensions.end()) {
- CHECK_EQ(shape.dimensions().size(), dimensions.size());
- CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size());
- CHECK(std::equal(operand->shape().dimensions().begin(),
- operand->shape().dimensions().end(),
- Permute(dimensions, shape.dimensions()).begin()))
- << "shape: " << ShapeUtil::HumanString(shape)
- << ", operand->shape(): " << ShapeUtil::HumanString(shape)
- << ", dimensions: {" << StrJoin(dimensions, ", ") << "}";
AppendOperand(operand);
}
@@ -1491,7 +1483,6 @@ HloParameterInstruction::CloneWithNewOperandsImpl(
HloGetTupleElementInstruction::HloGetTupleElementInstruction(
const Shape& shape, HloInstruction* operand, int64 index)
: HloInstruction(HloOpcode::kGetTupleElement, shape), tuple_index_(index) {
- CHECK(ShapeUtil::IsTuple(operand->shape()));
AppendOperand(operand);
}
@@ -1613,9 +1604,6 @@ HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape,
: HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()),
outfeed_shape_(outfeed_shape),
outfeed_config_(outfeed_config) {
- CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape))
- << "Outfeed shape " << outfeed_shape
- << " must be compatible with operand shape " << operand->shape();
AppendOperand(operand);
AppendOperand(token_operand);
}
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index ab168800f6..e169604072 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -896,7 +896,6 @@ class HloOutfeedInstruction : public HloInstruction {
absl::string_view outfeed_config);
// Returns the shape for the Outfeed instruction.
const Shape& outfeed_shape() const {
- TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape_));
return outfeed_shape_;
}
// Returns the config for the Outfeed instruction.
diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
index 55314d0ae9..5cee865b7a 100644
--- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
@@ -195,13 +195,15 @@ class ListScheduler {
return entry;
}
- // Returns the number of bytes freed if the HLO instruction is scheduled.
- // If the instruction calls subcomputations, we count the memory used by the
- // subcomputations as memory "defined" by the instruction. This is not
- // entirely accurate, because subcomputation memory will be freed after the
- // instruction finishes. But it is more accurate than not taking
- // subcomputations into account at all. In the future, we may improve
- // accounting for subcomputation memory (b/65409243).
+ // Returns the number of bytes freed *after* the HLO instruction finishes.
+ // The current List algorithm only considers two states for an instruction:
+ // right before it runs, and after it finishes. We don't represent memory
+ // usage during the execution of an instruction. But if the instruction calls
+ // subcomputations, they are only live during the instruction's execution.
+ // We end up counting the memory used by subcomputations as memory "defined"
+ // by the instruction. This is not entirely accurate, but it is more accurate
+ // than not taking subcomputations into account at all. In the future, we may
+ // improve accounting for subcomputation memory (b/65409243).
int64 BytesFreedIfScheduled(const ReadyListEntry& entry) {
int64 freed_bytes = 0;
for (const auto& kv : entry.used_buffer_unscheduled_use_counts) {
@@ -223,7 +225,18 @@ class ListScheduler {
}
}
}
- return freed_bytes - entry.bytes_defined - max_subcomputation_bytes;
+ int64 bytes_defined;
+ if (max_subcomputation_bytes > 0 &&
+ (entry.instruction->opcode() == HloOpcode::kWhile ||
+ entry.instruction->opcode() == HloOpcode::kCall ||
+ entry.instruction->opcode() == HloOpcode::kConditional)) {
+ // The output buffer of while/call/conditional is always aliased with the
+ // output buffer of the root instruction in the body. Don't double count.
+ bytes_defined = max_subcomputation_bytes;
+ } else {
+ bytes_defined = entry.bytes_defined + max_subcomputation_bytes;
+ }
+ return freed_bytes - bytes_defined;
}
// Constructs the scheduling priority of the given instruction.
@@ -263,9 +276,8 @@ class ListScheduler {
};
for (auto* instruction : computation_.instructions()) {
- // Instruction with no operands or control predecessors will
- // not be in the map.
- if (unscheduled_pred_count.count(instruction) == 0) {
+ if (instruction->operands().empty() &&
+ instruction->control_predecessors().empty()) {
add_to_ready_queue(instruction);
}
}
@@ -356,9 +368,8 @@ class ListScheduler {
buffer_uses_;
// A map containing the count of unscheduled HLOs which using a particular
- // LogicalBuffer. We rely on iterator stability in this map, and that the map
- // entries are std::pair's.
- std::unordered_map<const LogicalBuffer*, int64> unscheduled_use_count_;
+ // LogicalBuffer.
+ absl::flat_hash_map<const LogicalBuffer*, int64> unscheduled_use_count_;
// Set of instructions which have been scheduled.
absl::flat_hash_set<const HloInstruction*> scheduled_instructions_;
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index 7527e35c95..93e04eb3db 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -146,7 +146,8 @@ void HloModule::ReplaceComputations(
case HloOpcode::kCall:
case HloOpcode::kMap:
case HloOpcode::kReduce:
- case HloOpcode::kReduceWindow: {
+ case HloOpcode::kReduceWindow:
+ case HloOpcode::kScatter: {
HloComputation* new_arg = tensorflow::gtl::FindWithDefault(
replacements, instruction->to_apply(), nullptr);
if (new_arg != nullptr) {
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
index 83352ef91b..b4aac4c807 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
@@ -59,7 +59,7 @@ string HloModuleGroupMetadata::TrackedInstruction::ToString() const {
}
/* static */ StatusOr<std::unique_ptr<HloModuleGroupMetadata>>
-HloModuleGroupMetadata::Build(const std::vector<HloModule*>& modules) {
+HloModuleGroupMetadata::Build(absl::Span<HloModule* const> modules) {
auto metadata = absl::make_unique<HloModuleGroupMetadata>(modules);
TF_RETURN_IF_ERROR(metadata->Build());
return std::move(metadata);
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
index 0311b73207..928df0f5a7 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
@@ -102,14 +102,14 @@ class HloModuleGroupMetadata {
HloInstruction* recv_done = nullptr;
};
- explicit HloModuleGroupMetadata(const std::vector<HloModule*>& modules)
- : modules_(modules) {}
+ explicit HloModuleGroupMetadata(absl::Span<HloModule* const> modules)
+ : modules_(modules.begin(), modules.end()) {}
~HloModuleGroupMetadata() = default;
// Build and return the metadata for the given modules.
static StatusOr<std::unique_ptr<HloModuleGroupMetadata>> Build(
- const std::vector<HloModule*>& modules);
+ absl::Span<HloModule* const> modules);
// Returns true if the instruction is one of the 4 channel instructions (Send,
// Recv, SendDone, RecvDone).
@@ -274,7 +274,7 @@ class HloModuleGroupMetadata {
int64 max_channel_id_ = -1;
// The modules that this metadata was built from.
- const std::vector<HloModule*>& modules_;
+ const std::vector<HloModule*> modules_;
absl::flat_hash_map<HloModule*, std::unique_ptr<TuplePointsToAnalysis>>
points_to_analyses_;
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index b618510640..255123d331 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -1304,7 +1304,7 @@ TEST_F(HloParserTest, MoreConstants) {
ENTRY %SelectScalarS32True.v4 () -> s32[] {
%constant.2 = pred[] constant(true)
- %constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,3]1,2,3,4}
+ %constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,2]1,2,3,4}
%constant = s32[] constant(42)
%select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant)
}
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 94c7bafd3b..188f4acc79 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
+#include "tensorflow/compiler/xla/overflow_util.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
@@ -377,6 +378,20 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
<< "Maximal sharding is expected to have single device assignment, but "
<< proto.tile_assignment_devices().size() << " has provided.";
+ TF_RET_CHECK(proto.tile_assignment_devices().size() > 1);
+ TF_RET_CHECK(!proto.tile_assignment_dimensions().empty());
+
+ // RE: the product of tile assignment tensor dimensions must be
+ // equal to tile_assignment_devices.size().
+ int64 product_of_dimensions = 1;
+ for (auto dimension : proto.tile_assignment_dimensions()) {
+ TF_RET_CHECK(dimension > 0);
+ product_of_dimensions =
+ MultiplyWithoutOverflow(product_of_dimensions, dimension);
+ TF_RET_CHECK(product_of_dimensions > 0);
+ }
+ TF_RET_CHECK(product_of_dimensions == proto.tile_assignment_devices().size());
+
// Some versions of gcc cannot infer the TileAssignment constructor from a
// braced initializer-list, so create one manually.
std::vector<int64> devices(proto.tile_assignment_devices().begin(),
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index a7727824fe..496fe1795d 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -313,8 +313,9 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
operand_dimension < ShapeUtil::Rank(operand_shape);
++operand_dimension) {
int64 output_dimension = broadcast->dimensions()[operand_dimension];
- TF_RET_CHECK(broadcast->shape().dimensions(output_dimension) ==
- operand_shape.dimensions(operand_dimension))
+ TF_RET_CHECK((output_dimension < ShapeUtil::Rank(broadcast->shape())) &&
+ (broadcast->shape().dimensions(output_dimension) ==
+ operand_shape.dimensions(operand_dimension)))
<< broadcast->ToString() << " operand shape " << operand_shape;
}
return Status::OK();
@@ -548,6 +549,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
case HloOpcode::kTupleSelect:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
+ case HloOpcode::kSort:
case HloOpcode::kTuple:
case HloOpcode::kWhile:
break;
@@ -763,7 +765,136 @@ Status VerifyHloStructure(HloModule* module) {
return Status::OK();
}
-Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
+namespace {
+
+// Returns true if the given Shape has a TOKEN shape as any subshape.
+bool ShapeContainsToken(const Shape& shape) {
+ bool contains_token = false;
+ ShapeUtil::ForEachSubshape(
+ shape, [&contains_token](const Shape& subshape, const ShapeIndex&) {
+ if (ShapeUtil::IsToken(subshape)) {
+ contains_token = true;
+ }
+ });
+ return contains_token;
+}
+
+// Verifies that all types entering and exiting the entry computation are
+// legal.
+Status VerifyEntryAndExitShapes(const HloModule& module) {
+ // Tokens cannot be passed as entry parameters.
+ // TODO(b/80000000): Remove this constraint.
+ for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) {
+ HloInstruction* param =
+ module.entry_computation()->parameter_instruction(i);
+ if (ShapeContainsToken(param->shape())) {
+ return InternalError(
+ "Entry parameter %d is or contains a token shape: %s", i,
+ ShapeUtil::HumanString(param->shape()));
+ }
+ }
+ return Status::OK();
+}
+
+// Checks if the given two instructions share the same channel id.
+Status CheckSameChannel(const HloInstruction* instr1,
+ const HloInstruction* instr2) {
+ if (instr1->channel_id() != instr2->channel_id()) {
+ return InternalError(
+ "Expected to have the same channel id, actual channel ids are: %s "
+ "(%d), %s (%d)",
+ instr1->ToString(), instr1->channel_id(), instr2->ToString(),
+ instr2->channel_id());
+ }
+ return Status::OK();
+}
+
+// Checks if the given two instructions have the same is_host_transfer
+// attribute value. Intsructions must be send/recv instructions or their
+// 'done' variant.
+Status CheckSameIsHostTransfer(const HloInstruction* instr1,
+ const HloInstruction* instr2) {
+ const HloSendRecvInstruction* send_recv1 =
+ DynCast<const HloSendRecvInstruction>(instr1);
+ const HloSendRecvInstruction* send_recv2 =
+ DynCast<const HloSendRecvInstruction>(instr2);
+ TF_RET_CHECK(send_recv1 != nullptr);
+ TF_RET_CHECK(send_recv2 != nullptr);
+ if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) {
+ return InternalError(
+ "Expected instructions to have the same is-host-transfer property: "
+ "%s, "
+ "%s ",
+ instr1->ToString(), instr2->ToString());
+ }
+ return Status::OK();
+}
+
+// Checks various invariants of send and recv instructions.
+Status VerifySendsAndRecvs(const HloModule& module) {
+ absl::flat_hash_map<int64, const HloInstruction*> host_channels;
+ // Host send/recv instructions must have their own unique channel.
+ auto check_unique_host_channel = [&](const HloInstruction* instruction) {
+ const HloSendRecvInstruction* sendrecv =
+ DynCast<const HloSendRecvInstruction>(instruction);
+ if (sendrecv->is_host_transfer()) {
+ auto it_inserted =
+ host_channels.insert({sendrecv->channel_id(), sendrecv});
+ if (!it_inserted.second) {
+ return FailedPrecondition(
+ "Channel %d is used for multiple host send/recv instructions: "
+ "%s "
+ "and "
+ "%s",
+ sendrecv->channel_id(), sendrecv->ToString(),
+ it_inserted.first->second->ToString());
+ }
+ }
+
+ return Status::OK();
+ };
+
+ // Send/Recv instruction must have a single user: the corresponding
+ // SendDone/RecvDone. with matching channel.
+ for (const HloComputation* computation : module.computations()) {
+ for (const HloInstruction* instruction : computation->instructions()) {
+ switch (instruction->opcode()) {
+ case HloOpcode::kSend: {
+ TF_RETURN_IF_ERROR(check_unique_host_channel(instruction));
+ TF_RET_CHECK(instruction->users().size() == 1);
+ const HloInstruction* send_done = instruction->users().front();
+ TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
+ TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done));
+ TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done));
+ break;
+ }
+ case HloOpcode::kRecv: {
+ TF_RETURN_IF_ERROR(check_unique_host_channel(instruction));
+ TF_RET_CHECK(instruction->users().size() == 1);
+ const HloInstruction* recv_done = instruction->users().front();
+ TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
+ TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done));
+ TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done));
+ break;
+ }
+ case HloOpcode::kSendDone:
+ TF_RET_CHECK(instruction->operands().size() == 1);
+ TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kSend);
+ break;
+ case HloOpcode::kRecvDone:
+ TF_RET_CHECK(instruction->operands().size() == 1);
+ TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kRecv);
+ break;
+ default:
+ break;
+ }
+ }
+ }
+ return Status::OK();
+}
+
+// CHECKs various invariants of a fusion instruction.
+Status CheckFusionInstruction(HloInstruction* fusion) {
// The parent fusion instruction of the fusion computation must be 'fusion'.
HloComputation* fused_computation = fusion->fused_instructions_computation();
if (fusion != fused_computation->FusionInstruction()) {
@@ -866,50 +997,32 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
}
}
+ TF_RET_CHECK(fusion->called_computations() ==
+ absl::Span<HloComputation* const>(
+ {fusion->fused_instructions_computation()}))
+ << "Fusion HLO calls computations other than the "
+ "fused_instructions_computation: "
+ << fusion->ToString() << " fusion->fused_instructions_computation(): "
+ << fusion->fused_instructions_computation()->ToString()
+ << " fusion->called_computations(): "
+ << ComputationsToString(fusion->called_computations());
+
+ for (const auto& fused : fusion->fused_instructions()) {
+ TF_RET_CHECK(fused->parent() == fusion->fused_instructions_computation())
+ << "Fused HLO was missing a parent: " << fused->ToString()
+ << " parent: " << fused->parent()
+ << " computation: " << fusion->parent();
+ }
+
// TODO(b/65423525): We'd like to check that all operands are distinct.
// This is currently disabled due to the invariant being violated by
// multi-output fusion.
return Status::OK();
}
-Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) {
- auto* while_cond = instruction->while_condition();
- auto* while_body = instruction->while_body();
- if (while_cond->num_parameters() != 1) {
- return FailedPrecondition(
- "While condition must have exactly 1 parameter; had %d : %s",
- while_cond->num_parameters(), while_cond->ToString());
- }
- if (while_body->num_parameters() != 1) {
- return FailedPrecondition(
- "While body must have exactly 1 parameter; had %d : %s",
- while_body->num_parameters(), while_body->ToString());
- }
- if (instruction->operand_count() != 1) {
- return FailedPrecondition(
- "While loop must have exactly one operand; had %d : %s",
- instruction->operand_count(), instruction->ToString());
- }
- return Status::OK();
-}
-
-Status HloVerifier::CheckConditionalInstruction(HloInstruction* instruction) {
- if (instruction->true_computation()->num_parameters() != 1) {
- return FailedPrecondition(
- "True computation %s of %s must have 1 parameter insted of %d",
- instruction->true_computation()->name(), instruction->ToString(),
- instruction->true_computation()->num_parameters());
- }
- if (instruction->false_computation()->num_parameters() != 1) {
- return FailedPrecondition(
- "False computation %s of %s must have 1 parameter insted of %d",
- instruction->false_computation()->name(), instruction->ToString(),
- instruction->false_computation()->num_parameters());
- }
- return Status::OK();
-}
-
-Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) {
+// Checks that the non-scalar operand shapes are compatible to the output
+// shape, i.e., that there are no implicit broadcasts of size-one dimensions.
+Status CheckElementwiseInstruction(HloInstruction* instruction) {
const Shape& out_shape = instruction->shape();
for (HloInstruction* operand : instruction->operands()) {
const Shape& operand_shape = operand->shape();
@@ -926,133 +1039,143 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) {
return Status::OK();
}
-namespace {
+// Visitor which verifies various fields on the HLO instruction. This class does
+// not check result shape as that is checked in the ShapeVerifier.
+class InstructionVerifier : public DfsHloVisitorWithDefault {
+ public:
+ explicit InstructionVerifier(std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func)
+ : instruction_can_change_layout_func_(
+ instruction_can_change_layout_func) {}
-// Returns true if the given Shape has a TOKEN shape as any subshape.
-bool ShapeContainsToken(const Shape& shape) {
- bool contains_token = false;
- ShapeUtil::ForEachSubshape(
- shape, [&contains_token](const Shape& subshape, const ShapeIndex&) {
- if (ShapeUtil::IsToken(subshape)) {
- contains_token = true;
- }
- });
- return contains_token;
-}
+ Status DefaultAction(HloInstruction*) override { return Status::OK(); }
-// Verifies that all types entering and exiting the entry computation are
-// legal.
-Status VerifyEntryAndExitShapes(const HloModule& module) {
- // Tokens cannot be passed as entry parameters.
- // TODO(b/80000000): Remove this constraint.
- for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) {
- HloInstruction* param =
- module.entry_computation()->parameter_instruction(i);
- if (ShapeContainsToken(param->shape())) {
- return InternalError(
- "Entry parameter %d is or contains a token shape: %s", i,
- ShapeUtil::HumanString(param->shape()));
- }
+ Status HandleFusion(HloInstruction* fusion) override {
+ return CheckFusionInstruction(fusion);
}
- return Status::OK();
-}
-// Checks if the given two instructions share the same channel id.
-Status CheckSameChannel(const HloInstruction* instr1,
- const HloInstruction* instr2) {
- if (instr1->channel_id() != instr2->channel_id()) {
- return InternalError(
- "Expected to have the same channel id, actual channel ids are: %s "
- "(%d), %s (%d)",
- instr1->ToString(), instr1->channel_id(), instr2->ToString(),
- instr2->channel_id());
+ Status HandleBroadcast(HloInstruction* broadcast) override {
+ // If you see this failure then someone has confused the difference
+ // between the HLO broadcast op, and the UserComputation broadcast
+ // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I
+ // or ComputationLowerer::Visit()
+ TF_RET_CHECK(broadcast->dimensions().size() ==
+ ShapeUtil::Rank(broadcast->operand(0)->shape()))
+ << "Broadcast HLO (" << broadcast->ToShortString()
+ << ") has invalid number of dimensions: "
+ << broadcast->dimensions().size()
+ << " != " << ShapeUtil::Rank(broadcast->operand(0)->shape());
+ return Status::OK();
}
- return Status::OK();
-}
-// Checks if the given two instructions have the same is_host_transfer
-// attribute value. Intsructions must be send/recv instructions or their
-// 'done' variant.
-Status CheckSameIsHostTransfer(const HloInstruction* instr1,
- const HloInstruction* instr2) {
- const HloSendRecvInstruction* send_recv1 =
- DynCast<const HloSendRecvInstruction>(instr1);
- const HloSendRecvInstruction* send_recv2 =
- DynCast<const HloSendRecvInstruction>(instr2);
- TF_RET_CHECK(send_recv1 != nullptr);
- TF_RET_CHECK(send_recv2 != nullptr);
- if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) {
- return InternalError(
- "Expected instructions to have the same is-host-transfer property: "
- "%s, "
- "%s ",
- instr1->ToString(), instr2->ToString());
+ Status HandleWhile(HloInstruction* xla_while) override {
+ auto* while_cond = xla_while->while_condition();
+ auto* while_body = xla_while->while_body();
+ if (while_cond->num_parameters() != 1) {
+ return FailedPrecondition(
+ "While condition must have exactly 1 parameter; had %d : %s",
+ while_cond->num_parameters(), while_cond->ToString());
+ }
+ if (while_body->num_parameters() != 1) {
+ return FailedPrecondition(
+ "While body must have exactly 1 parameter; had %d : %s",
+ while_body->num_parameters(), while_body->ToString());
+ }
+ if (xla_while->operand_count() != 1) {
+ return FailedPrecondition(
+ "While loop must have exactly one operand; had %d : %s",
+ xla_while->operand_count(), xla_while->ToString());
+ }
+ return Status::OK();
}
- return Status::OK();
-}
-// Checks various invariants of send and recv instructions.
-Status VerifySendsAndRecvs(const HloModule& module) {
- absl::flat_hash_map<int64, const HloInstruction*> host_channels;
- // Host send/recv instructions must have their own unique channel.
- auto check_unique_host_channel = [&](const HloInstruction* instruction) {
- const HloSendRecvInstruction* sendrecv =
- DynCast<const HloSendRecvInstruction>(instruction);
- if (sendrecv->is_host_transfer()) {
- auto it_inserted =
- host_channels.insert({sendrecv->channel_id(), sendrecv});
- if (!it_inserted.second) {
- return FailedPrecondition(
- "Channel %d is used for multiple host send/recv instructions: "
- "%s "
- "and "
- "%s",
- sendrecv->channel_id(), sendrecv->ToString(),
- it_inserted.first->second->ToString());
- }
+ Status HandleConditional(HloInstruction* conditional) override {
+ if (conditional->true_computation()->num_parameters() != 1) {
+ return FailedPrecondition(
+ "True computation %s of %s must have 1 parameter insted of %d",
+ conditional->true_computation()->name(), conditional->ToString(),
+ conditional->true_computation()->num_parameters());
}
+ if (conditional->false_computation()->num_parameters() != 1) {
+ return FailedPrecondition(
+ "False computation %s of %s must have 1 parameter insted of %d",
+ conditional->false_computation()->name(), conditional->ToString(),
+ conditional->false_computation()->num_parameters());
+ }
+ return Status::OK();
+ }
+
+ Status HandleElementwiseUnary(HloInstruction* instruction) override {
+ return CheckElementwiseInstruction(instruction);
+ }
+
+ Status HandleElementwiseBinary(HloInstruction* instruction) override {
+ return CheckElementwiseInstruction(instruction);
+ }
+ Status HandleGetTupleElement(HloInstruction* gte) override {
+ TF_RET_CHECK(ShapeUtil::IsTuple(gte->operand(0)->shape()));
return Status::OK();
- };
+ }
- // Send/Recv instruction must have a single user: the corresponding
- // SendDone/RecvDone. with matching channel.
- for (const HloComputation* computation : module.computations()) {
- for (const HloInstruction* instruction : computation->instructions()) {
- switch (instruction->opcode()) {
- case HloOpcode::kSend: {
- TF_RETURN_IF_ERROR(check_unique_host_channel(instruction));
- TF_RET_CHECK(instruction->users().size() == 1);
- const HloInstruction* send_done = instruction->users().front();
- TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
- TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done));
- TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done));
- break;
- }
- case HloOpcode::kRecv: {
- TF_RETURN_IF_ERROR(check_unique_host_channel(instruction));
- TF_RET_CHECK(instruction->users().size() == 1);
- const HloInstruction* recv_done = instruction->users().front();
- TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
- TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done));
- TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done));
- break;
+ Status HandleTranspose(HloInstruction* transpose) override {
+ const Shape& shape = transpose->shape();
+ const HloInstruction* operand = transpose->operand(0);
+ TF_RET_CHECK(shape.dimensions().size() == transpose->dimensions().size());
+ TF_RET_CHECK(shape.dimensions().size() ==
+ transpose->operand(0)->shape().dimensions().size());
+ TF_RET_CHECK(std::equal(
+ operand->shape().dimensions().begin(),
+ operand->shape().dimensions().end(),
+ Permute(transpose->dimensions(), shape.dimensions()).begin()))
+ << "shape: " << shape << ", operand->shape(): " << shape
+ << ", dimensions: {" << absl::StrJoin(transpose->dimensions(), ", ")
+ << "}";
+ return Status::OK();
+ }
+
+ Status Preprocess(HloInstruction* instruction) override {
+ auto previous = instructions_by_name_.find(instruction->name());
+ TF_RET_CHECK(previous == instructions_by_name_.end())
+ << "HLO has name that is not unique within module:\n"
+ << instruction->ToString()
+ << " in computation: " << instruction->parent()->name()
+ << "\nPrevious HLO with same name:\n"
+ << previous->second->ToString()
+ << " in computation: " << previous->second->parent()->name();
+ instructions_by_name_[instruction->name()] = instruction;
+ return Status::OK();
+ }
+
+ Status Postprocess(HloInstruction* instruction) override {
+ if (instruction_can_change_layout_func_ &&
+ LayoutUtil::IsDenseArray(instruction->shape()) &&
+ !instruction_can_change_layout_func_(instruction)) {
+ const Shape& result_shape = instruction->shape();
+ const Layout& result_layout = result_shape.layout();
+ for (HloInstruction* operand : instruction->operands()) {
+ const Shape& operand_shape = operand->shape();
+ if (LayoutUtil::IsDenseArray(operand_shape) &&
+ ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(result_shape)) {
+ const Layout& operand_layout = operand_shape.layout();
+ TF_RET_CHECK(LayoutUtil::Equal(result_layout, operand_layout))
+ << "Instruction shouldn't change layouts "
+ << instruction->ToString() << " From "
+ << ShapeUtil::HumanString(result_shape) << " To "
+ << ShapeUtil::HumanString(operand_shape);
}
- case HloOpcode::kSendDone:
- TF_RET_CHECK(instruction->operands().size() == 1);
- TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kSend);
- break;
- case HloOpcode::kRecvDone:
- TF_RET_CHECK(instruction->operands().size() == 1);
- TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kRecv);
- break;
- default:
- break;
}
}
+
+ return Status::OK();
}
- return Status::OK();
-}
+
+ private:
+ absl::flat_hash_map<string, const HloInstruction*> instructions_by_name_;
+ // Determines whether an instruction can change layouts.
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func_;
+};
} // namespace
@@ -1061,65 +1184,13 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
TF_RETURN_IF_ERROR(VerifyHloStructure(module));
TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module));
- absl::flat_hash_map<string, const HloInstruction*> instructions;
-
for (auto* computation : module->computations()) {
- for (const auto& instruction : computation->instructions()) {
- TF_RET_CHECK(instruction->parent() == computation);
- if (instruction->opcode() == HloOpcode::kFusion) {
- TF_RETURN_IF_ERROR(CheckFusionInstruction(instruction));
- TF_RET_CHECK(instruction->called_computations() ==
- absl::Span<HloComputation* const>(
- {instruction->fused_instructions_computation()}))
- << "Fusion HLO calls computations other than the "
- "fused_instructions_computation: "
- << instruction->ToString()
- << " instruction->fused_instructions_computation(): "
- << instruction->fused_instructions_computation()->ToString()
- << " instruction->called_computations(): "
- << ComputationsToString(instruction->called_computations());
-
- for (const auto& fused : instruction->fused_instructions()) {
- TF_RET_CHECK(fused->parent() ==
- instruction->fused_instructions_computation())
- << "Fused HLO was missing a parent: " << fused->ToString()
- << " parent: " << fused->parent()
- << " computation: " << computation;
- }
- } else if (instruction->opcode() == HloOpcode::kBroadcast) {
- // If you see this failure then someone has confused the difference
- // between the HLO broadcast op, and the UserComputation broadcast
- // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I
- // or ComputationLowerer::Visit()
- TF_RET_CHECK(instruction->dimensions().size() ==
- ShapeUtil::Rank(instruction->operand(0)->shape()))
- << "Broadcast HLO (" << instruction->ToShortString()
- << ") has invalid number of dimensions: "
- << instruction->dimensions().size()
- << " != " << ShapeUtil::Rank(instruction->operand(0)->shape());
- } else if (instruction->opcode() == HloOpcode::kWhile) {
- TF_RETURN_IF_ERROR(CheckWhileInstruction(instruction));
- } else if (instruction->opcode() == HloOpcode::kConditional) {
- TF_RETURN_IF_ERROR(CheckConditionalInstruction(instruction));
- } else if (instruction->opcode() !=
- HloOpcode::kRng /* Rng operands are always scalar. */
- && instruction->IsElementwise()) {
- TF_RETURN_IF_ERROR(CheckElementwiseInstruction(instruction));
- }
-
- auto previous = instructions.find(instruction->name());
- TF_RET_CHECK(previous == instructions.end())
- << "HLO has name that is not unique within module:\n"
- << instruction->ToString()
- << " in computation: " << computation->name()
- << "\nPrevious HLO with same name:\n"
- << previous->second->ToString()
- << " in computation: " << previous->second->parent()->name();
- instructions[instruction->name()] = instruction;
- }
-
std::unique_ptr<ShapeVerifier> shape_verifier = shape_verifier_factory_();
TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get()));
+
+ InstructionVerifier instruction_verifier(
+ instruction_can_change_layout_func_);
+ TF_RETURN_IF_ERROR(computation->Accept(&instruction_verifier));
}
TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module));
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 0cde4a31af..cb49cb95ba 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -155,11 +155,17 @@ class HloVerifier : public HloModulePass {
public:
using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>;
- explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision)
+ explicit HloVerifier(bool layout_sensitive, bool allow_mixed_precision,
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func = {})
: shape_verifier_factory_([layout_sensitive, allow_mixed_precision] {
return absl::make_unique<ShapeVerifier>(layout_sensitive,
allow_mixed_precision);
- }) {}
+ }),
+ instruction_can_change_layout_func_(
+ std::move(instruction_can_change_layout_func)) {
+ CHECK(instruction_can_change_layout_func_ == nullptr || layout_sensitive);
+ }
// Uses custom shape verification.
explicit HloVerifier(ShapeVerifierFactory shape_verifier_factory)
@@ -172,22 +178,15 @@ class HloVerifier : public HloModulePass {
StatusOr<bool> Run(HloModule* module) override;
private:
- // CHECKs various invariants of a fusion instruction.
- Status CheckFusionInstruction(HloInstruction* fusion) const;
-
- Status CheckWhileInstruction(HloInstruction* instruction);
-
- Status CheckConditionalInstruction(HloInstruction* instruction);
-
- // Checks that the non-scalar operand shapes are compatible to the output
- // shape, i.e., that there are no implicit broadcasts of size-one dimensions.
- Status CheckElementwiseInstruction(HloInstruction* instruction);
-
// Creates a ShapeVerifier that checks that shapes match inferred
// expectations. This is a factory function because ShapeVerifier,
// being a DfsHloVisitor, is stateful. We want a clean object
// for each run of the verifier.
ShapeVerifierFactory shape_verifier_factory_;
+
+ // Determines whether an instruction can change layouts.
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
index 8f0423bb1c..afe01e5487 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
+#include "tensorflow/compiler/xla/service/layout_assignment.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
@@ -50,6 +51,14 @@ class HloVerifierTestAllowMixedPrecision : public HloTestBase {
/*allow_mixed_precision_in_hlo_verifier=*/true) {}
};
+class HloVerifierTestLayoutSensitive : public HloTestBase {
+ public:
+ HloVerifierTestLayoutSensitive()
+ : HloTestBase(/*verifier_layout_sensitive=*/true,
+ /*allow_mixed_precision_in_hlo_verifier=*/false,
+ LayoutAssignment::InstructionCanChangeLayout) {}
+};
+
TEST_F(HloVerifierTest, NullInstructionParent) {
HloComputation::Builder builder(TestName());
const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
@@ -358,5 +367,63 @@ TEST_F(HloVerifierTest, ConvNegativeBaseDilationNotAllowed) {
HasSubstr("non-positive base area dilation factor"));
}
+static const char* const kAddWithLayoutChangeHlo = R"(
+ HloModule AddWithLayoutChange
+ ENTRY AddWithLayoutChange {
+ par0 = f32[3,4]{1,0} parameter(0)
+ par1 = f32[3,4]{0,1} parameter(1)
+ ROOT add0 = f32[3,4]{1,0} add(par0,par1)
+ }
+ )";
+
+TEST_F(HloVerifierTest, AddWithLayoutChange) {
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo));
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_TRUE(status.ok());
+}
+
+TEST_F(HloVerifierTestLayoutSensitive, AddWithLayoutChangeNotAllowed) {
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kAddWithLayoutChangeHlo));
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(),
+ HasSubstr("Instruction shouldn't change layouts"));
+}
+
+TEST_F(HloVerifierTestLayoutSensitive, SliceWithLayoutChangeNotAllowed) {
+ const char* const kSliceWithLayoutChangeHlo = R"(
+ HloModule SliceWithLayoutChange
+ ENTRY SliceWithLayoutChange {
+ par0 = f32[4,5]{0,1} parameter(0)
+ par1 = s32[2] parameter(1)
+ ROOT dslice0 = f32[3,4]{1,0} dynamic-slice(par0, par1),
+ dynamic_slice_sizes={3,4}
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseHloString(kSliceWithLayoutChangeHlo));
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(),
+ HasSubstr("Instruction shouldn't change layouts"));
+}
+
+TEST_F(HloVerifierTestLayoutSensitive, ConcatWithLayoutChangeNotAllowed) {
+ const char* const kConcatWithLayoutChangeHlo = R"(
+ HloModule ConcatWithLayoutChange
+ ENTRY ConcatWithLayoutChange {
+ par0 = f32[3,5]{0,1} parameter(0)
+ par1 = f32[3,3]{1,0} parameter(1)
+ ROOT concat0 = f32[3,8]{1,0} concatenate(f32[3,5] par0, f32[3,3] par1),
+ dimensions={1}
+ }
+ )";
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseHloString(kConcatWithLayoutChangeHlo));
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(),
+ HasSubstr("Instruction shouldn't change layouts"));
+}
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 5a99c40df4..69a4c160ee 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/fusion_queue.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h
index da2032f6c7..f14c667520 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.h
+++ b/tensorflow/compiler/xla/service/instruction_fusion.h
@@ -17,6 +17,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_INSTRUCTION_FUSION_H_
+#include "tensorflow/compiler/xla/service/fusion_queue.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@@ -25,33 +26,6 @@ limitations under the License.
namespace xla {
-// A queue interface that allows implementations to choose fusion candidates in
-// custom order.
-class FusionQueue {
- public:
- FusionQueue() = default;
- virtual ~FusionQueue() = default;
-
- // Dequeues the next fusion candidates: a consumer and the list of producers
- // as operand indices.
- virtual std::pair<HloInstruction*, std::vector<int64>>
- DequeueNextInstructionAndOperandsToFuseInOrder() = 0;
-
- // A callback passed to the queue implementation right before the producer is
- // fused into the consumer.
- virtual void PreFusion(HloInstruction* producer, HloInstruction* consumer) {}
-
- // A callback passed to the queue implementation right after the fusion is
- // created. Note that original_producer could have been destroyed.
- virtual void OnFusingInstruction(HloInstruction* fusion,
- HloInstruction* original_producer,
- HloInstruction* original_consumer) {}
-
- // A callback passed to the queue implementation to notify the removal of an
- // instruction.
- virtual void RemoveInstruction(HloInstruction* instruction) = 0;
-};
-
// HLO pass which performs instruction fusion. Instructions are fused
// "vertically", meaning producing instructions are fused into their consumers
// with the intent that the loops which compute their values will be fused in
diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD
index 146c9052f1..1484e14df1 100644
--- a/tensorflow/compiler/xla/service/interpreter/BUILD
+++ b/tensorflow/compiler/xla/service/interpreter/BUILD
@@ -45,8 +45,8 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
- "//tensorflow/compiler/xla/service:inliner",
"//tensorflow/compiler/xla/service:layout_assignment",
+ "//tensorflow/compiler/xla/service:map_inliner",
"//tensorflow/compiler/xla/service:reshape_mover",
"//tensorflow/compiler/xla/service:while_loop_simplifier",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc
index 27fe89375d..7c79eb7d79 100644
--- a/tensorflow/compiler/xla/service/interpreter/compiler.cc
+++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc
@@ -28,9 +28,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
-#include "tensorflow/compiler/xla/service/inliner.h"
#include "tensorflow/compiler/xla/service/interpreter/executable.h"
#include "tensorflow/compiler/xla/service/layout_assignment.h"
+#include "tensorflow/compiler/xla/service/map_inliner.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
#include "tensorflow/compiler/xla/status_macros.h"
diff --git a/tensorflow/compiler/xla/service/inliner.cc b/tensorflow/compiler/xla/service/map_inliner.cc
index 5fd779ebf9..2200ef054a 100644
--- a/tensorflow/compiler/xla/service/inliner.cc
+++ b/tensorflow/compiler/xla/service/map_inliner.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/service/inliner.h"
+#include "tensorflow/compiler/xla/service/map_inliner.h"
#include <memory>
#include <string>
@@ -32,10 +32,10 @@ limitations under the License.
namespace xla {
-// InlinerVisitor traverses the HLO computation and inlines maps.
-class InlinerVisitor : public DfsHloVisitorWithDefault {
+// MapInlinerVisitor traverses the HLO computation and inlines maps.
+class MapInlinerVisitor : public DfsHloVisitorWithDefault {
public:
- explicit InlinerVisitor(HloComputation* computation)
+ explicit MapInlinerVisitor(HloComputation* computation)
: computation_(computation) {}
// Default visitor action is to do nothing and return OK.
@@ -49,48 +49,44 @@ class InlinerVisitor : public DfsHloVisitorWithDefault {
StatusOr<bool> Run(HloComputation* computation);
private:
- // Current HloComputation instance the InlinerVisitor is traversing.
+ // Current HloComputation instance the MapInlinerVisitor is traversing.
HloComputation* computation_;
// Whether algebraic simplification has occurred.
bool changed_ = false;
};
-StatusOr<bool> InlinerVisitor::Run(HloComputation* computation) {
+StatusOr<bool> MapInlinerVisitor::Run(HloComputation* computation) {
changed_ = false;
computation_ = computation;
TF_RETURN_IF_ERROR(computation->root_instruction()->Accept(this));
return changed_;
}
-Status InlinerVisitor::HandleMap(HloInstruction* map) {
+Status MapInlinerVisitor::HandleMap(HloInstruction* map) {
HloComputation* function = map->to_apply();
HloInstruction& root = *function->root_instruction();
- // TODO(b/29249531): Add DCE pass to remove unused HloComputations.
// Only inlining functions that are simply a single operation until a better
// profitability model for inlining is defined.
if (hlo_query::AllOperandsAreParameters(root)) {
if (root.opcode() == HloOpcode::kFusion ||
- root.opcode() == HloOpcode::kParameter ||
root.opcode() == HloOpcode::kTrace) {
// Cloning not supported for these instructions.
return Status::OK();
}
VLOG(10) << "inlining map({X ... Y}, op) => : op(X ... Y) with function "
<< root.ToShortString();
- // If the input is a constant then the shape of the constant could be
- // different than the map shape. Hence, a broadcast is needed, else the
- // cloned operand with new shape and operands work.
- if (root.opcode() != HloOpcode::kConstant) {
- std::vector<HloInstruction*> params;
- for (int64 o = 0; o < root.operands().size(); o++) {
- params.push_back(map->operands()[root.operand(o)->parameter_number()]);
- }
- HloInstruction* placed_instruction = computation_->AddInstruction(
- root.CloneWithNewOperands(map->shape(), params));
+ if (root.opcode() == HloOpcode::kParameter) {
+ // If the root is a parameter, then use the corresponding operand as the
+ // result of the computation.
TF_RETURN_IF_ERROR(
- computation_->ReplaceInstruction(map, placed_instruction));
- } else {
+ map->ReplaceAllUsesWith(map->operands()[root.parameter_number()]));
+ TF_RETURN_IF_ERROR(computation_->RemoveInstruction(map));
+ } else if (root.opcode() == HloOpcode::kConstant) {
+ // If the input is a constant then the shape of the constant could be
+ // different than the map shape. Hence, a broadcast is needed, else the
+ // cloned operand with new shape and operands work.
+ //
// The constant is in an embedded computation and needs to be recreated
// as part of the computation that the broadcast is inserted into.
HloInstruction* constant = computation_->AddInstruction(root.Clone());
@@ -98,6 +94,15 @@ Status InlinerVisitor::HandleMap(HloInstruction* map) {
HloInstruction::CreateBroadcast(map->shape(), constant, {}));
TF_RETURN_IF_ERROR(
computation_->ReplaceInstruction(map, placed_instruction));
+ } else {
+ std::vector<HloInstruction*> params;
+ for (int64 o = 0; o < root.operands().size(); o++) {
+ params.push_back(map->operands()[root.operand(o)->parameter_number()]);
+ }
+ HloInstruction* placed_instruction = computation_->AddInstruction(
+ root.CloneWithNewOperands(map->shape(), params));
+ TF_RETURN_IF_ERROR(
+ computation_->ReplaceInstruction(map, placed_instruction));
}
changed_ = true;
return Status::OK();
@@ -106,8 +111,8 @@ Status InlinerVisitor::HandleMap(HloInstruction* map) {
return Status::OK();
}
-StatusOr<bool> Inliner::Run(HloModule* module) {
- InlinerVisitor visitor(/*computation=*/nullptr);
+StatusOr<bool> MapInliner::Run(HloModule* module) {
+ MapInlinerVisitor visitor(/*computation=*/nullptr);
bool changed = false;
for (HloComputation* computation : module->computations()) {
TF_ASSIGN_OR_RETURN(bool computation_changed, visitor.Run(computation));
diff --git a/tensorflow/compiler/xla/service/inliner.h b/tensorflow/compiler/xla/service/map_inliner.h
index e20af08fb7..b679118118 100644
--- a/tensorflow/compiler/xla/service/inliner.h
+++ b/tensorflow/compiler/xla/service/map_inliner.h
@@ -13,27 +13,27 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
-// A pass which performs inlining. Which can result, for example, in functions
-// that were previously being mapped by Map instead directly applied to the
-// forwarded operands (i.e., map({X, Y}, max) -> max(X, Y)).
-class Inliner : public HloModulePass {
+// A pass which performs map inlining. This replaces kMap instructions with
+// their equivalent sequence of array operations. For example:
+// map({X, Y}, add) -> add(X, Y)).
+class MapInliner : public HloModulePass {
public:
- ~Inliner() override = default;
- absl::string_view name() const override { return "inline"; }
+ ~MapInliner() override = default;
+ absl::string_view name() const override { return "map-inline"; }
- // Run inlining on the given computation. Returns whether the computation was
- // changed.
+ // Run map inlining on the given computation. Returns whether the computation
+ // was changed.
StatusOr<bool> Run(HloModule* module) override;
};
} // namespace xla
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_INLINER_H_
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MAP_INLINER_H_
diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/map_inliner_test.cc
index 7e967f035c..84059dd0f7 100644
--- a/tensorflow/compiler/xla/service/inliner_test.cc
+++ b/tensorflow/compiler/xla/service/map_inliner_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/xla/service/inliner.h"
+#include "tensorflow/compiler/xla/service/map_inliner.h"
#include <memory>
#include <utility>
@@ -35,10 +35,10 @@ namespace op = xla::testing::opcode_matchers;
namespace xla {
namespace {
-using InlinerTest = HloVerifiedTestBase;
+using MapInlinerTest = HloVerifiedTestBase;
// Test that `map` with `max` is transformed to `max`
-TEST_F(InlinerTest, MapMax) {
+TEST_F(MapInlinerTest, MapMax) {
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
auto max_builder = HloComputation::Builder(TestName());
@@ -63,7 +63,7 @@ TEST_F(InlinerTest, MapMax) {
hlo_module->AddEmbeddedComputation(std::move(max_f32));
hlo_module->AddEntryComputation(std::move(computation));
- Inliner inliner;
+ MapInliner inliner;
EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
op::Maximum(lhs, rhs));
@@ -75,7 +75,7 @@ TEST_F(InlinerTest, MapMax) {
}
// Test that `constant` function is changed to `broadcast`.
-TEST_F(InlinerTest, MapConstant) {
+TEST_F(MapInlinerTest, MapConstant) {
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
auto const2_builder = HloComputation::Builder(TestName());
@@ -97,7 +97,7 @@ TEST_F(InlinerTest, MapConstant) {
hlo_module->AddEmbeddedComputation(std::move(const2_f32));
hlo_module->AddEntryComputation(std::move(computation));
HloInstruction* root = hlo_module->entry_computation()->root_instruction();
- Inliner inliner;
+ MapInliner inliner;
EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
root = hlo_module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Broadcast(op::Constant()));
@@ -108,7 +108,7 @@ TEST_F(InlinerTest, MapConstant) {
EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
-TEST_F(InlinerTest, MapSubtractOppositeOrder) {
+TEST_F(MapInlinerTest, MapSubtractOppositeOrder) {
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
// Note that the parameter ordinals are in the opposite order to their
@@ -135,7 +135,7 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
hlo_module->AddEmbeddedComputation(std::move(max_f32));
hlo_module->AddEntryComputation(std::move(computation));
- Inliner inliner;
+ MapInliner inliner;
EXPECT_TRUE(inliner.Run(hlo_module).ValueOrDie());
EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
op::Subtract(rhs, lhs));
@@ -146,6 +146,36 @@ TEST_F(InlinerTest, MapSubtractOppositeOrder) {
EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
}
+TEST_F(MapInlinerTest, MapParameter) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+
+ auto param_builder = HloComputation::Builder(TestName());
+ param_builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32, "p0"));
+ param_builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "p1"));
+ auto param_f32 = param_builder.Build();
+
+ auto builder = HloComputation::Builder("MapParamFunction");
+ auto lhs = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
+ auto rhs = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(4)));
+ builder.AddInstruction(
+ HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, param_f32.get()));
+
+ auto computation = builder.Build();
+ auto hlo_module = CreateNewVerifiedModule();
+ hlo_module->AddEmbeddedComputation(std::move(param_f32));
+ hlo_module->AddEntryComputation(std::move(computation));
+
+ MapInliner inliner;
+ EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
+ EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), rhs);
+
+ // Verify execution on CPU.
+ auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
+ auto expected = LiteralUtil::CreateR0<float>(4);
+ EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
+}
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h
index 9199e32d0f..f952e64af2 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.h
+++ b/tensorflow/compiler/xla/service/transfer_manager.h
@@ -89,16 +89,6 @@ class TransferManager {
const LiteralSlice& literal,
const ShapedBuffer& device_buffer);
- // Hint type given to TransferLiteralToDeviceAsync.
- enum TransferToDeviceHint {
- // No hint available.
- kNoHint,
-
- // The destination buffer is undefined on the device, meaning it can be
- // transferred to eagerly rather than waiting for Stream ordering.
- kBufferUndefined,
- };
-
// Transfers the given literal into the previously allocated device memory
// represented by the given ShapedBuffer using the given executor. The shape
// of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible,
@@ -106,13 +96,9 @@ class TransferManager {
//
// This operation is performed asynchronously on the given stream. It returns
// once the transfer is enqueued.
- //
- // The optional hint can allow implementations to optimize transfers. It is
- // not mandatory for an implementation to obey the hint.
virtual Status TransferLiteralToDeviceAsync(
se::Stream* stream, const LiteralSlice& literal,
- const ShapedBuffer& device_buffer,
- TransferToDeviceHint hint = kNoHint) = 0;
+ const ShapedBuffer& device_buffer) = 0;
// Convenience methods for transferring an array to or from the device at a
// known address. This avoids having to construct a ShapedBuffer just to
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 476a9fe868..d244923532 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -869,11 +869,8 @@ StatusOr<Shape> ParseShapeStringInternal(absl::string_view* s) {
return Status::OK();
}
- if (Rank(shape) != shape.dimensions_size()) {
- return InvalidArgument(
- "shape's rank is mismatched with dimension count; rank=%d "
- "dimensions_size=%d",
- Rank(shape), shape.dimensions_size());
+ if (LayoutUtil::IsSparseArray(shape) && Rank(shape) == 0) {
+ return InvalidArgument("sparse arrays must have rank > 0");
}
for (int64 i = 0; i < Rank(shape); ++i) {
int64 dimension = shape.dimensions(i);
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
index 070b092d18..b851db14ec 100644
--- a/tensorflow/compiler/xla/tests/convolution_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -91,7 +91,14 @@ class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest {
XlaBuilder builder(TestName());
auto lhs = ConstantR4FromArray4D<T>(&builder, *alhs);
auto rhs = ConstantR4FromArray4D<T>(&builder, *arhs);
- Conv(lhs, rhs, {1, 1}, Padding::kValid);
+ PrecisionConfig precision;
+ // The left hand side of the convolution is numbers between 0 and 2304 which
+ // requires at least 11 mantissa bits and the DEFAULT precision config is
+ // allowed to round to bfloat16 which only has 7 mantissa bits.
+ precision.add_operand_precision(PrecisionConfig::HIGHEST);
+ precision.add_operand_precision(PrecisionConfig::DEFAULT);
+ Conv(lhs, rhs, {1, 1}, Padding::kValid, /*feature_group_count=*/1,
+ &precision);
ComputeAndCompare(&builder, {}, error_spec_);
}
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc
index bdd4fd7e3d..7ab2ecda58 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc
@@ -86,19 +86,25 @@ ProgramShape GetProgramShapeWithLayout(const HloModule& module) {
} // namespace
HloTestBase::HloTestBase(bool verifier_layout_sensitive,
- bool allow_mixed_precision_in_hlo_verifier)
+ bool allow_mixed_precision_in_hlo_verifier,
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func)
: HloTestBase(GetTestPlatform(), GetReferencePlatform(),
verifier_layout_sensitive,
- allow_mixed_precision_in_hlo_verifier) {}
+ allow_mixed_precision_in_hlo_verifier,
+ instruction_can_change_layout_func) {}
HloTestBase::HloTestBase(se::Platform* test_platform,
se::Platform* reference_platform,
bool verifier_layout_sensitive,
- bool allow_mixed_precision_in_hlo_verifier)
+ bool allow_mixed_precision_in_hlo_verifier,
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func)
: test_runner_(test_platform), reference_runner_(reference_platform) {
hlo_verifier_ = absl::make_unique<HloVerifier>(
/*layout_sensitive=*/verifier_layout_sensitive,
- /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier);
+ /*allow_mixed_precision=*/allow_mixed_precision_in_hlo_verifier,
+ instruction_can_change_layout_func);
}
std::unique_ptr<HloModule> HloTestBase::CreateNewModule(const string& name) {
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
index 0ae4bdc104..217428befa 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -88,14 +88,18 @@ class HloTestBase : public ::testing::Test {
// interpreter is the only supported backend, it will be both the test backend
// and the reference backend.
HloTestBase(bool verifier_layout_sensitive = false,
- bool allow_mixed_precision_in_hlo_verifier = true);
+ bool allow_mixed_precision_in_hlo_verifier = true,
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func = {});
// If your test doesn't use interpreter as the reference backend, you can use
// this constructor. Note that your test target is responsible for linking in
// both needed backends.
HloTestBase(se::Platform* test_platform, se::Platform* reference_platform,
bool verifier_layout_sensitive = false,
- bool allow_mixed_precision_in_hlo_verifier = true);
+ bool allow_mixed_precision_in_hlo_verifier = true,
+ std::function<bool(const HloInstruction*)>
+ instruction_can_change_layout_func = {});
~HloTestBase() override {}
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index c25ccafaf8..22fe4a2670 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -638,6 +638,8 @@ class R4ReduceWindowTest : public ReduceWindowTestBase,
/*computation=*/computation,
/*window_dimensions=*/param.window_bounds,
/*window_strides=*/param.strides,
+ /*base_dilations=*/{},
+ /*window_dilations=*/{},
/*padding=*/padding);
CHECK(reducer == kAdd || reducer == kMax);
@@ -1158,7 +1160,10 @@ class R2ReduceWindowTest : public ReduceWindowTestBase,
/*init_value=*/init_value,
/*computation=*/computation,
/*window_dimensions=*/param.window_bounds,
- /*window_strides=*/param.strides, /*padding=*/padding);
+ /*window_strides=*/param.strides,
+ /*base_dilations=*/{},
+ /*window_dilations=*/{},
+ /*padding=*/padding);
auto reduce_func = param.reducer == kAdd
? +[](float a, float b) { return a + b; }
@@ -1369,7 +1374,10 @@ TEST_P(R1ReduceWindowTest, DoIt) {
/*init_value=*/init_value,
/*computation=*/computation,
/*window_dimensions=*/param.window_bounds,
- /*window_strides=*/param.strides, /*padding=*/padding);
+ /*window_strides=*/param.strides,
+ /*base_dilations=*/{},
+ /*window_dilations=*/{},
+ /*padding=*/padding);
auto reduce_func = param.reducer == kAdd
? +[](float a, float b) { return a + b; }
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index fbe0573d5d..fa06d351d4 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -29,6 +29,7 @@ py_library(
"//tensorflow/contrib/cluster_resolver:cluster_resolver_py",
"//tensorflow/contrib/coder:coder_py",
"//tensorflow/contrib/compiler:compiler_py",
+ "//tensorflow/contrib/compiler:xla",
"//tensorflow/contrib/autograph",
"//tensorflow/contrib/constrained_optimization",
"//tensorflow/contrib/copy_graph:copy_graph_py",
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
index 1056894f18..f4a8e16c99 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
@@ -60,6 +60,7 @@ class TPUClusterResolver(ClusterResolver):
if (self._tpu == compat.as_bytes('') or
self._tpu == compat.as_bytes('local') or
self._tpu.startswith(compat.as_bytes('/bns')) or
+ self._tpu.startswith(compat.as_bytes('localhost:')) or
self._tpu.startswith(compat.as_bytes('grpc://'))):
return False
return True
diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt
index f675c135f4..244683765a 100644
--- a/tensorflow/contrib/cmake/CMakeLists.txt
+++ b/tensorflow/contrib/cmake/CMakeLists.txt
@@ -1,6 +1,16 @@
# Minimum CMake required
cmake_minimum_required(VERSION 3.5)
+if(WIN32)
+ if(${CMAKE_VERSION} VERSION_LESS "3.8")
+ message(WARNING "Your current cmake version is ${CMAKE_VERSION} which does not support setting the toolset architecture to x64. This may cause \"compiler out of heap space\" errors when building. Consider upgrading your cmake to > 3.8 and using the flag -Thost=x64 when running cmake.")
+ else()
+ if(NOT CMAKE_VS_PLATFORM_TOOLSET_HOST_ARCHITECTURE OR NOT "${CMAKE_VS_PLATFORM_TOOLSET_HOST_ARCHITECTURE}" STREQUAL "x64")
+ message(WARNING "Your current cmake generator is set to use 32 bit toolset architecture. This may cause \"compiler out of heap space\" errors when building. Consider using the flag -Thost=x64 when running cmake.")
+ endif()
+ endif()
+endif()
+
# Project
project(tensorflow C CXX)
@@ -352,9 +362,7 @@ if (tensorflow_ENABLE_MKL_SUPPORT)
list(APPEND tensorflow_EXTERNAL_LIBRARIES ${mkldnn_STATIC_LIBRARIES})
list(APPEND tensorflow_EXTERNAL_DEPENDENCIES mkldnn_copy_shared_to_destination)
include_directories(${mkldnn_INCLUDE_DIRS})
- else (tensorflow_ENABLE_MKLDNN_SUPPORT)
- add_definitions(-DINTEL_MKL_ML_ONLY)
- endif()
+ endif(tensorflow_ENABLE_MKLDNN_SUPPORT)
endif (tensorflow_ENABLE_MKL_SUPPORT)
if (tensorflow_ENABLE_GPU)
diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md
index 77242b34fd..84c679162c 100644
--- a/tensorflow/contrib/cmake/README.md
+++ b/tensorflow/contrib/cmake/README.md
@@ -108,180 +108,177 @@ ops or APIs.
Step-by-step Windows build
==========================
-1. Install the prerequisites detailed above, and set up your environment.
-
- * The following commands assume that you are using the Windows Command
- Prompt (`cmd.exe`). You will need to set up your environment to use the
- appropriate toolchain, i.e. the 64-bit tools. (Some of the binary targets
- we will build are too large for the 32-bit tools, and they will fail with
- out-of-memory errors.) The typical command to do set up your
- environment is:
-
- ```
- D:\temp> "C:\Program Files (x86)\Microsoft Visual Studio 14.0\VC\bin\amd64\vcvarsall.bat"
- ```
-
- * When building with GPU support after installing the CUDNN zip file from NVidia, append its
- bin directory to your PATH environment variable.
- In case TensorFlow fails to find the CUDA dll's during initialization, check your PATH environment variable.
- It should contain the directory of the CUDA dlls and the directory of the CUDNN dll.
- For example:
-
- ```
- D:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\bin
- D:\local\cuda\bin
- ```
-
- * When building with MKL support after installing [MKL](https://software.intel.com/en-us/mkl) from INTEL, append its bin directories to your PATH environment variable.
-
- In case TensorFlow fails to find the MKL dll's during initialization, check your PATH environment variable.
- It should contain the directory of the MKL dlls. For example:
-
- ```
- D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\mkl
- D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\compiler
- D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\tbb\vc_mt
- ```
-
-
- * We assume that `cmake` and `git` are installed and in your `%PATH%`. If
- for example `cmake` is not in your path and it is installed in
- `C:\Program Files (x86)\CMake\bin\cmake.exe`, you can add this directory
- to your `%PATH%` as follows:
-
- ```
- D:\temp> set PATH="%PATH%;C:\Program Files (x86)\CMake\bin\cmake.exe"
- ```
-
-2. Clone the TensorFlow repository and create a working directory for your
- build:
-
- ```
- D:\temp> git clone https://github.com/tensorflow/tensorflow.git
- D:\temp> cd tensorflow\tensorflow\contrib\cmake
- D:\temp\tensorflow\tensorflow\contrib\cmake> mkdir build
- D:\temp\tensorflow\tensorflow\contrib\cmake> cd build
- D:\temp\tensorflow\tensorflow\contrib\cmake\build>
- ```
-
-3. Invoke CMake to create Visual Studio solution and project files.
-
- **N.B.** This assumes that `cmake.exe` is in your `%PATH%` environment
- variable. The other paths are for illustrative purposes only, and may
- be different on your platform. The `^` character is a line continuation
- and must be the last character on each line.
-
- ```
- D:\...\build> cmake .. -A x64 -DCMAKE_BUILD_TYPE=Release ^
- More? -DSWIG_EXECUTABLE=C:/tools/swigwin-3.0.10/swig.exe ^
- More? -DPYTHON_EXECUTABLE=C:/Users/%USERNAME%/AppData/Local/Continuum/Anaconda3/python.exe ^
- More? -DPYTHON_LIBRARIES=C:/Users/%USERNAME%/AppData/Local/Continuum/Anaconda3/libs/python35.lib
- ```
- To build with GPU support add "^" at the end of the last line above following with:
- ```
- More? -Dtensorflow_ENABLE_GPU=ON ^
- More? -DCUDNN_HOME="D:\...\cudnn"
- ```
- To build with MKL support add "^" at the end of the last line above following with:
-
- ```
- More? -Dtensorflow_ENABLE_MKL_SUPPORT=ON ^
- More? -DMKL_HOME="D:\...\compilers_and_libraries"
- ```
-
- To enable SIMD instructions with MSVC, as AVX and SSE, define it as follows:
-
- ```
- More? -Dtensorflow_WIN_CPU_SIMD_OPTIONS=/arch:AVX
- ```
-
- Note that the `-DCMAKE_BUILD_TYPE=Release` flag must match the build
- configuration that you choose when invoking `msbuild`. The known-good
- values are `Release` and `RelWithDebInfo`. The `Debug` build type is
- not currently supported, because it relies on a `Debug` library for
- Python (`python35d.lib`) that is not distributed by default.
-
- There are various options that can be specified when generating the
- solution and project files:
-
- * `-DCMAKE_BUILD_TYPE=(Release|RelWithDebInfo)`: Note that the
- `CMAKE_BUILD_TYPE` option must match the build configuration that you
- choose when invoking MSBuild in step 4. The known-good values are
- `Release` and `RelWithDebInfo`. The `Debug` build type is not currently
- supported, because it relies on a `Debug` library for Python
- (`python35d.lib`) that is not distributed by default.
-
- * `-Dtensorflow_BUILD_ALL_KERNELS=(ON|OFF)`. Defaults to `ON`. You can
- build a small subset of the kernels for a faster build by setting this
- option to `OFF`.
-
- * `-Dtensorflow_BUILD_CC_EXAMPLE=(ON|OFF)`. Defaults to `ON`. Generate
- project files for a simple C++
- [example training program](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/tutorials/example_trainer.cc).
-
- * `-Dtensorflow_BUILD_PYTHON_BINDINGS=(ON|OFF)`. Defaults to `ON`. Generate
- project files for building a PIP package containing the TensorFlow runtime
- and its Python bindings.
-
- * `-Dtensorflow_ENABLE_GRPC_SUPPORT=(ON|OFF)`. Defaults to `ON`. Include
- gRPC support and the distributed client and server code in the TensorFlow
- runtime.
-
- * `-Dtensorflow_ENABLE_SSL_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include
- SSL support (for making secure HTTP requests) in the TensorFlow runtime.
- This support is incomplete, and will be used for Google Cloud Storage
- support.
-
- * `-Dtensorflow_ENABLE_GPU=(ON|OFF)`. Defaults to `OFF`. Include
- GPU support. If GPU is enabled you need to install the CUDA 8.0 Toolkit and CUDNN 5.1.
- CMake will expect the location of CUDNN in -DCUDNN_HOME=path_you_unzipped_cudnn.
-
- * `-Dtensorflow_BUILD_CC_TESTS=(ON|OFF)`. Defaults to `OFF`. This builds cc unit tests.
- There are many of them and building will take a few hours.
- After cmake, build and execute the tests with
- ```
- MSBuild /p:Configuration=RelWithDebInfo ALL_BUILD.vcxproj
- ctest -C RelWithDebInfo
- ```
-
- * `-Dtensorflow_BUILD_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This enables python kernel tests.
- After building the python wheel, you need to install the new wheel before running the tests.
- To execute the tests, use
- ```
- ctest -C RelWithDebInfo
- ```
-
- * `-Dtensorflow_BUILD_MORE_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This enables python tests on
- serveral major packages. This option is only valid if this and tensorflow_BUILD_PYTHON_TESTS are both set as `ON`.
- After building the python wheel, you need to install the new wheel before running the tests.
- To execute the tests, use
- ```
- ctest -C RelWithDebInfo
- ```
-
- * `-Dtensorflow_ENABLE_MKL_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include MKL support. If MKL is enabled you need to install the [Intel Math Kernal Library](https://software.intel.com/en-us/mkl).
- CMake will expect the location of MKL in -MKL_HOME=path_you_install_mkl.
-
- * `-Dtensorflow_ENABLE_MKLDNN_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include MKL DNN support. MKL DNN is [Intel(R) Math Kernel Library for Deep Neural Networks (Intel(R) MKL-DNN)](https://github.com/intel/mkl-dnn). You have to add `-Dtensorflow_ENABLE_MKL_SUPPORT=ON` before including MKL DNN support.
-
-
-4. Invoke MSBuild to build TensorFlow.
-
- To build the C++ example program, which will be created as a `.exe`
- executable in the subdirectory `.\Release`:
-
- ```
- D:\...\build> MSBuild /p:Configuration=Release tf_tutorials_example_trainer.vcxproj
- D:\...\build> Release\tf_tutorials_example_trainer.exe
- ```
-
- To build the PIP package, which will be created as a `.whl` file in the
- subdirectory `.\tf_python\dist`:
-
- ```
- D:\...\build> MSBuild /p:Configuration=Release tf_python_build_pip_package.vcxproj
- ```
-
+1. Install the prerequisites detailed above, and set up your environment.
+
+ * When building with GPU support after installing the CUDNN zip file from
+ NVidia, append its bin directory to your PATH environment variable. In
+ case TensorFlow fails to find the CUDA dll's during initialization,
+ check your PATH environment variable. It should contain the directory of
+ the CUDA dlls and the directory of the CUDNN dll. For example:
+
+ ```
+ D:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\bin
+ D:\local\cuda\bin
+ ```
+
+ * When building with MKL support after installing
+ [MKL](https://software.intel.com/en-us/mkl) from INTEL, append its bin
+ directories to your PATH environment variable.
+
+ In case TensorFlow fails to find the MKL dll's during initialization,
+ check your PATH environment variable. It should contain the directory of
+ the MKL dlls. For example:
+
+ ```
+ D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\mkl
+ D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\compiler
+ D:\Tools\IntelSWTools\compilers_and_libraries\windows\redist\intel64\tbb\vc_mt
+ ```
+
+ * We assume that `cmake` and `git` are installed and in your `%PATH%`. If
+ for example `cmake` is not in your path and it is installed in
+ `C:\Program Files (x86)\CMake\bin\cmake.exe`, you can add this directory
+ to your `%PATH%` as follows:
+
+ ```
+ D:\temp> set PATH="%PATH%;C:\Program Files (x86)\CMake\bin\cmake.exe"
+ ```
+
+2. Clone the TensorFlow repository and create a working directory for your
+ build:
+
+ ```
+ D:\temp> git clone https://github.com/tensorflow/tensorflow.git
+ D:\temp> cd tensorflow\tensorflow\contrib\cmake
+ D:\temp\tensorflow\tensorflow\contrib\cmake> mkdir build
+ D:\temp\tensorflow\tensorflow\contrib\cmake> cd build
+ D:\temp\tensorflow\tensorflow\contrib\cmake\build>
+ ```
+
+3. Invoke CMake to create Visual Studio solution and project files.
+
+ **N.B.** This assumes that `cmake.exe` is in your `%PATH%` environment
+ variable. The other paths are for illustrative purposes only, and may be
+ different on your platform. The `^` character is a line continuation and
+ must be the last character on each line.
+
+ ```
+ D:\...\build> cmake .. -A x64 -Thost=x64 -DCMAKE_BUILD_TYPE=Release ^
+ More? -DSWIG_EXECUTABLE=C:/tools/swigwin-3.0.10/swig.exe ^
+ More? -DPYTHON_EXECUTABLE=C:/Users/%USERNAME%/AppData/Local/Continuum/Anaconda3/python.exe ^
+ More? -DPYTHON_LIBRARIES=C:/Users/%USERNAME%/AppData/Local/Continuum/Anaconda3/libs/python35.lib
+ ```
+
+ To build with GPU support add "^" at the end of the last line above
+ following with: `More? -Dtensorflow_ENABLE_GPU=ON ^ More?
+ -DCUDNN_HOME="D:\...\cudnn"` To build with MKL support add "^" at the end of
+ the last line above following with:
+
+ ```
+ More? -Dtensorflow_ENABLE_MKL_SUPPORT=ON ^
+ More? -DMKL_HOME="D:\...\compilers_and_libraries"
+ ```
+
+ To enable SIMD instructions with MSVC, as AVX and SSE, define it as follows:
+
+ ```
+ More? -Dtensorflow_WIN_CPU_SIMD_OPTIONS=/arch:AVX
+ ```
+
+ Note that the `-DCMAKE_BUILD_TYPE=Release` flag must match the build
+ configuration that you choose when invoking `msbuild`. The known-good values
+ are `Release` and `RelWithDebInfo`. The `Debug` build type is not currently
+ supported, because it relies on a `Debug` library for Python
+ (`python35d.lib`) that is not distributed by default.
+
+ The `-Thost=x64` flag will ensure that the 64 bit compiler and linker is
+ used when building. Without this flag, MSBuild will use the 32 bit toolchain
+ which is prone to compile errors such as "compiler out of heap space".
+
+ There are various options that can be specified when generating the solution
+ and project files:
+
+ * `-DCMAKE_BUILD_TYPE=(Release|RelWithDebInfo)`: Note that the
+ `CMAKE_BUILD_TYPE` option must match the build configuration that you
+ choose when invoking MSBuild in step 4. The known-good values are
+ `Release` and `RelWithDebInfo`. The `Debug` build type is not currently
+ supported, because it relies on a `Debug` library for Python
+ (`python35d.lib`) that is not distributed by default.
+
+ * `-Dtensorflow_BUILD_ALL_KERNELS=(ON|OFF)`. Defaults to `ON`. You can
+ build a small subset of the kernels for a faster build by setting this
+ option to `OFF`.
+
+ * `-Dtensorflow_BUILD_CC_EXAMPLE=(ON|OFF)`. Defaults to `ON`. Generate
+ project files for a simple C++
+ [example training program](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/tutorials/example_trainer.cc).
+
+ * `-Dtensorflow_BUILD_PYTHON_BINDINGS=(ON|OFF)`. Defaults to `ON`.
+ Generate project files for building a PIP package containing the
+ TensorFlow runtime and its Python bindings.
+
+ * `-Dtensorflow_ENABLE_GRPC_SUPPORT=(ON|OFF)`. Defaults to `ON`. Include
+ gRPC support and the distributed client and server code in the
+ TensorFlow runtime.
+
+ * `-Dtensorflow_ENABLE_SSL_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include
+ SSL support (for making secure HTTP requests) in the TensorFlow runtime.
+ This support is incomplete, and will be used for Google Cloud Storage
+ support.
+
+ * `-Dtensorflow_ENABLE_GPU=(ON|OFF)`. Defaults to `OFF`. Include GPU
+ support. If GPU is enabled you need to install the CUDA 8.0 Toolkit and
+ CUDNN 5.1. CMake will expect the location of CUDNN in
+ -DCUDNN_HOME=path_you_unzipped_cudnn.
+
+ * `-Dtensorflow_BUILD_CC_TESTS=(ON|OFF)`. Defaults to `OFF`. This builds
+ cc unit tests. There are many of them and building will take a few
+ hours. After cmake, build and execute the tests with `MSBuild
+ /p:Configuration=RelWithDebInfo ALL_BUILD.vcxproj ctest -C
+ RelWithDebInfo`
+
+ * `-Dtensorflow_BUILD_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This
+ enables python kernel tests. After building the python wheel, you need
+ to install the new wheel before running the tests. To execute the tests,
+ use `ctest -C RelWithDebInfo`
+
+ * `-Dtensorflow_BUILD_MORE_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This
+ enables python tests on serveral major packages. This option is only
+ valid if this and tensorflow_BUILD_PYTHON_TESTS are both set as `ON`.
+ After building the python wheel, you need to install the new wheel
+ before running the tests. To execute the tests, use `ctest -C
+ RelWithDebInfo`
+
+ * `-Dtensorflow_ENABLE_MKL_SUPPORT=(ON|OFF)`. Defaults to `OFF`. Include
+ MKL support. If MKL is enabled you need to install the
+ [Intel Math Kernal Library](https://software.intel.com/en-us/mkl). CMake
+ will expect the location of MKL in -MKL_HOME=path_you_install_mkl.
+
+ * `-Dtensorflow_ENABLE_MKLDNN_SUPPORT=(ON|OFF)`. Defaults to `OFF`.
+ Include MKL DNN support. MKL DNN is [Intel(R) Math Kernel Library for
+ Deep Neural Networks (Intel(R)
+ MKL-DNN)](https://github.com/intel/mkl-dnn). You have to add
+ `-Dtensorflow_ENABLE_MKL_SUPPORT=ON` before including MKL DNN support.
+
+4. Invoke MSBuild to build TensorFlow.
+
+ Set up the path to find MSbuild: `D:\temp> "C:\Program Files (x86)\Microsoft
+ Visual Studio 14.0\VC\bin\amd64\vcvarsall.bat"`
+
+ To build the C++ example program, which will be created as a `.exe`
+ executable in the subdirectory `.\Release`:
+
+ ```
+ D:\...\build> MSBuild /p:Configuration=Release tf_tutorials_example_trainer.vcxproj
+ D:\...\build> Release\tf_tutorials_example_trainer.exe
+ ```
+
+ To build the PIP package, which will be created as a `.whl` file in the
+ subdirectory `.\tf_python\dist`:
+
+ ```
+ D:\...\build> MSBuild /p:Configuration=Release tf_python_build_pip_package.vcxproj
+ ```
Linux Continuous Integration build
==================================
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index 8267612236..76d5b59ce1 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -412,6 +412,24 @@ cuda_py_test(
)
cuda_py_test(
+ name = "moving_averages_test",
+ srcs = ["moving_averages_test.py"],
+ additional_deps = [
+ ":combinations",
+ "@absl_py//absl/testing:parameterized",
+ "//tensorflow/python/eager:test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ ],
+ tags = [
+ "no_pip",
+ ],
+)
+
+cuda_py_test(
name = "optimizer_v2_test",
srcs = ["optimizer_v2_test.py"],
additional_deps = [
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index cff4b0a463..63a163e76c 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -349,26 +349,26 @@ mirrored_strategy_with_two_gpus = NamedDistribution(
required_gpus=2)
-adam_optimizer_v1_fn = NamedObject(
- "AdamV1", lambda: adam.AdamOptimizer(0.001, epsilon=1))
gradient_descent_optimizer_v1_fn = NamedObject(
"GradientDescentV1", lambda: gradient_descent.GradientDescentOptimizer(0.2))
adagrad_optimizer_v1_fn = NamedObject(
"AdagradV1", lambda: adagrad.AdagradOptimizer(0.001))
+adam_optimizer_v1_fn = NamedObject("AdamV1",
+ lambda: adam.AdamOptimizer(0.001, epsilon=1))
rmsprop_optimizer_v1_fn = NamedObject(
"RmsPropV1", lambda: rmsprop.RMSPropOptimizer(0.001))
-optimizers_v1 = [adam_optimizer_v1_fn, gradient_descent_optimizer_v1_fn,
- adagrad_optimizer_v1_fn]
-adam_optimizer_v2_fn = NamedObject(
- "AdamV2", lambda: adam_v2.AdamOptimizer(0.001, epsilon=1))
+optimizers_v1 = [gradient_descent_optimizer_v1_fn, adagrad_optimizer_v1_fn]
+
gradient_descent_optimizer_v2_fn = NamedObject(
"GradientDescentV2",
lambda: gradient_descent_v2.GradientDescentOptimizer(0.2))
adagrad_optimizer_v2_fn = NamedObject(
"AdagradV2", lambda: adagrad_v2.AdagradOptimizer(0.001))
-optimizers_v2 = [adam_optimizer_v2_fn, gradient_descent_optimizer_v2_fn,
- adagrad_optimizer_v2_fn]
+adam_optimizer_v2_fn = NamedObject(
+ "AdamV2", lambda: adam_v2.AdamOptimizer(0.001, epsilon=1))
+
+optimizers_v2 = [gradient_descent_optimizer_v2_fn, adagrad_optimizer_v2_fn]
graph_and_eager_modes = ["graph", "eager"]
diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py
index a84ef04196..da7f8c548f 100644
--- a/tensorflow/contrib/distribute/python/examples/keras_mnist.py
+++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py
@@ -113,7 +113,7 @@ def main(_):
distribute=strategy)
# Train the model with the train dataset.
- model.fit(x=train_ds, epochs=20, steps_per_epoch=310)
+ model.fit(x=train_ds, epochs=20, steps_per_epoch=468)
# Evaluate the model with the eval dataset.
score = model.evaluate(eval_ds, steps=10, verbose=0)
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py
index ba147e7824..60e134055f 100644
--- a/tensorflow/contrib/distribute/python/minimize_loss_test.py
+++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py
@@ -179,11 +179,6 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
def get_expected_variables(optimizer_fn, num_parameter_devices):
variables_map = {
"GradientDescent": ["dense/kernel", "dense/bias"],
- "Adam": [
- "dense/kernel", "dense/bias", "beta1_power", "beta2_power",
- "dense/kernel/Adam", "dense/kernel/Adam_1", "dense/bias/Adam",
- "dense/bias/Adam_1"
- ],
"Adagrad": [
"dense/kernel/Adagrad", "dense/kernel",
"dense/bias/Adagrad", "dense/bias"
diff --git a/tensorflow/contrib/distribute/python/moving_averages_test.py b/tensorflow/contrib/distribute/python/moving_averages_test.py
new file mode 100644
index 0000000000..119352ad91
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/moving_averages_test.py
@@ -0,0 +1,141 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for training.moving_averages when using a DistributionStrategy."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.contrib.distribute.python import combinations
+from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.training import moving_averages
+
+
+all_combinations = combinations.combine(
+ distribution=[combinations.default_strategy,
+ combinations.one_device_strategy,
+ combinations.mirrored_strategy_with_gpu_and_cpu],
+ mode=["graph"])
+
+
+class AssignMovingAveragesTest(test.TestCase, parameterized.TestCase):
+
+ @combinations.generate(all_combinations)
+ def testTowerModeWithoutZeroDebias(self, distribution):
+ tower_id = [0]
+
+ def tower_fn():
+ var = variables.Variable([10.0, 11.0])
+ val = constant_op.constant([1.0 + tower_id[0], 2.0 - tower_id[0]])
+ tower_id[0] += 1
+ decay = 0.25
+ assign = moving_averages.assign_moving_average(
+ var, val, decay, zero_debias=False)
+ return var, assign
+
+ with distribution.scope(), self.cached_session() as sess:
+ var, assign = distribution.call_for_each_tower(tower_fn)
+ variables.global_variables_initializer().run()
+ self.assertAllClose([10.0, 11.0], var.eval())
+ sess.run(distribution.unwrap(assign))
+ # Mean of val across calls to tower_fn().
+ average_val = [1.0 + 0.5 * (tower_id[0] - 1),
+ 2.0 - 0.5 * (tower_id[0] - 1)]
+ val_weight = 1.0 - 0.25
+ self.assertAllClose(
+ [10.0 * 0.25 + average_val[0] * val_weight,
+ 11.0 * 0.25 + average_val[1] * val_weight],
+ var.eval())
+
+ @combinations.generate(all_combinations)
+ def testTowerMode(self, distribution):
+ tower_id = [0]
+
+ def tower_fn():
+ var = variables.Variable([0.0, 0.0])
+ val = constant_op.constant([1.0 + tower_id[0], 2.0 - tower_id[0]])
+ tower_id[0] += 1
+ decay = 0.25
+ assign = moving_averages.assign_moving_average(var, val, decay)
+ return var, assign.op
+
+ with distribution.scope(), self.cached_session() as sess:
+ var, assign_op = distribution.call_for_each_tower(tower_fn)
+ variables.global_variables_initializer().run()
+ self.assertAllClose([0.0, 0.0], var.eval())
+ sess.run(distribution.unwrap(assign_op))
+ # Mean of val across calls to tower_fn().
+ average_val = [1.0 + 0.5 * (tower_id[0] - 1),
+ 2.0 - 0.5 * (tower_id[0] - 1)]
+ self.assertAllClose(average_val, var.eval())
+
+ @combinations.generate(all_combinations)
+ def testCrossTowerWithoutZeroDebias(self, distribution):
+ with distribution.scope(), self.cached_session() as sess:
+ var = variables.Variable([10.0, 11.0])
+ val = constant_op.constant([1.0, 2.0])
+ decay = 0.25
+ # NOTE(josh11b): We currently generate an error if val is a PerDevice value.
+ assign = moving_averages.assign_moving_average(
+ var, val, decay, zero_debias=False)
+
+ variables.global_variables_initializer().run()
+ self.assertAllClose([10.0, 11.0], var.eval())
+ sess.run(assign)
+ average_val = [1.0, 2.0]
+ val_weight = 1.0 - 0.25
+ self.assertAllClose(
+ [10.0 * 0.25 + average_val[0] * val_weight,
+ 11.0 * 0.25 + average_val[1] * val_weight],
+ var.eval())
+ # Also try assign.op.
+ sess.run(assign.op)
+ orig_weight = 0.25 * 0.25
+ val_weight = 1.0 - orig_weight
+ self.assertAllClose(
+ [10.0 * orig_weight + average_val[0] * val_weight,
+ 11.0 * orig_weight + average_val[1] * val_weight],
+ var.eval())
+
+ @combinations.generate(all_combinations)
+ def testCrossTower(self, distribution):
+ with distribution.scope(), self.cached_session() as sess:
+ var = variables.Variable([0.0, 0.0])
+ val = array_ops.placeholder(dtypes.float32)
+ decay = 0.25
+ # NOTE(josh11b): We currently generate an error if val is a PerDevice value.
+ assign = moving_averages.assign_moving_average(var, val, decay)
+
+ variables.global_variables_initializer().run()
+ self.assertAllClose([0.0, 0.0], var.eval())
+ sess.run(assign, feed_dict={val: [1.0, 2.0]})
+ self.assertAllClose([1.0, 2.0], var.eval())
+
+ # Also try assign.op.
+ sess.run(assign.op, feed_dict={val: [10.0, 0.0]})
+ self.assertAllClose(
+ [(1.0 * 0.25 + 10.0) / (1.0 * 0.25 + 1.0),
+ (2.0 * 0.25 + 0.0) / (1.0 * 0.25 + 1.0)],
+ var.eval())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
index 353d11a583..9c112e4f85 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
@@ -262,7 +262,9 @@ class ParameterServerStrategyTestBase(
h = f + 1.0
self.assertEqual(
device_util.canonicalize(u.device), tower_variable_device)
- self.assertEqual(device_util.canonicalize(x.device), h.device)
+ self.assertEqual(
+ device_util.canonicalize(x.device),
+ device_util.canonicalize(h.device))
return y_add, z_add, f
y, z, f = d.call_for_each_tower(model_fn)
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 18ceba42c2..0dd78ba185 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -571,6 +571,10 @@ class TPUMirroredVariable(checkpointable.CheckpointableBase):
ValueError("Device %s not found in %s (current device %s)" %
(device, self._index.keys(), device_util.current())), e)
+ @property
+ def device(self):
+ return self._get().device
+
# The arguments to update() are automatically unwrapped so the update()
# function would normally see regular variables, not MirroredVariables.
# However, the update function can still operate on wrapped MirroredVariables
diff --git a/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb
index 8fae622e12..446e340118 100644
--- a/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb
+++ b/tensorflow/contrib/eager/python/examples/notebooks/automatic_differentiation.ipynb
@@ -65,7 +65,7 @@
"\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/eager/automatic_differentiation.ipynb\"\u003e\n",
" \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
"\u003c/td\u003e\u003ctd\u003e\n",
- "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/blob/master/site/en/tutorials/eager/automatic_differentiation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e"
+ "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/eager/automatic_differentiation.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e"
]
}
],
diff --git a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
index a1f1c5f3d7..b131ed4f12 100644
--- a/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
+++ b/tensorflow/contrib/estimator/python/estimator/boosted_trees.py
@@ -75,7 +75,7 @@ class _BoostedTreesEstimator(canned_boosted_trees._BoostedTreesBase): # pylint:
layer.
head: the `Head` instance defined for Estimator.
model_dir: Directory to save model parameters, graph and etc. This can
- also be used to load checkpoints from the directory into a estimator
+ also be used to load checkpoints from the directory into an estimator
to continue training a previously saved model.
weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
@@ -199,7 +199,7 @@ def boosted_trees_classifier_train_in_memory(
the model. All items in the set should be instances of classes derived
from `FeatureColumn`.
model_dir: Directory to save model parameters, graph and etc. This can
- also be used to load checkpoints from the directory into a estimator
+ also be used to load checkpoints from the directory into an estimator
to continue training a previously saved model.
n_classes: number of label classes. Default is binary classification.
Multiclass support is not yet implemented.
@@ -345,7 +345,7 @@ def boosted_trees_regressor_train_in_memory(
the model. All items in the set should be instances of classes derived
from `FeatureColumn`.
model_dir: Directory to save model parameters, graph and etc. This can
- also be used to load checkpoints from the directory into a estimator
+ also be used to load checkpoints from the directory into an estimator
to continue training a previously saved model.
label_dimension: Number of regression targets per example.
Multi-dimensional support is not yet implemented.
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py
index 724bc2c82f..4e7965ef26 100644
--- a/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py
+++ b/tensorflow/contrib/estimator/python/estimator/dnn_linear_combined.py
@@ -118,7 +118,7 @@ class DNNLinearCombinedEstimator(estimator.Estimator):
head: A `_Head` instance constructed with a method such as
`tf.contrib.estimator.multi_label_head`.
model_dir: Directory to save model parameters, graph and etc. This can
- also be used to load checkpoints from the directory into a estimator
+ also be used to load checkpoints from the directory into an estimator
to continue training a previously saved model.
linear_feature_columns: An iterable containing all the feature columns
used by linear part of the model. All items in the set must be
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
index 6ca7aaf989..40a91175b7 100644
--- a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
+++ b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
@@ -248,7 +248,7 @@ def DNNClassifierWithLayerAnnotations( # pylint: disable=invalid-name
model. All items in the set should be instances of classes derived from
`_FeatureColumn`.
model_dir: Directory to save model parameters, graph and etc. This can also
- be used to load checkpoints from the directory into a estimator to
+ be used to load checkpoints from the directory into an estimator to
continue training a previously saved model.
n_classes: Number of label classes. Defaults to 2, namely binary
classification. Must be > 1.
diff --git a/tensorflow/contrib/estimator/python/estimator/rnn.py b/tensorflow/contrib/estimator/python/estimator/rnn.py
index 98660bb731..c595f47395 100644
--- a/tensorflow/contrib/estimator/python/estimator/rnn.py
+++ b/tensorflow/contrib/estimator/python/estimator/rnn.py
@@ -30,7 +30,6 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.layers import core as core_layers
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import partitioned_variables
@@ -92,55 +91,6 @@ def _make_rnn_cell_fn(num_units, cell_type='basic_rnn'):
return rnn_cell_fn
-def _concatenate_context_input(sequence_input, context_input):
- """Replicates `context_input` across all timesteps of `sequence_input`.
-
- Expands dimension 1 of `context_input` then tiles it `sequence_length` times.
- This value is appended to `sequence_input` on dimension 2 and the result is
- returned.
-
- Args:
- sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size,
- padded_length, d0]`.
- context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`.
-
- Returns:
- A `Tensor` of dtype `float32` and shape `[batch_size, padded_length,
- d0 + d1]`.
-
- Raises:
- ValueError: If `sequence_input` does not have rank 3 or `context_input` does
- not have rank 2.
- """
- seq_rank_check = check_ops.assert_rank(
- sequence_input,
- 3,
- message='sequence_input must have rank 3',
- data=[array_ops.shape(sequence_input)])
- seq_type_check = check_ops.assert_type(
- sequence_input,
- dtypes.float32,
- message='sequence_input must have dtype float32; got {}.'.format(
- sequence_input.dtype))
- ctx_rank_check = check_ops.assert_rank(
- context_input,
- 2,
- message='context_input must have rank 2',
- data=[array_ops.shape(context_input)])
- ctx_type_check = check_ops.assert_type(
- context_input,
- dtypes.float32,
- message='context_input must have dtype float32; got {}.'.format(
- context_input.dtype))
- with ops.control_dependencies(
- [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]):
- padded_length = array_ops.shape(sequence_input)[1]
- tiled_context_input = array_ops.tile(
- array_ops.expand_dims(context_input, 1),
- array_ops.concat([[1], [padded_length], [1]], 0))
- return array_ops.concat([sequence_input, tiled_context_input], 2)
-
-
def _select_last_activations(activations, sequence_lengths):
"""Selects the nth set of activations for each n in `sequence_length`.
@@ -222,8 +172,8 @@ def _rnn_logit_fn_builder(output_units, rnn_cell_fn, sequence_feature_columns,
context_input = feature_column_lib.input_layer(
features=features,
feature_columns=context_feature_columns)
- sequence_input = _concatenate_context_input(sequence_input,
- context_input)
+ sequence_input = seq_fc.concatenate_context_input(
+ context_input, sequence_input)
cell = rnn_cell_fn(mode)
# Ignore output state.
diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD
index aab7d0c9e8..a926ffd598 100644
--- a/tensorflow/contrib/feature_column/BUILD
+++ b/tensorflow/contrib/feature_column/BUILD
@@ -27,6 +27,7 @@ py_library(
"//tensorflow/python:check_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:tensor_shape",
@@ -46,9 +47,29 @@ py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:parsing_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:training",
"//tensorflow/python/feature_column",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "sequence_feature_column_integration_test",
+ srcs = ["python/feature_column/sequence_feature_column_integration_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":sequence_feature_column",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:util",
+ "//tensorflow/python/feature_column",
+ "//tensorflow/python/keras:layers",
],
)
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
index 05bcdac2ca..dd6da35ed0 100644
--- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
+++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
@@ -33,7 +33,6 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variable_scope
# pylint: disable=protected-access
-# TODO(b/73827486): Support SequenceExample.
def sequence_input_layer(
@@ -110,6 +109,7 @@ def sequence_input_layer(
output_tensors = []
sequence_lengths = []
ordered_columns = []
+
for column in sorted(feature_columns, key=lambda x: x.name):
ordered_columns.append(column)
with variable_scope.variable_scope(
@@ -121,17 +121,67 @@ def sequence_input_layer(
# Flattens the final dimension to produce a 3D Tensor.
num_elements = column._variable_shape.num_elements()
shape = array_ops.shape(dense_tensor)
+ target_shape = [shape[0], shape[1], num_elements]
output_tensors.append(
- array_ops.reshape(
- dense_tensor,
- shape=array_ops.concat([shape[:2], [num_elements]], axis=0)))
+ array_ops.reshape(dense_tensor, shape=target_shape))
sequence_lengths.append(sequence_length)
+
fc._verify_static_batch_size_equality(output_tensors, ordered_columns)
fc._verify_static_batch_size_equality(sequence_lengths, ordered_columns)
sequence_length = _assert_all_equal_and_return(sequence_lengths)
+
return array_ops.concat(output_tensors, -1), sequence_length
+def concatenate_context_input(context_input, sequence_input):
+ """Replicates `context_input` across all timesteps of `sequence_input`.
+
+ Expands dimension 1 of `context_input` then tiles it `sequence_length` times.
+ This value is appended to `sequence_input` on dimension 2 and the result is
+ returned.
+
+ Args:
+ context_input: A `Tensor` of dtype `float32` and shape `[batch_size, d1]`.
+ sequence_input: A `Tensor` of dtype `float32` and shape `[batch_size,
+ padded_length, d0]`.
+
+ Returns:
+ A `Tensor` of dtype `float32` and shape `[batch_size, padded_length,
+ d0 + d1]`.
+
+ Raises:
+ ValueError: If `sequence_input` does not have rank 3 or `context_input` does
+ not have rank 2.
+ """
+ seq_rank_check = check_ops.assert_rank(
+ sequence_input,
+ 3,
+ message='sequence_input must have rank 3',
+ data=[array_ops.shape(sequence_input)])
+ seq_type_check = check_ops.assert_type(
+ sequence_input,
+ dtypes.float32,
+ message='sequence_input must have dtype float32; got {}.'.format(
+ sequence_input.dtype))
+ ctx_rank_check = check_ops.assert_rank(
+ context_input,
+ 2,
+ message='context_input must have rank 2',
+ data=[array_ops.shape(context_input)])
+ ctx_type_check = check_ops.assert_type(
+ context_input,
+ dtypes.float32,
+ message='context_input must have dtype float32; got {}.'.format(
+ context_input.dtype))
+ with ops.control_dependencies(
+ [seq_rank_check, seq_type_check, ctx_rank_check, ctx_type_check]):
+ padded_length = array_ops.shape(sequence_input)[1]
+ tiled_context_input = array_ops.tile(
+ array_ops.expand_dims(context_input, 1),
+ array_ops.concat([[1], [padded_length], [1]], 0))
+ return array_ops.concat([sequence_input, tiled_context_input], 2)
+
+
def sequence_categorical_column_with_identity(
key, num_buckets, default_value=None):
"""Returns a feature column that represents sequences of integers.
@@ -453,9 +503,17 @@ class _SequenceNumericColumn(
[array_ops.shape(dense_tensor)[:1], [-1], self._variable_shape],
axis=0)
dense_tensor = array_ops.reshape(dense_tensor, shape=dense_shape)
- sequence_length = fc._sequence_length_from_sparse_tensor(
- sp_tensor, num_elements=self._variable_shape.num_elements())
+
+ # Get the number of timesteps per example
+ # For the 2D case, the raw values are grouped according to num_elements;
+ # for the 3D case, the grouping happens in the third dimension, and
+ # sequence length is not affected.
+ num_elements = (self._variable_shape.num_elements()
+ if sp_tensor.shape.ndims == 2 else 1)
+ seq_length = fc._sequence_length_from_sparse_tensor(
+ sp_tensor, num_elements=num_elements)
+
return fc._SequenceDenseColumn.TensorSequenceLengthPair(
- dense_tensor=dense_tensor, sequence_length=sequence_length)
+ dense_tensor=dense_tensor, sequence_length=seq_length)
# pylint: enable=protected-access
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py
new file mode 100644
index 0000000000..d8ca363627
--- /dev/null
+++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_integration_test.py
@@ -0,0 +1,280 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Integration test for sequence feature columns with SequenceExamples."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import string
+import tempfile
+
+from google.protobuf import text_format
+
+from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as sfc
+from tensorflow.core.example import example_pb2
+from tensorflow.core.example import feature_pb2
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.feature_column import feature_column as fc
+from tensorflow.python.keras.layers import recurrent
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+
+class SequenceFeatureColumnIntegrationTest(test.TestCase):
+
+ def _make_sequence_example(self):
+ example = example_pb2.SequenceExample()
+ example.context.feature['int_ctx'].int64_list.value.extend([5])
+ example.context.feature['float_ctx'].float_list.value.extend([123.6])
+ for val in range(0, 10, 2):
+ feat = feature_pb2.Feature()
+ feat.int64_list.value.extend([val] * val)
+ example.feature_lists.feature_list['int_list'].feature.extend([feat])
+ for val in range(1, 11, 2):
+ feat = feature_pb2.Feature()
+ feat.bytes_list.value.extend([compat.as_bytes(str(val))] * val)
+ example.feature_lists.feature_list['str_list'].feature.extend([feat])
+
+ return example
+
+ def _build_feature_columns(self):
+ col = fc.categorical_column_with_identity(
+ 'int_ctx', num_buckets=100)
+ ctx_cols = [
+ fc.embedding_column(col, dimension=10),
+ fc.numeric_column('float_ctx')]
+
+ identity_col = sfc.sequence_categorical_column_with_identity(
+ 'int_list', num_buckets=10)
+ bucket_col = sfc.sequence_categorical_column_with_hash_bucket(
+ 'bytes_list', hash_bucket_size=100)
+ seq_cols = [
+ fc.embedding_column(identity_col, dimension=10),
+ fc.embedding_column(bucket_col, dimension=20)]
+
+ return ctx_cols, seq_cols
+
+ def test_sequence_example_into_input_layer(self):
+ examples = [_make_sequence_example().SerializeToString()] * 100
+ ctx_cols, seq_cols = self._build_feature_columns()
+
+ def _parse_example(example):
+ ctx, seq = parsing_ops.parse_single_sequence_example(
+ example,
+ context_features=fc.make_parse_example_spec(ctx_cols),
+ sequence_features=fc.make_parse_example_spec(seq_cols))
+ ctx.update(seq)
+ return ctx
+
+ ds = dataset_ops.Dataset.from_tensor_slices(examples)
+ ds = ds.map(_parse_example)
+ ds = ds.batch(20)
+
+ # Test on a single batch
+ features = ds.make_one_shot_iterator().get_next()
+
+ # Tile the context features across the sequence features
+ seq_layer, _ = sfc.sequence_input_layer(features, seq_cols)
+ ctx_layer = fc.input_layer(features, ctx_cols)
+ input_layer = sfc.concatenate_context_input(ctx_layer, seq_layer)
+
+ rnn_layer = recurrent.RNN(recurrent.SimpleRNNCell(10))
+ output = rnn_layer(input_layer)
+
+ with self.cached_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ features_r = sess.run(features)
+ self.assertAllEqual(features_r['int_list'].dense_shape, [20, 3, 6])
+
+ output_r = sess.run(output)
+ self.assertAllEqual(output_r.shape, [20, 10])
+
+
+class SequenceExampleParsingTest(test.TestCase):
+
+ def test_seq_ex_in_sequence_categorical_column_with_identity(self):
+ self._test_parsed_sequence_example(
+ 'int_list', sfc.sequence_categorical_column_with_identity,
+ 10, [3, 6], [2, 4, 6])
+
+ def test_seq_ex_in_sequence_categorical_column_with_hash_bucket(self):
+ self._test_parsed_sequence_example(
+ 'bytes_list', sfc.sequence_categorical_column_with_hash_bucket,
+ 10, [3, 4], [compat.as_bytes(x) for x in 'acg'])
+
+ def test_seq_ex_in_sequence_categorical_column_with_vocabulary_list(self):
+ self._test_parsed_sequence_example(
+ 'bytes_list', sfc.sequence_categorical_column_with_vocabulary_list,
+ list(string.ascii_lowercase), [3, 4],
+ [compat.as_bytes(x) for x in 'acg'])
+
+ def test_seq_ex_in_sequence_categorical_column_with_vocabulary_file(self):
+ _, fname = tempfile.mkstemp()
+ with open(fname, 'w') as f:
+ f.write(string.ascii_lowercase)
+ self._test_parsed_sequence_example(
+ 'bytes_list', sfc.sequence_categorical_column_with_vocabulary_file,
+ fname, [3, 4], [compat.as_bytes(x) for x in 'acg'])
+
+ def _test_parsed_sequence_example(
+ self, col_name, col_fn, col_arg, shape, values):
+ """Helper function to check that each FeatureColumn parses correctly.
+
+ Args:
+ col_name: string, name to give to the feature column. Should match
+ the name that the column will parse out of the features dict.
+ col_fn: function used to create the feature column. For example,
+ sequence_numeric_column.
+ col_arg: second arg that the target feature column is expecting.
+ shape: the expected dense_shape of the feature after parsing into
+ a SparseTensor.
+ values: the expected values at index [0, 2, 6] of the feature
+ after parsing into a SparseTensor.
+ """
+ example = _make_sequence_example()
+ columns = [
+ fc.categorical_column_with_identity('int_ctx', num_buckets=100),
+ fc.numeric_column('float_ctx'),
+ col_fn(col_name, col_arg)
+ ]
+ context, seq_features = parsing_ops.parse_single_sequence_example(
+ example.SerializeToString(),
+ context_features=fc.make_parse_example_spec(columns[:2]),
+ sequence_features=fc.make_parse_example_spec(columns[2:]))
+
+ with self.cached_session() as sess:
+ ctx_result, seq_result = sess.run([context, seq_features])
+ self.assertEqual(list(seq_result[col_name].dense_shape), shape)
+ self.assertEqual(
+ list(seq_result[col_name].values[[0, 2, 6]]), values)
+ self.assertEqual(list(ctx_result['int_ctx'].dense_shape), [1])
+ self.assertEqual(ctx_result['int_ctx'].values[0], 5)
+ self.assertEqual(list(ctx_result['float_ctx'].shape), [1])
+ self.assertAlmostEqual(ctx_result['float_ctx'][0], 123.6, places=1)
+
+
+_SEQ_EX_PROTO = """
+context {
+ feature {
+ key: "float_ctx"
+ value {
+ float_list {
+ value: 123.6
+ }
+ }
+ }
+ feature {
+ key: "int_ctx"
+ value {
+ int64_list {
+ value: 5
+ }
+ }
+ }
+}
+feature_lists {
+ feature_list {
+ key: "bytes_list"
+ value {
+ feature {
+ bytes_list {
+ value: "a"
+ }
+ }
+ feature {
+ bytes_list {
+ value: "b"
+ value: "c"
+ }
+ }
+ feature {
+ bytes_list {
+ value: "d"
+ value: "e"
+ value: "f"
+ value: "g"
+ }
+ }
+ }
+ }
+ feature_list {
+ key: "float_list"
+ value {
+ feature {
+ float_list {
+ value: 1.0
+ }
+ }
+ feature {
+ float_list {
+ value: 3.0
+ value: 3.0
+ value: 3.0
+ }
+ }
+ feature {
+ float_list {
+ value: 5.0
+ value: 5.0
+ value: 5.0
+ value: 5.0
+ value: 5.0
+ }
+ }
+ }
+ }
+ feature_list {
+ key: "int_list"
+ value {
+ feature {
+ int64_list {
+ value: 2
+ value: 2
+ }
+ }
+ feature {
+ int64_list {
+ value: 4
+ value: 4
+ value: 4
+ value: 4
+ }
+ }
+ feature {
+ int64_list {
+ value: 6
+ value: 6
+ value: 6
+ value: 6
+ value: 6
+ value: 6
+ }
+ }
+ }
+ }
+}
+"""
+
+
+def _make_sequence_example():
+ example = example_pb2.SequenceExample()
+ return text_format.Parse(_SEQ_EX_PROTO, example)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
index 45d7b74046..929e83523a 100644
--- a/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
+++ b/tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import os
+from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as sfc
@@ -28,28 +29,61 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
from tensorflow.python.training import monitored_session
-class SequenceInputLayerTest(test.TestCase):
+class SequenceInputLayerTest(test.TestCase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'sparse_input_a': sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2, 0, 1),
+ dense_shape=(2, 2)),
+ 'sparse_input_b': sparse_tensor.SparseTensorValue(
+ # example 0, ids [1]
+ # example 1, ids [2, 0]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 0),
+ dense_shape=(2, 2)),
+ 'expected_input_layer': [
+ # example 0, ids_a [2], ids_b [1]
+ [[5., 6., 14., 15., 16.], [0., 0., 0., 0., 0.]],
+ # example 1, ids_a [0, 1], ids_b [2, 0]
+ [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]],],
+ 'expected_sequence_length': [1, 2]},
+ {'testcase_name': '3D',
+ 'sparse_input_a': sparse_tensor.SparseTensorValue(
+ # feature 0, ids [[2], [0, 1]]
+ # feature 1, ids [[0, 0], [1]]
+ indices=(
+ (0, 0, 0), (0, 1, 0), (0, 1, 1),
+ (1, 0, 0), (1, 0, 1), (1, 1, 0)),
+ values=(2, 0, 1, 0, 0, 1),
+ dense_shape=(2, 2, 2)),
+ 'sparse_input_b': sparse_tensor.SparseTensorValue(
+ # feature 0, ids [[1, 1], [1]]
+ # feature 1, ids [[2], [0]]
+ indices=((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)),
+ values=(1, 1, 1, 2, 0),
+ dense_shape=(2, 2, 2)),
+ 'expected_input_layer': [
+ # feature 0, [a: 2, -, b: 1, 1], [a: 0, 1, b: 1, -]
+ [[5., 6., 14., 15., 16.], [2., 3., 14., 15., 16.]],
+ # feature 1, [a: 0, 0, b: 2, -], [a: 1, -, b: 0, -]
+ [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]]],
+ 'expected_sequence_length': [2, 2]},
+ )
+ def test_embedding_column(
+ self, sparse_input_a, sparse_input_b, expected_input_layer,
+ expected_sequence_length):
- def test_embedding_column(self):
vocabulary_size = 3
- sparse_input_a = sparse_tensor.SparseTensorValue(
- # example 0, ids [2]
- # example 1, ids [0, 1]
- indices=((0, 0), (1, 0), (1, 1)),
- values=(2, 0, 1),
- dense_shape=(2, 2))
- sparse_input_b = sparse_tensor.SparseTensorValue(
- # example 0, ids [1]
- # example 1, ids [2, 0]
- indices=((0, 0), (1, 0), (1, 1)),
- values=(1, 2, 0),
- dense_shape=(2, 2))
-
embedding_dimension_a = 2
embedding_values_a = (
(1., 2.), # id 0
@@ -70,14 +104,6 @@ class SequenceInputLayerTest(test.TestCase):
return embedding_values
return _initializer
- expected_input_layer = [
- # example 0, ids_a [2], ids_b [1]
- [[5., 6., 14., 15., 16.], [0., 0., 0., 0., 0.]],
- # example 1, ids_a [0, 1], ids_b [2, 0]
- [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]],
- ]
- expected_sequence_length = [1, 2]
-
categorical_column_a = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
embedding_column_a = fc.embedding_column(
@@ -233,29 +259,53 @@ class SequenceInputLayerTest(test.TestCase):
},
feature_columns=shared_embedding_columns)
- def test_indicator_column(self):
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'sparse_input_a': sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2, 0, 1),
+ dense_shape=(2, 2)),
+ 'sparse_input_b': sparse_tensor.SparseTensorValue(
+ # example 0, ids [1]
+ # example 1, ids [1, 0]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 1, 0),
+ dense_shape=(2, 2)),
+ 'expected_input_layer': [
+ # example 0, ids_a [2], ids_b [1]
+ [[0., 0., 1., 0., 1.], [0., 0., 0., 0., 0.]],
+ # example 1, ids_a [0, 1], ids_b [1, 0]
+ [[1., 0., 0., 0., 1.], [0., 1., 0., 1., 0.]]],
+ 'expected_sequence_length': [1, 2]},
+ {'testcase_name': '3D',
+ 'sparse_input_a': sparse_tensor.SparseTensorValue(
+ # feature 0, ids [[2], [0, 1]]
+ # feature 1, ids [[0, 0], [1]]
+ indices=(
+ (0, 0, 0), (0, 1, 0), (0, 1, 1),
+ (1, 0, 0), (1, 0, 1), (1, 1, 0)),
+ values=(2, 0, 1, 0, 0, 1),
+ dense_shape=(2, 2, 2)),
+ 'sparse_input_b': sparse_tensor.SparseTensorValue(
+ # feature 0, ids [[1, 1], [1]]
+ # feature 1, ids [[1], [0]]
+ indices=((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)),
+ values=(1, 1, 1, 1, 0),
+ dense_shape=(2, 2, 2)),
+ 'expected_input_layer': [
+ # feature 0, [a: 2, -, b: 1, 1], [a: 0, 1, b: 1, -]
+ [[0., 0., 1., 0., 2.], [1., 1., 0., 0., 1.]],
+ # feature 1, [a: 0, 0, b: 1, -], [a: 1, -, b: 0, -]
+ [[2., 0., 0., 0., 1.], [0., 1., 0., 1., 0.]]],
+ 'expected_sequence_length': [2, 2]},
+ )
+ def test_indicator_column(
+ self, sparse_input_a, sparse_input_b, expected_input_layer,
+ expected_sequence_length):
vocabulary_size_a = 3
- sparse_input_a = sparse_tensor.SparseTensorValue(
- # example 0, ids [2]
- # example 1, ids [0, 1]
- indices=((0, 0), (1, 0), (1, 1)),
- values=(2, 0, 1),
- dense_shape=(2, 2))
vocabulary_size_b = 2
- sparse_input_b = sparse_tensor.SparseTensorValue(
- # example 0, ids [1]
- # example 1, ids [1, 0]
- indices=((0, 0), (1, 0), (1, 1)),
- values=(1, 1, 0),
- dense_shape=(2, 2))
-
- expected_input_layer = [
- # example 0, ids_a [2], ids_b [1]
- [[0., 0., 1., 0., 1.], [0., 0., 0., 0., 0.]],
- # example 1, ids_a [0, 1], ids_b [1, 0]
- [[1., 0., 0., 0., 1.], [0., 1., 0., 1., 0.]],
- ]
- expected_sequence_length = [1, 2]
categorical_column_a = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size_a)
@@ -298,18 +348,32 @@ class SequenceInputLayerTest(test.TestCase):
features={'aaa': sparse_input},
feature_columns=[indicator_column_a])
- def test_numeric_column(self):
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, values [[0.], [1]]
- # example 1, [[10.]]
- indices=((0, 0), (0, 1), (1, 0)),
- values=(0., 1., 10.),
- dense_shape=(2, 2))
- expected_input_layer = [
- [[0.], [1.]],
- [[10.], [0.]],
- ]
- expected_sequence_length = [2, 1]
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'sparse_input': sparse_tensor.SparseTensorValue(
+ # example 0, values [0., 1]
+ # example 1, [10.]
+ indices=((0, 0), (0, 1), (1, 0)),
+ values=(0., 1., 10.),
+ dense_shape=(2, 2)),
+ 'expected_input_layer': [
+ [[0.], [1.]],
+ [[10.], [0.]]],
+ 'expected_sequence_length': [2, 1]},
+ {'testcase_name': '3D',
+ 'sparse_input': sparse_tensor.SparseTensorValue(
+ # feature 0, ids [[20, 3], [5]]
+ # feature 1, ids [[3], [8]]
+ indices=((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)),
+ values=(20, 3, 5., 3., 8.),
+ dense_shape=(2, 2, 2)),
+ 'expected_input_layer': [
+ [[20.], [3.], [5.], [0.]],
+ [[3.], [0.], [8.], [0.]]],
+ 'expected_sequence_length': [2, 2]},
+ )
+ def test_numeric_column(
+ self, sparse_input, expected_input_layer, expected_sequence_length):
numeric_column = sfc.sequence_numeric_column('aaa')
input_layer, sequence_length = sfc.sequence_input_layer(
@@ -321,21 +385,38 @@ class SequenceInputLayerTest(test.TestCase):
self.assertAllEqual(
expected_sequence_length, sequence_length.eval(session=sess))
- def test_numeric_column_multi_dim(self):
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'sparse_input': sparse_tensor.SparseTensorValue(
+ # example 0, values [0., 1., 2., 3., 4., 5., 6., 7.]
+ # example 1, [10., 11., 12., 13.]
+ indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6),
+ (0, 7), (1, 0), (1, 1), (1, 2), (1, 3)),
+ values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
+ dense_shape=(2, 8)),
+ 'expected_input_layer': [
+ # The output of numeric_column._get_dense_tensor should be flattened.
+ [[0., 1., 2., 3.], [4., 5., 6., 7.]],
+ [[10., 11., 12., 13.], [0., 0., 0., 0.]]],
+ 'expected_sequence_length': [2, 1]},
+ {'testcase_name': '3D',
+ 'sparse_input': sparse_tensor.SparseTensorValue(
+ # example 0, values [[0., 1., 2., 3.]], [[4., 5., 6., 7.]]
+ # example 1, [[10., 11., 12., 13.], []]
+ indices=((0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3),
+ (0, 1, 0), (0, 1, 1), (0, 1, 2), (0, 1, 3),
+ (1, 0, 0), (1, 0, 1), (1, 0, 2), (1, 0, 3)),
+ values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
+ dense_shape=(2, 2, 4)),
+ 'expected_input_layer': [
+ # The output of numeric_column._get_dense_tensor should be flattened.
+ [[0., 1., 2., 3.], [4., 5., 6., 7.]],
+ [[10., 11., 12., 13.], [0., 0., 0., 0.]]],
+ 'expected_sequence_length': [2, 1]},
+ )
+ def test_numeric_column_multi_dim(
+ self, sparse_input, expected_input_layer, expected_sequence_length):
"""Tests sequence_input_layer for multi-dimensional numeric_column."""
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]]
- # example 1, [[[10., 11.], [12., 13.]]]
- indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7),
- (1, 0), (1, 1), (1, 2), (1, 3)),
- values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
- dense_shape=(2, 8))
- # The output of numeric_column._get_dense_tensor should be flattened.
- expected_input_layer = [
- [[0., 1., 2., 3.], [4., 5., 6., 7.]],
- [[10., 11., 12., 13.], [0., 0., 0., 0.]],
- ]
- expected_sequence_length = [2, 1]
numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2))
input_layer, sequence_length = sfc.sequence_input_layer(
@@ -377,6 +458,134 @@ class SequenceInputLayerTest(test.TestCase):
r'\[y \(sequence_input_layer/bbb/sequence_length:0\) = \] \[1 1\]'):
sess.run(sequence_length)
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'sparse_input': sparse_tensor.SparseTensorValue(
+ # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]]
+ # example 1, [[[10., 11.], [12., 13.]]]
+ indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6),
+ (0, 7), (1, 0), (1, 1), (1, 2), (1, 3)),
+ values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
+ dense_shape=(2, 8)),
+ 'expected_shape': [2, 2, 4]},
+ {'testcase_name': '3D',
+ 'sparse_input': sparse_tensor.SparseTensorValue(
+ # example 0, values [[0., 1., 2., 3.]], [[4., 5., 6., 7.]]
+ # example 1, [[10., 11., 12., 13.], []]
+ indices=((0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 0, 3),
+ (0, 1, 0), (0, 1, 1), (0, 1, 2), (0, 1, 2),
+ (1, 0, 0), (1, 0, 1), (1, 0, 2), (1, 0, 3)),
+ values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
+ dense_shape=(2, 2, 4)),
+ 'expected_shape': [2, 2, 4]},
+ )
+ def test_static_shape_from_tensors_numeric(
+ self, sparse_input, expected_shape):
+ """Tests that we return a known static shape when we have one."""
+ numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2))
+
+ input_layer, _ = sfc.sequence_input_layer(
+ features={'aaa': sparse_input},
+ feature_columns=[numeric_column])
+ shape = input_layer.get_shape()
+ self.assertEqual(shape, expected_shape)
+
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'sparse_input': sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 1), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(4, 2)),
+ 'expected_shape': [4, 2, 3]},
+ {'testcase_name': '3D',
+ 'sparse_input': sparse_tensor.SparseTensorValue(
+ # example 0, ids [[2]]
+ # example 1, ids [[0, 1], [2]]
+ # example 2, ids []
+ # example 3, ids [[1], [0, 2]]
+ indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0),
+ (3, 0, 0), (3, 1, 0), (3, 1, 1)),
+ values=(2, 0, 1, 2, 1, 0, 2),
+ dense_shape=(4, 2, 2)),
+ 'expected_shape': [4, 2, 3]}
+ )
+ def test_static_shape_from_tensors_indicator(
+ self, sparse_input, expected_shape):
+ """Tests that we return a known static shape when we have one."""
+ categorical_column = sfc.sequence_categorical_column_with_identity(
+ key='aaa', num_buckets=3)
+ indicator_column = fc.indicator_column(categorical_column)
+
+ input_layer, _ = sfc.sequence_input_layer(
+ features={'aaa': sparse_input}, feature_columns=[indicator_column])
+ shape = input_layer.get_shape()
+ self.assertEqual(shape, expected_shape)
+
+
+class ConcatenateContextInputTest(test.TestCase, parameterized.TestCase):
+ """Tests the utility fn concatenate_context_input."""
+
+ def test_concatenate_context_input(self):
+ seq_input = ops.convert_to_tensor(np.arange(12).reshape(2, 3, 2))
+ context_input = ops.convert_to_tensor(np.arange(10).reshape(2, 5))
+ seq_input = math_ops.cast(seq_input, dtype=dtypes.float32)
+ context_input = math_ops.cast(context_input, dtype=dtypes.float32)
+ input_layer = sfc.concatenate_context_input(context_input, seq_input)
+
+ expected = np.array([
+ [[0, 1, 0, 1, 2, 3, 4], [2, 3, 0, 1, 2, 3, 4], [4, 5, 0, 1, 2, 3, 4]],
+ [[6, 7, 5, 6, 7, 8, 9], [8, 9, 5, 6, 7, 8, 9], [10, 11, 5, 6, 7, 8, 9]]
+ ], dtype=np.float32)
+ with monitored_session.MonitoredSession() as sess:
+ output = sess.run(input_layer)
+ self.assertAllEqual(expected, output)
+
+ @parameterized.named_parameters(
+ {'testcase_name': 'rank_lt_3',
+ 'seq_input': ops.convert_to_tensor(np.arange(100).reshape(10, 10))},
+ {'testcase_name': 'rank_gt_3',
+ 'seq_input': ops.convert_to_tensor(np.arange(100).reshape(5, 5, 2, 2))}
+ )
+ def test_sequence_input_throws_error(self, seq_input):
+ context_input = ops.convert_to_tensor(np.arange(100).reshape(10, 10))
+ seq_input = math_ops.cast(seq_input, dtype=dtypes.float32)
+ context_input = math_ops.cast(context_input, dtype=dtypes.float32)
+ with self.assertRaisesRegexp(ValueError, 'sequence_input must have rank 3'):
+ sfc.concatenate_context_input(context_input, seq_input)
+
+ @parameterized.named_parameters(
+ {'testcase_name': 'rank_lt_2',
+ 'context_input': ops.convert_to_tensor(np.arange(100))},
+ {'testcase_name': 'rank_gt_2',
+ 'context_input': ops.convert_to_tensor(np.arange(100).reshape(5, 5, 4))}
+ )
+ def test_context_input_throws_error(self, context_input):
+ seq_input = ops.convert_to_tensor(np.arange(100).reshape(5, 5, 4))
+ seq_input = math_ops.cast(seq_input, dtype=dtypes.float32)
+ context_input = math_ops.cast(context_input, dtype=dtypes.float32)
+ with self.assertRaisesRegexp(ValueError, 'context_input must have rank 2'):
+ sfc.concatenate_context_input(context_input, seq_input)
+
+ def test_integer_seq_input_throws_error(self):
+ seq_input = ops.convert_to_tensor(np.arange(100).reshape(5, 5, 4))
+ context_input = ops.convert_to_tensor(np.arange(100).reshape(10, 10))
+ context_input = math_ops.cast(context_input, dtype=dtypes.float32)
+ with self.assertRaisesRegexp(
+ TypeError, 'sequence_input must have dtype float32'):
+ sfc.concatenate_context_input(context_input, seq_input)
+
+ def test_integer_context_input_throws_error(self):
+ seq_input = ops.convert_to_tensor(np.arange(100).reshape(5, 5, 4))
+ context_input = ops.convert_to_tensor(np.arange(100).reshape(10, 10))
+ seq_input = math_ops.cast(seq_input, dtype=dtypes.float32)
+ with self.assertRaisesRegexp(
+ TypeError, 'context_input must have dtype float32'):
+ sfc.concatenate_context_input(context_input, seq_input)
+
class InputLayerTest(test.TestCase):
"""Tests input_layer with sequence feature columns."""
@@ -443,75 +652,79 @@ def _assert_sparse_tensor_indices_shape(test_case, expected, actual):
test_case.assertAllEqual(expected.dense_shape, actual.dense_shape)
-class SequenceCategoricalColumnWithIdentityTest(test.TestCase):
-
- def test_get_sparse_tensors(self):
- column = sfc.sequence_categorical_column_with_identity(
- 'aaa', num_buckets=3)
- inputs = sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(1, 2, 0),
- dense_shape=(2, 2))
- expected_sparse_ids = sparse_tensor.SparseTensorValue(
- indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
- values=np.array((1, 2, 0), dtype=np.int64),
- dense_shape=(2, 2, 1))
+class SequenceCategoricalColumnWithIdentityTest(
+ test.TestCase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(1, 2, 0),
+ dense_shape=(2, 2)),
+ 'expected': sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
+ values=np.array((1, 2, 0), dtype=np.int64),
+ dense_shape=(2, 2, 1))},
+ {'testcase_name': '3D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 2), (1, 0, 0), (1, 2, 0)),
+ values=(6, 7, 8),
+ dense_shape=(2, 2, 2)),
+ 'expected': sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 2), (1, 0, 0), (1, 2, 0)),
+ values=(6, 7, 8),
+ dense_shape=(2, 2, 2))}
+ )
+ def test_get_sparse_tensors(self, inputs, expected):
+ column = sfc.sequence_categorical_column_with_identity('aaa', num_buckets=9)
id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
self.assertIsNone(id_weight_pair.weight_tensor)
with monitored_session.MonitoredSession() as sess:
_assert_sparse_tensor_value(
- self,
- expected_sparse_ids,
- id_weight_pair.id_tensor.eval(session=sess))
-
- def test_get_sparse_tensors_inputs3d(self):
- """Tests _get_sparse_tensors when the input is already 3D Tensor."""
- column = sfc.sequence_categorical_column_with_identity(
- 'aaa', num_buckets=3)
- inputs = sparse_tensor.SparseTensorValue(
- indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
- values=(1, 2, 0),
- dense_shape=(2, 2, 1))
-
- with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- r'Column aaa expected ID tensor of rank 2\.\s*'
- r'id_tensor shape:\s*\[2 2 1\]'):
- id_weight_pair = column._get_sparse_tensors(
- _LazyBuilder({'aaa': inputs}))
- with monitored_session.MonitoredSession() as sess:
- id_weight_pair.id_tensor.eval(session=sess)
-
-
-class SequenceCategoricalColumnWithHashBucketTest(test.TestCase):
-
- def test_get_sparse_tensors(self):
+ self, expected, id_weight_pair.id_tensor.eval(session=sess))
+
+
+class SequenceCategoricalColumnWithHashBucketTest(
+ test.TestCase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('omar', 'stringer', 'marlo'),
+ dense_shape=(2, 2)),
+ 'expected': sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
+ # Ignored to avoid hash dependence in test.
+ values=np.array((0, 0, 0), dtype=np.int64),
+ dense_shape=(2, 2, 1))},
+ {'testcase_name': '3D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 2), (1, 0, 0), (1, 2, 0)),
+ values=('omar', 'stringer', 'marlo'),
+ dense_shape=(2, 2, 2)),
+ 'expected': sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 2), (1, 0, 0), (1, 2, 0)),
+ # Ignored to avoid hash dependence in test.
+ values=np.array((0, 0, 0), dtype=np.int64),
+ dense_shape=(2, 2, 2))}
+ )
+ def test_get_sparse_tensors(self, inputs, expected):
column = sfc.sequence_categorical_column_with_hash_bucket(
'aaa', hash_bucket_size=10)
- inputs = sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=('omar', 'stringer', 'marlo'),
- dense_shape=(2, 2))
-
- expected_sparse_ids = sparse_tensor.SparseTensorValue(
- indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
- # Ignored to avoid hash dependence in test.
- values=np.array((0, 0, 0), dtype=np.int64),
- dense_shape=(2, 2, 1))
id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
self.assertIsNone(id_weight_pair.weight_tensor)
with monitored_session.MonitoredSession() as sess:
_assert_sparse_tensor_indices_shape(
- self,
- expected_sparse_ids,
- id_weight_pair.id_tensor.eval(session=sess))
+ self, expected, id_weight_pair.id_tensor.eval(session=sess))
-class SequenceCategoricalColumnWithVocabularyFileTest(test.TestCase):
+class SequenceCategoricalColumnWithVocabularyFileTest(
+ test.TestCase, parameterized.TestCase):
def _write_vocab(self, vocab_strings, file_name):
vocab_file = os.path.join(self.get_temp_dir(), file_name)
@@ -527,68 +740,120 @@ class SequenceCategoricalColumnWithVocabularyFileTest(test.TestCase):
'wire_vocabulary.txt')
self._wire_vocabulary_size = 3
- def test_get_sparse_tensors(self):
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2)),
+ 'expected': sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
+ values=np.array((2, -1, 0), dtype=np.int64),
+ dense_shape=(2, 2, 1))},
+ {'testcase_name': '3D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 2), (1, 0, 0), (1, 2, 0)),
+ values=('omar', 'skywalker', 'marlo'),
+ dense_shape=(2, 2, 2)),
+ 'expected': sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 2), (1, 0, 0), (1, 2, 0)),
+ values=np.array((0, -1, 2), dtype=np.int64),
+ dense_shape=(2, 2, 2))}
+ )
+ def test_get_sparse_tensors(self, inputs, expected):
column = sfc.sequence_categorical_column_with_vocabulary_file(
key='aaa',
vocabulary_file=self._wire_vocabulary_file_name,
vocabulary_size=self._wire_vocabulary_size)
- inputs = sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=('marlo', 'skywalker', 'omar'),
- dense_shape=(2, 2))
- expected_sparse_ids = sparse_tensor.SparseTensorValue(
- indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
- values=np.array((2, -1, 0), dtype=np.int64),
- dense_shape=(2, 2, 1))
id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
self.assertIsNone(id_weight_pair.weight_tensor)
with monitored_session.MonitoredSession() as sess:
_assert_sparse_tensor_value(
- self,
- expected_sparse_ids,
- id_weight_pair.id_tensor.eval(session=sess))
-
-
-class SequenceCategoricalColumnWithVocabularyListTest(test.TestCase):
-
- def test_get_sparse_tensors(self):
+ self, expected, id_weight_pair.id_tensor.eval(session=sess))
+
+
+class SequenceCategoricalColumnWithVocabularyListTest(
+ test.TestCase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=('marlo', 'skywalker', 'omar'),
+ dense_shape=(2, 2)),
+ 'expected': sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
+ values=np.array((2, -1, 0), dtype=np.int64),
+ dense_shape=(2, 2, 1))},
+ {'testcase_name': '3D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 2), (1, 0, 0), (1, 2, 0)),
+ values=('omar', 'skywalker', 'marlo'),
+ dense_shape=(2, 2, 2)),
+ 'expected': sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 2), (1, 0, 0), (1, 2, 0)),
+ values=np.array((0, -1, 2), dtype=np.int64),
+ dense_shape=(2, 2, 2))}
+ )
+ def test_get_sparse_tensors(self, inputs, expected):
column = sfc.sequence_categorical_column_with_vocabulary_list(
key='aaa',
vocabulary_list=('omar', 'stringer', 'marlo'))
- inputs = sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=('marlo', 'skywalker', 'omar'),
- dense_shape=(2, 2))
- expected_sparse_ids = sparse_tensor.SparseTensorValue(
- indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)),
- values=np.array((2, -1, 0), dtype=np.int64),
- dense_shape=(2, 2, 1))
id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
self.assertIsNone(id_weight_pair.weight_tensor)
with monitored_session.MonitoredSession() as sess:
_assert_sparse_tensor_value(
- self,
- expected_sparse_ids,
- id_weight_pair.id_tensor.eval(session=sess))
-
-
-class SequenceEmbeddingColumnTest(test.TestCase):
-
- def test_get_sequence_dense_tensor(self):
+ self, expected, id_weight_pair.id_tensor.eval(session=sess))
+
+
+class SequenceEmbeddingColumnTest(
+ test.TestCase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 1), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(4, 2)),
+ 'expected': [
+ # example 0, ids [2]
+ [[7., 11.], [0., 0.]],
+ # example 1, ids [0, 1]
+ [[1., 2.], [3., 5.]],
+ # example 2, ids []
+ [[0., 0.], [0., 0.]],
+ # example 3, ids [1]
+ [[3., 5.], [0., 0.]]]},
+ {'testcase_name': '3D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # example 0, ids [[2]]
+ # example 1, ids [[0, 1], [2]]
+ # example 2, ids []
+ # example 3, ids [[1], [0, 2]]
+ indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0),
+ (3, 0, 0), (3, 1, 0), (3, 1, 1)),
+ values=(2, 0, 1, 2, 1, 0, 2),
+ dense_shape=(4, 2, 2)),
+ 'expected': [
+ # example 0, ids [[2]]
+ [[7., 11.], [0., 0.]],
+ # example 1, ids [[0, 1], [2]]
+ [[2, 3.5], [7., 11.]],
+ # example 2, ids []
+ [[0., 0.], [0., 0.]],
+ # example 3, ids [[1], [0, 2]]
+ [[3., 5.], [4., 6.5]]]}
+ )
+ def test_get_sequence_dense_tensor(self, inputs, expected):
vocabulary_size = 3
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, ids [2]
- # example 1, ids [0, 1]
- # example 2, ids []
- # example 3, ids [1]
- indices=((0, 0), (1, 0), (1, 1), (3, 0)),
- values=(2, 0, 1, 1),
- dense_shape=(4, 2))
-
embedding_dimension = 2
embedding_values = (
(1., 2.), # id 0
@@ -601,17 +866,6 @@ class SequenceEmbeddingColumnTest(test.TestCase):
self.assertIsNone(partition_info)
return embedding_values
- expected_lookups = [
- # example 0, ids [2]
- [[7., 11.], [0., 0.]],
- # example 1, ids [0, 1]
- [[1., 2.], [3., 5.]],
- # example 2, ids []
- [[0., 0.], [0., 0.]],
- # example 3, ids [1]
- [[3., 5.], [0., 0.]],
- ]
-
categorical_column = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
embedding_column = fc.embedding_column(
@@ -619,24 +873,35 @@ class SequenceEmbeddingColumnTest(test.TestCase):
initializer=_initializer)
embedding_lookup, _ = embedding_column._get_sequence_dense_tensor(
- _LazyBuilder({'aaa': sparse_input}))
+ _LazyBuilder({'aaa': inputs}))
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertItemsEqual(
('embedding_weights:0',), tuple([v.name for v in global_vars]))
with monitored_session.MonitoredSession() as sess:
self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess))
- self.assertAllEqual(expected_lookups, embedding_lookup.eval(session=sess))
-
- def test_sequence_length(self):
+ self.assertAllEqual(expected, embedding_lookup.eval(session=sess))
+
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2, 0, 1),
+ dense_shape=(2, 2)),
+ 'expected_sequence_length': [1, 2]},
+ {'testcase_name': '3D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # example 0, ids [[2]]
+ # example 1, ids [[0, 1], [2]]
+ indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
+ values=(2, 0, 1, 2),
+ dense_shape=(2, 2, 2)),
+ 'expected_sequence_length': [1, 2]}
+ )
+ def test_sequence_length(self, inputs, expected_sequence_length):
vocabulary_size = 3
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, ids [2]
- # example 1, ids [0, 1]
- indices=((0, 0), (1, 0), (1, 1)),
- values=(2, 0, 1),
- dense_shape=(2, 2))
- expected_sequence_length = [1, 2]
categorical_column = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
@@ -644,7 +909,7 @@ class SequenceEmbeddingColumnTest(test.TestCase):
categorical_column, dimension=2)
_, sequence_length = embedding_column._get_sequence_dense_tensor(
- _LazyBuilder({'aaa': sparse_input}))
+ _LazyBuilder({'aaa': inputs}))
with monitored_session.MonitoredSession() as sess:
sequence_length = sess.run(sequence_length)
@@ -855,56 +1120,87 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase):
expected_sequence_length_b, sequence_length_b.eval(session=sess))
-class SequenceIndicatorColumnTest(test.TestCase):
-
- def test_get_sequence_dense_tensor(self):
+class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ # example 2, ids []
+ # example 3, ids [1]
+ indices=((0, 0), (1, 0), (1, 1), (3, 0)),
+ values=(2, 0, 1, 1),
+ dense_shape=(4, 2)),
+ 'expected': [
+ # example 0, ids [2]
+ [[0., 0., 1.], [0., 0., 0.]],
+ # example 1, ids [0, 1]
+ [[1., 0., 0.], [0., 1., 0.]],
+ # example 2, ids []
+ [[0., 0., 0.], [0., 0., 0.]],
+ # example 3, ids [1]
+ [[0., 1., 0.], [0., 0., 0.]]]},
+ {'testcase_name': '3D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # example 0, ids [[2]]
+ # example 1, ids [[0, 1], [2]]
+ # example 2, ids []
+ # example 3, ids [[1], [2, 2]]
+ indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0),
+ (3, 0, 0), (3, 1, 0), (3, 1, 1)),
+ values=(2, 0, 1, 2, 1, 2, 2),
+ dense_shape=(4, 2, 2)),
+ 'expected': [
+ # example 0, ids [[2]]
+ [[0., 0., 1.], [0., 0., 0.]],
+ # example 1, ids [[0, 1], [2]]
+ [[1., 1., 0.], [0., 0., 1.]],
+ # example 2, ids []
+ [[0., 0., 0.], [0., 0., 0.]],
+ # example 3, ids [[1], [2, 2]]
+ [[0., 1., 0.], [0., 0., 2.]]]}
+ )
+ def test_get_sequence_dense_tensor(self, inputs, expected):
vocabulary_size = 3
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, ids [2]
- # example 1, ids [0, 1]
- # example 2, ids []
- # example 3, ids [1]
- indices=((0, 0), (1, 0), (1, 1), (3, 0)),
- values=(2, 0, 1, 1),
- dense_shape=(4, 2))
-
- expected_lookups = [
- # example 0, ids [2]
- [[0., 0., 1.], [0., 0., 0.]],
- # example 1, ids [0, 1]
- [[1., 0., 0.], [0., 1., 0.]],
- # example 2, ids []
- [[0., 0., 0.], [0., 0., 0.]],
- # example 3, ids [1]
- [[0., 1., 0.], [0., 0., 0.]],
- ]
categorical_column = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
indicator_column = fc.indicator_column(categorical_column)
indicator_tensor, _ = indicator_column._get_sequence_dense_tensor(
- _LazyBuilder({'aaa': sparse_input}))
+ _LazyBuilder({'aaa': inputs}))
with monitored_session.MonitoredSession() as sess:
- self.assertAllEqual(expected_lookups, indicator_tensor.eval(session=sess))
-
- def test_sequence_length(self):
+ self.assertAllEqual(expected, indicator_tensor.eval(session=sess))
+
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2, 0, 1),
+ dense_shape=(2, 2)),
+ 'expected_sequence_length': [1, 2]},
+ {'testcase_name': '3D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # example 0, ids [[2]]
+ # example 1, ids [[0, 1], [2]]
+ indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
+ values=(2, 0, 1, 2),
+ dense_shape=(2, 2, 2)),
+ 'expected_sequence_length': [1, 2]}
+ )
+ def test_sequence_length(self, inputs, expected_sequence_length):
vocabulary_size = 3
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, ids [2]
- # example 1, ids [0, 1]
- indices=((0, 0), (1, 0), (1, 1)),
- values=(2, 0, 1),
- dense_shape=(2, 2))
- expected_sequence_length = [1, 2]
categorical_column = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
indicator_column = fc.indicator_column(categorical_column)
_, sequence_length = indicator_column._get_sequence_dense_tensor(
- _LazyBuilder({'aaa': sparse_input}))
+ _LazyBuilder({'aaa': inputs}))
with monitored_session.MonitoredSession() as sess:
sequence_length = sess.run(sequence_length)
@@ -938,7 +1234,7 @@ class SequenceIndicatorColumnTest(test.TestCase):
expected_sequence_length, sequence_length.eval(session=sess))
-class SequenceNumericColumnTest(test.TestCase):
+class SequenceNumericColumnTest(test.TestCase, parameterized.TestCase):
def test_defaults(self):
a = sfc.sequence_numeric_column('aaa')
@@ -971,25 +1267,36 @@ class SequenceNumericColumnTest(test.TestCase):
with self.assertRaisesRegexp(TypeError, 'must be a callable'):
sfc.sequence_numeric_column('aaa', normalizer_fn='NotACallable')
- def test_get_sequence_dense_tensor(self):
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, values [[0.], [1]]
- # example 1, [[10.]]
- indices=((0, 0), (0, 1), (1, 0)),
- values=(0., 1., 10.),
- dense_shape=(2, 2))
- expected_dense_tensor = [
- [[0.], [1.]],
- [[10.], [0.]],
- ]
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # example 0, values [0., 1]
+ # example 1, [10.]
+ indices=((0, 0), (0, 1), (1, 0)),
+ values=(0., 1., 10.),
+ dense_shape=(2, 2)),
+ 'expected': [
+ [[0.], [1.]],
+ [[10.], [0.]]]},
+ {'testcase_name': '3D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # feature 0, ids [[20, 3], [5]]
+ # feature 1, ids [[3], [8]]
+ indices=((0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)),
+ values=(20, 3, 5., 3., 8.),
+ dense_shape=(2, 2, 2)),
+ 'expected': [
+ [[20.], [3.], [5.], [0.]],
+ [[3.], [0.], [8.], [0.]]]},
+ )
+ def test_get_sequence_dense_tensor(self, inputs, expected):
numeric_column = sfc.sequence_numeric_column('aaa')
dense_tensor, _ = numeric_column._get_sequence_dense_tensor(
- _LazyBuilder({'aaa': sparse_input}))
+ _LazyBuilder({'aaa': inputs}))
with monitored_session.MonitoredSession() as sess:
- self.assertAllEqual(
- expected_dense_tensor, dense_tensor.eval(session=sess))
+ self.assertAllEqual(expected, dense_tensor.eval(session=sess))
def test_get_sequence_dense_tensor_with_normalizer_fn(self):
@@ -1026,41 +1333,34 @@ class SequenceNumericColumnTest(test.TestCase):
self.assertAllEqual(
expected_dense_tensor, dense_tensor.eval(session=sess))
- def test_get_sequence_dense_tensor_with_shape(self):
- """Tests get_sequence_dense_tensor with shape !=(1,)."""
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, values [[0., 1., 2.], [3., 4., 5.]]
- # example 1, [[10., 11., 12.]]
- indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5),
- (1, 0), (1, 1), (1, 2)),
- values=(0., 1., 2., 3., 4., 5., 10., 11., 12.),
- dense_shape=(2, 6))
- expected_dense_tensor = [
- [[0., 1., 2.], [3., 4., 5.]],
- [[10., 11., 12.], [0., 0., 0.]],
- ]
- numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,))
-
- dense_tensor, _ = numeric_column._get_sequence_dense_tensor(
- _LazyBuilder({'aaa': sparse_input}))
-
- with monitored_session.MonitoredSession() as sess:
- self.assertAllEqual(
- expected_dense_tensor, dense_tensor.eval(session=sess))
-
- def test_get_dense_tensor_multi_dim(self):
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'sparse_input': sparse_tensor.SparseTensorValue(
+ # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]]
+ # example 1, [[[10., 11.], [12., 13.]]]
+ indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6),
+ (0, 7), (1, 0), (1, 1), (1, 2), (1, 3)),
+ values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
+ dense_shape=(2, 8)),
+ 'expected_dense_tensor': [
+ [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]],
+ [[[10., 11.], [12., 13.]], [[0., 0.], [0., 0.]]]]},
+ {'testcase_name': '3D',
+ 'sparse_input': sparse_tensor.SparseTensorValue(
+ indices=((0, 0, 0), (0, 0, 2), (0, 0, 4), (0, 0, 6),
+ (0, 1, 0), (0, 1, 2), (0, 1, 4), (0, 1, 6),
+ (1, 0, 0), (1, 0, 2), (1, 0, 4), (1, 0, 6)),
+ values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
+ dense_shape=(2, 2, 8)),
+ 'expected_dense_tensor': [
+ [[[0., 0.], [1., 0.]], [[2., 0.], [3., 0.]],
+ [[4., 0.], [5., 0.]], [[6., 0.], [7., 0.]]],
+ [[[10., 0.], [11., 0.]], [[12., 0.], [13., 0.]],
+ [[0., 0.], [0., 0.]], [[0., 0.], [0., 0.]]]]},
+ )
+ def test_get_dense_tensor_multi_dim(
+ self, sparse_input, expected_dense_tensor):
"""Tests get_sequence_dense_tensor for multi-dim numeric_column."""
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]]
- # example 1, [[[10., 11.], [12., 13.]]]
- indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7),
- (1, 0), (1, 1), (1, 2), (1, 3)),
- values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.),
- dense_shape=(2, 8))
- expected_dense_tensor = [
- [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]],
- [[[10., 11.], [12., 13.]], [[0., 0.], [0., 0.]]],
- ]
numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2))
dense_tensor, _ = numeric_column._get_sequence_dense_tensor(
@@ -1070,43 +1370,55 @@ class SequenceNumericColumnTest(test.TestCase):
self.assertAllEqual(
expected_dense_tensor, dense_tensor.eval(session=sess))
- def test_sequence_length(self):
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, values [[0., 1., 2.], [3., 4., 5.]]
- # example 1, [[10., 11., 12.]]
- indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5),
- (1, 0), (1, 1), (1, 2)),
- values=(0., 1., 2., 3., 4., 5., 10., 11., 12.),
- dense_shape=(2, 6))
- expected_sequence_length = [2, 1]
- numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,))
+ @parameterized.named_parameters(
+ {'testcase_name': '2D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2., 0., 1.),
+ dense_shape=(2, 2)),
+ 'expected_sequence_length': [1, 2],
+ 'shape': (1,)},
+ {'testcase_name': '3D',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # example 0, ids [[2]]
+ # example 1, ids [[0, 1], [2]]
+ indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
+ values=(2., 0., 1., 2.),
+ dense_shape=(2, 2, 2)),
+ 'expected_sequence_length': [1, 2],
+ 'shape': (1,)},
+ {'testcase_name': '2D_with_shape',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # example 0, ids [2]
+ # example 1, ids [0, 1]
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(2., 0., 1.),
+ dense_shape=(2, 2)),
+ 'expected_sequence_length': [1, 1],
+ 'shape': (2,)},
+ {'testcase_name': '3D_with_shape',
+ 'inputs': sparse_tensor.SparseTensorValue(
+ # example 0, ids [[2]]
+ # example 1, ids [[0, 1], [2]]
+ indices=((0, 0, 0), (1, 0, 0), (1, 0, 1), (1, 1, 0)),
+ values=(2., 0., 1., 2.),
+ dense_shape=(2, 2, 2)),
+ 'expected_sequence_length': [1, 2],
+ 'shape': (2,)},
+ )
+ def test_sequence_length(self, inputs, expected_sequence_length, shape):
+ numeric_column = sfc.sequence_numeric_column('aaa', shape=shape)
_, sequence_length = numeric_column._get_sequence_dense_tensor(
- _LazyBuilder({'aaa': sparse_input}))
+ _LazyBuilder({'aaa': inputs}))
with monitored_session.MonitoredSession() as sess:
sequence_length = sess.run(sequence_length)
self.assertAllEqual(expected_sequence_length, sequence_length)
self.assertEqual(np.int64, sequence_length.dtype)
- def test_sequence_length_with_shape(self):
- """Tests _sequence_length with shape !=(1,)."""
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, values [[0.], [1]]
- # example 1, [[10.]]
- indices=((0, 0), (0, 1), (1, 0)),
- values=(0., 1., 10.),
- dense_shape=(2, 2))
- expected_sequence_length = [2, 1]
- numeric_column = sfc.sequence_numeric_column('aaa')
-
- _, sequence_length = numeric_column._get_sequence_dense_tensor(
- _LazyBuilder({'aaa': sparse_input}))
-
- with monitored_session.MonitoredSession() as sess:
- self.assertAllEqual(
- expected_sequence_length, sequence_length.eval(session=sess))
-
def test_sequence_length_with_empty_rows(self):
"""Tests _sequence_length when some examples do not have ids."""
sparse_input = sparse_tensor.SparseTensorValue(
diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc
index bb06f1c41c..3549cedb70 100644
--- a/tensorflow/contrib/gdr/gdr_memory_manager.cc
+++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc
@@ -22,7 +22,6 @@ limitations under the License.
#include <fstream>
#include <list>
#include <map>
-#include <set>
#include <fcntl.h>
#include <rdma/rdma_cma.h>
@@ -30,19 +29,17 @@ limitations under the License.
#include <sys/epoll.h>
#include "tensorflow/contrib/gdr/gdr.pb.h"
-#include "tensorflow/core/common_runtime/bfc_allocator.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
-#include "tensorflow/core/common_runtime/pool_allocator.h"
#include "tensorflow/core/common_runtime/process_state.h"
#if GOOGLE_CUDA
#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#endif // GOOGLE_CUDA
-#include "tensorflow/core/framework/allocator_registry.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/numa.h"
namespace tensorflow {
@@ -70,14 +67,11 @@ bool IsGDRAvailable() {
int TryToReadNumaNode(ibv_device* device) {
#if defined(__APPLE__)
LOG(INFO) << "OS X does not support NUMA - returning NUMA node 0";
- return 0;
+ return port::kNUMANoAffinity;
#elif defined(PLATFORM_WINDOWS)
// Windows support for NUMA is not currently implemented. Return node 0.
- return 0;
+ return port::kNUMANoAffinity;
#else
- VLOG(2) << "Trying to read NUMA node for device: " << device->name;
- static const int kUnknownNumaNode = -1;
-
auto filename = string(device->ibdev_path) + "/device/numa_node";
std::ifstream ifs(filename.c_str());
@@ -91,12 +85,12 @@ int TryToReadNumaNode(ibv_device* device) {
<< value
<< "), but there must be at least one NUMA node"
", so returning NUMA node zero";
- return 0;
+ return port::kNUMANoAffinity;
}
LOG(INFO) << "NUMA node for device: " << device->name << " is " << value;
return value;
}
- return kUnknownNumaNode;
+ return port::kNUMANoAffinity;
#endif
}
@@ -138,8 +132,6 @@ class GdrMemoryManager : public RemoteMemoryManager {
Device* device, DeviceContext* device_context, bool on_host,
StatusCallback done) override;
- static void RegMemVisitors();
-
protected:
Status CreateEndpoint(const string& host, const string& port,
RdmaEndpointPtr& endpoint);
@@ -150,7 +142,8 @@ class GdrMemoryManager : public RemoteMemoryManager {
ibv_mr* FindMemoryRegion(void* addr, size_t length);
- void InsertMemoryRegion(void* addr, size_t length);
+ void InsertMemoryRegion(void* addr, size_t length,
+ const std::string& allocator_name);
void EvictMemoryRegion(void* addr, size_t length);
@@ -160,6 +153,7 @@ class GdrMemoryManager : public RemoteMemoryManager {
RdmaEndpointPtr listening_;
std::atomic<bool> stopped_;
int epfd_;
+ int numa_node_;
// Server side endpoints
// Accessed sequentially in Run() so not protected by lock
@@ -190,46 +184,10 @@ GdrMemoryManager::GdrMemoryManager(const string& host, const string& port)
port_(port),
listening_(nullptr, EndpointDeleter),
stopped_(true),
- next_key_(0) {
- static std::once_flag flag;
- std::call_once(flag, []() { RegMemVisitors(); });
-}
+ next_key_(0) {}
GdrMemoryManager::~GdrMemoryManager() { close(epfd_); }
-/*static*/ void GdrMemoryManager::RegMemVisitors() {
- SubAllocator::Visitor alloc_visitor = [](void* ptr, int numa_node,
- size_t num_bytes) {
- GdrMemoryManager::Singleton().InsertMemoryRegion(
- ptr, num_bytes, strings::StrCat("CPU:", numa_node));
- };
- SubAllocator::Visitor free_visitor = [](void* ptr, int numa_node,
- size_t num_bytes) {
- GdrMemoryManager::Singleton().EvictMemoryRegion(ptr, num_bytes);
- };
- ProcessState::singleton()->AddCPUAllocVisitor(alloc_visitor);
- ProcessState::singleton()->AddCPUFreeVisitor(free_visitor);
-
-#if GOOGLE_CUDA
- if (IsGDRAvailable()) {
- int32_t bus_id = TryToReadNumaNode(rdma_adapter_->context_->device) + 1;
-
- // Note we don't free allocated GPU memory so there is no free visitor
- SubAllocator::Visitor cuda_alloc_visitor = [](void* ptr, int gpu_id,
- size_t num_bytes) {
- RdmaMemoryMgr::Singleton().InsertMemoryRegion(
- ptr, num_bytes, strings::StrCat("GPU:", gpu_id));
- };
- GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id,
- cuda_alloc_visitor);
- GPUProcessState::singleton()->AddCUDAHostAllocVisitor(bus_id,
- alloc_visitor);
- GPUProcessState::singleton()->AddCUDAHostFreeVisitor(bus_id, free_visitor);
- LOG(INFO) << "Instrumenting GPU allocator with bus_id " << bus_id;
- }
-#endif // GOOGLE_CUDA
-}
-
Status GdrMemoryManager::Init() {
epfd_ = epoll_create1(0);
if (epfd_ == -1) {
@@ -289,6 +247,42 @@ Status GdrMemoryManager::Init() {
"cannot add server to epoll");
}
+ numa_node_ = TryToReadNumaNode(listening_->verbs->device);
+
+ SubAllocator::Visitor alloc_visitor = [this](void* ptr, int numa_node,
+ size_t num_bytes) {
+ VLOG(2) << "Registering RDMA capable memory region on numa_node "
+ << numa_node;
+ InsertMemoryRegion(ptr, num_bytes, strings::StrCat("CPU:", numa_node));
+ };
+ SubAllocator::Visitor free_visitor = [this](void* ptr, int numa_node,
+ size_t num_bytes) {
+ VLOG(2) << "De-registering RDMA capable memory region on numa_node "
+ << numa_node;
+ EvictMemoryRegion(ptr, num_bytes);
+ };
+ ProcessState::singleton()->AddCPUAllocVisitor(alloc_visitor);
+ ProcessState::singleton()->AddCPUFreeVisitor(free_visitor);
+ LOG(INFO) << "Instrumenting CPU allocator(s)";
+
+#if GOOGLE_CUDA
+ if (IsGDRAvailable()) {
+ int bus_id = numa_node_ + 1;
+
+ SubAllocator::Visitor cuda_alloc_visitor = [this](void* ptr, int gpu_id,
+ size_t num_bytes) {
+ VLOG(2) << "Registering RDMA capable memory region on GPU " << gpu_id;
+ InsertMemoryRegion(ptr, num_bytes, strings::StrCat("GPU:", gpu_id));
+ };
+ GPUProcessState::singleton()->AddGPUAllocVisitor(bus_id,
+ cuda_alloc_visitor);
+ GPUProcessState::singleton()->AddCUDAHostAllocVisitor(bus_id,
+ alloc_visitor);
+ GPUProcessState::singleton()->AddCUDAHostFreeVisitor(bus_id, free_visitor);
+ LOG(INFO) << "Instrumenting GPU allocator(s) with bus_id " << bus_id;
+ }
+#endif // GOOGLE_CUDA
+
return Status::OK();
}
@@ -405,7 +399,7 @@ void GdrMemoryManager::TransportOptionsFromTensor(
ibv_mr* mr = FindMemoryRegion(addr, length);
#if GOOGLE_CUDA
- if (!on_host) {
+ if (device->tensorflow_gpu_device_info() && !on_host) {
Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0);
Tensor* host_copy = new Tensor(alloc, tensor.dtype(), tensor.shape());
GPUUtil::CopyGPUTensorToCPU(
@@ -456,11 +450,27 @@ void GdrMemoryManager::TransportOptionsFromTensor(
#endif
if (mr == nullptr) {
- done(errors::Unavailable("Cannot find pinned memory region"));
- return;
+ Allocator* alloc = ProcessState::singleton()->GetCPUAllocator(numa_node_);
+ Tensor host_copy(alloc, tensor.dtype(), tensor.shape());
+
+ std::memcpy(DMAHelper::buffer(&host_copy)->data(), buffer->data(), length);
+ VLOG(2) << "Copying " << length << " bytes unpinned tensor buffer";
+
+ buffer = DMAHelper::buffer(&host_copy);
+ addr = buffer->data();
+ length = buffer->size();
+
+ mr = FindMemoryRegion(addr, length);
+ if (mr == nullptr) {
+ done(errors::Unavailable("Cannot find pinned memory region"));
+ return;
+ }
+
+ buffer->Ref();
+ } else {
+ buffer->Ref();
}
- buffer->Ref();
TensorKey tensor_key = next_key_++;
{
mutex_lock l(server_mu_);
@@ -470,7 +480,7 @@ void GdrMemoryManager::TransportOptionsFromTensor(
uint64_t checksum = 0;
if (VLOG_IS_ON(2)) {
#ifdef GOOGLE_CUDA
- if (!on_host) {
+ if (device->tensorflow_gpu_device_info() && !on_host) {
checksum = GPUUtil::Checksum(device, device_context, tensor);
} else {
checksum = GPUUtil::Checksum(tensor);
@@ -508,7 +518,8 @@ void GdrMemoryManager::TensorFromTransportOptions(
Tensor host_copy;
#if GOOGLE_CUDA
if (mr == nullptr && !on_host) {
- Allocator* alloc = GPUProcessState::singleton()->GetCUDAHostAllocator(0);
+ Allocator* alloc =
+ GPUProcessState::singleton()->GetCUDAHostAllocator(numa_node_);
host_copy = Tensor(alloc, tensor->dtype(), tensor->shape());
buffer = DMAHelper::buffer(&host_copy);
addr = buffer->data();
@@ -518,8 +529,18 @@ void GdrMemoryManager::TensorFromTransportOptions(
#endif // GOOGLE_CUDA
if (mr == nullptr) {
- done(errors::Unavailable("Cannot find pinned memory region"));
- return;
+ Allocator* alloc = ProcessState::singleton()->GetCPUAllocator(numa_node_);
+ host_copy = Tensor(alloc, tensor->dtype(), tensor->shape());
+
+ buffer = DMAHelper::buffer(&host_copy);
+ addr = buffer->data();
+ length = buffer->size();
+
+ mr = FindMemoryRegion(addr, length);
+ if (mr == nullptr) {
+ done(errors::Unavailable("Cannot find pinned memory region"));
+ return;
+ }
}
decltype(clients_)::iterator iter;
@@ -568,7 +589,8 @@ void GdrMemoryManager::TensorFromTransportOptions(
}
#if GOOGLE_CUDA
- if (host_copy.NumElements() > 0) {
+ if (device->tensorflow_gpu_device_info() && !on_host &&
+ host_copy.NumElements() > 0) {
uint64_t checksum = 0;
if (VLOG_IS_ON(2)) {
checksum = GPUUtil::Checksum(host_copy);
@@ -598,6 +620,12 @@ void GdrMemoryManager::TensorFromTransportOptions(
}
#endif // GOOGLE_CUDA
+ if ((on_host || !device->tensorflow_gpu_device_info()) &&
+ host_copy.NumElements() > 0) {
+ std::memcpy(DMAHelper::buffer(tensor)->data(), addr, length);
+ VLOG(2) << "Copying " << length << " bytes unpinned tensor buffer";
+ }
+
uint64_t end = Env::Default()->NowMicros();
VLOG(2) << "RDMA from remote memory region " << remote_mr.rkey()
@@ -607,7 +635,7 @@ void GdrMemoryManager::TensorFromTransportOptions(
uint64_t checksum = 0;
if (VLOG_IS_ON(2)) {
#ifdef GOOGLE_CUDA
- if (device->tensorflow_gpu_device_info() && (!on_host)) {
+ if (device->tensorflow_gpu_device_info() && !on_host) {
checksum = GPUUtil::Checksum(device, device_context, *tensor);
} else {
checksum = GPUUtil::Checksum(*tensor);
@@ -668,7 +696,8 @@ ibv_mr* GdrMemoryManager::FindMemoryRegion(void* addr, size_t length) {
}
}
-void GdrMemoryManager::InsertMemoryRegion(void* addr, size_t length) {
+void GdrMemoryManager::InsertMemoryRegion(void* addr, size_t length,
+ const std::string& allocator_name) {
if (length == 0) return;
ibv_mr* mr = rdma_reg_read(listening_.get(), addr, length);
if (mr != nullptr) {
@@ -676,7 +705,8 @@ void GdrMemoryManager::InsertMemoryRegion(void* addr, size_t length) {
auto iter = std::upper_bound(mrs_.begin(), mrs_.end(), addr, &Comparator);
mrs_.insert(iter, {mr, &MRDeleter});
} else {
- LOG(WARNING) << "Cannot register memory region";
+ LOG(WARNING) << "Cannot register memory region allocated by "
+ << allocator_name;
}
}
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD
index f3ebe3b245..787a85644c 100644
--- a/tensorflow/contrib/lite/BUILD
+++ b/tensorflow/contrib/lite/BUILD
@@ -4,6 +4,7 @@ package(default_visibility = [
licenses(["notice"]) # Apache 2.0
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops")
exports_files(glob([
@@ -165,10 +166,6 @@ cc_library(
"stderr_reporter.h",
],
copts = tflite_copts(),
- defines = select({
- ":with_tflite_flex": ["TFLITE_FLEX"],
- "//conditions:default": [],
- }),
linkopts = [
] + select({
"//tensorflow:android": [
@@ -276,6 +273,7 @@ cc_test(
"testdata/0_subgraphs.bin",
"testdata/2_subgraphs.bin",
"testdata/empty_model.bin",
+ "testdata/multi_add_flex.bin",
"testdata/test_model.bin",
"testdata/test_model_broken.bin",
],
@@ -283,6 +281,26 @@ cc_test(
":framework",
"//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/core/api",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ "//tensorflow/contrib/lite/testing:util",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+# Test model framework with the flex library linked into the target.
+tf_cc_test(
+ name = "model_flex_test",
+ size = "small",
+ srcs = ["model_flex_test.cc"],
+ data = [
+ "testdata/multi_add_flex.bin",
+ ],
+ tags = ["no_windows"], # TODO(b/116667551): No weak symbols with MSVC.
+ deps = [
+ ":framework",
+ "//tensorflow/contrib/lite/core/api",
+ "//tensorflow/contrib/lite/delegates/flex:delegate",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
"//tensorflow/contrib/lite/testing:util",
"@com_google_googletest//:gtest",
],
diff --git a/tensorflow/contrib/lite/c/builtin_op_data.h b/tensorflow/contrib/lite/c/builtin_op_data.h
index 44daf7adaa..1e65c3cee2 100644
--- a/tensorflow/contrib/lite/c/builtin_op_data.h
+++ b/tensorflow/contrib/lite/c/builtin_op_data.h
@@ -191,6 +191,13 @@ typedef struct {
TfLiteFusedActivation activation;
float cell_clip;
float proj_clip;
+} TfLiteUnidirectionalSequenceLSTMParams;
+
+typedef struct {
+ // Parameters for the LSTM kernel.
+ TfLiteFusedActivation activation;
+ float cell_clip;
+ float proj_clip;
// If true, store the outputs of both directions in the first output.
bool merge_outputs;
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
index eac7db9a88..b092e5ee54 100644
--- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
@@ -371,7 +371,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
- case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
case BuiltinOperator_LSTM: {
auto params = allocator->AllocatePOD<TfLiteLSTMParams>();
if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
@@ -391,6 +390,20 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
+ case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: {
+ auto* params =
+ allocator->AllocatePOD<TfLiteUnidirectionalSequenceLSTMParams>();
+ if (auto* seq_lstm_params =
+ op->builtin_options_as_UnidirectionalSequenceLSTMOptions()) {
+ params->activation =
+ parse_activation(seq_lstm_params->fused_activation_function());
+ params->cell_clip = seq_lstm_params->cell_clip();
+ params->proj_clip = seq_lstm_params->proj_clip();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+
case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: {
auto params =
allocator->AllocatePOD<TfLiteBidirectionalSequenceLSTMParams>();
diff --git a/tensorflow/contrib/lite/delegates/flex/BUILD b/tensorflow/contrib/lite/delegates/flex/BUILD
index 9dd38958e5..9b89ed4f84 100644
--- a/tensorflow/contrib/lite/delegates/flex/BUILD
+++ b/tensorflow/contrib/lite/delegates/flex/BUILD
@@ -2,7 +2,7 @@
# This is a TF Lite delegate that is powered by TensorFlow's Eager.
#
package(default_visibility = [
- "//visibility:public",
+ "//visibility:private",
])
licenses(["notice"]) # Apache 2.0
@@ -50,6 +50,7 @@ cc_library(
hdrs = [
"delegate.h",
],
+ visibility = ["//visibility:public"],
deps = [
":buffer_map",
":delegate_data",
@@ -66,6 +67,7 @@ cc_library(
"//tensorflow/core:lib",
],
}),
+ alwayslink = 1,
)
tf_cc_test(
diff --git a/tensorflow/contrib/lite/delegates/flex/delegate.cc b/tensorflow/contrib/lite/delegates/flex/delegate.cc
index ba065a8ff5..c72b0cf513 100644
--- a/tensorflow/contrib/lite/delegates/flex/delegate.cc
+++ b/tensorflow/contrib/lite/delegates/flex/delegate.cc
@@ -83,6 +83,15 @@ TfLiteStatus CopyFromBufferHandle(TfLiteContext* context,
} // namespace delegate
} // namespace flex
+// Corresponding weak declaration found in lite/model.cc.
+std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>
+AcquireFlexDelegate() {
+ return std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>(
+ tflite::FlexDelegate::Create().release(), [](TfLiteDelegate* delegate) {
+ delete reinterpret_cast<tflite::FlexDelegate*>(delegate);
+ });
+}
+
std::unique_ptr<FlexDelegate> FlexDelegate::Create() {
std::unique_ptr<flex::DelegateData> delegate_data;
if (!flex::DelegateData::Create(&delegate_data).ok()) {
diff --git a/tensorflow/contrib/lite/experimental/micro/BUILD b/tensorflow/contrib/lite/experimental/micro/BUILD
new file mode 100644
index 0000000000..df1036bc8b
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/BUILD
@@ -0,0 +1,76 @@
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load(
+ "//tensorflow/contrib/lite/experimental/micro/testing:micro_test.bzl",
+ "tflite_micro_cc_test",
+)
+
+cc_library(
+ name = "micro_framework",
+ srcs = [
+ "micro_error_reporter.cc",
+ "micro_interpreter.cc",
+ "micro_mutable_op_resolver.cc",
+ "simple_tensor_allocator.cc",
+ ],
+ hdrs = [
+ "compatibility.h",
+ "micro_error_reporter.h",
+ "micro_interpreter.h",
+ "micro_mutable_op_resolver.h",
+ "simple_tensor_allocator.h",
+ ],
+ deps = [
+ "//tensorflow/contrib/lite:schema_fbs_version",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/core/api",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ ],
+)
+
+tflite_micro_cc_test(
+ name = "micro_error_reporter_test",
+ srcs = [
+ "micro_error_reporter_test.cc",
+ ],
+ deps = [
+ ":micro_framework",
+ ],
+)
+
+tflite_micro_cc_test(
+ name = "micro_mutable_op_resolver_test",
+ srcs = [
+ "micro_mutable_op_resolver_test.cc",
+ ],
+ deps = [
+ ":micro_framework",
+ "//tensorflow/contrib/lite/experimental/micro/testing:micro_test",
+ ],
+)
+
+tflite_micro_cc_test(
+ name = "micro_interpreter_test",
+ srcs = [
+ "micro_interpreter_test.cc",
+ ],
+ deps = [
+ ":micro_framework",
+ "//tensorflow/contrib/lite/experimental/micro/testing:micro_test",
+ ],
+)
+
+tflite_micro_cc_test(
+ name = "simple_tensor_allocator_test",
+ srcs = [
+ "simple_tensor_allocator_test.cc",
+ ],
+ deps = [
+ ":micro_framework",
+ "//tensorflow/contrib/lite/experimental/micro/testing:micro_test",
+ ],
+)
diff --git a/tensorflow/contrib/lite/experimental/micro/README.md b/tensorflow/contrib/lite/experimental/micro/README.md
new file mode 100644
index 0000000000..414cafde4d
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/README.md
@@ -0,0 +1,114 @@
+# TensorFlow Lite for Microcontrollers
+
+This an experimental port of TensorFlow Lite aimed at micro controllers and other devices with only kilobytes of memory. It doesn't require any operating system support, any standard C or C++ libraries, or dynamic memory allocation, so it's designed to be portable even to 'bare metal' systems. The core runtime fits in 16KB on a Cortex M3, and with enough operators to run a speech keyword detection model, takes up a total of 22KB.
+
+The design goals are for the framework to be:
+
+- **Readable**: We want embedded software engineers to be able to understand what's required to run ML inference without having to study research papers. We've tried to keep the code base small, modular, and have reference implementations of all operations to help with this.
+
+- **Easy to modify**: We know that there are a lot of different platforms and requirements in the embedded world, and we don't expect to cover all of them in one framework. Instead, we're hoping that it can be a good starting point for developers to build on top of to meet their own needs. For example, we tried to make it easy to replace the implementations of key computational operators that are often crucial for performance, without having to touch the data flow and other runtime code. We want it to make more sense to use our workflow to handle things like model import and less-important operations, and customize the parts that matter, rather than having to reimplement everything in your own engine.
+
+- **Well-tested**: If you're modifying code, you need to know if your changes are correct. Having an easy way to test lets you develop much faster. To help there, we've written tests for all the components, and we've made sure that the tests can be run on almost any platform, with no dependencies apart from the ability to log text to a debug console somewhere. We also provide an easy way to run all the tests on-device as part of an automated test framework, and we use qemu/Renode emulation so that tests can be run even without physical devices present.
+
+- **Easy to integrate**: We want to be as open a system as possible, and use the best code available for each platform. To do that, we're going to rely on projects like [CMSIS-NN](https://www.keil.com/pack/doc/CMSIS/NN/html/index.html), [uTensor](https://github.com/uTensor/uTensor), and other vendor libraries to handle as much performance-critical code as possible. We know that there are an increasing number of options to accelerate neural networks on microcontrollers, so we're aiming to be a good host for deploying those hardware technologies too.
+
+- **Compatible**: We're using the same file schema, interpreter API, and kernel interface as regular TensorFlow Lite, so we leverage the large existing set of tools, documentation, and examples for the project. The biggest barrier to deploying ML models is getting them from a training environment into a form that's easy to run inference on, so we see reusing this rich ecosystem as being crucial to being easily usable. We also hope to integrate this experimental work back into the main codebase in the future.
+
+To meet those goals, we've made some tradeoffs:
+
+- **Simple C++**: To help with readability, our code is written in a modern version of C++, but we generally treat it as a "better C", rather relying on more complex features such as template meta-programming. As mentioned earlier, we avoid any use of dynamic memory allocation (new/delete) or the standard C/C++ libraries, so we believe this should still be fairly portable. It does mean that some older devices with C-only toolchains won't be supported, but we're hoping that the reference operator implementations (which are simple C-like functions) can still be useful in those cases. The interfaces are also designed to be C-only, so it should be possible to integrate the resulting library with pure C projects.
+
+- **Interpreted**: Code generation is a popular pattern for embedded code, because it gives standalone code that's easy to modify and step through, but we've chosen to go with an interpreted approach. In our internal microcontroller work we've found that using an extremely stripped-down interpreter with almost no dependencies gives us a lot of the same advantages, but is easier to maintain. For example, when new updates come out for the underlying library, you can just merge your local modifications in a single step, rather than having to regenerate new code and then patch in any changes you subsequently made. The coarse granularity of the interpreted primitives means that each operation call typically takes hundreds of thousands of instruction cycles at least, so we don't see noticeable performance gains from avoiding what's essentially a single switch statement at the interpreter level to call each operation. We're still working on improving the packaging though, for example we're considering having the ability to snapshot all the source files and headers used for a particular model, being able to compile the code and data together as a library, and then access it through a minimal set of C interface calls which hide the underlying complexity.
+
+- **Flatbuffers**: We represent our models using [the standard flatbuffer schema used by the rest of TensorFlow Lite](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/schema/schema.fbs), with the difference that we always keep it in read-only program memory (typically flash) rather than relying on having a file system to read it from. This is a good fit because flatbuffer's serialized format is designed to be mapped into memory without requiring any extra memory allocations or modifications to access it. All of the functions to read model values work directly on the serialized bytes, and large sections of data like weights are directly accessible as sequential C-style arrays of their data type, with no strides or unpacking needed. We do get a lot of value from using flatbuffers, but there is a cost in complexity. The flat buffer library code is all inline [inside the main headers](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/schema/schema_generated.h), but it isn't straightforward to inspect their implementations, and the model data structures aren't easy to comprehend from the debugger. The header for the schema itself also has to be periodically updated when new information is added to the file format, though we try to handle that transparently for most developers by checking in a pre-generated version.
+
+- **Code Duplication**: Some of the code in this prototype largely duplicates the logic in other parts of the TensorFlow Lite code base, for example the operator wrappers. We've tried to keep share as much as we can between the two interpreters, but there are some assumptions built into the original runtime that make this difficult. We'll be working on modularizing the main interpreter so that we can move to an entirely shared system.
+
+This initial preview release is designed to get early feedback, and is not intended to be a final product. It only includes enough operations to run a simple keyword recognition model, and the implementations are not optimized. We're hoping this will be a good way to get feedback and collaborate to improve the framework.
+
+## Getting Started
+
+Building requires a Linux or OS X machine.
+
+ - Open a terminal
+ - Download the TensorFlow source with `git clone https://github.com/tensorflow`
+ - Enter the source root directory by running `cd tensorflow`
+ - Download the dependencies by running `tensorflow/contrib/lite/experimental/micro/tools/make/download_dependencies.sh`. This may take a few minutes
+ - Build and test the library with `make -f tensorflow/contrib/lite/experimental/micro/tools/make/Makefile test`
+
+You should see a series of compilation steps, followed by "~~~ALL TESTS PASSED~~~" for the various tests of the code that it will run. If there's an error, you should get an informative message from make about what went wrong.
+
+These tests are all built as simple binaries with few dependencies, so you can run them manually. For example, here's how to run the depthwise convolution test, and its output:
+
+```
+tensorflow/contrib/lite/experimental/micro/tools/make/gen/linux_x86_64/bin/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test
+
+Testing SimpleTest
+Testing SimpleTestQuantized
+Testing SimpleTestRelu
+Testing SimpleTestReluQuantized
+4/4 tests passed
+~ALL TESTS PASSED~~~
+```
+
+Looking at the [depthwise_conv_test.cc](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test.cc) code, you'll see a sequence that looks like this:
+
+```
+...
+TF_LITE_MICRO_TESTS_BEGIN
+
+TF_LITE_MICRO_TEST(SimpleTest) {
+...
+}
+...
+TF_LITE_MICRO_TESTS_END
+```
+
+These macros work a lot like [the Google test framework](https://github.com/google/googletest), but they don't require any dependencies and just write results to stderr, rather than aborting the program. If all the tests pass, then "~~~ALL TESTS PASSED~~~" is output, and the test harness that runs the binary during the make process knows that everything ran correctly. If there's an error, the lack of the expected string lets the harness know that the test failed.
+
+So, why are we running tests in this complicated way? So far, we've been building binaries that run locally on the Mac OS or Linux machine you're building on, but this approach becomes important when we're targeting simple micro controller devices.
+
+## Building for the "Blue Pill" STM32F103
+
+The goal of this library is to enable machine learning on resource-constrained micro controllers and DSPs, and as part of that we've targeted the ["Blue Pill" STM32F103-compatible development board](https://github.com/google/googletest) as a cheap and popular platform. It only has 20KB of RAM and 64KB of flash, so it's a good device to ensure we can run efficiently on small chips.
+
+It's fairly easy to [buy and wire up a physical board](https://github.com/google/stm32_bare_lib#wiring-up-your-blue-pill), but even if you don't have an actual device, the [Renode project](https://renode.io/) makes it easy to run a faithful emulation on your desktop machine. You'll need [Docker](https://www.docker.com/) installed, but once you have that set up, try running the following command:
+
+`make -f tensorflow/contrib/lite/experimental/micro/tools/make/Makefile TARGET=bluepill test`
+
+You should see a similar set of outputs as you did in the previous section, with the addition of some extra Docker logging messages. These are because we're using Docker to run the Renode micro controller emulation tool, and the tests themselves are being run on a simulated STM32F103 device. The communication channels between an embedded device and the host are quite limited, so the test harness looks at the output of the debug log to see if tests have passed, just as it did in the previous section. This makes it a very flexible way to run cross-platform tests, even when a platform has no operating system facilities, as long as it can output debugging text logs.
+
+To understand what's happening here, try running the same depthwise convolution test, but through the emulated device test harness, with the following command:
+
+```
+tensorflow/contrib/lite/experimental/micro/testing/test_bluepill_binary.sh \
+tensorflow/contrib/lite/experimental/micro/tools/make/gen/bluepill_cortex-m3/bin/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test
+
+```
+
+You should see output that looks something like this:
+
+```
+Sending build context to Docker daemon 21.5kB
+Step 1/2 : FROM antmicro/renode:latest
+ ---> 1b670a243e8f
+Step 2/2 : LABEL maintainer="Pete Warden <petewarden@google.com>"
+ ---> Using cache
+ ---> 3afcd410846d
+Successfully built 3afcd410846d
+Successfully tagged renode_bluepill:latest
+LOGS:
+...
+03:27:32.4340 [INFO] machine-0: Machine started.
+03:27:32.4790 [DEBUG] cpu.uartSemihosting: [+0.22s host +0s virt 0s virt from start] Testing SimpleTest
+03:27:32.4812 [DEBUG] cpu.uartSemihosting: [+2.21ms host +0s virt 0s virt from start] Testing SimpleTestQuantized
+03:27:32.4833 [DEBUG] cpu.uartSemihosting: [+2.14ms host +0s virt 0s virt from start] Testing SimpleTestRelu
+03:27:32.4834 [DEBUG] cpu.uartSemihosting: [+0.18ms host +0s virt 0s virt from start] Testing SimpleTestReluQuantized
+03:27:32.4838 [DEBUG] cpu.uartSemihosting: [+0.4ms host +0s virt 0s virt from start] 4/4 tests passed
+03:27:32.4839 [DEBUG] cpu.uartSemihosting: [+41µs host +0s virt 0s virt from start] ~~~ALL TESTS PASSED~~~
+03:27:32.4839 [DEBUG] cpu.uartSemihosting: [+5µs host +0s virt 0s virt from start]
+...
+tensorflow/contrib/lite/experimental/micro/tools/make/gen/bluepill_cortex-m3/bin/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test: PASS
+```
+
+There's a lot of output here, but you should be able to see that the same tests that were covered when we ran locally on the development machine show up in the debug logs here, along with the magic string "~~~ALL TESTS PASSED~~~". This is the exact same code as before, just compiled and run on the STM32F103 rather than your desktop. We hope that the simplicity of this testing approach will help make adding support for new platforms as easy as possible.
diff --git a/tensorflow/contrib/lite/experimental/micro/compatibility.h b/tensorflow/contrib/lite/experimental/micro/compatibility.h
new file mode 100644
index 0000000000..4f0fd9f312
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/compatibility.h
@@ -0,0 +1,32 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_COMPATIBILITY_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_COMPATIBILITY_H_
+
+// C++ will automatically create class-specific delete operators for virtual
+// objects, which by default call the global delete function. For embedded
+// applications we want to avoid this, and won't be calling new/delete on these
+// objects, so we need to override the default implementation with one that does
+// nothing to avoid linking in ::delete().
+// This macro needs to be included in all subclasses of a virtual base class in
+// the private section.
+#ifdef TF_LITE_STATIC_MEMORY
+#define TF_LITE_REMOVE_VIRTUAL_DELETE \
+ void operator delete(void* p) {}
+#else
+#define TF_LITE_REMOVE_VIRTUAL_DELETE
+#endif
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_COMPATIBILITY_H_
diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/BUILD b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/BUILD
new file mode 100644
index 0000000000..dad58b6c1c
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/BUILD
@@ -0,0 +1,31 @@
+# Description:
+# TensorFlow Lite microcontroller example.
+
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+load(
+ "//tensorflow/contrib/lite/experimental/micro/testing:micro_test.bzl",
+ "tflite_micro_cc_test",
+)
+
+tflite_micro_cc_test(
+ name = "micro_speech_test",
+ srcs = [
+ "micro_speech_test.cc",
+ "tiny_conv_model_data.cc",
+ "tiny_conv_model_data.h",
+ ],
+ tags = [
+ "nomsan",
+ ],
+ deps = [
+ "//tensorflow/contrib/lite:schema_fbs_version",
+ "//tensorflow/contrib/lite/experimental/micro:micro_framework",
+ "//tensorflow/contrib/lite/experimental/micro/kernels:all_ops_resolver",
+ "//tensorflow/contrib/lite/experimental/micro/kernels:micro_ops",
+ "//tensorflow/contrib/lite/experimental/micro/testing:micro_test",
+ "//tensorflow/contrib/lite/schema:schema_fbs",
+ ],
+)
diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc
new file mode 100644
index 0000000000..86cd056a72
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc
@@ -0,0 +1,55 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h"
+#include "tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h"
+#include "tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h"
+#include "tensorflow/contrib/lite/experimental/micro/micro_interpreter.h"
+#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+#include "tensorflow/contrib/lite/version.h"
+
+TF_LITE_MICRO_TESTS_BEGIN
+
+TF_LITE_MICRO_TEST(TestInvoke) {
+ tflite::MicroErrorReporter micro_error_reporter;
+ tflite::ErrorReporter* error_reporter = &micro_error_reporter;
+
+ const tflite::Model* model = ::tflite::GetModel(g_tiny_conv_model_data);
+ if (model->version() != TFLITE_SCHEMA_VERSION) {
+ error_reporter->Report(
+ "Model provided is schema version %d not equal "
+ "to supported version %d.\n",
+ model->version(), TFLITE_SCHEMA_VERSION);
+ }
+ tflite::ops::micro::AllOpsResolver resolver;
+
+ const int tensor_arena_size = 10 * 1024;
+ uint8_t tensor_arena[tensor_arena_size];
+ tflite::SimpleTensorAllocator tensor_allocator(tensor_arena,
+ tensor_arena_size);
+
+ tflite::MicroInterpreter interpreter(model, resolver, &tensor_allocator,
+ error_reporter);
+ TfLiteStatus invoke_status = interpreter.Invoke();
+ if (invoke_status != kTfLiteOk) {
+ error_reporter->Report("Invoke failed\n");
+ }
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, invoke_status);
+
+ error_reporter->Report("Ran successfully\n");
+}
+
+TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc
new file mode 100644
index 0000000000..f1f9e0e219
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc
@@ -0,0 +1,1672 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Automatically created from a TensorFlow Lite flatbuffer using the command:
+// xxd -i tiny_conv.tflite > tiny_conv_model_data.cc
+
+#include "tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h"
+
+const unsigned char g_tiny_conv_model_data[] = {
+ 0x18, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c, 0x33, 0x00, 0x00, 0x0e, 0x00,
+ 0x18, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x14, 0x00,
+ 0x0e, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x08, 0x4d, 0x00, 0x00,
+ 0x0c, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00,
+ 0x01, 0x00, 0x00, 0x00, 0xf4, 0x47, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00,
+ 0x54, 0x4f, 0x43, 0x4f, 0x20, 0x43, 0x6f, 0x6e, 0x76, 0x65, 0x72, 0x74,
+ 0x65, 0x64, 0x2e, 0x00, 0x09, 0x00, 0x00, 0x00, 0xd4, 0x47, 0x00, 0x00,
+ 0x04, 0x03, 0x00, 0x00, 0xfc, 0x02, 0x00, 0x00, 0xf4, 0x02, 0x00, 0x00,
+ 0x64, 0x00, 0x00, 0x00, 0x44, 0x00, 0x00, 0x00, 0x3c, 0x00, 0x00, 0x00,
+ 0x0c, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xb8, 0xb3, 0xff, 0xff,
+ 0x16, 0xb4, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0xd7, 0x02, 0x00, 0x00, 0x2f, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xe8, 0xb3, 0xff, 0xff,
+ 0x46, 0xb4, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00,
+ 0xab, 0x00, 0x00, 0x00, 0x1e, 0xff, 0xff, 0xff, 0xed, 0xff, 0xff, 0xff,
+ 0x4a, 0x00, 0x00, 0x00, 0x62, 0xb4, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00,
+ 0x80, 0x02, 0x00, 0x00, 0xce, 0xad, 0xaf, 0x3c, 0xc8, 0xe9, 0xb0, 0x83,
+ 0xa1, 0xbf, 0xb2, 0xb1, 0xab, 0xd0, 0xa7, 0x53, 0xa5, 0xe9, 0xb5, 0xac,
+ 0xa2, 0xd3, 0xc4, 0x9e, 0x8b, 0xb2, 0x64, 0xb3, 0x9d, 0xa2, 0xae, 0xa6,
+ 0xd5, 0xbe, 0x43, 0x9f, 0x9c, 0x54, 0xb5, 0xa8, 0x49, 0x78, 0x86, 0xa2,
+ 0xa3, 0x55, 0x35, 0x96, 0x3d, 0x7f, 0xe2, 0xb5, 0xb0, 0x47, 0x28, 0xa9,
+ 0x9d, 0xbb, 0xd6, 0xff, 0xb7, 0x79, 0x63, 0xb5, 0xaf, 0xa7, 0xab, 0x7e,
+ 0xbc, 0xc7, 0xa0, 0xc3, 0xb1, 0xb6, 0xb2, 0xa1, 0xc2, 0xbb, 0x79, 0x57,
+ 0xbe, 0xc1, 0xb7, 0xb0, 0x6b, 0xb7, 0xa5, 0x75, 0x97, 0xb8, 0xe7, 0xac,
+ 0xad, 0x7e, 0xb1, 0x9b, 0xc3, 0xba, 0x6b, 0xa2, 0x7f, 0x58, 0xb9, 0x7a,
+ 0x4c, 0x91, 0x74, 0x9e, 0xa7, 0x3d, 0xc2, 0x94, 0x75, 0xa1, 0xa4, 0xac,
+ 0xab, 0x45, 0x2e, 0xb4, 0xb6, 0xbf, 0xc1, 0xdb, 0xaf, 0x6c, 0x67, 0xb1,
+ 0xa9, 0xa6, 0xa8, 0xca, 0xc2, 0xc4, 0xb9, 0xbf, 0xb4, 0xb9, 0xaa, 0x9d,
+ 0x9f, 0xb9, 0xb2, 0x71, 0xb2, 0xca, 0xbe, 0xaf, 0x5f, 0xbc, 0xa0, 0x5b,
+ 0xa8, 0xb4, 0xa4, 0xa8, 0xd8, 0x69, 0xb7, 0x8a, 0xbc, 0xb8, 0xaf, 0x9c,
+ 0x7c, 0x5d, 0xb3, 0x6b, 0x49, 0x95, 0x64, 0xa0, 0xa2, 0x49, 0xcb, 0x87,
+ 0xa5, 0xb5, 0xa1, 0xb2, 0xa3, 0x40, 0x6d, 0x9f, 0xc5, 0xb6, 0xbb, 0xd4,
+ 0x9c, 0x6d, 0x69, 0xa9, 0xa8, 0x91, 0xad, 0xb8, 0xd2, 0xc6, 0xaf, 0xb8,
+ 0xac, 0xa9, 0xa2, 0xa7, 0x60, 0xa6, 0xa1, 0xc9, 0xb8, 0xd6, 0xcf, 0xb1,
+ 0x56, 0xb4, 0xac, 0x40, 0xae, 0xbd, 0xbf, 0xa2, 0x54, 0x72, 0x9b, 0x8c,
+ 0xc2, 0xb5, 0xc2, 0x9b, 0x64, 0x6d, 0xb4, 0x62, 0x4e, 0x9b, 0x6c, 0xa6,
+ 0x8f, 0x4c, 0xca, 0x95, 0xb6, 0xbf, 0x92, 0xae, 0x9c, 0x49, 0xae, 0xb2,
+ 0xc0, 0xb6, 0xbc, 0xd1, 0xa4, 0x7b, 0x64, 0xa0, 0xa6, 0x81, 0xac, 0xa6,
+ 0xbd, 0xc8, 0xbc, 0xae, 0xaa, 0x9e, 0x61, 0xb1, 0x57, 0xac, 0xbf, 0xbf,
+ 0xbb, 0xe0, 0xa6, 0xae, 0x47, 0xc9, 0xbc, 0x57, 0xb0, 0xb5, 0xc7, 0x98,
+ 0xf4, 0x93, 0xb6, 0x70, 0xc3, 0xb3, 0xca, 0xab, 0x77, 0x9a, 0xac, 0x45,
+ 0x5c, 0x9e, 0x9a, 0xa9, 0x9b, 0x35, 0xc0, 0x6f, 0xc6, 0xc7, 0x91, 0xb4,
+ 0xa8, 0x3c, 0xce, 0xb8, 0xad, 0xb9, 0xb5, 0xdd, 0x9c, 0x6d, 0xbf, 0x91,
+ 0xb2, 0x7d, 0xa0, 0xaf, 0x9f, 0xbd, 0xb9, 0xcf, 0x9b, 0x5d, 0x3f, 0xac,
+ 0x64, 0xae, 0xaf, 0xb8, 0xbc, 0xb8, 0x86, 0xb5, 0x36, 0xcf, 0xb4, 0xa9,
+ 0xad, 0xcd, 0xdb, 0xa4, 0x68, 0xa6, 0xa4, 0x67, 0xc8, 0xb7, 0xe5, 0xa4,
+ 0x76, 0xb8, 0xa8, 0x28, 0x6b, 0xa5, 0xba, 0xad, 0x9f, 0x3a, 0xa5, 0x42,
+ 0xc5, 0xb0, 0x88, 0xad, 0xa5, 0x4d, 0xea, 0x8a, 0xb8, 0xb5, 0xb3, 0xd9,
+ 0xa0, 0x77, 0xbb, 0x92, 0x9e, 0x80, 0xbd, 0xbd, 0x6d, 0xcc, 0xab, 0x99,
+ 0x88, 0x58, 0x4d, 0xb0, 0x6c, 0xbc, 0x96, 0xbd, 0xae, 0xab, 0x5b, 0xac,
+ 0x2f, 0xc3, 0x9a, 0xbe, 0xac, 0xb3, 0x84, 0x9b, 0xe3, 0xaf, 0x95, 0x6b,
+ 0xc2, 0xb5, 0xca, 0xb7, 0x4e, 0xbc, 0x9d, 0x24, 0x75, 0xa9, 0xd2, 0xae,
+ 0xa0, 0x2b, 0x90, 0x34, 0xd1, 0xb5, 0x96, 0xae, 0xaa, 0x4d, 0xc1, 0xa3,
+ 0xb1, 0xb4, 0xaa, 0xd2, 0x9c, 0x7d, 0xc0, 0x91, 0x91, 0x7a, 0xb8, 0x83,
+ 0x44, 0xcb, 0xaf, 0x9b, 0x6b, 0x5b, 0x75, 0xb2, 0x62, 0xb6, 0xaa, 0xcb,
+ 0x99, 0xa8, 0x63, 0xae, 0x24, 0xc7, 0x8a, 0xbe, 0xa9, 0xb6, 0xa0, 0xa1,
+ 0x41, 0xac, 0x84, 0xb5, 0xb9, 0xb3, 0x9b, 0xad, 0x77, 0xbf, 0xa8, 0x7e,
+ 0x82, 0xb9, 0xbe, 0xaa, 0xa3, 0x47, 0x6d, 0xb5, 0xc3, 0xb1, 0xbf, 0xa7,
+ 0xb1, 0x57, 0x75, 0xb5, 0xb0, 0xb6, 0xb9, 0xce, 0xa4, 0x86, 0xb0, 0xa4,
+ 0x98, 0x80, 0xc5, 0x3e, 0x90, 0xca, 0x9b, 0xa2, 0x5a, 0x50, 0xc5, 0xa5,
+ 0xad, 0xc1, 0x9c, 0x91, 0x83, 0x8f, 0x21, 0xab, 0xac, 0xba, 0x70, 0xb4,
+ 0xae, 0x85, 0x7e, 0xa7, 0xbd, 0xba, 0x7c, 0xb2, 0xb5, 0xb2, 0x7e, 0xb3,
+ 0xc3, 0xcd, 0x82, 0xac, 0x9b, 0xb3, 0xa6, 0xb0, 0xbc, 0x6f, 0x52, 0xb9,
+ 0xbf, 0xb1, 0xa6, 0xa4, 0xc1, 0x7a, 0x90, 0xc0, 0xae, 0xab, 0x94, 0xd8,
+ 0xab, 0xa4, 0x98, 0xbb, 0x8b, 0x86, 0x94, 0x01, 0xad, 0xe7, 0xb1, 0x9b,
+ 0x57, 0x48, 0xc1, 0x88, 0xbf, 0xcc, 0xb4, 0x4b, 0x62, 0x8b, 0x48, 0xa7,
+ 0xbe, 0xe1, 0x80, 0xa6, 0xb3, 0x64, 0xaa, 0xa4, 0xcf, 0xba, 0x6d, 0xa6,
+ 0xb8, 0xa0, 0x8f, 0xb3, 0xce, 0xc3, 0x87, 0xb2, 0xa0, 0xc0, 0x78, 0xb0,
+ 0xb9, 0xaa, 0x40, 0xb8, 0xd8, 0xa3, 0x9a, 0xaa, 0xcc, 0xa2, 0x9f, 0xb9,
+ 0xbe, 0xc2, 0x89, 0xd6, 0xc6, 0x9c, 0xa3, 0xc7, 0x94, 0xb6, 0xff, 0xff,
+ 0x98, 0xb6, 0xff, 0xff, 0xf6, 0xb6, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00,
+ 0xc0, 0x44, 0x00, 0x00, 0x4a, 0x4d, 0x59, 0x60, 0x5a, 0x45, 0x3d, 0x50,
+ 0x4a, 0x43, 0x3d, 0x59, 0x3e, 0x49, 0x4a, 0x59, 0x45, 0x44, 0x41, 0x5d,
+ 0x50, 0x2f, 0x4e, 0x34, 0x46, 0x48, 0x41, 0x4a, 0x4c, 0x3b, 0x4b, 0x3e,
+ 0x49, 0x49, 0x43, 0x4b, 0x3e, 0x49, 0x47, 0x41, 0x3e, 0x4a, 0x46, 0x43,
+ 0x41, 0x43, 0x47, 0x49, 0x4a, 0x4c, 0x46, 0x58, 0x3f, 0x4c, 0x4b, 0x4c,
+ 0x4d, 0x4b, 0x45, 0x52, 0x45, 0x42, 0x52, 0x52, 0x48, 0x40, 0x46, 0x5f,
+ 0x4c, 0x41, 0x47, 0x48, 0x48, 0x4c, 0x43, 0x61, 0x50, 0x4b, 0x49, 0x49,
+ 0x46, 0x3f, 0x40, 0x67, 0x40, 0x4d, 0x45, 0x40, 0x40, 0x45, 0x47, 0x56,
+ 0x44, 0x3a, 0x4a, 0x4c, 0x52, 0x48, 0x46, 0x50, 0x4b, 0x44, 0x51, 0x45,
+ 0x40, 0x45, 0x45, 0x48, 0x4e, 0x4e, 0x43, 0x48, 0x44, 0x4b, 0x45, 0x4a,
+ 0x53, 0x45, 0x4a, 0x4b, 0x3f, 0x43, 0x45, 0x53, 0x4d, 0x43, 0x46, 0x3f,
+ 0x47, 0x4e, 0x51, 0x50, 0x48, 0x4f, 0x4f, 0x4a, 0x4a, 0x4e, 0x45, 0x4e,
+ 0x46, 0x41, 0x4a, 0x46, 0x45, 0x47, 0x45, 0x4b, 0x50, 0x4c, 0x46, 0x45,
+ 0x41, 0x47, 0x41, 0x47, 0x46, 0x4f, 0x3f, 0x4f, 0x4a, 0x51, 0x4f, 0x53,
+ 0x54, 0x48, 0x51, 0x43, 0x4b, 0x48, 0x4d, 0x46, 0x48, 0x4f, 0x49, 0x44,
+ 0x43, 0x53, 0x50, 0x59, 0x56, 0x3d, 0x45, 0x44, 0x48, 0x38, 0x3b, 0x5f,
+ 0x39, 0x43, 0x43, 0x52, 0x46, 0x3e, 0x43, 0x58, 0x43, 0x1e, 0x50, 0x3c,
+ 0x46, 0x4b, 0x46, 0x50, 0x3c, 0x37, 0x4c, 0x47, 0x47, 0x4b, 0x47, 0x54,
+ 0x43, 0x3e, 0x47, 0x4f, 0x4b, 0x41, 0x53, 0x50, 0x42, 0x46, 0x4f, 0x4b,
+ 0x4e, 0x3f, 0x49, 0x52, 0x4a, 0x4a, 0x49, 0x53, 0x52, 0x47, 0x52, 0x5a,
+ 0x40, 0x42, 0x4d, 0x4b, 0x50, 0x43, 0x49, 0x59, 0x47, 0x4c, 0x4d, 0x50,
+ 0x4e, 0x3c, 0x44, 0x61, 0x51, 0x49, 0x49, 0x46, 0x49, 0x47, 0x4b, 0x5a,
+ 0x45, 0x4b, 0x43, 0x40, 0x44, 0x52, 0x4d, 0x54, 0x49, 0x47, 0x44, 0x48,
+ 0x46, 0x48, 0x3e, 0x40, 0x45, 0x4f, 0x4d, 0x4b, 0x4c, 0x40, 0x3d, 0x40,
+ 0x3e, 0x48, 0x50, 0x4e, 0x4c, 0x42, 0x48, 0x4b, 0x3d, 0x48, 0x4b, 0x44,
+ 0x52, 0x4b, 0x49, 0x4f, 0x49, 0x3f, 0x47, 0x43, 0x4d, 0x3f, 0x53, 0x4e,
+ 0x4a, 0x4f, 0x4e, 0x4e, 0x53, 0x42, 0x46, 0x4c, 0x44, 0x4c, 0x46, 0x51,
+ 0x45, 0x48, 0x4a, 0x50, 0x47, 0x41, 0x45, 0x54, 0x4a, 0x44, 0x50, 0x49,
+ 0x48, 0x50, 0x51, 0x4b, 0x50, 0x4c, 0x4a, 0x49, 0x43, 0x47, 0x50, 0x4a,
+ 0x4d, 0x4c, 0x4e, 0x49, 0x42, 0x50, 0x52, 0x48, 0x45, 0x5a, 0x4e, 0x55,
+ 0x51, 0x3d, 0x3d, 0x4d, 0x42, 0x32, 0x36, 0x64, 0x39, 0x4c, 0x41, 0x48,
+ 0x44, 0x35, 0x43, 0x56, 0x47, 0x1e, 0x4b, 0x3e, 0x47, 0x3f, 0x43, 0x52,
+ 0x51, 0x34, 0x41, 0x4d, 0x3e, 0x41, 0x41, 0x48, 0x3c, 0x4b, 0x45, 0x3b,
+ 0x40, 0x43, 0x4c, 0x46, 0x46, 0x47, 0x3e, 0x4f, 0x4b, 0x48, 0x42, 0x47,
+ 0x4e, 0x3e, 0x49, 0x47, 0x43, 0x43, 0x4e, 0x52, 0x51, 0x45, 0x3f, 0x54,
+ 0x46, 0x44, 0x48, 0x5d, 0x3e, 0x4a, 0x47, 0x52, 0x53, 0x3a, 0x4f, 0x5d,
+ 0x41, 0x4c, 0x48, 0x51, 0x43, 0x4b, 0x4b, 0x67, 0x48, 0x4b, 0x45, 0x4d,
+ 0x4b, 0x43, 0x4a, 0x54, 0x4c, 0x46, 0x43, 0x4a, 0x4d, 0x43, 0x4c, 0x47,
+ 0x4a, 0x48, 0x4d, 0x42, 0x4d, 0x48, 0x3f, 0x43, 0x4c, 0x44, 0x4e, 0x4c,
+ 0x40, 0x45, 0x4b, 0x48, 0x47, 0x47, 0x3e, 0x4c, 0x52, 0x41, 0x44, 0x4e,
+ 0x4d, 0x44, 0x49, 0x4d, 0x3d, 0x45, 0x48, 0x4f, 0x4c, 0x4a, 0x55, 0x51,
+ 0x4d, 0x4c, 0x45, 0x4e, 0x46, 0x45, 0x44, 0x49, 0x4e, 0x44, 0x40, 0x48,
+ 0x49, 0x44, 0x53, 0x51, 0x42, 0x41, 0x51, 0x49, 0x51, 0x45, 0x51, 0x3f,
+ 0x4b, 0x3f, 0x52, 0x3c, 0x50, 0x4d, 0x4f, 0x4b, 0x44, 0x4f, 0x40, 0x52,
+ 0x49, 0x4a, 0x50, 0x3f, 0x3d, 0x54, 0x4c, 0x53, 0x52, 0x45, 0x41, 0x43,
+ 0x47, 0x2d, 0x40, 0x63, 0x3a, 0x51, 0x43, 0x4e, 0x40, 0x2b, 0x36, 0x5b,
+ 0x4b, 0x12, 0x4d, 0x35, 0x4b, 0x3f, 0x44, 0x4a, 0x46, 0x31, 0x54, 0x48,
+ 0x43, 0x42, 0x3d, 0x51, 0x41, 0x45, 0x49, 0x4b, 0x47, 0x49, 0x3d, 0x3e,
+ 0x46, 0x3d, 0x4d, 0x48, 0x3d, 0x45, 0x48, 0x4b, 0x49, 0x52, 0x44, 0x4c,
+ 0x45, 0x44, 0x45, 0x49, 0x50, 0x48, 0x45, 0x46, 0x45, 0x44, 0x52, 0x55,
+ 0x46, 0x45, 0x4b, 0x3d, 0x42, 0x4a, 0x3e, 0x57, 0x48, 0x4b, 0x3c, 0x42,
+ 0x4a, 0x46, 0x47, 0x6c, 0x54, 0x4b, 0x41, 0x49, 0x49, 0x50, 0x43, 0x56,
+ 0x44, 0x43, 0x4d, 0x3e, 0x44, 0x41, 0x47, 0x40, 0x4a, 0x4b, 0x4d, 0x4d,
+ 0x3e, 0x46, 0x45, 0x47, 0x3e, 0x42, 0x4a, 0x45, 0x49, 0x3d, 0x3f, 0x43,
+ 0x40, 0x44, 0x47, 0x4a, 0x45, 0x4d, 0x4b, 0x4c, 0x43, 0x40, 0x3d, 0x3e,
+ 0x4c, 0x4c, 0x42, 0x4d, 0x48, 0x4d, 0x49, 0x42, 0x51, 0x51, 0x4c, 0x4b,
+ 0x53, 0x4f, 0x48, 0x4d, 0x40, 0x46, 0x45, 0x4b, 0x47, 0x47, 0x4b, 0x46,
+ 0x54, 0x42, 0x42, 0x46, 0x46, 0x4a, 0x4c, 0x55, 0x3f, 0x3c, 0x52, 0x4b,
+ 0x4b, 0x4d, 0x4e, 0x48, 0x53, 0x4c, 0x4b, 0x42, 0x52, 0x54, 0x50, 0x4b,
+ 0x40, 0x5f, 0x58, 0x53, 0x50, 0x42, 0x35, 0x48, 0x39, 0x24, 0x3c, 0x5e,
+ 0x41, 0x50, 0x3c, 0x51, 0x42, 0x26, 0x42, 0x56, 0x41, 0x0c, 0x3e, 0x3d,
+ 0x48, 0x3e, 0x50, 0x4b, 0x3a, 0x2c, 0x43, 0x3d, 0x48, 0x3e, 0x43, 0x48,
+ 0x4c, 0x3f, 0x4a, 0x3e, 0x51, 0x4a, 0x4f, 0x40, 0x47, 0x43, 0x50, 0x4c,
+ 0x43, 0x4d, 0x3f, 0x45, 0x4d, 0x3e, 0x4c, 0x44, 0x51, 0x47, 0x4b, 0x51,
+ 0x45, 0x49, 0x44, 0x3f, 0x46, 0x46, 0x46, 0x57, 0x49, 0x4c, 0x49, 0x4e,
+ 0x47, 0x4c, 0x47, 0x5e, 0x43, 0x46, 0x45, 0x4b, 0x52, 0x49, 0x45, 0x5f,
+ 0x47, 0x41, 0x46, 0x43, 0x4f, 0x3b, 0x43, 0x51, 0x46, 0x53, 0x4a, 0x4e,
+ 0x4b, 0x43, 0x4e, 0x40, 0x48, 0x49, 0x46, 0x3f, 0x48, 0x50, 0x4b, 0x41,
+ 0x4a, 0x47, 0x4b, 0x3d, 0x46, 0x49, 0x4b, 0x43, 0x43, 0x42, 0x3e, 0x47,
+ 0x47, 0x4a, 0x45, 0x46, 0x51, 0x48, 0x51, 0x4e, 0x3f, 0x50, 0x44, 0x4b,
+ 0x4d, 0x4e, 0x44, 0x4d, 0x3d, 0x49, 0x4a, 0x4e, 0x42, 0x51, 0x43, 0x42,
+ 0x46, 0x3e, 0x48, 0x4b, 0x4f, 0x50, 0x3d, 0x48, 0x4c, 0x4f, 0x46, 0x44,
+ 0x44, 0x48, 0x42, 0x4b, 0x48, 0x41, 0x43, 0x46, 0x4d, 0x49, 0x4f, 0x43,
+ 0x41, 0x44, 0x3f, 0x3d, 0x45, 0x4f, 0x45, 0x41, 0x40, 0x58, 0x4f, 0x54,
+ 0x5b, 0x4b, 0x3a, 0x47, 0x3d, 0x28, 0x3d, 0x57, 0x3e, 0x51, 0x3f, 0x47,
+ 0x3f, 0x2e, 0x3e, 0x54, 0x4e, 0x0b, 0x41, 0x3d, 0x3b, 0x3d, 0x43, 0x47,
+ 0x47, 0x28, 0x4d, 0x43, 0x43, 0x3b, 0x4e, 0x4a, 0x4d, 0x42, 0x51, 0x46,
+ 0x4f, 0x3d, 0x4c, 0x3a, 0x49, 0x49, 0x4a, 0x43, 0x42, 0x4b, 0x47, 0x42,
+ 0x42, 0x49, 0x3f, 0x4d, 0x46, 0x4a, 0x49, 0x4e, 0x42, 0x3c, 0x4a, 0x41,
+ 0x4c, 0x40, 0x4d, 0x5a, 0x49, 0x46, 0x51, 0x46, 0x4b, 0x4c, 0x46, 0x62,
+ 0x45, 0x42, 0x51, 0x4e, 0x4d, 0x3e, 0x4d, 0x5b, 0x4d, 0x43, 0x45, 0x50,
+ 0x4b, 0x40, 0x50, 0x53, 0x4f, 0x4f, 0x51, 0x53, 0x46, 0x41, 0x4e, 0x3a,
+ 0x4b, 0x47, 0x3f, 0x3e, 0x4d, 0x48, 0x53, 0x3f, 0x45, 0x42, 0x4c, 0x45,
+ 0x55, 0x4c, 0x4b, 0x39, 0x4a, 0x45, 0x48, 0x4d, 0x47, 0x40, 0x48, 0x4f,
+ 0x4d, 0x49, 0x3e, 0x41, 0x46, 0x4e, 0x40, 0x49, 0x4b, 0x47, 0x4c, 0x45,
+ 0x44, 0x51, 0x4f, 0x4b, 0x48, 0x49, 0x44, 0x41, 0x43, 0x46, 0x51, 0x45,
+ 0x40, 0x48, 0x4b, 0x42, 0x44, 0x4f, 0x53, 0x4d, 0x44, 0x46, 0x4e, 0x4c,
+ 0x48, 0x50, 0x41, 0x45, 0x42, 0x48, 0x4d, 0x4d, 0x47, 0x45, 0x41, 0x45,
+ 0x48, 0x58, 0x4e, 0x46, 0x43, 0x53, 0x57, 0x52, 0x5e, 0x42, 0x45, 0x4e,
+ 0x39, 0x24, 0x32, 0x56, 0x47, 0x56, 0x49, 0x52, 0x46, 0x26, 0x3a, 0x51,
+ 0x4b, 0x05, 0x3e, 0x43, 0x3f, 0x38, 0x4d, 0x4b, 0x4f, 0x27, 0x51, 0x46,
+ 0x47, 0x41, 0x4a, 0x47, 0x4a, 0x3e, 0x44, 0x51, 0x3f, 0x3a, 0x43, 0x46,
+ 0x4d, 0x49, 0x46, 0x52, 0x43, 0x48, 0x49, 0x3e, 0x47, 0x46, 0x4a, 0x4d,
+ 0x47, 0x46, 0x52, 0x50, 0x44, 0x48, 0x4c, 0x47, 0x45, 0x41, 0x49, 0x5b,
+ 0x4d, 0x4b, 0x47, 0x4c, 0x4a, 0x47, 0x45, 0x5b, 0x49, 0x46, 0x52, 0x47,
+ 0x47, 0x3d, 0x55, 0x59, 0x40, 0x4b, 0x3e, 0x50, 0x42, 0x43, 0x40, 0x4f,
+ 0x48, 0x3f, 0x47, 0x53, 0x4d, 0x44, 0x4e, 0x37, 0x4c, 0x43, 0x51, 0x4d,
+ 0x46, 0x4e, 0x40, 0x41, 0x52, 0x44, 0x43, 0x4a, 0x50, 0x48, 0x47, 0x42,
+ 0x48, 0x45, 0x50, 0x4d, 0x42, 0x52, 0x44, 0x43, 0x45, 0x43, 0x4c, 0x4d,
+ 0x44, 0x51, 0x47, 0x48, 0x51, 0x4f, 0x48, 0x45, 0x49, 0x4a, 0x3e, 0x43,
+ 0x4d, 0x4e, 0x4e, 0x46, 0x54, 0x4d, 0x49, 0x4d, 0x47, 0x46, 0x4b, 0x41,
+ 0x4a, 0x49, 0x44, 0x45, 0x4d, 0x3e, 0x53, 0x50, 0x47, 0x4d, 0x4e, 0x43,
+ 0x4f, 0x45, 0x4e, 0x4a, 0x47, 0x49, 0x4c, 0x4c, 0x4d, 0x54, 0x42, 0x4c,
+ 0x43, 0x5d, 0x59, 0x50, 0x5e, 0x4b, 0x44, 0x43, 0x3c, 0x25, 0x31, 0x5b,
+ 0x46, 0x5a, 0x50, 0x4d, 0x41, 0x2a, 0x41, 0x4f, 0x44, 0x00, 0x41, 0x3d,
+ 0x43, 0x4b, 0x47, 0x45, 0x4e, 0x2e, 0x44, 0x46, 0x53, 0x3d, 0x43, 0x41,
+ 0x44, 0x46, 0x49, 0x42, 0x45, 0x4f, 0x4d, 0x3a, 0x43, 0x3c, 0x47, 0x53,
+ 0x43, 0x4e, 0x3f, 0x41, 0x4d, 0x50, 0x4b, 0x4c, 0x51, 0x47, 0x53, 0x4f,
+ 0x45, 0x4a, 0x44, 0x45, 0x41, 0x46, 0x47, 0x50, 0x51, 0x3f, 0x3e, 0x41,
+ 0x48, 0x45, 0x46, 0x5d, 0x45, 0x4a, 0x4c, 0x46, 0x4a, 0x49, 0x50, 0x51,
+ 0x51, 0x4c, 0x4f, 0x47, 0x47, 0x42, 0x45, 0x47, 0x4e, 0x48, 0x46, 0x40,
+ 0x45, 0x46, 0x4d, 0x3b, 0x4d, 0x52, 0x4c, 0x51, 0x49, 0x51, 0x47, 0x3d,
+ 0x4d, 0x42, 0x4f, 0x4e, 0x43, 0x43, 0x45, 0x3a, 0x42, 0x50, 0x4c, 0x4a,
+ 0x41, 0x53, 0x4c, 0x45, 0x51, 0x3f, 0x54, 0x43, 0x4b, 0x54, 0x56, 0x4d,
+ 0x4f, 0x4a, 0x50, 0x4b, 0x44, 0x45, 0x4f, 0x4f, 0x47, 0x3e, 0x50, 0x4f,
+ 0x4b, 0x48, 0x4d, 0x49, 0x55, 0x4d, 0x45, 0x4d, 0x4a, 0x53, 0x43, 0x46,
+ 0x4c, 0x45, 0x41, 0x46, 0x49, 0x49, 0x4f, 0x4b, 0x49, 0x50, 0x52, 0x49,
+ 0x41, 0x54, 0x44, 0x4c, 0x44, 0x63, 0x4a, 0x49, 0x40, 0x59, 0x52, 0x52,
+ 0x59, 0x3f, 0x3e, 0x3e, 0x40, 0x25, 0x3c, 0x5c, 0x4f, 0x57, 0x44, 0x50,
+ 0x41, 0x2a, 0x48, 0x4f, 0x43, 0x08, 0x47, 0x43, 0x49, 0x48, 0x4d, 0x49,
+ 0x46, 0x2b, 0x48, 0x44, 0x4e, 0x47, 0x47, 0x43, 0x44, 0x3e, 0x4a, 0x52,
+ 0x3f, 0x4a, 0x53, 0x42, 0x49, 0x47, 0x4c, 0x50, 0x43, 0x46, 0x46, 0x3c,
+ 0x4c, 0x47, 0x4e, 0x4d, 0x42, 0x41, 0x53, 0x52, 0x4f, 0x40, 0x54, 0x50,
+ 0x46, 0x43, 0x50, 0x56, 0x51, 0x48, 0x48, 0x48, 0x49, 0x39, 0x47, 0x5e,
+ 0x4e, 0x4b, 0x4f, 0x4e, 0x43, 0x45, 0x42, 0x58, 0x4a, 0x3b, 0x48, 0x4d,
+ 0x43, 0x3e, 0x4b, 0x43, 0x3c, 0x45, 0x46, 0x4b, 0x42, 0x42, 0x4e, 0x3d,
+ 0x4b, 0x4e, 0x51, 0x52, 0x48, 0x3e, 0x4b, 0x3f, 0x4c, 0x4a, 0x4b, 0x4c,
+ 0x46, 0x48, 0x3e, 0x48, 0x47, 0x4d, 0x4a, 0x46, 0x49, 0x4d, 0x4a, 0x48,
+ 0x50, 0x4b, 0x40, 0x48, 0x4b, 0x52, 0x46, 0x50, 0x4f, 0x3e, 0x42, 0x44,
+ 0x44, 0x42, 0x43, 0x49, 0x4f, 0x4f, 0x46, 0x42, 0x4a, 0x54, 0x42, 0x48,
+ 0x50, 0x4f, 0x4f, 0x4c, 0x4c, 0x47, 0x52, 0x49, 0x4c, 0x45, 0x4a, 0x4d,
+ 0x4a, 0x41, 0x47, 0x4a, 0x4d, 0x4a, 0x4c, 0x46, 0x51, 0x44, 0x4b, 0x49,
+ 0x53, 0x5e, 0x45, 0x4a, 0x3b, 0x57, 0x5a, 0x4c, 0x59, 0x43, 0x3e, 0x4a,
+ 0x3e, 0x20, 0x36, 0x5d, 0x47, 0x5b, 0x3f, 0x55, 0x3e, 0x24, 0x41, 0x52,
+ 0x3f, 0x01, 0x49, 0x41, 0x40, 0x45, 0x42, 0x46, 0x49, 0x2a, 0x47, 0x40,
+ 0x44, 0x3f, 0x42, 0x47, 0x4e, 0x42, 0x4b, 0x3d, 0x45, 0x4c, 0x47, 0x3d,
+ 0x4c, 0x44, 0x48, 0x43, 0x43, 0x41, 0x4a, 0x3d, 0x48, 0x4b, 0x46, 0x4e,
+ 0x4c, 0x45, 0x48, 0x4d, 0x54, 0x4d, 0x3e, 0x46, 0x3e, 0x47, 0x44, 0x4e,
+ 0x48, 0x49, 0x53, 0x4b, 0x41, 0x45, 0x4c, 0x57, 0x52, 0x4e, 0x40, 0x48,
+ 0x4d, 0x43, 0x44, 0x5a, 0x4a, 0x4c, 0x48, 0x4d, 0x3f, 0x52, 0x41, 0x50,
+ 0x4a, 0x47, 0x3e, 0x43, 0x4c, 0x42, 0x48, 0x3e, 0x4f, 0x4b, 0x41, 0x43,
+ 0x49, 0x40, 0x43, 0x36, 0x3f, 0x4b, 0x49, 0x49, 0x51, 0x43, 0x48, 0x40,
+ 0x4c, 0x51, 0x4d, 0x4a, 0x49, 0x3f, 0x4b, 0x3d, 0x4f, 0x4b, 0x43, 0x4d,
+ 0x46, 0x40, 0x46, 0x4d, 0x49, 0x48, 0x4d, 0x4c, 0x52, 0x4c, 0x49, 0x4f,
+ 0x53, 0x40, 0x49, 0x53, 0x47, 0x43, 0x4c, 0x45, 0x42, 0x48, 0x42, 0x4e,
+ 0x49, 0x43, 0x42, 0x40, 0x4f, 0x46, 0x50, 0x47, 0x51, 0x4a, 0x52, 0x45,
+ 0x4c, 0x51, 0x48, 0x47, 0x40, 0x41, 0x52, 0x4f, 0x41, 0x5a, 0x53, 0x47,
+ 0x42, 0x5f, 0x55, 0x4f, 0x53, 0x3e, 0x41, 0x49, 0x3d, 0x20, 0x3f, 0x54,
+ 0x42, 0x5b, 0x49, 0x4d, 0x3d, 0x22, 0x3e, 0x48, 0x41, 0x01, 0x4c, 0x3d,
+ 0x43, 0x4a, 0x46, 0x43, 0x4f, 0x2b, 0x49, 0x46, 0x47, 0x4a, 0x51, 0x3d,
+ 0x4b, 0x44, 0x49, 0x41, 0x47, 0x47, 0x45, 0x3a, 0x44, 0x42, 0x40, 0x52,
+ 0x46, 0x51, 0x4a, 0x41, 0x4a, 0x52, 0x44, 0x52, 0x4a, 0x40, 0x46, 0x45,
+ 0x52, 0x4c, 0x4e, 0x42, 0x42, 0x48, 0x40, 0x4f, 0x4b, 0x4f, 0x51, 0x4c,
+ 0x4e, 0x48, 0x4a, 0x5a, 0x46, 0x3d, 0x41, 0x50, 0x52, 0x4c, 0x44, 0x53,
+ 0x4b, 0x4d, 0x4f, 0x49, 0x47, 0x4c, 0x48, 0x45, 0x48, 0x4a, 0x44, 0x4e,
+ 0x4c, 0x40, 0x4d, 0x35, 0x40, 0x49, 0x4a, 0x51, 0x49, 0x4a, 0x46, 0x36,
+ 0x46, 0x47, 0x4a, 0x4c, 0x40, 0x4e, 0x42, 0x38, 0x48, 0x45, 0x42, 0x49,
+ 0x54, 0x4c, 0x3f, 0x49, 0x4c, 0x39, 0x47, 0x45, 0x4e, 0x4a, 0x42, 0x44,
+ 0x4b, 0x53, 0x43, 0x40, 0x46, 0x51, 0x3d, 0x50, 0x4b, 0x43, 0x4a, 0x4c,
+ 0x55, 0x54, 0x4a, 0x43, 0x48, 0x40, 0x44, 0x3f, 0x47, 0x45, 0x3e, 0x41,
+ 0x49, 0x44, 0x4d, 0x49, 0x44, 0x41, 0x4a, 0x50, 0x44, 0x49, 0x4d, 0x47,
+ 0x4a, 0x49, 0x46, 0x49, 0x40, 0x5b, 0x4d, 0x51, 0x47, 0x57, 0x49, 0x4f,
+ 0x56, 0x46, 0x3a, 0x4a, 0x3e, 0x22, 0x36, 0x5c, 0x44, 0x56, 0x46, 0x48,
+ 0x3a, 0x2d, 0x4a, 0x48, 0x44, 0x17, 0x41, 0x42, 0x40, 0x3d, 0x4e, 0x45,
+ 0x40, 0x26, 0x43, 0x52, 0x41, 0x40, 0x44, 0x4a, 0x48, 0x42, 0x4f, 0x47,
+ 0x46, 0x4c, 0x4a, 0x3b, 0x42, 0x3e, 0x3e, 0x49, 0x4e, 0x44, 0x4e, 0x49,
+ 0x47, 0x41, 0x47, 0x44, 0x4c, 0x45, 0x4d, 0x49, 0x49, 0x48, 0x55, 0x3d,
+ 0x4a, 0x45, 0x50, 0x4f, 0x46, 0x4c, 0x46, 0x45, 0x3c, 0x51, 0x4b, 0x5a,
+ 0x46, 0x47, 0x54, 0x41, 0x44, 0x40, 0x4f, 0x53, 0x49, 0x46, 0x46, 0x48,
+ 0x44, 0x40, 0x50, 0x49, 0x49, 0x43, 0x50, 0x41, 0x52, 0x4b, 0x46, 0x3e,
+ 0x44, 0x44, 0x46, 0x4e, 0x47, 0x48, 0x3e, 0x38, 0x4c, 0x4c, 0x48, 0x43,
+ 0x48, 0x3e, 0x50, 0x42, 0x51, 0x50, 0x4a, 0x48, 0x4a, 0x42, 0x44, 0x3d,
+ 0x4a, 0x46, 0x46, 0x3d, 0x4e, 0x47, 0x3d, 0x48, 0x4c, 0x46, 0x50, 0x4d,
+ 0x49, 0x45, 0x4a, 0x4c, 0x4c, 0x47, 0x4a, 0x42, 0x4a, 0x45, 0x50, 0x52,
+ 0x4b, 0x4d, 0x4c, 0x43, 0x42, 0x53, 0x41, 0x45, 0x49, 0x41, 0x4b, 0x4c,
+ 0x52, 0x54, 0x4b, 0x41, 0x48, 0x4c, 0x47, 0x4c, 0x41, 0x49, 0x4a, 0x47,
+ 0x50, 0x59, 0x4e, 0x45, 0x3c, 0x5d, 0x53, 0x4c, 0x5a, 0x3e, 0x3a, 0x51,
+ 0x3a, 0x22, 0x35, 0x59, 0x40, 0x5a, 0x43, 0x46, 0x41, 0x32, 0x44, 0x4b,
+ 0x47, 0x04, 0x4c, 0x3a, 0x4a, 0x49, 0x48, 0x3d, 0x45, 0x2b, 0x50, 0x41,
+ 0x3e, 0x44, 0x4f, 0x43, 0x4a, 0x3f, 0x48, 0x4b, 0x53, 0x49, 0x4b, 0x38,
+ 0x44, 0x40, 0x48, 0x4c, 0x41, 0x3f, 0x47, 0x3e, 0x47, 0x49, 0x45, 0x42,
+ 0x43, 0x3e, 0x46, 0x44, 0x53, 0x4d, 0x48, 0x44, 0x45, 0x42, 0x43, 0x53,
+ 0x55, 0x49, 0x4d, 0x4b, 0x45, 0x44, 0x47, 0x5f, 0x48, 0x44, 0x4a, 0x48,
+ 0x45, 0x4d, 0x4f, 0x5e, 0x4e, 0x46, 0x49, 0x49, 0x4d, 0x49, 0x44, 0x48,
+ 0x4d, 0x41, 0x50, 0x48, 0x3d, 0x3f, 0x4d, 0x38, 0x46, 0x4a, 0x50, 0x4a,
+ 0x45, 0x3e, 0x43, 0x36, 0x42, 0x48, 0x53, 0x54, 0x49, 0x43, 0x4b, 0x3a,
+ 0x45, 0x48, 0x50, 0x45, 0x4a, 0x4c, 0x4a, 0x4d, 0x43, 0x4c, 0x55, 0x4e,
+ 0x4c, 0x42, 0x45, 0x52, 0x52, 0x45, 0x46, 0x40, 0x54, 0x4c, 0x3d, 0x4e,
+ 0x49, 0x4e, 0x44, 0x47, 0x45, 0x48, 0x4b, 0x50, 0x49, 0x4b, 0x44, 0x4b,
+ 0x4f, 0x49, 0x47, 0x47, 0x53, 0x3f, 0x4b, 0x42, 0x45, 0x3e, 0x4d, 0x4d,
+ 0x48, 0x51, 0x45, 0x40, 0x43, 0x43, 0x4e, 0x44, 0x51, 0x55, 0x4a, 0x3e,
+ 0x45, 0x55, 0x58, 0x50, 0x50, 0x38, 0x44, 0x4f, 0x3b, 0x23, 0x3c, 0x55,
+ 0x3c, 0x54, 0x49, 0x42, 0x44, 0x2f, 0x3e, 0x47, 0x42, 0x01, 0x42, 0x37,
+ 0x3f, 0x42, 0x45, 0x45, 0x47, 0x2a, 0x52, 0x4b, 0x45, 0x3c, 0x47, 0x44,
+ 0x44, 0x40, 0x50, 0x53, 0x48, 0x42, 0x4d, 0x36, 0x50, 0x3d, 0x49, 0x44,
+ 0x4f, 0x4c, 0x4a, 0x42, 0x4d, 0x3e, 0x3d, 0x3f, 0x4e, 0x44, 0x4d, 0x4e,
+ 0x54, 0x3d, 0x42, 0x46, 0x49, 0x47, 0x4b, 0x53, 0x45, 0x46, 0x47, 0x4a,
+ 0x45, 0x3d, 0x4a, 0x5f, 0x51, 0x3e, 0x45, 0x45, 0x44, 0x3a, 0x4d, 0x57,
+ 0x45, 0x47, 0x4d, 0x45, 0x4e, 0x4b, 0x51, 0x48, 0x4b, 0x4a, 0x3c, 0x4e,
+ 0x51, 0x41, 0x4d, 0x36, 0x47, 0x4a, 0x46, 0x51, 0x4e, 0x4c, 0x52, 0x41,
+ 0x55, 0x47, 0x41, 0x47, 0x4d, 0x47, 0x4b, 0x3d, 0x4a, 0x4a, 0x46, 0x49,
+ 0x4d, 0x48, 0x46, 0x46, 0x4d, 0x52, 0x52, 0x48, 0x49, 0x3f, 0x4b, 0x4e,
+ 0x4c, 0x49, 0x45, 0x47, 0x41, 0x4b, 0x44, 0x48, 0x52, 0x4b, 0x53, 0x44,
+ 0x46, 0x4e, 0x44, 0x49, 0x52, 0x50, 0x46, 0x4b, 0x44, 0x43, 0x50, 0x49,
+ 0x4a, 0x53, 0x45, 0x49, 0x52, 0x3f, 0x4a, 0x4e, 0x49, 0x4c, 0x4d, 0x4d,
+ 0x40, 0x40, 0x3f, 0x4a, 0x47, 0x56, 0x51, 0x43, 0x40, 0x5a, 0x58, 0x52,
+ 0x4f, 0x3d, 0x3d, 0x45, 0x38, 0x29, 0x33, 0x59, 0x45, 0x54, 0x3c, 0x42,
+ 0x3f, 0x27, 0x3e, 0x49, 0x48, 0x06, 0x4a, 0x3f, 0x41, 0x49, 0x4c, 0x48,
+ 0x46, 0x2b, 0x4a, 0x4f, 0x44, 0x46, 0x4c, 0x46, 0x4a, 0x3b, 0x4d, 0x4a,
+ 0x40, 0x41, 0x45, 0x38, 0x51, 0x39, 0x46, 0x46, 0x41, 0x51, 0x4e, 0x41,
+ 0x49, 0x44, 0x48, 0x4a, 0x4b, 0x46, 0x47, 0x46, 0x4a, 0x4c, 0x47, 0x48,
+ 0x3d, 0x42, 0x50, 0x4f, 0x50, 0x4a, 0x4a, 0x48, 0x4a, 0x45, 0x45, 0x61,
+ 0x4a, 0x4c, 0x49, 0x3d, 0x4b, 0x4a, 0x4a, 0x5a, 0x48, 0x49, 0x50, 0x4f,
+ 0x42, 0x48, 0x3e, 0x44, 0x43, 0x3b, 0x4f, 0x54, 0x4b, 0x4a, 0x47, 0x31,
+ 0x4a, 0x49, 0x47, 0x4e, 0x48, 0x48, 0x46, 0x42, 0x4a, 0x45, 0x4c, 0x49,
+ 0x4b, 0x4e, 0x53, 0x43, 0x4c, 0x49, 0x4f, 0x4b, 0x46, 0x4c, 0x4b, 0x4e,
+ 0x51, 0x4b, 0x49, 0x52, 0x44, 0x55, 0x45, 0x49, 0x4b, 0x4a, 0x50, 0x4c,
+ 0x4d, 0x4a, 0x4b, 0x48, 0x41, 0x46, 0x47, 0x43, 0x4b, 0x3f, 0x54, 0x4a,
+ 0x46, 0x49, 0x51, 0x48, 0x4e, 0x4a, 0x41, 0x52, 0x52, 0x4e, 0x53, 0x47,
+ 0x42, 0x48, 0x43, 0x44, 0x54, 0x51, 0x40, 0x49, 0x4c, 0x48, 0x49, 0x44,
+ 0x4c, 0x56, 0x52, 0x49, 0x3d, 0x59, 0x4f, 0x56, 0x56, 0x42, 0x46, 0x45,
+ 0x3e, 0x28, 0x3f, 0x5b, 0x3f, 0x5a, 0x4c, 0x42, 0x44, 0x22, 0x3f, 0x46,
+ 0x47, 0x0d, 0x3e, 0x41, 0x45, 0x49, 0x4a, 0x3b, 0x45, 0x2d, 0x4d, 0x4a,
+ 0x44, 0x43, 0x49, 0x46, 0x4b, 0x47, 0x49, 0x45, 0x4e, 0x40, 0x4c, 0x3c,
+ 0x42, 0x3e, 0x4b, 0x50, 0x48, 0x49, 0x4c, 0x42, 0x3c, 0x43, 0x50, 0x43,
+ 0x49, 0x4e, 0x4e, 0x43, 0x46, 0x4c, 0x48, 0x4a, 0x43, 0x4c, 0x49, 0x4e,
+ 0x47, 0x44, 0x50, 0x4c, 0x4a, 0x48, 0x47, 0x5f, 0x3f, 0x3e, 0x48, 0x4f,
+ 0x4f, 0x49, 0x4a, 0x5f, 0x4e, 0x40, 0x4e, 0x48, 0x47, 0x44, 0x40, 0x4d,
+ 0x3f, 0x4a, 0x53, 0x45, 0x3e, 0x50, 0x3f, 0x39, 0x50, 0x45, 0x45, 0x4b,
+ 0x43, 0x41, 0x46, 0x41, 0x49, 0x47, 0x4b, 0x41, 0x3c, 0x4b, 0x46, 0x3f,
+ 0x41, 0x4a, 0x4e, 0x4c, 0x49, 0x4c, 0x3f, 0x44, 0x53, 0x4c, 0x45, 0x49,
+ 0x48, 0x4d, 0x48, 0x4a, 0x48, 0x4f, 0x45, 0x4d, 0x48, 0x4c, 0x41, 0x49,
+ 0x42, 0x48, 0x53, 0x46, 0x4a, 0x46, 0x4b, 0x4f, 0x4c, 0x52, 0x4c, 0x51,
+ 0x41, 0x4d, 0x49, 0x41, 0x49, 0x4f, 0x49, 0x42, 0x4a, 0x48, 0x51, 0x4a,
+ 0x44, 0x4d, 0x55, 0x48, 0x47, 0x4d, 0x4d, 0x45, 0x42, 0x60, 0x4a, 0x51,
+ 0x42, 0x54, 0x56, 0x56, 0x50, 0x4a, 0x3f, 0x4a, 0x40, 0x25, 0x3a, 0x59,
+ 0x46, 0x58, 0x52, 0x46, 0x41, 0x28, 0x3d, 0x3e, 0x45, 0x13, 0x47, 0x41,
+ 0x3d, 0x44, 0x48, 0x45, 0x49, 0x26, 0x46, 0x4c, 0x3b, 0x4a, 0x42, 0x47,
+ 0x46, 0x41, 0x44, 0x52, 0x50, 0x4a, 0x4f, 0x40, 0x4b, 0x39, 0x42, 0x45,
+ 0x4a, 0x4d, 0x4f, 0x3f, 0x42, 0x4f, 0x49, 0x45, 0x42, 0x4a, 0x46, 0x47,
+ 0x48, 0x40, 0x4a, 0x46, 0x41, 0x3b, 0x48, 0x55, 0x4b, 0x4e, 0x4e, 0x48,
+ 0x4b, 0x44, 0x46, 0x53, 0x48, 0x45, 0x4b, 0x53, 0x49, 0x43, 0x4a, 0x5c,
+ 0x46, 0x45, 0x45, 0x49, 0x49, 0x49, 0x4c, 0x43, 0x4e, 0x4a, 0x41, 0x4a,
+ 0x42, 0x43, 0x4a, 0x38, 0x44, 0x4a, 0x4b, 0x3f, 0x45, 0x49, 0x45, 0x38,
+ 0x43, 0x40, 0x45, 0x4c, 0x47, 0x42, 0x3f, 0x42, 0x3e, 0x4a, 0x43, 0x50,
+ 0x4a, 0x4e, 0x4f, 0x47, 0x4d, 0x49, 0x49, 0x47, 0x4a, 0x4d, 0x46, 0x4c,
+ 0x4f, 0x3d, 0x52, 0x4a, 0x41, 0x44, 0x4b, 0x50, 0x4c, 0x52, 0x49, 0x50,
+ 0x4b, 0x45, 0x49, 0x4d, 0x48, 0x55, 0x50, 0x47, 0x4e, 0x50, 0x4f, 0x48,
+ 0x46, 0x4d, 0x4d, 0x41, 0x48, 0x51, 0x4b, 0x4c, 0x47, 0x51, 0x42, 0x42,
+ 0x4d, 0x47, 0x43, 0x4c, 0x4c, 0x5a, 0x4e, 0x47, 0x3b, 0x59, 0x51, 0x57,
+ 0x4c, 0x40, 0x46, 0x4c, 0x37, 0x2a, 0x35, 0x58, 0x44, 0x5b, 0x4c, 0x44,
+ 0x3e, 0x2e, 0x3f, 0x43, 0x46, 0x23, 0x49, 0x3e, 0x41, 0x3f, 0x4b, 0x3e,
+ 0x4e, 0x2f, 0x4d, 0x4a, 0x4e, 0x40, 0x4e, 0x41, 0x40, 0x3f, 0x4a, 0x42,
+ 0x4d, 0x4c, 0x44, 0x47, 0x4e, 0x44, 0x40, 0x43, 0x4d, 0x49, 0x4f, 0x3d,
+ 0x49, 0x3f, 0x51, 0x48, 0x42, 0x4a, 0x49, 0x47, 0x49, 0x46, 0x4a, 0x45,
+ 0x45, 0x49, 0x53, 0x4d, 0x4c, 0x4e, 0x44, 0x50, 0x4b, 0x43, 0x4e, 0x5f,
+ 0x3c, 0x40, 0x44, 0x46, 0x48, 0x4b, 0x42, 0x62, 0x4e, 0x50, 0x4c, 0x49,
+ 0x4a, 0x4f, 0x44, 0x53, 0x42, 0x43, 0x49, 0x48, 0x4b, 0x3c, 0x4a, 0x37,
+ 0x4c, 0x41, 0x49, 0x46, 0x46, 0x47, 0x43, 0x40, 0x4d, 0x4d, 0x4a, 0x48,
+ 0x50, 0x4b, 0x50, 0x41, 0x44, 0x3e, 0x51, 0x47, 0x44, 0x4a, 0x44, 0x45,
+ 0x48, 0x4d, 0x52, 0x4e, 0x44, 0x48, 0x4d, 0x43, 0x42, 0x45, 0x48, 0x52,
+ 0x44, 0x42, 0x50, 0x42, 0x4d, 0x45, 0x48, 0x4d, 0x4f, 0x4e, 0x45, 0x49,
+ 0x51, 0x48, 0x4f, 0x53, 0x4d, 0x4c, 0x48, 0x50, 0x4e, 0x4d, 0x50, 0x48,
+ 0x49, 0x42, 0x4c, 0x42, 0x4b, 0x4b, 0x49, 0x48, 0x48, 0x49, 0x4a, 0x54,
+ 0x44, 0x57, 0x4d, 0x4b, 0x3f, 0x56, 0x53, 0x5c, 0x50, 0x4e, 0x46, 0x49,
+ 0x40, 0x24, 0x44, 0x58, 0x49, 0x54, 0x48, 0x49, 0x41, 0x22, 0x44, 0x3f,
+ 0x48, 0x1c, 0x4d, 0x39, 0x3e, 0x4c, 0x3d, 0x4a, 0x48, 0x2d, 0x48, 0x3e,
+ 0x3f, 0x3a, 0x46, 0x4e, 0x44, 0x43, 0x49, 0x51, 0x4d, 0x3c, 0x44, 0x41,
+ 0x4e, 0x44, 0x42, 0x4c, 0x45, 0x48, 0x45, 0x46, 0x42, 0x46, 0x47, 0x42,
+ 0x4f, 0x45, 0x47, 0x44, 0x48, 0x47, 0x4a, 0x42, 0x4d, 0x48, 0x3e, 0x53,
+ 0x47, 0x4b, 0x44, 0x4b, 0x45, 0x4a, 0x50, 0x55, 0x4c, 0x45, 0x48, 0x43,
+ 0x53, 0x3d, 0x4e, 0x5f, 0x42, 0x44, 0x4a, 0x4f, 0x3f, 0x48, 0x4e, 0x4b,
+ 0x43, 0x48, 0x43, 0x41, 0x4a, 0x4b, 0x51, 0x39, 0x52, 0x46, 0x44, 0x49,
+ 0x48, 0x45, 0x4c, 0x40, 0x45, 0x49, 0x51, 0x48, 0x45, 0x42, 0x45, 0x48,
+ 0x40, 0x43, 0x3d, 0x47, 0x53, 0x54, 0x4d, 0x4a, 0x4a, 0x47, 0x48, 0x43,
+ 0x4c, 0x46, 0x43, 0x4f, 0x49, 0x4c, 0x3f, 0x3d, 0x4b, 0x41, 0x40, 0x48,
+ 0x4e, 0x4c, 0x4b, 0x40, 0x4c, 0x43, 0x49, 0x4d, 0x47, 0x4f, 0x47, 0x42,
+ 0x47, 0x4a, 0x4d, 0x4f, 0x46, 0x4d, 0x51, 0x49, 0x48, 0x4d, 0x4e, 0x46,
+ 0x47, 0x41, 0x44, 0x4d, 0x4b, 0x55, 0x4b, 0x4c, 0x41, 0x5e, 0x50, 0x45,
+ 0x40, 0x55, 0x4b, 0x60, 0x55, 0x47, 0x3d, 0x4a, 0x42, 0x22, 0x46, 0x5a,
+ 0x47, 0x53, 0x49, 0x44, 0x44, 0x27, 0x41, 0x4f, 0x3e, 0x22, 0x4a, 0x44,
+ 0x49, 0x3e, 0x4e, 0x4d, 0x3f, 0x3a, 0x4c, 0x44, 0x4a, 0x44, 0x46, 0x51,
+ 0x4f, 0x42, 0x4c, 0x4e, 0x39, 0x4b, 0x42, 0x39, 0x4b, 0x3e, 0x4f, 0x47,
+ 0x4a, 0x4f, 0x3f, 0x4d, 0x43, 0x4c, 0x4a, 0x4b, 0x4b, 0x3d, 0x51, 0x46,
+ 0x49, 0x4c, 0x47, 0x44, 0x43, 0x3d, 0x3c, 0x54, 0x4a, 0x47, 0x4d, 0x50,
+ 0x4a, 0x46, 0x51, 0x62, 0x46, 0x4d, 0x4b, 0x46, 0x49, 0x3c, 0x50, 0x57,
+ 0x47, 0x40, 0x3e, 0x4c, 0x4b, 0x3f, 0x55, 0x46, 0x3d, 0x45, 0x42, 0x4e,
+ 0x50, 0x49, 0x46, 0x3a, 0x4c, 0x47, 0x4a, 0x49, 0x42, 0x42, 0x4a, 0x44,
+ 0x42, 0x40, 0x49, 0x54, 0x46, 0x4b, 0x47, 0x45, 0x51, 0x47, 0x41, 0x42,
+ 0x49, 0x50, 0x4e, 0x48, 0x4b, 0x4b, 0x47, 0x4a, 0x47, 0x49, 0x4b, 0x45,
+ 0x4b, 0x54, 0x48, 0x54, 0x4b, 0x49, 0x51, 0x4a, 0x4a, 0x40, 0x46, 0x42,
+ 0x44, 0x44, 0x4d, 0x4b, 0x47, 0x43, 0x45, 0x41, 0x3e, 0x49, 0x43, 0x51,
+ 0x3e, 0x4b, 0x52, 0x46, 0x48, 0x3f, 0x4e, 0x51, 0x51, 0x49, 0x3f, 0x48,
+ 0x4c, 0x4c, 0x52, 0x47, 0x43, 0x57, 0x44, 0x42, 0x40, 0x52, 0x50, 0x5d,
+ 0x4f, 0x40, 0x42, 0x45, 0x46, 0x26, 0x3c, 0x51, 0x4b, 0x4e, 0x4b, 0x49,
+ 0x46, 0x35, 0x49, 0x53, 0x49, 0x2b, 0x4d, 0x3e, 0x50, 0x44, 0x4f, 0x54,
+ 0x46, 0x34, 0x49, 0x4d, 0x42, 0x45, 0x44, 0x4b, 0x52, 0x44, 0x52, 0x41,
+ 0x4d, 0x4c, 0x52, 0x41, 0x49, 0x3a, 0x4e, 0x49, 0x40, 0x4b, 0x45, 0x4d,
+ 0x4b, 0x4a, 0x47, 0x49, 0x45, 0x49, 0x4d, 0x50, 0x3e, 0x47, 0x44, 0x51,
+ 0x4c, 0x41, 0x45, 0x50, 0x47, 0x41, 0x4a, 0x52, 0x4b, 0x3d, 0x4b, 0x5b,
+ 0x4c, 0x4c, 0x4d, 0x3f, 0x47, 0x44, 0x49, 0x5d, 0x4a, 0x53, 0x44, 0x45,
+ 0x45, 0x46, 0x3d, 0x4f, 0x50, 0x3b, 0x44, 0x4e, 0x40, 0x41, 0x4c, 0x3a,
+ 0x4a, 0x45, 0x49, 0x48, 0x45, 0x4a, 0x45, 0x36, 0x45, 0x4d, 0x4c, 0x49,
+ 0x3f, 0x47, 0x4d, 0x40, 0x53, 0x48, 0x49, 0x4c, 0x47, 0x4f, 0x42, 0x44,
+ 0x45, 0x40, 0x4a, 0x4c, 0x49, 0x4f, 0x4b, 0x4d, 0x42, 0x45, 0x3e, 0x4a,
+ 0x48, 0x4a, 0x49, 0x50, 0x4c, 0x53, 0x50, 0x45, 0x4b, 0x4c, 0x46, 0x4f,
+ 0x44, 0x43, 0x54, 0x50, 0x3f, 0x48, 0x42, 0x4b, 0x43, 0x3f, 0x4d, 0x4c,
+ 0x43, 0x49, 0x4a, 0x47, 0x54, 0x4b, 0x4f, 0x4d, 0x44, 0x47, 0x49, 0x4e,
+ 0x4e, 0x55, 0x40, 0x46, 0x44, 0x56, 0x4e, 0x65, 0x4f, 0x3f, 0x43, 0x48,
+ 0x39, 0x27, 0x43, 0x55, 0x4b, 0x4c, 0x44, 0x46, 0x42, 0x34, 0x44, 0x52,
+ 0x43, 0x22, 0x4e, 0x41, 0x49, 0x48, 0x49, 0x51, 0x3b, 0x37, 0x4b, 0x40,
+ 0x4f, 0x45, 0x53, 0x4c, 0x47, 0x46, 0x47, 0x4c, 0x3e, 0x44, 0x45, 0x49,
+ 0x48, 0x50, 0x45, 0x40, 0x46, 0x4c, 0x47, 0x4d, 0x44, 0x48, 0x49, 0x50,
+ 0x4f, 0x4a, 0x46, 0x55, 0x4e, 0x42, 0x4c, 0x4c, 0x50, 0x48, 0x3d, 0x55,
+ 0x46, 0x3e, 0x4a, 0x4b, 0x4f, 0x46, 0x46, 0x60, 0x50, 0x3f, 0x55, 0x40,
+ 0x42, 0x44, 0x48, 0x63, 0x50, 0x3d, 0x45, 0x4f, 0x4e, 0x41, 0x47, 0x48,
+ 0x4a, 0x3c, 0x3d, 0x46, 0x3f, 0x42, 0x43, 0x37, 0x4f, 0x4f, 0x50, 0x47,
+ 0x47, 0x4b, 0x52, 0x40, 0x3f, 0x44, 0x4a, 0x40, 0x4d, 0x44, 0x4e, 0x37,
+ 0x43, 0x48, 0x47, 0x3f, 0x51, 0x4d, 0x45, 0x42, 0x41, 0x46, 0x3d, 0x53,
+ 0x4f, 0x4b, 0x54, 0x45, 0x51, 0x40, 0x4a, 0x4a, 0x48, 0x4f, 0x43, 0x4a,
+ 0x4f, 0x4c, 0x4c, 0x4f, 0x48, 0x4c, 0x44, 0x4e, 0x43, 0x46, 0x4f, 0x4a,
+ 0x43, 0x41, 0x49, 0x49, 0x47, 0x53, 0x45, 0x49, 0x4e, 0x46, 0x4c, 0x4e,
+ 0x3c, 0x49, 0x44, 0x45, 0x4c, 0x42, 0x49, 0x41, 0x48, 0x58, 0x54, 0x4d,
+ 0x35, 0x52, 0x4e, 0x5b, 0x4f, 0x40, 0x3e, 0x46, 0x46, 0x36, 0x3d, 0x60,
+ 0x4d, 0x49, 0x4a, 0x43, 0x44, 0x36, 0x49, 0x67, 0x4a, 0x2d, 0x4b, 0x40,
+ 0x3f, 0x49, 0x43, 0x5f, 0x45, 0x3c, 0x49, 0x4c, 0x4a, 0x43, 0x48, 0x55,
+ 0x49, 0x46, 0x49, 0x46, 0x44, 0x4e, 0x42, 0x4e, 0x40, 0x45, 0x42, 0x52,
+ 0x4a, 0x40, 0x4a, 0x44, 0x40, 0x45, 0x54, 0x3d, 0x4c, 0x3e, 0x4c, 0x55,
+ 0x4d, 0x45, 0x4d, 0x51, 0x4a, 0x4b, 0x44, 0x5b, 0x48, 0x3d, 0x3e, 0x46,
+ 0x4f, 0x4d, 0x3f, 0x62, 0x4d, 0x45, 0x3f, 0x47, 0x47, 0x47, 0x44, 0x5b,
+ 0x4b, 0x4f, 0x51, 0x4c, 0x4a, 0x47, 0x48, 0x5b, 0x47, 0x40, 0x4a, 0x47,
+ 0x42, 0x44, 0x46, 0x46, 0x45, 0x48, 0x4a, 0x3f, 0x40, 0x4f, 0x48, 0x3a,
+ 0x49, 0x52, 0x4a, 0x53, 0x43, 0x4c, 0x4b, 0x4a, 0x4a, 0x4a, 0x4e, 0x42,
+ 0x4b, 0x46, 0x3d, 0x50, 0x51, 0x4b, 0x4b, 0x4f, 0x50, 0x4c, 0x4f, 0x4c,
+ 0x4d, 0x41, 0x41, 0x3c, 0x40, 0x43, 0x54, 0x51, 0x48, 0x3d, 0x48, 0x51,
+ 0x42, 0x42, 0x4c, 0x4e, 0x4d, 0x4b, 0x49, 0x43, 0x48, 0x47, 0x4b, 0x49,
+ 0x49, 0x4e, 0x4d, 0x46, 0x4c, 0x52, 0x49, 0x49, 0x51, 0x4e, 0x45, 0x47,
+ 0x44, 0x47, 0x42, 0x4a, 0x46, 0x59, 0x48, 0x48, 0x4b, 0x4f, 0x4c, 0x5e,
+ 0x5c, 0x45, 0x3f, 0x48, 0x3d, 0x3f, 0x37, 0x5a, 0x4b, 0x4b, 0x45, 0x49,
+ 0x3e, 0x42, 0x41, 0x6b, 0x49, 0x2d, 0x45, 0x43, 0x47, 0x45, 0x49, 0x61,
+ 0x3d, 0x3b, 0x49, 0x43, 0x49, 0x4b, 0x4b, 0x55, 0x4b, 0x47, 0x46, 0x46,
+ 0x48, 0x4d, 0x49, 0x4f, 0x4a, 0x4c, 0x42, 0x51, 0x41, 0x44, 0x45, 0x4f,
+ 0x4e, 0x44, 0x3f, 0x55, 0x3e, 0x4a, 0x45, 0x50, 0x46, 0x42, 0x41, 0x49,
+ 0x49, 0x47, 0x49, 0x61, 0x47, 0x40, 0x41, 0x4e, 0x4d, 0x4b, 0x4a, 0x5e,
+ 0x52, 0x49, 0x4b, 0x52, 0x51, 0x55, 0x42, 0x61, 0x53, 0x4c, 0x48, 0x4a,
+ 0x4e, 0x48, 0x48, 0x57, 0x4c, 0x40, 0x40, 0x48, 0x45, 0x43, 0x3e, 0x46,
+ 0x43, 0x4a, 0x45, 0x45, 0x44, 0x4f, 0x44, 0x40, 0x49, 0x48, 0x4e, 0x49,
+ 0x4a, 0x4e, 0x49, 0x51, 0x46, 0x4f, 0x47, 0x44, 0x42, 0x4d, 0x43, 0x4e,
+ 0x4f, 0x4d, 0x44, 0x51, 0x47, 0x49, 0x40, 0x57, 0x4b, 0x49, 0x47, 0x4c,
+ 0x4d, 0x4d, 0x3e, 0x47, 0x45, 0x41, 0x50, 0x4b, 0x4b, 0x45, 0x42, 0x4e,
+ 0x48, 0x47, 0x4e, 0x4b, 0x56, 0x4c, 0x4f, 0x52, 0x51, 0x49, 0x4d, 0x4a,
+ 0x4b, 0x52, 0x4d, 0x55, 0x4b, 0x4e, 0x4e, 0x4b, 0x51, 0x57, 0x47, 0x42,
+ 0x49, 0x48, 0x56, 0x44, 0x52, 0x56, 0x53, 0x5a, 0x63, 0x53, 0x4c, 0x4c,
+ 0x43, 0x56, 0x3c, 0x57, 0x47, 0x47, 0x4d, 0x52, 0x43, 0x48, 0x45, 0x5f,
+ 0x45, 0x29, 0x47, 0x45, 0x48, 0x40, 0x41, 0x4b, 0x3f, 0x39, 0x49, 0x4e,
+ 0x47, 0x55, 0x42, 0x56, 0x4d, 0x43, 0x48, 0x44, 0x45, 0x53, 0x43, 0x46,
+ 0x49, 0x43, 0x49, 0x4a, 0x40, 0x4e, 0x4a, 0x4a, 0x47, 0x43, 0x45, 0x4d,
+ 0x4a, 0x47, 0x3f, 0x53, 0x45, 0x43, 0x4b, 0x4c, 0x42, 0x47, 0x47, 0x5f,
+ 0x48, 0x48, 0x46, 0x44, 0x50, 0x47, 0x41, 0x64, 0x4e, 0x46, 0x49, 0x4a,
+ 0x4d, 0x55, 0x42, 0x55, 0x46, 0x3d, 0x49, 0x43, 0x52, 0x52, 0x47, 0x52,
+ 0x4e, 0x46, 0x47, 0x41, 0x49, 0x4d, 0x50, 0x47, 0x42, 0x49, 0x41, 0x42,
+ 0x4b, 0x48, 0x49, 0x42, 0x4d, 0x48, 0x51, 0x54, 0x43, 0x56, 0x4c, 0x52,
+ 0x53, 0x4d, 0x54, 0x4a, 0x51, 0x50, 0x48, 0x4c, 0x4e, 0x48, 0x4c, 0x4c,
+ 0x52, 0x49, 0x4a, 0x4e, 0x4e, 0x41, 0x4f, 0x53, 0x49, 0x52, 0x42, 0x4b,
+ 0x50, 0x46, 0x50, 0x4a, 0x53, 0x56, 0x46, 0x4f, 0x4b, 0x49, 0x3d, 0x41,
+ 0x4c, 0x52, 0x42, 0x50, 0x4d, 0x45, 0x4e, 0x51, 0x4b, 0x4c, 0x46, 0x42,
+ 0x41, 0x4b, 0x40, 0x4a, 0x42, 0x57, 0x4f, 0x43, 0x40, 0x50, 0x4c, 0x51,
+ 0x4f, 0x48, 0x3a, 0x4e, 0x51, 0x40, 0x49, 0x66, 0x4b, 0x42, 0x48, 0x3c,
+ 0x5b, 0x47, 0x53, 0x40, 0x4a, 0x48, 0x35, 0x44, 0x5f, 0x50, 0x4a, 0x3c,
+ 0x41, 0x45, 0x48, 0x3b, 0x42, 0x59, 0x43, 0x4b, 0x48, 0x49, 0x4a, 0x40,
+ 0x4f, 0x5c, 0x50, 0x54, 0x53, 0x55, 0x4c, 0x4a, 0x43, 0x46, 0x49, 0x47,
+ 0x49, 0x48, 0x4b, 0x43, 0x42, 0x44, 0x42, 0x46, 0x44, 0x3f, 0x4b, 0x42,
+ 0x4d, 0x49, 0x41, 0x46, 0x47, 0x51, 0x51, 0x44, 0x4c, 0x54, 0x4e, 0x4b,
+ 0x42, 0x52, 0x4e, 0x4c, 0x4b, 0x4a, 0x50, 0x4e, 0x44, 0x4b, 0x4e, 0x4e,
+ 0x4f, 0x42, 0x4b, 0x48, 0x46, 0x43, 0x48, 0x54, 0x4b, 0x4e, 0x48, 0x4f,
+ 0x4a, 0x4d, 0x43, 0x4e, 0x47, 0x50, 0x4a, 0x44, 0x47, 0x52, 0x46, 0x53,
+ 0x4a, 0x40, 0x46, 0x54, 0x50, 0x4a, 0x47, 0x51, 0x49, 0x45, 0x4b, 0x4e,
+ 0x4b, 0x46, 0x4c, 0x4c, 0x52, 0x47, 0x45, 0x45, 0x4a, 0x47, 0x4c, 0x52,
+ 0x44, 0x51, 0x47, 0x42, 0x47, 0x43, 0x43, 0x49, 0x52, 0x5a, 0x55, 0x3e,
+ 0x45, 0x4b, 0x4c, 0x46, 0x4f, 0x4b, 0x45, 0x49, 0x4a, 0x4e, 0x4a, 0x50,
+ 0x3e, 0x4e, 0x42, 0x4e, 0x44, 0x55, 0x3d, 0x4a, 0x4d, 0x49, 0x4d, 0x42,
+ 0x49, 0x4e, 0x50, 0x44, 0x4b, 0x3c, 0x41, 0x49, 0x51, 0x49, 0x3c, 0x4e,
+ 0x4c, 0x39, 0x4c, 0x72, 0x44, 0x4b, 0x49, 0x42, 0x5f, 0x48, 0x4a, 0x48,
+ 0x41, 0x4c, 0x43, 0x40, 0x62, 0x5e, 0x47, 0x3c, 0x4a, 0x4c, 0x55, 0x49,
+ 0x4b, 0x52, 0x4e, 0x4b, 0x4d, 0x48, 0x4c, 0x3c, 0x3f, 0x4f, 0x4e, 0x48,
+ 0x45, 0x55, 0x4a, 0x46, 0x48, 0x3d, 0x45, 0x44, 0x4b, 0x4a, 0x46, 0x3a,
+ 0x4e, 0x44, 0x4d, 0x49, 0x49, 0x49, 0x40, 0x3e, 0x40, 0x47, 0x48, 0x43,
+ 0x3f, 0x51, 0x46, 0x4c, 0x45, 0x4c, 0x49, 0x44, 0x3e, 0x57, 0x49, 0x4e,
+ 0x48, 0x3f, 0x48, 0x47, 0x53, 0x4d, 0x50, 0x51, 0x49, 0x42, 0x45, 0x44,
+ 0x49, 0x49, 0x46, 0x4b, 0x45, 0x49, 0x4f, 0x49, 0x46, 0x48, 0x4c, 0x55,
+ 0x46, 0x51, 0x48, 0x4a, 0x48, 0x54, 0x4b, 0x5a, 0x4c, 0x47, 0x40, 0x47,
+ 0x40, 0x55, 0x50, 0x52, 0x4a, 0x4b, 0x4f, 0x49, 0x4b, 0x50, 0x4b, 0x5b,
+ 0x51, 0x53, 0x4f, 0x4e, 0x49, 0x48, 0x44, 0x52, 0x46, 0x4e, 0x47, 0x48,
+ 0x44, 0x43, 0x49, 0x55, 0x48, 0x58, 0x4f, 0x46, 0x45, 0x53, 0x45, 0x4a,
+ 0x4c, 0x4c, 0x49, 0x46, 0x47, 0x4d, 0x41, 0x4d, 0x4f, 0x59, 0x4a, 0x49,
+ 0x46, 0x4e, 0x44, 0x49, 0x4d, 0x48, 0x54, 0x47, 0x48, 0x4e, 0x48, 0x43,
+ 0x46, 0x41, 0x46, 0x44, 0x52, 0x46, 0x42, 0x4c, 0x4c, 0x31, 0x4d, 0x6f,
+ 0x51, 0x4f, 0x4d, 0x43, 0x5c, 0x48, 0x49, 0x49, 0x46, 0x4c, 0x43, 0x3b,
+ 0x5d, 0x63, 0x58, 0x46, 0x49, 0x45, 0x4e, 0x48, 0x49, 0x5d, 0x45, 0x50,
+ 0x56, 0x4d, 0x57, 0x37, 0x40, 0x55, 0x43, 0x4b, 0x4e, 0x46, 0x4c, 0x3b,
+ 0x3d, 0x4b, 0x49, 0x4b, 0x52, 0x47, 0x4d, 0x34, 0x4c, 0x4c, 0x47, 0x4e,
+ 0x4d, 0x4c, 0x3d, 0x3f, 0x4a, 0x49, 0x44, 0x45, 0x4a, 0x54, 0x43, 0x44,
+ 0x50, 0x4b, 0x4d, 0x4c, 0x4e, 0x48, 0x46, 0x51, 0x43, 0x48, 0x48, 0x48,
+ 0x42, 0x44, 0x4e, 0x48, 0x47, 0x45, 0x48, 0x51, 0x53, 0x4a, 0x4f, 0x58,
+ 0x42, 0x4d, 0x48, 0x4f, 0x4c, 0x45, 0x4a, 0x57, 0x4b, 0x43, 0x4d, 0x4b,
+ 0x4a, 0x4e, 0x4c, 0x5f, 0x3f, 0x4f, 0x4a, 0x42, 0x4b, 0x48, 0x4d, 0x62,
+ 0x4f, 0x4b, 0x50, 0x4c, 0x45, 0x49, 0x44, 0x53, 0x4a, 0x4f, 0x45, 0x56,
+ 0x4b, 0x44, 0x41, 0x53, 0x49, 0x48, 0x4d, 0x49, 0x47, 0x4b, 0x46, 0x4c,
+ 0x49, 0x4b, 0x4c, 0x54, 0x4f, 0x4b, 0x47, 0x49, 0x44, 0x4a, 0x4e, 0x53,
+ 0x4f, 0x49, 0x54, 0x4e, 0x4a, 0x48, 0x42, 0x54, 0x51, 0x46, 0x4b, 0x52,
+ 0x45, 0x48, 0x51, 0x4a, 0x40, 0x4a, 0x50, 0x45, 0x4a, 0x46, 0x49, 0x46,
+ 0x54, 0x46, 0x42, 0x48, 0x50, 0x36, 0x4a, 0x6b, 0x46, 0x59, 0x51, 0x47,
+ 0x5f, 0x4d, 0x43, 0x4d, 0x44, 0x4d, 0x42, 0x3b, 0x65, 0x6a, 0x56, 0x48,
+ 0x4d, 0x4c, 0x52, 0x4a, 0x4d, 0x61, 0x52, 0x4b, 0x47, 0x4f, 0x48, 0x49,
+ 0x3f, 0x5b, 0x45, 0x51, 0x48, 0x48, 0x4b, 0x3c, 0x3b, 0x4c, 0x54, 0x52,
+ 0x4f, 0x51, 0x53, 0x31, 0x47, 0x4c, 0x45, 0x4a, 0x42, 0x4b, 0x47, 0x40,
+ 0x41, 0x49, 0x4c, 0x46, 0x4b, 0x53, 0x46, 0x49, 0x44, 0x4b, 0x4e, 0x4b,
+ 0x48, 0x51, 0x49, 0x4d, 0x4b, 0x3f, 0x42, 0x44, 0x45, 0x43, 0x46, 0x56,
+ 0x42, 0x4b, 0x49, 0x4e, 0x4e, 0x53, 0x42, 0x5c, 0x4b, 0x46, 0x49, 0x46,
+ 0x4e, 0x41, 0x42, 0x67, 0x41, 0x49, 0x4d, 0x48, 0x49, 0x4e, 0x3f, 0x61,
+ 0x48, 0x4a, 0x40, 0x42, 0x4c, 0x51, 0x50, 0x63, 0x49, 0x44, 0x49, 0x47,
+ 0x45, 0x4d, 0x49, 0x61, 0x3f, 0x48, 0x40, 0x41, 0x49, 0x49, 0x45, 0x57,
+ 0x45, 0x46, 0x4d, 0x46, 0x4c, 0x4a, 0x4d, 0x4b, 0x43, 0x54, 0x4b, 0x49,
+ 0x4c, 0x49, 0x41, 0x49, 0x4b, 0x47, 0x45, 0x4b, 0x44, 0x43, 0x46, 0x3f,
+ 0x47, 0x47, 0x43, 0x4c, 0x49, 0x4c, 0x3d, 0x4d, 0x4b, 0x54, 0x4a, 0x4f,
+ 0x44, 0x4c, 0x4b, 0x47, 0x4c, 0x45, 0x3d, 0x52, 0x58, 0x4b, 0x45, 0x4e,
+ 0x48, 0x39, 0x53, 0x70, 0x4a, 0x5d, 0x4c, 0x4e, 0x5a, 0x4f, 0x46, 0x4b,
+ 0x3e, 0x4f, 0x44, 0x3d, 0x66, 0x6b, 0x50, 0x4d, 0x4d, 0x57, 0x52, 0x4a,
+ 0x4c, 0x5b, 0x4e, 0x53, 0x4d, 0x54, 0x50, 0x42, 0x3c, 0x5d, 0x4a, 0x4c,
+ 0x56, 0x52, 0x50, 0x40, 0x48, 0x4c, 0x4d, 0x49, 0x49, 0x4f, 0x51, 0x38,
+ 0x42, 0x49, 0x4d, 0x4f, 0x45, 0x40, 0x4d, 0x41, 0x4b, 0x4a, 0x47, 0x51,
+ 0x4b, 0x53, 0x4c, 0x4a, 0x51, 0x4c, 0x42, 0x56, 0x48, 0x4a, 0x47, 0x58,
+ 0x49, 0x46, 0x52, 0x4a, 0x45, 0x47, 0x51, 0x54, 0x4f, 0x50, 0x50, 0x53,
+ 0x49, 0x4a, 0x4d, 0x56, 0x56, 0x4b, 0x4d, 0x45, 0x40, 0x4d, 0x48, 0x60,
+ 0x4e, 0x56, 0x48, 0x4b, 0x47, 0x45, 0x47, 0x62, 0x4e, 0x4f, 0x41, 0x49,
+ 0x48, 0x57, 0x44, 0x64, 0x4f, 0x4f, 0x49, 0x44, 0x49, 0x4c, 0x3f, 0x53,
+ 0x40, 0x41, 0x4e, 0x4b, 0x4d, 0x54, 0x42, 0x53, 0x4e, 0x41, 0x49, 0x44,
+ 0x41, 0x45, 0x4d, 0x4f, 0x47, 0x51, 0x45, 0x4a, 0x42, 0x45, 0x4e, 0x40,
+ 0x4b, 0x52, 0x48, 0x47, 0x4e, 0x4f, 0x47, 0x41, 0x48, 0x53, 0x47, 0x47,
+ 0x46, 0x42, 0x48, 0x4b, 0x42, 0x4c, 0x49, 0x4c, 0x45, 0x4c, 0x54, 0x45,
+ 0x4c, 0x43, 0x4e, 0x49, 0x56, 0x47, 0x45, 0x4f, 0x4d, 0x3a, 0x58, 0x74,
+ 0x49, 0x5b, 0x4c, 0x4f, 0x64, 0x4e, 0x45, 0x43, 0x44, 0x5b, 0x43, 0x41,
+ 0x63, 0x70, 0x55, 0x45, 0x4a, 0x4a, 0x4d, 0x51, 0x4b, 0x5a, 0x51, 0x57,
+ 0x54, 0x5b, 0x55, 0x44, 0x38, 0x57, 0x4e, 0x50, 0x4e, 0x56, 0x57, 0x3a,
+ 0x3a, 0x4b, 0x57, 0x4c, 0x51, 0x53, 0x4d, 0x3b, 0x44, 0x43, 0x47, 0x4c,
+ 0x48, 0x59, 0x51, 0x41, 0x43, 0x44, 0x51, 0x51, 0x4a, 0x54, 0x51, 0x4b,
+ 0x4e, 0x45, 0x51, 0x4a, 0x49, 0x4a, 0x4f, 0x52, 0x4c, 0x3e, 0x4e, 0x55,
+ 0x42, 0x46, 0x46, 0x4a, 0x42, 0x52, 0x49, 0x47, 0x4a, 0x56, 0x4f, 0x50,
+ 0x46, 0x4f, 0x43, 0x51, 0x53, 0x46, 0x40, 0x60, 0x44, 0x4d, 0x46, 0x54,
+ 0x3d, 0x49, 0x43, 0x64, 0x45, 0x4d, 0x50, 0x49, 0x4f, 0x4d, 0x53, 0x60,
+ 0x4a, 0x52, 0x49, 0x47, 0x48, 0x5a, 0x48, 0x58, 0x4e, 0x4f, 0x43, 0x4f,
+ 0x50, 0x51, 0x41, 0x52, 0x4c, 0x4d, 0x45, 0x42, 0x41, 0x4c, 0x44, 0x54,
+ 0x4e, 0x4d, 0x4a, 0x47, 0x40, 0x4a, 0x3e, 0x47, 0x4c, 0x58, 0x46, 0x46,
+ 0x55, 0x4c, 0x4d, 0x45, 0x49, 0x51, 0x53, 0x46, 0x46, 0x43, 0x43, 0x48,
+ 0x52, 0x3d, 0x4b, 0x4e, 0x49, 0x47, 0x3f, 0x3d, 0x4f, 0x45, 0x44, 0x3f,
+ 0x5a, 0x43, 0x4b, 0x4d, 0x51, 0x35, 0x54, 0x76, 0x4f, 0x5e, 0x4c, 0x50,
+ 0x5a, 0x51, 0x46, 0x49, 0x44, 0x61, 0x4f, 0x41, 0x67, 0x72, 0x56, 0x4f,
+ 0x42, 0x48, 0x4b, 0x52, 0x46, 0x60, 0x50, 0x4e, 0x4a, 0x5b, 0x5f, 0x46,
+ 0x31, 0x5b, 0x4a, 0x48, 0x4b, 0x58, 0x51, 0x41, 0x37, 0x4e, 0x4f, 0x55,
+ 0x51, 0x5c, 0x4f, 0x42, 0x4b, 0x4e, 0x4f, 0x54, 0x4f, 0x52, 0x43, 0x43,
+ 0x48, 0x53, 0x53, 0x41, 0x4b, 0x49, 0x4e, 0x50, 0x46, 0x4c, 0x4f, 0x49,
+ 0x42, 0x49, 0x4c, 0x4c, 0x4c, 0x41, 0x4e, 0x48, 0x47, 0x4c, 0x49, 0x53,
+ 0x44, 0x46, 0x51, 0x53, 0x45, 0x52, 0x4e, 0x53, 0x50, 0x58, 0x42, 0x45,
+ 0x44, 0x42, 0x48, 0x58, 0x4e, 0x4d, 0x54, 0x56, 0x4c, 0x46, 0x4a, 0x58,
+ 0x48, 0x4f, 0x47, 0x51, 0x47, 0x4f, 0x4f, 0x5b, 0x41, 0x4e, 0x45, 0x45,
+ 0x4a, 0x50, 0x3e, 0x57, 0x48, 0x4e, 0x41, 0x4c, 0x45, 0x51, 0x46, 0x4c,
+ 0x46, 0x4f, 0x42, 0x45, 0x4b, 0x4c, 0x49, 0x4c, 0x44, 0x4f, 0x4e, 0x4d,
+ 0x48, 0x56, 0x43, 0x48, 0x42, 0x54, 0x48, 0x43, 0x3e, 0x51, 0x43, 0x47,
+ 0x47, 0x47, 0x49, 0x4d, 0x46, 0x4e, 0x52, 0x42, 0x48, 0x4e, 0x4c, 0x4a,
+ 0x4d, 0x3e, 0x43, 0x40, 0x48, 0x41, 0x47, 0x4f, 0x5e, 0x49, 0x40, 0x4c,
+ 0x50, 0x42, 0x56, 0x75, 0x51, 0x5e, 0x51, 0x4e, 0x62, 0x58, 0x49, 0x47,
+ 0x51, 0x59, 0x46, 0x46, 0x6c, 0x72, 0x55, 0x44, 0x4c, 0x4a, 0x4d, 0x59,
+ 0x53, 0x64, 0x4d, 0x51, 0x55, 0x5e, 0x59, 0x50, 0x30, 0x58, 0x50, 0x4c,
+ 0x4c, 0x60, 0x59, 0x42, 0x32, 0x53, 0x50, 0x55, 0x4d, 0x53, 0x59, 0x43,
+ 0x3e, 0x49, 0x4f, 0x52, 0x4d, 0x51, 0x47, 0x45, 0x4d, 0x4e, 0x53, 0x4e,
+ 0x54, 0x4f, 0x4d, 0x4d, 0x4e, 0x40, 0x47, 0x53, 0x53, 0x49, 0x56, 0x4d,
+ 0x4d, 0x3a, 0x4c, 0x4e, 0x45, 0x4a, 0x47, 0x45, 0x53, 0x4a, 0x4e, 0x52,
+ 0x4d, 0x4e, 0x48, 0x56, 0x4e, 0x4a, 0x4d, 0x52, 0x49, 0x4e, 0x4e, 0x58,
+ 0x47, 0x50, 0x4c, 0x54, 0x49, 0x42, 0x46, 0x54, 0x50, 0x54, 0x54, 0x46,
+ 0x40, 0x49, 0x4b, 0x57, 0x4b, 0x59, 0x44, 0x46, 0x52, 0x55, 0x51, 0x55,
+ 0x4f, 0x50, 0x4d, 0x4d, 0x48, 0x50, 0x4e, 0x49, 0x4e, 0x42, 0x45, 0x3f,
+ 0x4d, 0x4f, 0x51, 0x47, 0x4a, 0x4c, 0x4b, 0x4b, 0x46, 0x4d, 0x44, 0x52,
+ 0x4d, 0x44, 0x40, 0x4d, 0x54, 0x46, 0x54, 0x44, 0x4b, 0x46, 0x47, 0x45,
+ 0x50, 0x45, 0x45, 0x4b, 0x4c, 0x48, 0x3f, 0x55, 0x4a, 0x45, 0x49, 0x4e,
+ 0x40, 0x49, 0x4a, 0x41, 0x56, 0x4b, 0x49, 0x4e, 0x4a, 0x41, 0x50, 0x70,
+ 0x56, 0x59, 0x4b, 0x55, 0x58, 0x59, 0x49, 0x47, 0x4a, 0x5a, 0x4c, 0x46,
+ 0x62, 0x7b, 0x58, 0x51, 0x44, 0x47, 0x44, 0x57, 0x4f, 0x65, 0x4e, 0x50,
+ 0x4d, 0x67, 0x5c, 0x4a, 0x2b, 0x61, 0x48, 0x4b, 0x4b, 0x5d, 0x5c, 0x48,
+ 0x39, 0x50, 0x45, 0x4d, 0x53, 0x60, 0x53, 0x46, 0x42, 0x46, 0x50, 0x45,
+ 0x4f, 0x4e, 0x46, 0x4a, 0x4d, 0x51, 0x54, 0x47, 0x59, 0x4b, 0x58, 0x4a,
+ 0x50, 0x3d, 0x59, 0x48, 0x45, 0x4e, 0x4e, 0x47, 0x4f, 0x47, 0x4d, 0x4b,
+ 0x52, 0x42, 0x4c, 0x48, 0x4a, 0x4f, 0x47, 0x43, 0x4e, 0x4c, 0x4d, 0x51,
+ 0x49, 0x4f, 0x4c, 0x47, 0x47, 0x48, 0x47, 0x59, 0x4f, 0x4f, 0x53, 0x49,
+ 0x4e, 0x4b, 0x4f, 0x5a, 0x50, 0x42, 0x47, 0x50, 0x4a, 0x54, 0x47, 0x5a,
+ 0x43, 0x49, 0x47, 0x4e, 0x49, 0x4d, 0x43, 0x54, 0x4c, 0x53, 0x4e, 0x4e,
+ 0x42, 0x43, 0x48, 0x46, 0x4f, 0x43, 0x43, 0x45, 0x51, 0x47, 0x4b, 0x4f,
+ 0x56, 0x48, 0x48, 0x49, 0x46, 0x45, 0x4d, 0x52, 0x47, 0x4b, 0x46, 0x50,
+ 0x3e, 0x4e, 0x4c, 0x43, 0x45, 0x4d, 0x53, 0x43, 0x46, 0x45, 0x44, 0x52,
+ 0x45, 0x49, 0x49, 0x51, 0x3d, 0x4a, 0x4d, 0x46, 0x42, 0x41, 0x4e, 0x48,
+ 0x5a, 0x49, 0x49, 0x49, 0x4f, 0x3d, 0x56, 0x68, 0x56, 0x67, 0x4b, 0x57,
+ 0x5f, 0x5c, 0x40, 0x4a, 0x4a, 0x54, 0x4c, 0x47, 0x64, 0x7a, 0x54, 0x48,
+ 0x46, 0x45, 0x46, 0x57, 0x4e, 0x61, 0x4f, 0x50, 0x4d, 0x64, 0x5b, 0x43,
+ 0x2d, 0x60, 0x55, 0x51, 0x4c, 0x54, 0x4f, 0x4e, 0x2f, 0x50, 0x4f, 0x52,
+ 0x50, 0x61, 0x54, 0x4b, 0x3d, 0x4c, 0x47, 0x51, 0x4a, 0x54, 0x4b, 0x42,
+ 0x3b, 0x55, 0x47, 0x50, 0x4f, 0x49, 0x4a, 0x46, 0x43, 0x44, 0x45, 0x47,
+ 0x46, 0x4b, 0x4f, 0x46, 0x43, 0x47, 0x4a, 0x4e, 0x51, 0x43, 0x55, 0x47,
+ 0x4d, 0x46, 0x4c, 0x4c, 0x49, 0x4d, 0x43, 0x51, 0x47, 0x51, 0x52, 0x4a,
+ 0x46, 0x4f, 0x49, 0x52, 0x50, 0x4a, 0x43, 0x53, 0x46, 0x4e, 0x50, 0x54,
+ 0x45, 0x3a, 0x4a, 0x4a, 0x4c, 0x50, 0x4b, 0x54, 0x43, 0x4f, 0x4e, 0x45,
+ 0x49, 0x4f, 0x46, 0x53, 0x4d, 0x51, 0x52, 0x53, 0x3d, 0x4a, 0x47, 0x4e,
+ 0x43, 0x4a, 0x53, 0x48, 0x4a, 0x4c, 0x4a, 0x4a, 0x42, 0x53, 0x3e, 0x43,
+ 0x4f, 0x4c, 0x47, 0x48, 0x54, 0x4d, 0x48, 0x48, 0x4e, 0x4c, 0x43, 0x51,
+ 0x42, 0x49, 0x44, 0x3e, 0x49, 0x51, 0x4a, 0x4d, 0x4f, 0x49, 0x45, 0x44,
+ 0x4e, 0x41, 0x48, 0x4b, 0x4c, 0x49, 0x46, 0x47, 0x5d, 0x4c, 0x4d, 0x50,
+ 0x45, 0x40, 0x4e, 0x6a, 0x4f, 0x62, 0x53, 0x50, 0x5c, 0x5e, 0x4a, 0x4c,
+ 0x50, 0x56, 0x52, 0x42, 0x60, 0x7e, 0x5b, 0x4b, 0x43, 0x41, 0x4c, 0x56,
+ 0x46, 0x5f, 0x4d, 0x49, 0x43, 0x65, 0x5c, 0x4d, 0x2c, 0x61, 0x48, 0x4c,
+ 0x44, 0x55, 0x5c, 0x49, 0x37, 0x54, 0x4e, 0x57, 0x52, 0x5c, 0x50, 0x49,
+ 0x3e, 0x4d, 0x4f, 0x4f, 0x51, 0x4c, 0x48, 0x43, 0x4a, 0x5a, 0x4d, 0x4b,
+ 0x4e, 0x58, 0x54, 0x49, 0x51, 0x42, 0x49, 0x4f, 0x46, 0x45, 0x52, 0x3d,
+ 0x4b, 0x4b, 0x43, 0x54, 0x47, 0x47, 0x4c, 0x42, 0x4b, 0x49, 0x45, 0x46,
+ 0x46, 0x4a, 0x51, 0x47, 0x47, 0x4f, 0x48, 0x4a, 0x3f, 0x4c, 0x4b, 0x57,
+ 0x4a, 0x3f, 0x52, 0x4a, 0x56, 0x52, 0x4b, 0x54, 0x4c, 0x3e, 0x3f, 0x4f,
+ 0x4b, 0x50, 0x4c, 0x53, 0x4a, 0x49, 0x46, 0x4e, 0x50, 0x48, 0x4f, 0x4b,
+ 0x4a, 0x4e, 0x3e, 0x49, 0x45, 0x42, 0x42, 0x41, 0x47, 0x4b, 0x4f, 0x42,
+ 0x49, 0x4c, 0x55, 0x4c, 0x4e, 0x42, 0x47, 0x42, 0x4b, 0x48, 0x46, 0x41,
+ 0x46, 0x4e, 0x4d, 0x3f, 0x4f, 0x46, 0x4f, 0x4b, 0x4b, 0x4d, 0x50, 0x3e,
+ 0x42, 0x43, 0x44, 0x4a, 0x49, 0x40, 0x4e, 0x43, 0x3e, 0x52, 0x3e, 0x44,
+ 0x49, 0x43, 0x4d, 0x44, 0x62, 0x51, 0x42, 0x53, 0x51, 0x40, 0x4c, 0x64,
+ 0x4f, 0x63, 0x4e, 0x5c, 0x5b, 0x5c, 0x48, 0x4d, 0x4a, 0x57, 0x4f, 0x42,
+ 0x65, 0xfe, 0x5c, 0x4e, 0x47, 0x43, 0x4a, 0x58, 0x4e, 0x5e, 0x48, 0x4c,
+ 0x51, 0x5e, 0x60, 0x56, 0x2f, 0x62, 0x54, 0x58, 0x51, 0x52, 0x55, 0x51,
+ 0x36, 0x4b, 0x46, 0x51, 0x53, 0x5f, 0x46, 0x4c, 0x37, 0x4d, 0x4a, 0x45,
+ 0x4b, 0x3f, 0x41, 0x42, 0x3f, 0x53, 0x4a, 0x48, 0x49, 0x4a, 0x4a, 0x45,
+ 0x52, 0x3f, 0x52, 0x52, 0x45, 0x4d, 0x4f, 0x45, 0x46, 0x4a, 0x51, 0x48,
+ 0x56, 0x47, 0x50, 0x3e, 0x46, 0x49, 0x4c, 0x51, 0x49, 0x54, 0x45, 0x4f,
+ 0x4b, 0x4b, 0x49, 0x46, 0x4b, 0x4d, 0x49, 0x5c, 0x4d, 0x43, 0x47, 0x49,
+ 0x48, 0x52, 0x46, 0x50, 0x51, 0x37, 0x50, 0x52, 0x4c, 0x4d, 0x4f, 0x51,
+ 0x4f, 0x42, 0x50, 0x47, 0x48, 0x4e, 0x4d, 0x4c, 0x48, 0x48, 0x4a, 0x51,
+ 0x49, 0x42, 0x50, 0x4f, 0x43, 0x4e, 0x47, 0x4b, 0x47, 0x4a, 0x44, 0x44,
+ 0x4c, 0x51, 0x49, 0x44, 0x45, 0x45, 0x45, 0x48, 0x3f, 0x4a, 0x43, 0x49,
+ 0x46, 0x49, 0x4c, 0x4d, 0x45, 0x50, 0x44, 0x45, 0x44, 0x55, 0x4a, 0x45,
+ 0x48, 0x47, 0x4c, 0x43, 0x3f, 0x48, 0x42, 0x43, 0x43, 0x43, 0x48, 0x46,
+ 0x5c, 0x51, 0x47, 0x51, 0x48, 0x40, 0x54, 0x66, 0x4e, 0x67, 0x4d, 0x5a,
+ 0x60, 0x57, 0x47, 0x4d, 0x4d, 0x58, 0x53, 0x46, 0x66, 0x7e, 0x56, 0x48,
+ 0x44, 0x4f, 0x49, 0x5c, 0x4a, 0x63, 0x50, 0x4c, 0x49, 0x56, 0x61, 0x50,
+ 0x2c, 0x68, 0x4d, 0x51, 0x46, 0x4e, 0x5b, 0x51, 0x2e, 0x53, 0x54, 0x50,
+ 0x46, 0x58, 0x44, 0x4f, 0x37, 0x48, 0x55, 0x50, 0x49, 0x49, 0x4e, 0x46,
+ 0x43, 0x56, 0x52, 0x4e, 0x50, 0x4b, 0x50, 0x4c, 0x49, 0x40, 0x4d, 0x4f,
+ 0x50, 0x41, 0x44, 0x39, 0x4b, 0x4d, 0x4b, 0x41, 0x51, 0x4d, 0x4c, 0x41,
+ 0x3f, 0x52, 0x4e, 0x4b, 0x49, 0x53, 0x45, 0x43, 0x4d, 0x4f, 0x44, 0x4d,
+ 0x4b, 0x53, 0x50, 0x4e, 0x45, 0x3f, 0x4e, 0x51, 0x50, 0x55, 0x4f, 0x51,
+ 0x4d, 0x3d, 0x58, 0x3f, 0x46, 0x50, 0x50, 0x50, 0x56, 0x42, 0x49, 0x49,
+ 0x50, 0x4f, 0x42, 0x4b, 0x4c, 0x45, 0x52, 0x41, 0x46, 0x43, 0x4c, 0x4a,
+ 0x4c, 0x51, 0x4d, 0x4d, 0x4a, 0x49, 0x54, 0x49, 0x58, 0x53, 0x49, 0x45,
+ 0x47, 0x4c, 0x4c, 0x44, 0x4e, 0x51, 0x4c, 0x4c, 0x47, 0x48, 0x4c, 0x4e,
+ 0x49, 0x54, 0x4c, 0x51, 0x49, 0x48, 0x47, 0x45, 0x42, 0x49, 0x42, 0x51,
+ 0x4e, 0x3f, 0x49, 0x41, 0x50, 0x3e, 0x4d, 0x50, 0x5c, 0x51, 0x4d, 0x56,
+ 0x47, 0x48, 0x58, 0x65, 0x51, 0x6b, 0x56, 0x5b, 0x56, 0x55, 0x46, 0x49,
+ 0x4b, 0x58, 0x59, 0x4a, 0x68, 0x79, 0x53, 0x46, 0x45, 0x4b, 0x53, 0x5d,
+ 0x4b, 0x6f, 0x4e, 0x4f, 0x4c, 0x53, 0x5b, 0x52, 0x30, 0x63, 0x46, 0x57,
+ 0x46, 0x50, 0x4b, 0x48, 0x2e, 0x4c, 0x46, 0x48, 0x44, 0x51, 0x46, 0x4a,
+ 0x35, 0x55, 0x43, 0x4c, 0x43, 0x4d, 0x4e, 0x3e, 0x47, 0x56, 0x50, 0x4d,
+ 0x44, 0x59, 0x4c, 0x51, 0x46, 0x42, 0x4e, 0x43, 0x4c, 0x44, 0x42, 0x3a,
+ 0x40, 0x48, 0x46, 0x44, 0x45, 0x4a, 0x46, 0x3a, 0x53, 0x4c, 0x4d, 0x4c,
+ 0x4a, 0x4f, 0x53, 0x40, 0x4b, 0x48, 0x54, 0x4b, 0x44, 0x59, 0x41, 0x50,
+ 0x4e, 0x50, 0x55, 0x4d, 0x55, 0x41, 0x4a, 0x4f, 0x47, 0x43, 0x4e, 0x50,
+ 0x52, 0x4c, 0x50, 0x4d, 0x47, 0x42, 0x4f, 0x4b, 0x47, 0x43, 0x41, 0x4a,
+ 0x55, 0x3e, 0x50, 0x4b, 0x41, 0x49, 0x47, 0x49, 0x53, 0x4d, 0x48, 0x4b,
+ 0x43, 0x43, 0x51, 0x44, 0x4d, 0x4c, 0x44, 0x50, 0x4d, 0x42, 0x49, 0x4e,
+ 0x50, 0x50, 0x4c, 0x49, 0x49, 0x51, 0x46, 0x43, 0x4a, 0x4e, 0x53, 0x47,
+ 0x43, 0x46, 0x40, 0x49, 0x47, 0x44, 0x44, 0x4d, 0x4b, 0x4b, 0x51, 0x4b,
+ 0x45, 0x49, 0x47, 0x43, 0x56, 0x49, 0x4c, 0x54, 0x50, 0x3c, 0x4c, 0x5e,
+ 0x51, 0x67, 0x4f, 0x57, 0x57, 0x53, 0x3e, 0x4e, 0x4e, 0x5e, 0x4b, 0x48,
+ 0x5a, 0x78, 0x55, 0x4a, 0x3f, 0x4b, 0x4c, 0x5b, 0x53, 0x64, 0x4d, 0x53,
+ 0x49, 0x57, 0x57, 0x58, 0x37, 0x62, 0x4f, 0x56, 0x44, 0x4e, 0x58, 0x4a,
+ 0x30, 0x4f, 0x40, 0x4e, 0x47, 0x58, 0x52, 0x50, 0x35, 0x4d, 0x49, 0x52,
+ 0x4e, 0x42, 0x46, 0x47, 0x44, 0x57, 0x54, 0x43, 0x4e, 0x56, 0x43, 0x49,
+ 0x44, 0x40, 0x44, 0x41, 0x50, 0x49, 0x4b, 0x44, 0x4d, 0x52, 0x49, 0x43,
+ 0x52, 0x54, 0x49, 0x3f, 0x49, 0x42, 0x49, 0x4a, 0x43, 0x3e, 0x50, 0x40,
+ 0x46, 0x4b, 0x50, 0x4b, 0x53, 0x4b, 0x47, 0x52, 0x51, 0x4b, 0x47, 0x3f,
+ 0x46, 0x4b, 0x4c, 0x57, 0x49, 0x47, 0x54, 0x49, 0x50, 0x50, 0x4d, 0x4a,
+ 0x42, 0x4e, 0x51, 0x4c, 0x47, 0x47, 0x42, 0x43, 0x54, 0x43, 0x46, 0x47,
+ 0x4d, 0x43, 0x54, 0x47, 0x43, 0x58, 0x48, 0x45, 0x4b, 0x46, 0x48, 0x3d,
+ 0x47, 0x3f, 0x44, 0x4f, 0x4e, 0x46, 0x41, 0x40, 0x4d, 0x4d, 0x4d, 0x52,
+ 0x54, 0x47, 0x4f, 0x51, 0x4f, 0x45, 0x45, 0x48, 0x4b, 0x4d, 0x44, 0x52,
+ 0x51, 0x4b, 0x48, 0x4f, 0x49, 0x49, 0x46, 0x50, 0x54, 0x42, 0x44, 0x51,
+ 0x58, 0x4e, 0x43, 0x58, 0x55, 0x40, 0x53, 0x5a, 0x51, 0x61, 0x51, 0x60,
+ 0x53, 0x57, 0x45, 0x4f, 0x45, 0x5e, 0x51, 0x42, 0x61, 0x7a, 0x55, 0x47,
+ 0x41, 0x4b, 0x4a, 0x5b, 0x4c, 0x65, 0x4f, 0x55, 0x46, 0x54, 0x65, 0x59,
+ 0x36, 0x61, 0x54, 0x55, 0x48, 0x57, 0x52, 0x4e, 0x24, 0x4b, 0x49, 0x4d,
+ 0x43, 0x57, 0x44, 0x51, 0x3b, 0x4f, 0x45, 0x40, 0x47, 0x4a, 0x43, 0x47,
+ 0x46, 0x58, 0x50, 0x54, 0x4d, 0x50, 0x44, 0x42, 0x4a, 0x46, 0x4b, 0x4d,
+ 0x4f, 0x4f, 0x4d, 0x40, 0x48, 0x4a, 0x53, 0x48, 0x49, 0x48, 0x4d, 0x39,
+ 0x47, 0x4e, 0x44, 0x4c, 0x4b, 0x49, 0x44, 0x42, 0x4a, 0x45, 0x46, 0x46,
+ 0x53, 0x4d, 0x49, 0x4f, 0x4e, 0x48, 0x50, 0x4a, 0x4c, 0x46, 0x56, 0x4b,
+ 0x4b, 0x57, 0x4c, 0x49, 0x4a, 0x4a, 0x43, 0x4e, 0x56, 0x45, 0x50, 0x4c,
+ 0x47, 0x55, 0x48, 0x46, 0x4e, 0x46, 0x45, 0x3f, 0x4a, 0x4c, 0x4c, 0x47,
+ 0x4a, 0x51, 0x4e, 0x50, 0x40, 0x52, 0x45, 0x45, 0x4b, 0x46, 0x4f, 0x44,
+ 0x51, 0x4a, 0x4e, 0x4d, 0x4c, 0x46, 0x42, 0x47, 0x4a, 0x4e, 0x46, 0x42,
+ 0x4b, 0x4f, 0x4b, 0x4e, 0x4e, 0x46, 0x42, 0x50, 0x53, 0x51, 0x4f, 0x54,
+ 0x45, 0x4f, 0x45, 0x42, 0x4c, 0x45, 0x40, 0x48, 0x59, 0x49, 0x49, 0x53,
+ 0x4c, 0x43, 0x4b, 0x57, 0x54, 0x64, 0x4e, 0x5f, 0x5c, 0x59, 0x4b, 0x56,
+ 0x49, 0x5d, 0x4f, 0x4b, 0x62, 0x73, 0x54, 0x45, 0x49, 0x50, 0x48, 0x5a,
+ 0x50, 0x6d, 0x4a, 0x4e, 0x48, 0x55, 0x5d, 0x57, 0x38, 0x68, 0x52, 0x5a,
+ 0x46, 0x56, 0x4c, 0x5a, 0x2e, 0x55, 0x49, 0x4f, 0x4a, 0x57, 0x4f, 0x54,
+ 0x41, 0x53, 0x46, 0x43, 0x45, 0x47, 0x53, 0x4a, 0x42, 0x4f, 0x4d, 0x48,
+ 0x4c, 0x49, 0x47, 0x48, 0x45, 0x49, 0x48, 0x53, 0x48, 0x52, 0x4a, 0x44,
+ 0x4c, 0x49, 0x52, 0x4b, 0x47, 0x51, 0x42, 0x47, 0x49, 0x51, 0x3f, 0x45,
+ 0x47, 0x4e, 0x53, 0x33, 0x55, 0x51, 0x55, 0x48, 0x4b, 0x51, 0x56, 0x47,
+ 0x43, 0x55, 0x47, 0x42, 0x47, 0x4f, 0x47, 0x51, 0x46, 0x55, 0x4a, 0x4b,
+ 0x50, 0x52, 0x4f, 0x43, 0x4b, 0x53, 0x4d, 0x3f, 0x4e, 0x56, 0x50, 0x49,
+ 0x4d, 0x47, 0x51, 0x49, 0x4a, 0x52, 0x44, 0x43, 0x4d, 0x4e, 0x41, 0x51,
+ 0x4c, 0x4d, 0x47, 0x48, 0x4f, 0x40, 0x50, 0x46, 0x43, 0x4d, 0x4e, 0x50,
+ 0x43, 0x47, 0x4e, 0x46, 0x4f, 0x4b, 0x51, 0x4b, 0x4a, 0x57, 0x42, 0x51,
+ 0x4c, 0x54, 0x52, 0x42, 0x4c, 0x42, 0x47, 0x54, 0x4a, 0x4a, 0x47, 0x4a,
+ 0x3f, 0x46, 0x4e, 0x4c, 0x53, 0x50, 0x47, 0x53, 0x49, 0x44, 0x52, 0x5a,
+ 0x4b, 0x65, 0x50, 0x5b, 0x57, 0x59, 0x4a, 0x48, 0x48, 0x5f, 0x55, 0x48,
+ 0x5c, 0x78, 0x55, 0x48, 0x4a, 0x4b, 0x49, 0x4c, 0x46, 0x6b, 0x54, 0x57,
+ 0x55, 0x4b, 0x59, 0x52, 0x38, 0x5b, 0x57, 0x56, 0x4b, 0x4f, 0x48, 0x4e,
+ 0x34, 0x5a, 0x4e, 0x4f, 0x43, 0x4e, 0x4b, 0x4e, 0x36, 0x4d, 0x52, 0x48,
+ 0x4d, 0x4c, 0x4c, 0x49, 0x51, 0x54, 0x45, 0x54, 0x4a, 0x4e, 0x52, 0x41,
+ 0x4c, 0x45, 0x4a, 0x53, 0x55, 0x4b, 0x50, 0x47, 0x4e, 0x4d, 0x43, 0x51,
+ 0x4e, 0x4a, 0x51, 0x46, 0x4e, 0x4d, 0x48, 0x3f, 0x43, 0x52, 0x56, 0x38,
+ 0x52, 0x46, 0x43, 0x49, 0x40, 0x49, 0x53, 0x41, 0x47, 0x41, 0x41, 0x42,
+ 0x4f, 0x4b, 0x46, 0x4b, 0x4a, 0x57, 0x4a, 0x45, 0x4b, 0x46, 0x47, 0x3c,
+ 0x43, 0x46, 0x4f, 0x50, 0x4c, 0x53, 0x4f, 0x41, 0x4a, 0x4a, 0x40, 0x4a,
+ 0x3e, 0x4e, 0x4d, 0x41, 0x4a, 0x42, 0x49, 0x4c, 0x51, 0x46, 0x4f, 0x43,
+ 0x4b, 0x41, 0x50, 0x48, 0x4a, 0x40, 0x52, 0x45, 0x40, 0x40, 0x46, 0x48,
+ 0x48, 0x52, 0x52, 0x41, 0x43, 0x49, 0x49, 0x4c, 0x44, 0x48, 0x50, 0x4a,
+ 0x47, 0x48, 0x4c, 0x42, 0x49, 0x48, 0x52, 0x56, 0x4b, 0x41, 0x4e, 0x47,
+ 0x52, 0x56, 0x4e, 0x56, 0x4b, 0x38, 0x50, 0x55, 0x5a, 0x63, 0x51, 0x5a,
+ 0x54, 0x52, 0x44, 0x45, 0x47, 0x5e, 0x4c, 0x4a, 0x5e, 0x71, 0x56, 0x44,
+ 0x4c, 0x4b, 0x4c, 0x4e, 0x49, 0x69, 0x50, 0x53, 0x4d, 0x5c, 0x59, 0x50,
+ 0x36, 0x5d, 0x46, 0x5b, 0x51, 0x55, 0x55, 0x51, 0x36, 0x5a, 0x53, 0x56,
+ 0x54, 0x4a, 0x55, 0x53, 0x3c, 0x52, 0x4a, 0x45, 0x4c, 0x56, 0x49, 0x46,
+ 0x4f, 0x5b, 0x43, 0x4b, 0x49, 0x4c, 0x4b, 0x41, 0x44, 0x4b, 0x47, 0x4b,
+ 0x4b, 0x54, 0x4a, 0x4c, 0x49, 0x44, 0x46, 0x46, 0x48, 0x49, 0x47, 0x4a,
+ 0x40, 0x4e, 0x47, 0x53, 0x4a, 0x47, 0x4a, 0x3b, 0x48, 0x4b, 0x50, 0x51,
+ 0x50, 0x44, 0x4d, 0x49, 0x42, 0x4b, 0x43, 0x48, 0x4a, 0x43, 0x4d, 0x4d,
+ 0x49, 0x4d, 0x43, 0x4f, 0x50, 0x49, 0x47, 0x48, 0x48, 0x4f, 0x49, 0x41,
+ 0x4c, 0x46, 0x47, 0x3e, 0x51, 0x4d, 0x4e, 0x42, 0x3d, 0x53, 0x4d, 0x3b,
+ 0x53, 0x52, 0x4c, 0x4c, 0x43, 0x46, 0x43, 0x3d, 0x53, 0x48, 0x43, 0x4e,
+ 0x45, 0x52, 0x4d, 0x4a, 0x44, 0x49, 0x47, 0x4c, 0x4e, 0x4c, 0x4a, 0x4e,
+ 0x41, 0x48, 0x4b, 0x44, 0x4d, 0x4a, 0x4d, 0x44, 0x4a, 0x45, 0x4f, 0x52,
+ 0x45, 0x3f, 0x4b, 0x48, 0x43, 0x41, 0x3d, 0x53, 0x53, 0x50, 0x4a, 0x56,
+ 0x4d, 0x3e, 0x55, 0x4e, 0x56, 0x5e, 0x52, 0x52, 0x54, 0x50, 0x42, 0x4a,
+ 0x4d, 0x5f, 0x4f, 0x49, 0x5d, 0x6f, 0x55, 0x4a, 0x47, 0x49, 0x4e, 0x4a,
+ 0x43, 0x6e, 0x4e, 0x4f, 0x52, 0x59, 0x62, 0x4b, 0x3e, 0x5c, 0x4c, 0x4e,
+ 0x45, 0x52, 0x43, 0x4d, 0x3c, 0x58, 0x52, 0x49, 0x48, 0x55, 0x53, 0x4e,
+ 0x3d, 0x4e, 0x4c, 0x4b, 0x4b, 0x50, 0x4a, 0x47, 0x45, 0x62, 0x50, 0x49,
+ 0x48, 0x4b, 0x55, 0x45, 0x46, 0x51, 0x41, 0x55, 0x54, 0x55, 0x50, 0x47,
+ 0x46, 0x4d, 0x46, 0x4b, 0x41, 0x49, 0x4c, 0x40, 0x45, 0x4f, 0x52, 0x54,
+ 0x45, 0x4d, 0x53, 0x3a, 0x4c, 0x55, 0x4e, 0x48, 0x44, 0x45, 0x56, 0x3c,
+ 0x48, 0x46, 0x4b, 0x51, 0x53, 0x43, 0x41, 0x49, 0x4c, 0x52, 0x48, 0x42,
+ 0x48, 0x3f, 0x4c, 0x38, 0x46, 0x50, 0x4a, 0x44, 0x50, 0x54, 0x4e, 0x38,
+ 0x48, 0x42, 0x43, 0x4a, 0x4c, 0x44, 0x47, 0x42, 0x42, 0x46, 0x4a, 0x50,
+ 0x47, 0x4b, 0x43, 0x40, 0x44, 0x46, 0x46, 0x4d, 0x50, 0x4a, 0x4e, 0x51,
+ 0x44, 0x40, 0x50, 0x43, 0x52, 0x4d, 0x42, 0x4c, 0x50, 0x41, 0x4a, 0x4e,
+ 0x45, 0x49, 0x4d, 0x40, 0x46, 0x51, 0x43, 0x4b, 0x48, 0x47, 0x42, 0x55,
+ 0x4a, 0x41, 0x4f, 0x49, 0x4f, 0x4e, 0x47, 0x4c, 0x4a, 0x48, 0x50, 0x4e,
+ 0x50, 0x57, 0x4e, 0x56, 0x56, 0x4e, 0x44, 0x48, 0x4a, 0x5b, 0x55, 0x49,
+ 0x59, 0x67, 0x54, 0x46, 0x4f, 0x41, 0x4d, 0x4e, 0x4a, 0x63, 0x4d, 0x44,
+ 0x53, 0x5b, 0x59, 0x4f, 0x43, 0x55, 0x56, 0x4e, 0x55, 0x4c, 0x4b, 0x54,
+ 0x3c, 0x56, 0x4d, 0x50, 0x4f, 0x4a, 0x5a, 0x47, 0x48, 0x56, 0x4f, 0x4f,
+ 0x50, 0x51, 0x48, 0x4e, 0x4d, 0x50, 0x4e, 0x45, 0x4b, 0x48, 0x4e, 0x44,
+ 0x46, 0x4d, 0x43, 0x46, 0x41, 0x59, 0x53, 0x4b, 0x4a, 0x3e, 0x51, 0x47,
+ 0x43, 0x48, 0x52, 0x3f, 0x43, 0x50, 0x4b, 0x4f, 0x41, 0x48, 0x43, 0x2e,
+ 0x4d, 0x4e, 0x4c, 0x45, 0x45, 0x46, 0x4b, 0x43, 0x46, 0x49, 0x46, 0x4d,
+ 0x47, 0x4e, 0x4d, 0x3c, 0x47, 0x4a, 0x52, 0x4e, 0x41, 0x50, 0x43, 0x3a,
+ 0x50, 0x47, 0x4a, 0x45, 0x52, 0x4a, 0x4c, 0x3f, 0x42, 0x3d, 0x49, 0x48,
+ 0x48, 0x4c, 0x42, 0x3a, 0x40, 0x47, 0x46, 0x4e, 0x44, 0x52, 0x46, 0x44,
+ 0x4a, 0x44, 0x43, 0x49, 0x42, 0x45, 0x3f, 0x50, 0x4c, 0x44, 0x48, 0x43,
+ 0x47, 0x4a, 0x48, 0x48, 0x3e, 0x45, 0x43, 0x48, 0x4a, 0x48, 0x53, 0x4b,
+ 0x50, 0x49, 0x43, 0x4d, 0x53, 0x4f, 0x4b, 0x4b, 0x40, 0x42, 0x50, 0x4d,
+ 0x53, 0x4e, 0x44, 0x4d, 0x45, 0x3d, 0x51, 0x51, 0x4f, 0x59, 0x4b, 0x51,
+ 0x4a, 0x4e, 0x42, 0x40, 0x49, 0x5b, 0x4b, 0x43, 0x53, 0x60, 0x47, 0x49,
+ 0x4a, 0x44, 0x44, 0x48, 0x4b, 0x60, 0x51, 0x3f, 0x4b, 0x5b, 0x4f, 0x4a,
+ 0x4a, 0x50, 0x49, 0x46, 0x55, 0x50, 0x4b, 0x4c, 0x40, 0x4e, 0x51, 0x4f,
+ 0x4b, 0x51, 0x54, 0x50, 0x48, 0x4e, 0x4a, 0x4f, 0x4d, 0x4e, 0x54, 0x4d,
+ 0x41, 0x50, 0x4e, 0x47, 0x47, 0x47, 0x54, 0x3b, 0x51, 0x54, 0x50, 0x49,
+ 0x48, 0x4c, 0x4e, 0x47, 0x3f, 0x3c, 0x4c, 0x43, 0x45, 0x42, 0x45, 0x37,
+ 0x41, 0x52, 0x49, 0x47, 0x4e, 0x4a, 0x4b, 0x37, 0x48, 0x4d, 0x4e, 0x4a,
+ 0x42, 0x56, 0x3d, 0x35, 0x48, 0x42, 0x4b, 0x4a, 0x44, 0x52, 0x40, 0x48,
+ 0x4f, 0x49, 0x4f, 0x4c, 0x4d, 0x43, 0x49, 0x38, 0x4b, 0x42, 0x48, 0x42,
+ 0x45, 0x45, 0x54, 0x3a, 0x47, 0x47, 0x52, 0x45, 0x4a, 0x48, 0x47, 0x39,
+ 0x4d, 0x45, 0x54, 0x4b, 0x4e, 0x4f, 0x4e, 0x38, 0x4a, 0x4b, 0x48, 0x45,
+ 0x4e, 0x43, 0x4e, 0x4e, 0x46, 0x4e, 0x4e, 0x50, 0x46, 0x4c, 0x42, 0x45,
+ 0x4b, 0x46, 0x47, 0x4d, 0x49, 0x3f, 0x4f, 0x50, 0x46, 0x4a, 0x47, 0x4e,
+ 0x4a, 0x3e, 0x50, 0x46, 0x47, 0x40, 0x4f, 0x47, 0x51, 0x4b, 0x43, 0x46,
+ 0x4a, 0x42, 0x55, 0x4d, 0x46, 0x63, 0x49, 0x4e, 0x4f, 0x4f, 0x42, 0x45,
+ 0x50, 0x57, 0x49, 0x3e, 0x57, 0x63, 0x45, 0x4a, 0x49, 0x50, 0x41, 0x4a,
+ 0x48, 0x64, 0x4f, 0x42, 0x47, 0x58, 0x4b, 0x45, 0x43, 0x57, 0x49, 0x58,
+ 0x51, 0x51, 0x47, 0x43, 0x51, 0x4b, 0x4a, 0x45, 0x50, 0x54, 0x4d, 0x4d,
+ 0x3e, 0x4a, 0x50, 0x40, 0x51, 0x4f, 0x52, 0x48, 0x53, 0x49, 0x44, 0x4b,
+ 0x51, 0x4b, 0x50, 0x42, 0x4d, 0x49, 0x4a, 0x46, 0x44, 0x50, 0x47, 0x3f,
+ 0x48, 0x47, 0x41, 0x4a, 0x42, 0x52, 0x4a, 0x33, 0x50, 0x50, 0x54, 0x3f,
+ 0x44, 0x4e, 0x51, 0x3c, 0x4e, 0x51, 0x48, 0x4b, 0x47, 0x49, 0x3f, 0x3d,
+ 0x4e, 0x46, 0x4a, 0x41, 0x40, 0x50, 0x49, 0x40, 0x4a, 0x4b, 0x45, 0x50,
+ 0x4e, 0x4d, 0x4b, 0x39, 0x4e, 0x4b, 0x48, 0x3c, 0x47, 0x44, 0x4c, 0x42,
+ 0x45, 0x50, 0x3e, 0x54, 0x4d, 0x49, 0x48, 0x3c, 0x45, 0x42, 0x55, 0x4a,
+ 0x41, 0x4f, 0x40, 0x3f, 0x47, 0x46, 0x46, 0x44, 0x4f, 0x47, 0x46, 0x44,
+ 0x41, 0x40, 0x44, 0x48, 0x3e, 0x3c, 0x46, 0x3e, 0x4a, 0x45, 0x4c, 0x52,
+ 0x47, 0x42, 0x47, 0x3f, 0x47, 0x4e, 0x4b, 0x53, 0x4a, 0x3d, 0x4d, 0x47,
+ 0x4f, 0x3d, 0x4e, 0x43, 0x4f, 0x46, 0x43, 0x43, 0x46, 0x41, 0x4f, 0x42,
+ 0x46, 0x57, 0x4d, 0x51, 0x49, 0x51, 0x4c, 0x44, 0x51, 0x4f, 0x46, 0x44,
+ 0x54, 0x5d, 0x4f, 0x40, 0x59, 0x46, 0x53, 0x46, 0x48, 0x54, 0x43, 0x45,
+ 0x4d, 0x51, 0x4f, 0x44, 0x44, 0x53, 0x49, 0x4e, 0x48, 0x46, 0x44, 0x4a,
+ 0x4a, 0x42, 0x4c, 0x46, 0x54, 0x4f, 0x52, 0x47, 0x46, 0x44, 0x4c, 0x4d,
+ 0x4c, 0x47, 0x4d, 0x40, 0x55, 0x58, 0x46, 0x46, 0x3f, 0x3e, 0x47, 0x36,
+ 0x3f, 0x4d, 0x4b, 0x4d, 0x4f, 0x4f, 0x48, 0x34, 0x4d, 0x46, 0x46, 0x50,
+ 0x50, 0x4b, 0x47, 0x45, 0x4e, 0x49, 0x50, 0x4f, 0x4a, 0x48, 0x4f, 0x39,
+ 0x53, 0x4c, 0x4b, 0x56, 0x45, 0x4f, 0x55, 0x3a, 0x40, 0x53, 0x43, 0x4b,
+ 0x47, 0x3d, 0x4c, 0x34, 0x4b, 0x4e, 0x4a, 0x4b, 0x4d, 0x49, 0x4e, 0x40,
+ 0x4d, 0x48, 0x40, 0x4a, 0x4a, 0x4b, 0x4a, 0x42, 0x4c, 0x52, 0x43, 0x42,
+ 0x44, 0x3f, 0x4e, 0x42, 0x44, 0x45, 0x40, 0x3d, 0x4b, 0x45, 0x4a, 0x43,
+ 0x4b, 0x4b, 0x4e, 0x46, 0x55, 0x43, 0x44, 0x3f, 0x44, 0x43, 0x4b, 0x4b,
+ 0x45, 0x51, 0x48, 0x49, 0x3d, 0x44, 0x4a, 0x4a, 0x50, 0x50, 0x47, 0x44,
+ 0x4f, 0x3e, 0x3f, 0x43, 0x4c, 0x46, 0x4a, 0x4e, 0x4c, 0x52, 0x48, 0x4e,
+ 0x48, 0x46, 0x45, 0x48, 0x41, 0x4f, 0x51, 0x48, 0x40, 0x4d, 0x4a, 0x4b,
+ 0x4c, 0x51, 0x49, 0x50, 0x4e, 0x4b, 0x4a, 0x42, 0x49, 0x54, 0x4e, 0x43,
+ 0x52, 0x47, 0x4a, 0x41, 0x42, 0x51, 0x48, 0x4a, 0x46, 0x45, 0x4a, 0x43,
+ 0x4e, 0x4f, 0x41, 0x49, 0x4b, 0x42, 0x40, 0x4a, 0x50, 0x41, 0x42, 0x3f,
+ 0x49, 0x4a, 0x40, 0x3e, 0x3f, 0x42, 0x4d, 0x51, 0x4e, 0x4e, 0x47, 0x41,
+ 0x4e, 0x4e, 0x49, 0x4b, 0x41, 0x45, 0x51, 0x40, 0x45, 0x4c, 0x3f, 0x42,
+ 0x4c, 0x45, 0x4d, 0x39, 0x46, 0x52, 0x4a, 0x4e, 0x4c, 0x49, 0x4e, 0x43,
+ 0x43, 0x4c, 0x48, 0x46, 0x48, 0x49, 0x50, 0x3a, 0x3f, 0x49, 0x42, 0x4f,
+ 0x42, 0x4d, 0x4e, 0x3f, 0x51, 0x4b, 0x4e, 0x4b, 0x51, 0x44, 0x43, 0x4a,
+ 0x4a, 0x4c, 0x50, 0x48, 0x45, 0x47, 0x4d, 0x41, 0x47, 0x45, 0x51, 0x41,
+ 0x42, 0x48, 0x4c, 0x39, 0x51, 0x45, 0x46, 0x53, 0x4b, 0x50, 0x46, 0x45,
+ 0x4b, 0x4d, 0x42, 0x4b, 0x3f, 0x45, 0x4b, 0x4e, 0x50, 0x50, 0x47, 0x4a,
+ 0x45, 0x40, 0x4b, 0x43, 0x3f, 0x4a, 0x41, 0x42, 0x51, 0x41, 0x4d, 0x42,
+ 0x53, 0x48, 0x48, 0x49, 0x4b, 0x40, 0x42, 0x3d, 0x4f, 0x53, 0x49, 0x46,
+ 0x46, 0x43, 0x42, 0x44, 0x46, 0x48, 0x3f, 0x46, 0x31, 0x43, 0x4d, 0x4b,
+ 0x48, 0x4d, 0x4c, 0x43, 0x45, 0x53, 0x50, 0x40, 0x4a, 0x48, 0x45, 0x3b,
+ 0x4f, 0x4d, 0x53, 0x4c, 0x44, 0x54, 0x50, 0x66, 0x3f, 0x45, 0x4c, 0x4c,
+ 0x4a, 0x49, 0x49, 0x4a, 0x40, 0x52, 0x3e, 0x4c, 0x49, 0x40, 0x44, 0x49,
+ 0x48, 0x3f, 0x45, 0x5b, 0x49, 0x4b, 0x4c, 0x44, 0x50, 0x4e, 0x4a, 0x4a,
+ 0x49, 0x4e, 0x4f, 0x47, 0x46, 0x4b, 0x44, 0x3b, 0x4e, 0x4b, 0x48, 0x46,
+ 0x45, 0x45, 0x3d, 0x35, 0x4c, 0x49, 0x54, 0x42, 0x51, 0x46, 0x49, 0x2d,
+ 0x43, 0x4a, 0x53, 0x49, 0x49, 0x42, 0x4f, 0x40, 0x4e, 0x50, 0x54, 0x51,
+ 0x4b, 0x45, 0x48, 0x35, 0x4d, 0x41, 0x51, 0x40, 0x41, 0x49, 0x4a, 0x3b,
+ 0x45, 0x50, 0x48, 0x51, 0x51, 0x4d, 0x4c, 0x36, 0x47, 0x4a, 0x44, 0x45,
+ 0x4d, 0x47, 0x43, 0x3a, 0x48, 0x40, 0x42, 0x4f, 0x4f, 0x4f, 0x4f, 0x43,
+ 0x4a, 0x41, 0x4b, 0x53, 0x43, 0x46, 0x4f, 0x39, 0x46, 0x4a, 0x4d, 0x53,
+ 0x41, 0x44, 0x4e, 0x44, 0x3f, 0x47, 0x4c, 0x4d, 0x4d, 0x43, 0x45, 0x3d,
+ 0x43, 0x4b, 0x3e, 0x48, 0x42, 0x4c, 0x47, 0x42, 0x42, 0x50, 0x49, 0x4b,
+ 0x43, 0x4e, 0x44, 0x44, 0x4c, 0x3d, 0x4c, 0x47, 0x4e, 0x42, 0x4b, 0x44,
+ 0x4b, 0x44, 0x3f, 0x49, 0x33, 0x46, 0x4a, 0x4a, 0x42, 0x57, 0x5e, 0x4a,
+ 0x46, 0x4f, 0x55, 0x3c, 0x4a, 0x4b, 0x4c, 0x43, 0x51, 0x59, 0x64, 0x51,
+ 0x45, 0x60, 0x4b, 0x65, 0x46, 0x4a, 0x4e, 0x49, 0x41, 0x4b, 0x50, 0x5c,
+ 0x48, 0x4b, 0x3e, 0x52, 0x4f, 0x2f, 0x4e, 0x4a, 0x45, 0x53, 0x48, 0x59,
+ 0x4c, 0x4e, 0x4a, 0x4d, 0x49, 0x40, 0x52, 0x44, 0x49, 0x46, 0x4e, 0x46,
+ 0x42, 0x4b, 0x4a, 0x4b, 0x4b, 0x4b, 0x4f, 0x52, 0x46, 0x50, 0x4d, 0x3d,
+ 0x46, 0x4b, 0x4b, 0x40, 0x4d, 0x3f, 0x43, 0x33, 0x4e, 0x53, 0x4b, 0x4a,
+ 0x45, 0x48, 0x4c, 0x2e, 0x48, 0x4f, 0x49, 0x42, 0x54, 0x4f, 0x4b, 0x2b,
+ 0x55, 0x4e, 0x43, 0x4d, 0x4d, 0x47, 0x42, 0x3e, 0x48, 0x48, 0x4d, 0x54,
+ 0x52, 0x4f, 0x43, 0x37, 0x4b, 0x42, 0x4b, 0x4e, 0x49, 0x49, 0x4b, 0x2e,
+ 0x45, 0x4e, 0x48, 0x4e, 0x44, 0x49, 0x48, 0x30, 0x4c, 0x4b, 0x3f, 0x42,
+ 0x4f, 0x4f, 0x4e, 0x38, 0x4f, 0x42, 0x54, 0x49, 0x41, 0x42, 0x45, 0x3a,
+ 0x47, 0x43, 0x43, 0x4b, 0x49, 0x40, 0x4d, 0x38, 0x52, 0x4c, 0x3d, 0x4d,
+ 0x43, 0x54, 0x4e, 0x41, 0x4a, 0x47, 0x44, 0x51, 0x47, 0x48, 0x41, 0x47,
+ 0x4d, 0x41, 0x46, 0x4c, 0x4d, 0x46, 0x51, 0x4a, 0x49, 0x46, 0x4a, 0x42,
+ 0x3a, 0x43, 0x4a, 0x4b, 0x43, 0x4c, 0x68, 0x44, 0x4b, 0x52, 0x50, 0x37,
+ 0x4d, 0x4c, 0x57, 0x4c, 0x68, 0x62, 0x64, 0x4a, 0x3e, 0x64, 0x4b, 0x66,
+ 0x48, 0x4d, 0x54, 0x57, 0x4b, 0x52, 0x49, 0x5c, 0x4d, 0x55, 0x51, 0x57,
+ 0x4c, 0x3a, 0x48, 0x43, 0x3b, 0x43, 0x52, 0x5d, 0x45, 0x4e, 0x51, 0x4d,
+ 0x4a, 0x55, 0x4e, 0x4c, 0x44, 0x51, 0x4c, 0x4f, 0x41, 0x4f, 0x4a, 0x43,
+ 0x53, 0x48, 0x47, 0x49, 0x46, 0x52, 0x48, 0x3e, 0x4b, 0x4e, 0x4a, 0x50,
+ 0x4f, 0x47, 0x3e, 0x2e, 0x4b, 0x51, 0x4a, 0x44, 0x4c, 0x49, 0x4f, 0x26,
+ 0x48, 0x4f, 0x44, 0x51, 0x48, 0x3f, 0x4c, 0x30, 0x4e, 0x48, 0x4d, 0x48,
+ 0x48, 0x44, 0x4b, 0x2f, 0x50, 0x41, 0x4d, 0x50, 0x52, 0x42, 0x45, 0x33,
+ 0x4c, 0x48, 0x48, 0x3d, 0x46, 0x41, 0x43, 0x38, 0x45, 0x4f, 0x48, 0x4b,
+ 0x41, 0x49, 0x4c, 0x2f, 0x53, 0x4c, 0x48, 0x4a, 0x47, 0x40, 0x4a, 0x31,
+ 0x52, 0x40, 0x49, 0x4c, 0x3f, 0x48, 0x48, 0x39, 0x48, 0x3f, 0x45, 0x43,
+ 0x40, 0x48, 0x3c, 0x40, 0x4c, 0x48, 0x48, 0x4d, 0x3e, 0x42, 0x4a, 0x3d,
+ 0x4c, 0x45, 0x44, 0x46, 0x44, 0x45, 0x4a, 0x47, 0x52, 0x48, 0x4a, 0x4d,
+ 0x3f, 0x49, 0x4c, 0x4c, 0x48, 0x44, 0x4c, 0x44, 0x3d, 0x41, 0x47, 0x45,
+ 0x43, 0x4a, 0x5a, 0x3f, 0x48, 0x5d, 0x50, 0x35, 0x47, 0x4f, 0x5b, 0x46,
+ 0x6e, 0x50, 0x6d, 0x44, 0x49, 0x6a, 0x53, 0x6b, 0x4b, 0x4b, 0x4f, 0x62,
+ 0x45, 0x57, 0x48, 0x5b, 0x40, 0x4b, 0x4f, 0x63, 0x48, 0x3a, 0x4b, 0x42,
+ 0x43, 0x53, 0x41, 0x5f, 0x54, 0x3e, 0x4d, 0x43, 0x3d, 0x4c, 0x46, 0x46,
+ 0x49, 0x56, 0x4b, 0x45, 0x47, 0x45, 0x4e, 0x4f, 0x4c, 0x4d, 0x4f, 0x47,
+ 0x49, 0x4b, 0x51, 0x33, 0x4b, 0x45, 0x4d, 0x41, 0x51, 0x4a, 0x43, 0x2a,
+ 0x50, 0x4b, 0x4a, 0x4b, 0x4c, 0x52, 0x4c, 0x3b, 0x45, 0x4c, 0x51, 0x44,
+ 0x4c, 0x48, 0x43, 0x35, 0x51, 0x50, 0x48, 0x49, 0x3f, 0x48, 0x3d, 0x3b,
+ 0x52, 0x3f, 0x42, 0x4b, 0x49, 0x49, 0x47, 0x38, 0x4a, 0x4a, 0x41, 0x52,
+ 0x41, 0x3e, 0x4b, 0x2f, 0x46, 0x4d, 0x49, 0x44, 0x46, 0x3b, 0x47, 0x36,
+ 0x46, 0x3f, 0x49, 0x48, 0x47, 0x42, 0x42, 0x35, 0x44, 0x4b, 0x4d, 0x56,
+ 0x50, 0x49, 0x43, 0x42, 0x4b, 0x3e, 0x53, 0x44, 0x4a, 0x43, 0x47, 0x38,
+ 0x4a, 0x45, 0x4d, 0x3f, 0x46, 0x4a, 0x47, 0x3a, 0x4c, 0x3e, 0x47, 0x45,
+ 0x46, 0x4b, 0x45, 0x49, 0x4a, 0x4b, 0x54, 0x49, 0x4a, 0x53, 0x4a, 0x4c,
+ 0x45, 0x48, 0x53, 0x42, 0x4b, 0x47, 0x4e, 0x50, 0x3d, 0x51, 0x60, 0x3e,
+ 0x53, 0x5d, 0x51, 0x30, 0x45, 0x50, 0x59, 0x4e, 0x62, 0x52, 0x68, 0x51,
+ 0x45, 0x6c, 0x4c, 0x64, 0x4d, 0x47, 0x55, 0x61, 0x44, 0x57, 0x44, 0x58,
+ 0x44, 0x4a, 0x53, 0x58, 0x47, 0x31, 0x3f, 0x4c, 0x43, 0x45, 0x48, 0x5e,
+ 0x41, 0x43, 0x3f, 0x43, 0x51, 0x46, 0x48, 0x4b, 0x4d, 0x5b, 0x45, 0x4b,
+ 0x48, 0x46, 0x3f, 0x45, 0x47, 0x45, 0x40, 0x4a, 0x51, 0x51, 0x3d, 0x3f,
+ 0x43, 0x45, 0x4d, 0x4a, 0x47, 0x50, 0x49, 0x32, 0x4c, 0x5a, 0x55, 0x4f,
+ 0x4c, 0x51, 0x43, 0x37, 0x40, 0x59, 0x49, 0x49, 0x4e, 0x4f, 0x47, 0x34,
+ 0x40, 0x4c, 0x4a, 0x41, 0x4a, 0x47, 0x4a, 0x42, 0x4e, 0x4a, 0x48, 0x4e,
+ 0x4e, 0x4e, 0x45, 0x39, 0x4e, 0x45, 0x45, 0x4e, 0x4c, 0x48, 0x4a, 0x35,
+ 0x45, 0x4c, 0x49, 0x4f, 0x51, 0x43, 0x3c, 0x3a, 0x4a, 0x4a, 0x46, 0x48,
+ 0x49, 0x42, 0x4e, 0x2f, 0x42, 0x4e, 0x45, 0x50, 0x51, 0x40, 0x45, 0x32,
+ 0x4a, 0x4d, 0x44, 0x4e, 0x48, 0x48, 0x47, 0x2f, 0x48, 0x4b, 0x49, 0x44,
+ 0x48, 0x4d, 0x46, 0x3b, 0x46, 0x4a, 0x41, 0x4e, 0x4e, 0x47, 0x54, 0x4b,
+ 0x45, 0x49, 0x45, 0x44, 0x45, 0x48, 0x4a, 0x46, 0x55, 0x49, 0x47, 0x49,
+ 0x4b, 0x42, 0x48, 0x4f, 0x3f, 0x52, 0x60, 0x39, 0x4b, 0x5e, 0x55, 0x2e,
+ 0x48, 0x50, 0x59, 0x4f, 0x68, 0x5f, 0x64, 0x4f, 0x3b, 0x71, 0x50, 0x63,
+ 0x4f, 0x50, 0x50, 0x6c, 0x4b, 0x55, 0x47, 0x5b, 0x4c, 0x40, 0x48, 0x59,
+ 0x4f, 0x2e, 0x4b, 0x4c, 0x4e, 0x4e, 0x46, 0x61, 0x50, 0x41, 0x4c, 0x4a,
+ 0x44, 0x3e, 0x3f, 0x47, 0x4b, 0x4f, 0x47, 0x4b, 0x47, 0x3d, 0x41, 0x49,
+ 0x49, 0x3f, 0x4d, 0x44, 0x4a, 0x4d, 0x45, 0x41, 0x4d, 0x43, 0x49, 0x3c,
+ 0x49, 0x57, 0x49, 0x3b, 0x49, 0x59, 0x3f, 0x4f, 0x4e, 0x49, 0x4e, 0x46,
+ 0x52, 0x4e, 0x4c, 0x54, 0x4a, 0x48, 0x48, 0x3a, 0x44, 0x4a, 0x4f, 0x4a,
+ 0x44, 0x4b, 0x43, 0x4d, 0x51, 0x42, 0x53, 0x4d, 0x52, 0x41, 0x4d, 0x43,
+ 0x4e, 0x54, 0x4b, 0x42, 0x4b, 0x3f, 0x53, 0x45, 0x3f, 0x4a, 0x45, 0x50,
+ 0x3f, 0x4c, 0x4f, 0x43, 0x46, 0x42, 0x4b, 0x4d, 0x4c, 0x3b, 0x48, 0x40,
+ 0x4e, 0x4e, 0x49, 0x46, 0x4d, 0x4d, 0x52, 0x40, 0x4e, 0x4f, 0x46, 0x4a,
+ 0x40, 0x4b, 0x4c, 0x40, 0x4f, 0x4a, 0x44, 0x41, 0x46, 0x3c, 0x40, 0x3d,
+ 0x44, 0x48, 0x4a, 0x50, 0x46, 0x53, 0x46, 0x40, 0x44, 0x3e, 0x47, 0x43,
+ 0x48, 0x3d, 0x4e, 0x3e, 0x48, 0x49, 0x4b, 0x49, 0x4c, 0x3e, 0x4c, 0x4a,
+ 0x46, 0x4e, 0x62, 0x3c, 0x59, 0x60, 0x51, 0x29, 0x47, 0x52, 0x59, 0x4c,
+ 0x67, 0x68, 0x68, 0x4e, 0x3b, 0x72, 0x4d, 0x68, 0x44, 0x4f, 0x53, 0x63,
+ 0x47, 0x5a, 0x45, 0x4f, 0x4b, 0x37, 0x43, 0x5b, 0x4b, 0x3d, 0x44, 0x41,
+ 0x4a, 0x4b, 0x3c, 0x64, 0x48, 0x38, 0x42, 0x3f, 0x48, 0x46, 0x4b, 0x46,
+ 0x46, 0x4f, 0x46, 0x46, 0x44, 0x3c, 0x4b, 0x4f, 0x4d, 0x4a, 0x4b, 0x46,
+ 0x4d, 0x4f, 0x4f, 0x3f, 0x3a, 0x4b, 0x55, 0x3c, 0x51, 0x56, 0x4d, 0x42,
+ 0x52, 0x5a, 0x3e, 0x4b, 0x54, 0x57, 0x4e, 0x4d, 0x4e, 0x5b, 0x4e, 0x49,
+ 0x4e, 0x3c, 0x40, 0x41, 0x40, 0x4d, 0x48, 0x42, 0x49, 0x4e, 0x4f, 0x47,
+ 0x47, 0x48, 0x50, 0x49, 0x51, 0x46, 0x44, 0x45, 0x49, 0x46, 0x43, 0x48,
+ 0x48, 0x49, 0x4d, 0x4c, 0x45, 0x4f, 0x4c, 0x45, 0x44, 0x40, 0x49, 0x45,
+ 0x49, 0x51, 0x4b, 0x4b, 0x50, 0x4b, 0x48, 0x3d, 0x4e, 0x52, 0x4a, 0x47,
+ 0x49, 0x41, 0x55, 0x3d, 0x48, 0x4d, 0x49, 0x48, 0x4e, 0x4c, 0x48, 0x3d,
+ 0x3f, 0x4c, 0x4e, 0x53, 0x3e, 0x48, 0x4a, 0x3f, 0x54, 0x4d, 0x54, 0x4b,
+ 0x47, 0x4e, 0x44, 0x48, 0x49, 0x4b, 0x4c, 0x49, 0x4d, 0x42, 0x52, 0x4b,
+ 0x40, 0x3e, 0x54, 0x49, 0x55, 0x45, 0x47, 0x4d, 0x45, 0x5c, 0x60, 0x40,
+ 0x57, 0x60, 0x5b, 0x27, 0x4a, 0x5a, 0x64, 0x53, 0x6a, 0x5a, 0x5f, 0x52,
+ 0x3a, 0x72, 0x4b, 0x5f, 0x45, 0x56, 0x5f, 0x5f, 0x54, 0x5f, 0x39, 0x52,
+ 0x51, 0x3e, 0x3b, 0x5a, 0x44, 0x32, 0x46, 0x50, 0x3a, 0x4f, 0x44, 0x5d,
+ 0x4c, 0x41, 0x39, 0x3f, 0x45, 0x46, 0x3b, 0x43, 0x46, 0x51, 0x3c, 0x4c,
+ 0x4b, 0x43, 0x4b, 0x51, 0x43, 0x48, 0x4d, 0x43, 0x38, 0x46, 0x46, 0x43,
+ 0x44, 0x4a, 0x46, 0x49, 0x48, 0x50, 0x4e, 0x4a, 0x4e, 0x58, 0x4a, 0x49,
+ 0x48, 0x4f, 0x4a, 0x49, 0x41, 0x57, 0x51, 0x50, 0x4b, 0x48, 0x47, 0x4b,
+ 0x53, 0x3d, 0x4b, 0x4c, 0x4b, 0x4b, 0x55, 0x56, 0x45, 0x49, 0x46, 0x4c,
+ 0x45, 0x51, 0x47, 0x50, 0x40, 0x4b, 0x4f, 0x4b, 0x4d, 0x4a, 0x4f, 0x50,
+ 0x49, 0x53, 0x50, 0x46, 0x40, 0x48, 0x4a, 0x4a, 0x49, 0x4a, 0x42, 0x45,
+ 0x4b, 0x45, 0x42, 0x45, 0x4e, 0x4e, 0x44, 0x41, 0x4b, 0x4a, 0x49, 0x3f,
+ 0x41, 0x51, 0x48, 0x4c, 0x40, 0x41, 0x51, 0x42, 0x49, 0x49, 0x48, 0x42,
+ 0x48, 0x4c, 0x4b, 0x3c, 0x49, 0x45, 0x42, 0x49, 0x4c, 0x46, 0x45, 0x43,
+ 0x43, 0x48, 0x48, 0x41, 0x43, 0x42, 0x4c, 0x4b, 0x40, 0x45, 0x44, 0x46,
+ 0x4c, 0x4b, 0x4e, 0x4d, 0x3f, 0x59, 0x55, 0x41, 0x56, 0x5a, 0x51, 0x30,
+ 0x49, 0x5a, 0x63, 0x4d, 0x61, 0x5b, 0x64, 0x55, 0x34, 0x7a, 0x4c, 0x62,
+ 0x3e, 0x5d, 0x56, 0x60, 0x48, 0x61, 0x3f, 0x54, 0x46, 0x40, 0x42, 0x56,
+ 0x52, 0x35, 0x4c, 0x59, 0x45, 0x4c, 0x42, 0x60, 0x49, 0x3f, 0x4c, 0x3c,
+ 0x52, 0x36, 0x46, 0x3d, 0x58, 0x4b, 0x41, 0x48, 0x3e, 0x45, 0x4e, 0x54,
+ 0x4c, 0x56, 0x47, 0x44, 0x39, 0x4a, 0x4a, 0x4a, 0x46, 0x48, 0x4a, 0x48,
+ 0x51, 0x4f, 0x4b, 0x49, 0x45, 0x4b, 0x44, 0x4c, 0x3e, 0x4c, 0x42, 0x59,
+ 0x47, 0x55, 0x47, 0x47, 0x41, 0x44, 0x44, 0x4a, 0x44, 0x4b, 0x44, 0x46,
+ 0x49, 0x5a, 0x48, 0x5d, 0x4f, 0x4a, 0x47, 0x50, 0x48, 0x4e, 0x44, 0x57,
+ 0x49, 0x46, 0x42, 0x4d, 0x3d, 0x4a, 0x4a, 0x58, 0x41, 0x4d, 0x3c, 0x47,
+ 0x42, 0x4e, 0x4d, 0x49, 0x44, 0x4b, 0x4c, 0x4b, 0x53, 0x42, 0x4a, 0x46,
+ 0x4e, 0x56, 0x4b, 0x47, 0x50, 0x43, 0x4f, 0x48, 0x49, 0x50, 0x48, 0x50,
+ 0x42, 0x4c, 0x4e, 0x3c, 0x41, 0x4f, 0x4a, 0x41, 0x44, 0x47, 0x4c, 0x42,
+ 0x51, 0x4f, 0x53, 0x46, 0x4c, 0x4b, 0x48, 0x51, 0x47, 0x4b, 0x4c, 0x4d,
+ 0x4d, 0x49, 0x3d, 0x44, 0x4b, 0x42, 0x43, 0x49, 0x51, 0x47, 0x4c, 0x4b,
+ 0x4a, 0x50, 0x5b, 0x43, 0x5b, 0x68, 0x54, 0x31, 0x4c, 0x5d, 0x5c, 0x54,
+ 0x63, 0x5a, 0x61, 0x54, 0x3d, 0x7a, 0x51, 0x5b, 0x40, 0x59, 0x5a, 0x62,
+ 0x4c, 0x5e, 0x42, 0x58, 0x49, 0x3c, 0x38, 0x50, 0x54, 0x37, 0x42, 0x51,
+ 0x4d, 0x4f, 0x42, 0x68, 0x4a, 0x40, 0x4e, 0x40, 0x3f, 0x3e, 0x3f, 0x40,
+ 0x54, 0x52, 0x3e, 0x43, 0x46, 0x4a, 0x48, 0x51, 0x4e, 0x4d, 0x42, 0x47,
+ 0x3f, 0x51, 0x47, 0x44, 0x3f, 0x4c, 0x46, 0x47, 0x4f, 0x55, 0x4b, 0x4e,
+ 0x4c, 0x51, 0x40, 0x51, 0x47, 0x4a, 0x44, 0x5c, 0x48, 0x54, 0x4b, 0x46,
+ 0x49, 0x4b, 0x53, 0x59, 0x43, 0x3e, 0x45, 0x4e, 0x4f, 0x58, 0x4b, 0x64,
+ 0x41, 0x4b, 0x45, 0x4a, 0x4c, 0x51, 0x47, 0x57, 0x45, 0x46, 0x43, 0x4f,
+ 0x4d, 0x4d, 0x49, 0x58, 0x4b, 0x52, 0x43, 0x4b, 0x45, 0x4c, 0x50, 0x4c,
+ 0x4e, 0x4b, 0x40, 0x4c, 0x44, 0x4e, 0x4c, 0x47, 0x41, 0x55, 0x45, 0x4a,
+ 0x4c, 0x48, 0x46, 0x41, 0x47, 0x52, 0x44, 0x4f, 0x48, 0x49, 0x4b, 0x47,
+ 0x50, 0x4f, 0x42, 0x4a, 0x44, 0x4b, 0x52, 0x43, 0x45, 0x4e, 0x46, 0x49,
+ 0x45, 0x52, 0x51, 0x45, 0x44, 0x41, 0x4c, 0x46, 0x4c, 0x4b, 0x44, 0x4d,
+ 0x4f, 0x48, 0x44, 0x4d, 0x56, 0x48, 0x50, 0x4f, 0x3b, 0x4e, 0x55, 0x43,
+ 0x52, 0x62, 0x57, 0x2c, 0x4d, 0x5e, 0x5e, 0x50, 0x64, 0x5b, 0x6a, 0x55,
+ 0x39, 0x7d, 0x4b, 0x5e, 0x43, 0x54, 0x5d, 0x5c, 0x4d, 0x5c, 0x42, 0x51,
+ 0x4c, 0x3d, 0x46, 0x51, 0x4c, 0x2a, 0x3e, 0x54, 0x47, 0x48, 0x46, 0x64,
+ 0x42, 0x3d, 0x47, 0x3f, 0x42, 0x45, 0x49, 0x3b, 0x59, 0x50, 0x4c, 0x46,
+ 0x4d, 0x44, 0x47, 0x4d, 0x4a, 0x50, 0x41, 0x48, 0x43, 0x50, 0x3e, 0x44,
+ 0x4b, 0x53, 0x48, 0x49, 0x51, 0x51, 0x4d, 0x57, 0x49, 0x4f, 0x53, 0x50,
+ 0x46, 0x4f, 0x41, 0x5d, 0x47, 0x46, 0x49, 0x51, 0x45, 0x41, 0x4a, 0x56,
+ 0x4f, 0x4e, 0x4d, 0x4a, 0x3e, 0x55, 0x47, 0x65, 0x48, 0x51, 0x4d, 0x4e,
+ 0x46, 0x43, 0x48, 0x5b, 0x48, 0x4f, 0x4f, 0x48, 0x4b, 0x4d, 0x4e, 0x5c,
+ 0x4f, 0x4c, 0x54, 0x48, 0x4a, 0x4d, 0x4e, 0x4e, 0x44, 0x48, 0x43, 0x52,
+ 0x41, 0x52, 0x48, 0x4f, 0x46, 0x4f, 0x51, 0x41, 0x44, 0x45, 0x41, 0x4b,
+ 0x43, 0x4e, 0x4e, 0x42, 0x48, 0x41, 0x45, 0x43, 0x44, 0x43, 0x4c, 0x4c,
+ 0x51, 0x54, 0x4c, 0x32, 0x46, 0x52, 0x4e, 0x49, 0x40, 0x4d, 0x43, 0x4f,
+ 0x4a, 0x4d, 0x4d, 0x49, 0x46, 0x4c, 0x41, 0x4d, 0x41, 0x3a, 0x50, 0x4c,
+ 0x5a, 0x4e, 0x49, 0x53, 0x4d, 0x53, 0x53, 0x3d, 0x52, 0x64, 0x55, 0x2a,
+ 0x47, 0x5d, 0x61, 0x51, 0x5b, 0x5d, 0x66, 0x52, 0x3f, 0xfd, 0x55, 0x5a,
+ 0x4b, 0x54, 0x5b, 0x60, 0x49, 0x5d, 0x43, 0x57, 0x47, 0x41, 0x45, 0x5e,
+ 0x4c, 0x28, 0x3e, 0x40, 0x49, 0x4e, 0x40, 0x69, 0x4a, 0x44, 0x45, 0x43,
+ 0x45, 0x3d, 0x39, 0x40, 0x4c, 0x53, 0x4b, 0x3d, 0x4e, 0x43, 0x48, 0x55,
+ 0x4d, 0x50, 0x4d, 0x49, 0x4f, 0x48, 0x3e, 0x46, 0x47, 0x56, 0x40, 0x48,
+ 0x46, 0x53, 0x50, 0x5d, 0x43, 0x54, 0x49, 0x47, 0x49, 0x4c, 0x48, 0x5d,
+ 0x49, 0x51, 0x50, 0x3d, 0x41, 0x47, 0x48, 0x64, 0x4b, 0x44, 0x49, 0x41,
+ 0x54, 0x48, 0x3d, 0x6b, 0x4c, 0x5a, 0x48, 0x4e, 0x40, 0x4c, 0x52, 0x5f,
+ 0x54, 0x4a, 0x3f, 0x48, 0x43, 0x43, 0x44, 0x66, 0x49, 0x47, 0x43, 0x46,
+ 0x47, 0x54, 0x42, 0x54, 0x4b, 0x4e, 0x49, 0x49, 0x49, 0x4b, 0x52, 0x4f,
+ 0x43, 0x46, 0x4b, 0x49, 0x54, 0x4b, 0x40, 0x48, 0x47, 0x4a, 0x46, 0x47,
+ 0x44, 0x47, 0x4c, 0x37, 0x3f, 0x49, 0x45, 0x44, 0x50, 0x49, 0x44, 0x36,
+ 0x4d, 0x40, 0x45, 0x49, 0x53, 0x55, 0x44, 0x42, 0x47, 0x48, 0x46, 0x40,
+ 0x4f, 0x4c, 0x41, 0x42, 0x52, 0x3a, 0x43, 0x46, 0x55, 0x51, 0x4e, 0x4f,
+ 0x48, 0x51, 0x55, 0x48, 0x52, 0x66, 0x4e, 0x33, 0x49, 0x5b, 0x5f, 0x4b,
+ 0x5f, 0x5b, 0x66, 0x52, 0x41, 0x7c, 0x4a, 0x59, 0x47, 0x59, 0x58, 0x67,
+ 0x49, 0x5e, 0x44, 0x57, 0x49, 0x4c, 0x43, 0x56, 0x41, 0x27, 0x4c, 0x44,
+ 0x51, 0x44, 0x42, 0x65, 0x49, 0x44, 0x40, 0x3d, 0x4d, 0x3e, 0x4c, 0x3c,
+ 0x4f, 0x4b, 0x45, 0x44, 0x4d, 0x48, 0x47, 0x54, 0x4d, 0x4e, 0x44, 0x42,
+ 0x47, 0x44, 0x3d, 0x49, 0x4e, 0x50, 0x49, 0x45, 0x58, 0x4a, 0x54, 0x5c,
+ 0x41, 0x49, 0x4f, 0x42, 0x44, 0x4f, 0x4a, 0x62, 0x48, 0x50, 0x48, 0x43,
+ 0x51, 0x53, 0x47, 0x6c, 0x40, 0x46, 0x3d, 0x46, 0x4a, 0x50, 0x43, 0x69,
+ 0x49, 0x4f, 0x4a, 0x4c, 0x49, 0x46, 0x43, 0x6a, 0x48, 0x50, 0x49, 0x48,
+ 0x48, 0x51, 0x4b, 0x65, 0x42, 0x4b, 0x4d, 0x48, 0x44, 0x4e, 0x49, 0x60,
+ 0x44, 0x52, 0x42, 0x42, 0x47, 0x48, 0x4b, 0x51, 0x50, 0x4b, 0x3c, 0x4d,
+ 0x4c, 0x44, 0x48, 0x55, 0x51, 0x4c, 0x55, 0x4e, 0x52, 0x4c, 0x4b, 0x39,
+ 0x48, 0x42, 0x49, 0x49, 0x49, 0x50, 0x49, 0x32, 0x4e, 0x4b, 0x45, 0x4f,
+ 0x42, 0x4b, 0x47, 0x50, 0x48, 0x45, 0x54, 0x49, 0x4c, 0x46, 0x40, 0x46,
+ 0x43, 0x3d, 0x51, 0x44, 0x53, 0x4f, 0x54, 0x55, 0x43, 0x4f, 0x5b, 0x47,
+ 0x53, 0x6c, 0x57, 0x2e, 0x50, 0x55, 0x5a, 0x4d, 0x57, 0x5d, 0x70, 0x50,
+ 0x3f, 0x79, 0x4a, 0x5a, 0x4c, 0x58, 0x59, 0x63, 0x45, 0x69, 0x48, 0x58,
+ 0x42, 0x4b, 0x43, 0x5c, 0x46, 0x28, 0x48, 0x49, 0x4c, 0x3f, 0x45, 0x58,
+ 0x45, 0x44, 0x47, 0x40, 0x4c, 0x42, 0x3e, 0x37, 0x45, 0x54, 0x48, 0x3b,
+ 0x4e, 0x48, 0x43, 0x4a, 0x50, 0x4a, 0x49, 0x46, 0x4c, 0x54, 0x3f, 0x4b,
+ 0x4e, 0x56, 0x48, 0x49, 0x49, 0x4c, 0x51, 0x5f, 0x4d, 0x4b, 0x43, 0x4d,
+ 0x47, 0x51, 0x43, 0x59, 0x45, 0x4e, 0x4f, 0x45, 0x44, 0x54, 0x44, 0x6d,
+ 0x47, 0x51, 0x43, 0x4e, 0x4c, 0x4f, 0x43, 0x6d, 0x48, 0x53, 0x4b, 0x47,
+ 0x49, 0x48, 0x46, 0x6a, 0x51, 0x4c, 0x4d, 0x45, 0x4e, 0x47, 0x46, 0x62,
+ 0x4a, 0x54, 0x51, 0x4c, 0x47, 0x4d, 0x4a, 0x61, 0x3d, 0x50, 0x4c, 0x4c,
+ 0x45, 0x3f, 0x3e, 0x54, 0x3d, 0x53, 0x48, 0x47, 0x52, 0x4b, 0x47, 0x51,
+ 0x4f, 0x45, 0x4b, 0x4a, 0x4c, 0x46, 0x44, 0x37, 0x42, 0x50, 0x49, 0x4f,
+ 0x51, 0x41, 0x44, 0x38, 0x54, 0x40, 0x51, 0x52, 0x3e, 0x43, 0x44, 0x47,
+ 0x49, 0x4b, 0x4b, 0x46, 0x53, 0x54, 0x55, 0x4b, 0x4a, 0x37, 0x43, 0x4a,
+ 0x51, 0x47, 0x51, 0x54, 0x43, 0x46, 0x56, 0x3d, 0x54, 0x66, 0x4f, 0x30,
+ 0x45, 0x52, 0x5a, 0x43, 0x5c, 0x65, 0x5d, 0x52, 0x32, 0x77, 0x53, 0x5f,
+ 0x4a, 0x5a, 0x4f, 0x5e, 0x4e, 0x61, 0x4b, 0x5b, 0x4a, 0x53, 0x3e, 0x61,
+ 0x47, 0x24, 0x3e, 0x48, 0x4d, 0x43, 0x40, 0x53, 0x4e, 0x41, 0x43, 0x3d,
+ 0x50, 0x49, 0x41, 0x3a, 0x4e, 0x4b, 0x48, 0x49, 0x48, 0x49, 0x46, 0x50,
+ 0x4f, 0x4b, 0x47, 0x4b, 0x48, 0x52, 0x3e, 0x4d, 0x4d, 0x59, 0x4c, 0x3e,
+ 0x52, 0x49, 0x4f, 0x5e, 0x54, 0x59, 0x47, 0x4d, 0x40, 0x4c, 0x4b, 0x64,
+ 0x42, 0x4c, 0x53, 0x46, 0x4e, 0x50, 0x46, 0x6a, 0x41, 0x59, 0x44, 0x4b,
+ 0x4f, 0x44, 0x52, 0x6c, 0x54, 0x4e, 0x46, 0x48, 0x42, 0x3d, 0x44, 0x67,
+ 0x44, 0x4f, 0x47, 0x54, 0x4c, 0x4f, 0x43, 0x61, 0x4c, 0x54, 0x4f, 0x43,
+ 0x49, 0x40, 0x4a, 0x5f, 0x4a, 0x52, 0x47, 0x43, 0x4c, 0x43, 0x49, 0x53,
+ 0x4c, 0x4b, 0x43, 0x3d, 0x4e, 0x45, 0x49, 0x50, 0x44, 0x53, 0x4f, 0x48,
+ 0x4b, 0x46, 0x44, 0x3c, 0x50, 0x42, 0x43, 0x40, 0x47, 0x43, 0x42, 0x34,
+ 0x47, 0x42, 0x3f, 0x4a, 0x48, 0x42, 0x48, 0x4c, 0x42, 0x4c, 0x4e, 0x47,
+ 0x48, 0x47, 0x51, 0x51, 0x4d, 0x3d, 0x3e, 0x4b, 0x54, 0x4c, 0x4c, 0x59,
+ 0x4f, 0x50, 0x57, 0x3c, 0x54, 0x62, 0x54, 0x35, 0x3d, 0x5a, 0x5b, 0x47,
+ 0x59, 0x63, 0x66, 0x4d, 0x3c, 0x79, 0x50, 0x5f, 0x45, 0x58, 0x4e, 0x5d,
+ 0x48, 0x61, 0x43, 0x54, 0x47, 0x54, 0x4d, 0x54, 0x4b, 0x25, 0x41, 0x44,
+ 0x4c, 0x4a, 0x3b, 0x52, 0x47, 0x3c, 0x45, 0x3c, 0x53, 0x44, 0x44, 0x40,
+ 0x50, 0x4c, 0x45, 0x3a, 0x4c, 0x51, 0x44, 0x49, 0x4d, 0x52, 0x4d, 0x4b,
+ 0x45, 0x52, 0x3d, 0x50, 0x4a, 0x58, 0x4a, 0x47, 0x4d, 0x47, 0x4e, 0x52,
+ 0x4f, 0x4d, 0x4f, 0x49, 0x52, 0x52, 0x4c, 0x5e, 0x47, 0x4d, 0x46, 0x4d,
+ 0x4c, 0x48, 0x50, 0x70, 0x41, 0x4a, 0x48, 0x3d, 0x45, 0x48, 0x45, 0x74,
+ 0x47, 0x4c, 0x43, 0x4f, 0x4a, 0x4a, 0x40, 0x68, 0x52, 0x49, 0x3e, 0x3e,
+ 0x4e, 0x4b, 0x4b, 0x69, 0x42, 0x4f, 0x45, 0x47, 0x3f, 0x45, 0x46, 0x56,
+ 0x45, 0x4a, 0x47, 0x44, 0x52, 0x4b, 0x53, 0x4e, 0x4e, 0x46, 0x45, 0x40,
+ 0x47, 0x4b, 0x53, 0x52, 0x53, 0x51, 0x4f, 0x46, 0x42, 0x43, 0x50, 0x3e,
+ 0x48, 0x4e, 0x41, 0x53, 0x4d, 0x48, 0x48, 0x33, 0x40, 0x43, 0x4b, 0x42,
+ 0x52, 0x4c, 0x42, 0x4e, 0x41, 0x4e, 0x4f, 0x50, 0x43, 0x49, 0x4d, 0x47,
+ 0x4a, 0x3a, 0x3f, 0x51, 0x51, 0x44, 0x4e, 0x54, 0x40, 0x55, 0x59, 0x3c,
+ 0x57, 0x67, 0x4e, 0x2e, 0x4c, 0x5b, 0x5b, 0x51, 0x58, 0x63, 0x62, 0x52,
+ 0x3c, 0x72, 0x51, 0x5a, 0x4e, 0x53, 0x4a, 0x5c, 0x51, 0x69, 0x42, 0x51,
+ 0x48, 0x54, 0x48, 0x57, 0x3e, 0x37, 0x3f, 0x4d, 0x4d, 0x4a, 0x35, 0x57,
+ 0x4e, 0x40, 0x45, 0x4a, 0x45, 0x4e, 0x49, 0x40, 0x49, 0x53, 0x51, 0x44,
+ 0x4a, 0x50, 0x4b, 0x4b, 0x50, 0x4f, 0x3e, 0x44, 0x45, 0x44, 0x4c, 0x51,
+ 0x47, 0x51, 0x46, 0x42, 0x48, 0x50, 0x49, 0x4d, 0x43, 0x54, 0x52, 0x4d,
+ 0x4e, 0x4f, 0x3f, 0x63, 0x54, 0x57, 0x41, 0x44, 0x4e, 0x50, 0x4e, 0x66,
+ 0x41, 0x53, 0x4b, 0x4d, 0x4e, 0x4f, 0x43, 0x6d, 0x4e, 0x51, 0x49, 0x4f,
+ 0x49, 0x4a, 0x4a, 0x6c, 0x4b, 0x4f, 0x3d, 0x47, 0x4d, 0x51, 0x3c, 0x66,
+ 0x4b, 0x56, 0x3e, 0x4c, 0x41, 0x46, 0x45, 0x68, 0x47, 0x4b, 0x4a, 0x54,
+ 0x53, 0x48, 0x51, 0x59, 0x45, 0x43, 0x50, 0x45, 0x4f, 0x45, 0x42, 0x55,
+ 0x48, 0x52, 0x4c, 0x46, 0x52, 0x49, 0x47, 0x3d, 0x55, 0x48, 0x52, 0x52,
+ 0x40, 0x4e, 0x47, 0x31, 0x45, 0x4f, 0x42, 0x4a, 0x4e, 0x50, 0x42, 0x4a,
+ 0x49, 0x57, 0x46, 0x4b, 0x45, 0x4e, 0x4d, 0x46, 0x47, 0x43, 0x50, 0x4e,
+ 0x4f, 0x4c, 0x53, 0x55, 0x45, 0x51, 0x5b, 0x3a, 0x52, 0x64, 0x54, 0x2d,
+ 0x42, 0x59, 0x59, 0x45, 0x59, 0x67, 0x69, 0x53, 0x3f, 0x78, 0x50, 0x60,
+ 0x4c, 0x4c, 0x5b, 0x53, 0x45, 0x63, 0x49, 0x63, 0x51, 0x4c, 0x41, 0x4e,
+ 0x4b, 0x37, 0x45, 0x4e, 0x48, 0x4c, 0x39, 0x55, 0x44, 0x37, 0x3c, 0x49,
+ 0x44, 0x56, 0x3e, 0x40, 0x4d, 0x45, 0x4c, 0x43, 0x42, 0x41, 0x40, 0x42,
+ 0x57, 0x4f, 0x43, 0x3f, 0x52, 0x53, 0x51, 0x4b, 0x4b, 0x55, 0x46, 0x40,
+ 0x49, 0x45, 0x40, 0x4f, 0x47, 0x58, 0x4b, 0x53, 0x4e, 0x52, 0x54, 0x5e,
+ 0x4b, 0x51, 0x50, 0x44, 0x50, 0x4b, 0x4f, 0x70, 0x49, 0x4f, 0x4c, 0x50,
+ 0x45, 0x56, 0x4b, 0x6b, 0x49, 0x52, 0x4a, 0x3f, 0x44, 0x4b, 0x48, 0x72,
+ 0x4c, 0x47, 0x4e, 0x43, 0x46, 0x4c, 0x4f, 0x61, 0x4a, 0x52, 0x52, 0x46,
+ 0x4a, 0x4d, 0x46, 0x65, 0x48, 0x4e, 0x4d, 0x4e, 0x46, 0x4e, 0x53, 0x59,
+ 0x43, 0x49, 0x43, 0x47, 0x45, 0x47, 0x53, 0x50, 0x3e, 0x4d, 0x41, 0x46,
+ 0x4c, 0x4a, 0x4c, 0x35, 0x3f, 0x4f, 0x50, 0x48, 0x47, 0x4d, 0x4c, 0x32,
+ 0x45, 0x53, 0x43, 0x4d, 0x4e, 0x4a, 0x3e, 0x4b, 0x55, 0x4f, 0x53, 0x4c,
+ 0x4a, 0x4d, 0x48, 0x53, 0x4f, 0x3a, 0x47, 0x4b, 0x4e, 0x4e, 0x51, 0x59,
+ 0x41, 0x50, 0x57, 0x38, 0x5d, 0x63, 0x59, 0x2b, 0x45, 0x53, 0x5a, 0x4e,
+ 0x5c, 0x60, 0x5e, 0x4c, 0x41, 0x6f, 0x53, 0x5c, 0x48, 0x53, 0x56, 0x54,
+ 0x4b, 0x62, 0x46, 0x63, 0x47, 0x4e, 0x40, 0x51, 0x43, 0x36, 0x44, 0x42,
+ 0x46, 0x51, 0x41, 0x54, 0x4e, 0x36, 0x40, 0x4b, 0x55, 0x49, 0x40, 0x3f,
+ 0x4b, 0x42, 0x4a, 0x4a, 0x48, 0x47, 0x40, 0x43, 0x4d, 0x4f, 0x55, 0x3f,
+ 0x53, 0x42, 0x4d, 0x56, 0x49, 0x51, 0x4f, 0x41, 0x3b, 0x48, 0x43, 0x4e,
+ 0x4b, 0x5c, 0x4f, 0x45, 0x4a, 0x4c, 0x46, 0x66, 0x43, 0x45, 0x46, 0x48,
+ 0x4f, 0x4e, 0x40, 0x71, 0x4b, 0x4e, 0x3e, 0x42, 0x4d, 0x52, 0x42, 0x71,
+ 0x4c, 0x54, 0x4f, 0x3f, 0x4c, 0x43, 0x4a, 0x73, 0x48, 0x48, 0x4c, 0x4b,
+ 0x4c, 0x4d, 0x40, 0x72, 0x3e, 0x51, 0x49, 0x48, 0x52, 0x53, 0x45, 0x65,
+ 0x52, 0x4e, 0x4f, 0x44, 0x4c, 0x43, 0x4a, 0x5e, 0x3e, 0x56, 0x46, 0x55,
+ 0x55, 0x43, 0x49, 0x51, 0x4f, 0x52, 0x49, 0x4d, 0x46, 0x47, 0x49, 0x3e,
+ 0x51, 0x49, 0x41, 0x53, 0x42, 0x47, 0x46, 0x3b, 0x4d, 0x4e, 0x48, 0x44,
+ 0x42, 0x48, 0x4c, 0x47, 0x42, 0x4e, 0x4a, 0x3e, 0x44, 0x54, 0x4a, 0x4d,
+ 0x49, 0x41, 0x41, 0x53, 0x52, 0x4c, 0x4c, 0x56, 0x49, 0x4a, 0x5a, 0x3f,
+ 0x5b, 0x5c, 0x59, 0x2f, 0x49, 0x52, 0x5a, 0x4e, 0x5a, 0x61, 0x67, 0x4c,
+ 0x41, 0x6f, 0x5a, 0x5a, 0x40, 0x5a, 0x54, 0x4e, 0x49, 0x66, 0x45, 0x5a,
+ 0x4a, 0x45, 0x44, 0x4b, 0x44, 0x36, 0x41, 0x4c, 0x45, 0x44, 0x3d, 0x51,
+ 0x3f, 0x35, 0x3c, 0x46, 0x53, 0x5c, 0x3f, 0x3e, 0x50, 0x43, 0x46, 0x4b,
+ 0x40, 0x54, 0x41, 0x47, 0x4b, 0x51, 0x41, 0x46, 0x4a, 0x4d, 0x51, 0x52,
+ 0x43, 0x58, 0x45, 0x46, 0x4e, 0x46, 0x4a, 0x4b, 0x44, 0x54, 0x4c, 0x4c,
+ 0x43, 0x59, 0x48, 0x61, 0x4e, 0x4f, 0x4d, 0x4d, 0x4a, 0x52, 0x4c, 0x6e,
+ 0x49, 0x57, 0x48, 0x4d, 0x46, 0x46, 0x4d, 0x72, 0x4a, 0x4e, 0x47, 0x44,
+ 0x49, 0x4f, 0x48, 0x73, 0x42, 0x40, 0x4d, 0x44, 0x4d, 0x57, 0x3e, 0x69,
+ 0x50, 0x52, 0x4c, 0x55, 0x46, 0x4c, 0x44, 0x5f, 0x4b, 0x4d, 0x55, 0x4c,
+ 0x48, 0x49, 0x4a, 0x5e, 0x47, 0x4b, 0x45, 0x53, 0x55, 0x53, 0x4d, 0x53,
+ 0x47, 0x5c, 0x45, 0x4e, 0x4e, 0x52, 0x4c, 0x39, 0x4b, 0x4c, 0x49, 0x46,
+ 0x4a, 0x4e, 0x4b, 0x33, 0x46, 0x47, 0x52, 0x41, 0x49, 0x4b, 0x4c, 0x48,
+ 0x51, 0x53, 0x44, 0x4c, 0x4a, 0x45, 0x46, 0x49, 0x49, 0x4b, 0x50, 0x47,
+ 0x4d, 0x4b, 0x4c, 0x4f, 0x44, 0x45, 0x58, 0x3c, 0x56, 0x5a, 0x56, 0x23,
+ 0x4f, 0x4d, 0x5c, 0x4e, 0x59, 0x5a, 0x65, 0x43, 0x45, 0x66, 0x54, 0x5f,
+ 0x45, 0x5e, 0x54, 0x4f, 0x48, 0x5f, 0x44, 0x59, 0x48, 0x46, 0x47, 0x49,
+ 0x4d, 0x3c, 0x49, 0x54, 0x3e, 0x48, 0x43, 0x5b, 0x4a, 0x35, 0x41, 0x43,
+ 0x4b, 0x55, 0x43, 0x38, 0x46, 0x42, 0x4a, 0x4e, 0x54, 0x4b, 0x4d, 0x46,
+ 0x43, 0x4e, 0x44, 0x47, 0x56, 0x4c, 0x51, 0x57, 0x41, 0x4d, 0x43, 0x41,
+ 0x51, 0x47, 0x41, 0x51, 0x51, 0x4f, 0x46, 0x50, 0x52, 0x4e, 0x4d, 0x60,
+ 0x41, 0x49, 0x46, 0x50, 0x48, 0x56, 0x42, 0x6d, 0x40, 0x45, 0x44, 0x55,
+ 0x40, 0x4e, 0x40, 0x7c, 0x47, 0x5a, 0x44, 0x44, 0x45, 0x56, 0x55, 0x71,
+ 0x47, 0x4b, 0x4b, 0x45, 0x4f, 0x54, 0x4c, 0x73, 0x48, 0x55, 0x44, 0x4d,
+ 0x4a, 0x47, 0x49, 0x5e, 0x4d, 0x52, 0x4e, 0x4c, 0x48, 0x52, 0x48, 0x58,
+ 0x4c, 0x5a, 0x49, 0x4b, 0x53, 0x46, 0x4d, 0x4b, 0x48, 0x53, 0x41, 0x49,
+ 0x4a, 0x56, 0x51, 0x3a, 0x4c, 0x4e, 0x4f, 0x51, 0x4c, 0x59, 0x47, 0x45,
+ 0x4f, 0x50, 0x4a, 0x4f, 0x4d, 0x3f, 0x44, 0x4e, 0x42, 0x4a, 0x4a, 0x43,
+ 0x46, 0x4e, 0x4c, 0x4f, 0x47, 0x47, 0x4c, 0x4b, 0x52, 0x50, 0x50, 0x4b,
+ 0x42, 0x45, 0x54, 0x44, 0x54, 0x59, 0x4c, 0x2b, 0x4d, 0x4c, 0x55, 0x4e,
+ 0x5c, 0x5b, 0x5a, 0x42, 0x47, 0x5e, 0x56, 0x59, 0x47, 0x65, 0x55, 0x4c,
+ 0x4c, 0x59, 0x42, 0x5a, 0x4e, 0x46, 0x4e, 0x4b, 0x53, 0x46, 0x49, 0x56,
+ 0x48, 0x58, 0x4b, 0x4f, 0x45, 0x38, 0x40, 0x44, 0x49, 0x51, 0x4a, 0x3b,
+ 0x53, 0x40, 0x40, 0x48, 0x51, 0x49, 0x44, 0x46, 0x52, 0x4b, 0x4e, 0x45,
+ 0x48, 0x5a, 0x4e, 0x57, 0x44, 0x53, 0x49, 0x40, 0x4c, 0x47, 0x41, 0x4f,
+ 0x49, 0x55, 0x46, 0x50, 0x57, 0x5b, 0x48, 0x66, 0x50, 0x49, 0x51, 0x55,
+ 0x55, 0x4f, 0x47, 0x72, 0x49, 0x4f, 0x41, 0x4c, 0x49, 0x42, 0x48, 0x75,
+ 0x4a, 0x55, 0x45, 0x4a, 0x41, 0x51, 0x41, 0x70, 0x47, 0x49, 0x42, 0x52,
+ 0x4f, 0x47, 0x46, 0x63, 0x4f, 0x53, 0x46, 0x4f, 0x49, 0x53, 0x52, 0x63,
+ 0x4c, 0x59, 0x46, 0x41, 0x49, 0x51, 0x3e, 0x53, 0x45, 0x52, 0x51, 0x40,
+ 0x4f, 0x4c, 0x41, 0x4c, 0x47, 0x4a, 0x46, 0x47, 0x53, 0x47, 0x48, 0x39,
+ 0x53, 0x4b, 0x46, 0x4b, 0x50, 0x4c, 0x41, 0x40, 0x48, 0x4e, 0x49, 0x4e,
+ 0x44, 0x53, 0x44, 0x4e, 0x53, 0x49, 0x49, 0x4e, 0x46, 0x3f, 0x45, 0x42,
+ 0x4c, 0x47, 0x42, 0x4e, 0x49, 0x4a, 0x49, 0x44, 0x51, 0x48, 0x57, 0x4c,
+ 0x4d, 0x60, 0x4e, 0x2d, 0x46, 0x4d, 0x58, 0x53, 0x5c, 0x56, 0x5e, 0x41,
+ 0x3e, 0x66, 0x53, 0x5b, 0x49, 0x59, 0x5a, 0x55, 0x4e, 0x59, 0x46, 0x4a,
+ 0x44, 0x42, 0x45, 0x3d, 0x4d, 0x45, 0x44, 0x4f, 0x4d, 0x53, 0x42, 0x5a,
+ 0x43, 0x3c, 0x48, 0x4f, 0x44, 0x59, 0x3f, 0x33, 0x45, 0x48, 0x43, 0x45,
+ 0x4d, 0x56, 0x48, 0x44, 0x3e, 0x48, 0x46, 0x4d, 0x44, 0x53, 0x46, 0x4e,
+ 0x45, 0x52, 0x40, 0x46, 0x4c, 0x50, 0x4e, 0x4b, 0x4d, 0x46, 0x48, 0x46,
+ 0x50, 0x52, 0x4e, 0x57, 0x3f, 0x4a, 0x49, 0x50, 0x53, 0x4e, 0x41, 0x66,
+ 0x49, 0x4f, 0x40, 0x4b, 0x50, 0x4c, 0x4a, 0x70, 0x42, 0x51, 0x41, 0x4c,
+ 0x50, 0x4f, 0x46, 0x60, 0x45, 0x47, 0x54, 0x4c, 0x49, 0x59, 0x52, 0x61,
+ 0x4a, 0x53, 0x52, 0x4f, 0x4b, 0x4c, 0x46, 0x56, 0x4b, 0x54, 0x4f, 0x47,
+ 0x53, 0x49, 0x4f, 0x50, 0x4a, 0x54, 0x45, 0x4e, 0x47, 0x48, 0x47, 0x42,
+ 0x49, 0x44, 0x46, 0x46, 0x55, 0x4c, 0x4f, 0x36, 0x4c, 0x49, 0x3f, 0x4e,
+ 0x45, 0x4b, 0x4b, 0x36, 0x48, 0x4f, 0x4b, 0x50, 0x45, 0x47, 0x49, 0x3f,
+ 0x50, 0x4b, 0x52, 0x48, 0x4c, 0x41, 0x49, 0x43, 0x4e, 0x3c, 0x43, 0x45,
+ 0x3e, 0x45, 0x48, 0x44, 0x4d, 0x48, 0x56, 0x47, 0x4b, 0x54, 0x52, 0x2b,
+ 0x4d, 0x4e, 0x57, 0x4f, 0x57, 0x4f, 0x56, 0x43, 0x48, 0x5f, 0x4c, 0x51,
+ 0x4d, 0x58, 0x4f, 0x4e, 0x50, 0x50, 0x48, 0x4a, 0x4d, 0x3f, 0x47, 0x40,
+ 0x4b, 0x4a, 0x4e, 0x4b, 0x4a, 0x58, 0x42, 0x49, 0x3f, 0x42, 0x3d, 0x4d,
+ 0x46, 0x53, 0x45, 0x3e, 0x4e, 0x49, 0x4f, 0x4a, 0x47, 0x46, 0x40, 0x3e,
+ 0x4c, 0x4d, 0x4d, 0x45, 0x4a, 0x56, 0x40, 0x4a, 0x47, 0x57, 0x4f, 0x48,
+ 0x4f, 0x48, 0x47, 0x49, 0x4e, 0x52, 0x50, 0x48, 0x42, 0x52, 0x43, 0x5a,
+ 0x49, 0x42, 0x4f, 0x4f, 0x51, 0x51, 0x50, 0x5c, 0x4b, 0x43, 0x4b, 0x48,
+ 0x50, 0x51, 0x4b, 0x6d, 0x53, 0x4e, 0x44, 0x4c, 0x4c, 0x51, 0x46, 0x5b,
+ 0x44, 0x48, 0x4d, 0x4c, 0x46, 0x4f, 0x54, 0x54, 0x4e, 0x54, 0x42, 0x4e,
+ 0x4c, 0x49, 0x49, 0x58, 0x49, 0x53, 0x53, 0x4a, 0x4e, 0x4b, 0x47, 0x53,
+ 0x43, 0x55, 0x46, 0x51, 0x3d, 0x3d, 0x4c, 0x47, 0x4e, 0x51, 0x47, 0x48,
+ 0x4b, 0x4c, 0x42, 0x3b, 0x43, 0x4f, 0x44, 0x4d, 0x54, 0x4b, 0x4a, 0x47,
+ 0x4c, 0x42, 0x4b, 0x43, 0x41, 0x4e, 0x4d, 0x50, 0x45, 0x46, 0x41, 0x4a,
+ 0x49, 0x49, 0x54, 0x47, 0x4c, 0x4b, 0x50, 0x4e, 0x3f, 0x43, 0x40, 0x41,
+ 0x44, 0x54, 0x51, 0x47, 0x4c, 0x4b, 0x4f, 0x34, 0x4d, 0x4c, 0x4f, 0x49,
+ 0x56, 0x4e, 0x4b, 0x3e, 0x48, 0x53, 0x4e, 0x56, 0x49, 0x4e, 0x4c, 0x40,
+ 0x55, 0x4a, 0x46, 0x4f, 0x48, 0x4a, 0x55, 0x41, 0x55, 0x3d, 0x47, 0x51,
+ 0x50, 0x51, 0x45, 0x51, 0x4b, 0x4e, 0x4a, 0x4f, 0x4b, 0x45, 0x42, 0x3c,
+ 0x4e, 0x46, 0x47, 0x49, 0x4a, 0x4c, 0x48, 0x41, 0x4f, 0x4a, 0x44, 0x45,
+ 0x4e, 0x4e, 0x43, 0x41, 0x4c, 0x47, 0x48, 0x49, 0x4c, 0x48, 0x4f, 0x4a,
+ 0x4f, 0x4a, 0x4b, 0x45, 0x42, 0x40, 0x52, 0x55, 0x4f, 0x49, 0x44, 0x54,
+ 0x49, 0x48, 0x51, 0x4d, 0x44, 0x4a, 0x4d, 0x49, 0x4e, 0x4e, 0x51, 0x5d,
+ 0x42, 0x4d, 0x49, 0x3f, 0x48, 0x58, 0x40, 0x5e, 0x48, 0x4f, 0x49, 0x53,
+ 0x45, 0x47, 0x4f, 0x53, 0x4d, 0x4f, 0x4d, 0x4d, 0x46, 0x55, 0x43, 0x51,
+ 0x4f, 0x51, 0x4a, 0x4e, 0x49, 0x42, 0x49, 0x50, 0x47, 0x4d, 0x42, 0x47,
+ 0x46, 0x50, 0x55, 0x47, 0x4d, 0x47, 0x3e, 0x51, 0x4d, 0x43, 0x44, 0x39,
+ 0x4e, 0x4b, 0x41, 0x48, 0x52, 0x53, 0x4d, 0x39, 0x4d, 0x51, 0x4c, 0x46,
+ 0x4e, 0x47, 0x49, 0x41, 0x45, 0x4a, 0x4a, 0x45, 0x50, 0x4a, 0x40, 0x48,
+ 0x43, 0x47, 0x44, 0x50, 0x4d, 0x47, 0x4a, 0x47, 0x45, 0x57, 0x41, 0x34,
+ 0x51, 0x40, 0x45, 0x44, 0x3c, 0x47, 0x46, 0x47, 0x44, 0x48, 0x42, 0x40,
+ 0x37, 0x53, 0x4a, 0x43, 0x49, 0x4b, 0x43, 0x44, 0x4f, 0x4f, 0x48, 0x48,
+ 0x53, 0x49, 0x4b, 0x48, 0x4e, 0x4c, 0x42, 0x45, 0x4c, 0x4a, 0x4a, 0x46,
+ 0x47, 0x57, 0x3e, 0x46, 0x46, 0x45, 0x4a, 0x43, 0x46, 0x49, 0x43, 0x52,
+ 0x3e, 0x48, 0x4a, 0x4b, 0x47, 0x47, 0x48, 0x4a, 0x4b, 0x4b, 0x4e, 0x44,
+ 0x42, 0x44, 0x50, 0x41, 0x49, 0x49, 0x4d, 0x4b, 0x44, 0x46, 0x4a, 0x52,
+ 0x4d, 0x47, 0x49, 0x4b, 0x4d, 0x49, 0x41, 0x48, 0x4b, 0x3f, 0x45, 0x4f,
+ 0x51, 0x41, 0x55, 0x42, 0x49, 0x4b, 0x4b, 0x51, 0x4f, 0x4f, 0x42, 0x4e,
+ 0x4e, 0x4a, 0x52, 0x41, 0x4f, 0x42, 0x48, 0x3d, 0x4a, 0x44, 0x50, 0x4b,
+ 0x49, 0x45, 0x51, 0x46, 0x51, 0x44, 0x4d, 0x47, 0x4a, 0x4a, 0x4d, 0x49,
+ 0x4d, 0x48, 0x4d, 0x4f, 0x4d, 0x44, 0x48, 0x4e, 0x4a, 0x4b, 0x40, 0x4f,
+ 0x47, 0x3a, 0x41, 0x47, 0x4a, 0x4a, 0x4a, 0x48, 0x42, 0x41, 0x4d, 0x56,
+ 0x3f, 0x52, 0x4d, 0x4c, 0x44, 0x48, 0x47, 0x4e, 0x51, 0x4c, 0x49, 0x47,
+ 0x44, 0x4c, 0x4b, 0x47, 0x48, 0x46, 0x47, 0x4f, 0x43, 0x41, 0x3e, 0x47,
+ 0x53, 0x4a, 0x46, 0x42, 0x46, 0x61, 0x43, 0x30, 0x4e, 0x52, 0x43, 0x45,
+ 0x32, 0x4a, 0x45, 0x48, 0x51, 0x3e, 0x44, 0x3b, 0x3a, 0x63, 0x4c, 0x46,
+ 0x4c, 0x49, 0x3d, 0x41, 0x52, 0x53, 0x43, 0x43, 0x45, 0x3d, 0x48, 0x40,
+ 0x4b, 0x4a, 0x49, 0x48, 0x4d, 0x49, 0x4b, 0x4c, 0x3f, 0x4e, 0x4b, 0x47,
+ 0x45, 0x4d, 0x3f, 0x4d, 0x43, 0x50, 0x48, 0x4b, 0x54, 0x3e, 0x44, 0x4e,
+ 0x3e, 0x4c, 0x43, 0x4b, 0x4c, 0x4b, 0x3e, 0x49, 0x50, 0x52, 0x4a, 0x4a,
+ 0x50, 0x50, 0x43, 0x4e, 0x49, 0x48, 0x51, 0x50, 0x47, 0x3d, 0x45, 0x4b,
+ 0x47, 0x46, 0x4d, 0x4c, 0x45, 0x4d, 0x4a, 0x4d, 0x42, 0x4d, 0x47, 0x4f,
+ 0x40, 0x43, 0x46, 0x51, 0x47, 0x4b, 0x43, 0x49, 0x49, 0x50, 0x4b, 0x4b,
+ 0x46, 0x4a, 0x4c, 0x48, 0x49, 0x47, 0x4b, 0x56, 0x55, 0x4f, 0x49, 0x4f,
+ 0x4f, 0x4e, 0x4b, 0x49, 0x4a, 0x4a, 0x49, 0x47, 0x44, 0x4b, 0x47, 0x50,
+ 0x46, 0x4c, 0x46, 0x4c, 0x4b, 0x4e, 0x49, 0x57, 0x4d, 0x3e, 0x46, 0x47,
+ 0x50, 0x45, 0x4f, 0x52, 0x3e, 0x4d, 0x49, 0x4a, 0x40, 0x49, 0x4f, 0x5c,
+ 0x3e, 0x4a, 0x47, 0x45, 0x47, 0x41, 0x44, 0x3f, 0x4b, 0x4a, 0x52, 0x43,
+ 0x41, 0x43, 0x43, 0x47, 0x55, 0x49, 0x42, 0x4c, 0x58, 0x4b, 0x42, 0x48,
+ 0x4b, 0x5a, 0x36, 0x33, 0x53, 0x57, 0x4d, 0x4a, 0x37, 0x4c, 0x3e, 0x48,
+ 0x43, 0x46, 0x39, 0x3c, 0x34, 0x65, 0x47, 0x3d, 0x47, 0x42, 0x3c, 0x3e,
+ 0x45, 0x5b, 0x44, 0x3e, 0x45, 0x43, 0x46, 0x43, 0x59, 0x4e, 0x48, 0x46,
+ 0x43, 0x3f, 0x46, 0x47, 0x4e, 0x53, 0x50, 0x4b, 0x4a, 0x3f, 0x4a, 0x54,
+ 0x4c, 0x4a, 0x43, 0x50, 0x4c, 0x42, 0x4d, 0x55, 0x4d, 0x51, 0x51, 0x46,
+ 0x49, 0x41, 0x50, 0x44, 0x4a, 0x4b, 0x4b, 0x43, 0x4b, 0x4e, 0x47, 0x4b,
+ 0x3e, 0x4e, 0x44, 0x4d, 0x49, 0x41, 0x49, 0x44, 0x50, 0x4d, 0x45, 0x4e,
+ 0x4b, 0x50, 0x45, 0x4c, 0x46, 0x4a, 0x46, 0x42, 0x50, 0x45, 0x48, 0x53,
+ 0x4d, 0x44, 0x42, 0x50, 0x4c, 0x49, 0x45, 0x55, 0x4d, 0x42, 0x43, 0x41,
+ 0x4c, 0x41, 0x4e, 0x4d, 0x42, 0x4e, 0x3f, 0x44, 0x4d, 0x4c, 0x4b, 0x4a,
+ 0x47, 0x47, 0x4e, 0x54, 0x43, 0x40, 0x41, 0x55, 0x49, 0x49, 0x4e, 0x49,
+ 0x52, 0x4e, 0x46, 0x58, 0x4b, 0x3d, 0x4a, 0x44, 0x4e, 0x47, 0x53, 0x58,
+ 0x47, 0x42, 0x52, 0x46, 0x49, 0x4b, 0x47, 0x5a, 0x4c, 0x46, 0x46, 0x49,
+ 0x4b, 0x4d, 0x3d, 0x48, 0x40, 0x54, 0x48, 0x4c, 0x4c, 0x44, 0x4c, 0x46,
+ 0x47, 0x4b, 0x4d, 0x44, 0x5a, 0x4a, 0x3e, 0x46, 0x48, 0x53, 0x39, 0x30,
+ 0x51, 0x60, 0x4d, 0x47, 0x35, 0x4f, 0x45, 0x45, 0x4a, 0x4b, 0x42, 0x3f,
+ 0x38, 0x6c, 0x3d, 0x40, 0x44, 0x48, 0x3a, 0x3b, 0x46, 0x5e, 0x45, 0x3b,
+ 0x47, 0x47, 0x45, 0x42, 0x53, 0x55, 0x44, 0x45, 0x46, 0x43, 0x48, 0x48,
+ 0x52, 0x5d, 0x3e, 0x41, 0x53, 0x42, 0x48, 0x55, 0x49, 0x4d, 0x4a, 0x46,
+ 0x52, 0x46, 0x51, 0x48, 0x44, 0x46, 0x48, 0x41, 0x49, 0x49, 0x49, 0x49,
+ 0x41, 0x4d, 0x40, 0x4f, 0x45, 0x46, 0x45, 0x3f, 0x53, 0x40, 0x46, 0x43,
+ 0x47, 0x4d, 0x50, 0x4c, 0x55, 0x48, 0x45, 0x47, 0x4f, 0x46, 0x42, 0x4d,
+ 0x41, 0x48, 0x46, 0x4e, 0x42, 0x48, 0x48, 0x45, 0x41, 0x45, 0x48, 0x4a,
+ 0x40, 0x49, 0x43, 0x4b, 0x48, 0x4a, 0x4c, 0x45, 0x4b, 0x48, 0x48, 0x4f,
+ 0x40, 0x4b, 0x4a, 0x44, 0x50, 0x4a, 0x43, 0x50, 0x4c, 0x44, 0x46, 0x4c,
+ 0x42, 0x44, 0x4e, 0x55, 0x47, 0x49, 0x48, 0x47, 0x52, 0x4e, 0x44, 0x59,
+ 0x4e, 0x44, 0x4a, 0x48, 0x49, 0x4a, 0x42, 0x4e, 0x3e, 0x39, 0x51, 0x45,
+ 0x4d, 0x49, 0x4f, 0x54, 0x51, 0x4b, 0x50, 0x44, 0x53, 0x4f, 0x4d, 0x48,
+ 0x42, 0x45, 0x4e, 0x40, 0x4a, 0x48, 0x43, 0x48, 0x52, 0x54, 0x4d, 0x49,
+ 0x5f, 0x53, 0x46, 0x4e, 0x3f, 0x5a, 0x36, 0x31, 0x52, 0x60, 0x4b, 0x4a,
+ 0x32, 0x51, 0x40, 0x44, 0x46, 0x52, 0x44, 0x41, 0x3a, 0x6e, 0x41, 0x3e,
+ 0x47, 0x3e, 0x3a, 0x2a, 0x44, 0x5a, 0x40, 0x3c, 0x4d, 0x48, 0x46, 0x3b,
+ 0x5e, 0x58, 0x4d, 0x47, 0x51, 0x3a, 0x4b, 0x48, 0x5b, 0x5a, 0x54, 0x43,
+ 0x50, 0x4c, 0x54, 0x54, 0x49, 0x47, 0x4f, 0x48, 0x50, 0x40, 0x4f, 0x4a,
+ 0x42, 0x42, 0x3c, 0x41, 0x43, 0x4e, 0x53, 0x49, 0x4b, 0x4d, 0x49, 0x41,
+ 0x4c, 0x3e, 0x40, 0x49, 0x40, 0x44, 0x49, 0x4f, 0x50, 0x4a, 0x42, 0x3a,
+ 0x49, 0x4b, 0x47, 0x50, 0x49, 0x41, 0x52, 0x46, 0x3d, 0x44, 0x46, 0x43,
+ 0x4b, 0x4b, 0x4d, 0x4b, 0x4e, 0x40, 0x45, 0x43, 0x48, 0x44, 0x55, 0x51,
+ 0x4a, 0x46, 0x4e, 0x40, 0x53, 0x4a, 0x45, 0x41, 0x48, 0x48, 0x45, 0x4e,
+ 0x4a, 0x48, 0x40, 0x4c, 0x54, 0x44, 0x42, 0x4d, 0x49, 0x43, 0x45, 0x4c,
+ 0x43, 0x4f, 0x46, 0x3f, 0x46, 0x4f, 0x4b, 0x59, 0x46, 0x49, 0x54, 0x47,
+ 0x49, 0x46, 0x45, 0x53, 0x4a, 0x49, 0x54, 0x45, 0x41, 0x45, 0x4c, 0x5e,
+ 0x50, 0x3d, 0x4d, 0x49, 0x55, 0x4b, 0x49, 0x47, 0x4c, 0x4f, 0x43, 0x3d,
+ 0x41, 0x4b, 0x43, 0x46, 0x4f, 0x4a, 0x4c, 0x54, 0x5e, 0x4e, 0x40, 0x4d,
+ 0x3d, 0x59, 0x40, 0x28, 0x54, 0x5f, 0x4d, 0x4b, 0x36, 0x51, 0x3a, 0x47,
+ 0x4a, 0x55, 0x42, 0x43, 0x3b, 0x72, 0x3b, 0x3d, 0x51, 0x42, 0x3f, 0x2d,
+ 0x4b, 0x5a, 0x48, 0x44, 0x49, 0x49, 0x3d, 0x39, 0x56, 0x55, 0x46, 0x46,
+ 0x4b, 0x43, 0x40, 0x4a, 0x52, 0x56, 0x4d, 0x45, 0x4b, 0x48, 0x40, 0x5a,
+ 0x4e, 0x3a, 0x53, 0x48, 0x4c, 0x44, 0x49, 0x4e, 0x42, 0x47, 0x46, 0x40,
+ 0x51, 0x42, 0x50, 0x4b, 0x43, 0x53, 0x44, 0x44, 0x46, 0x4c, 0x4c, 0x3c,
+ 0x42, 0x45, 0x42, 0x45, 0x44, 0x4b, 0x52, 0x3d, 0x47, 0x4b, 0x4c, 0x4e,
+ 0x52, 0x4a, 0x4e, 0x41, 0x3f, 0x46, 0x43, 0x54, 0x44, 0x53, 0x4e, 0x48,
+ 0x40, 0x41, 0x4f, 0x45, 0x43, 0x3c, 0x52, 0x49, 0x40, 0x44, 0x4a, 0x3f,
+ 0x4d, 0x4c, 0x4f, 0x47, 0x44, 0x47, 0x55, 0x47, 0x50, 0x4d, 0x4a, 0x4c,
+ 0x50, 0x48, 0x47, 0x55, 0x4b, 0x4a, 0x52, 0x49, 0x3d, 0x3f, 0x4f, 0x51,
+ 0x48, 0x4e, 0x42, 0x4e, 0x42, 0x48, 0x4e, 0x49, 0x4a, 0x50, 0x45, 0x54,
+ 0x41, 0x43, 0x45, 0x4d, 0x48, 0x48, 0x48, 0x51, 0x53, 0x3e, 0x55, 0x44,
+ 0x52, 0x56, 0x44, 0x4d, 0x4e, 0x48, 0x4b, 0x43, 0x48, 0x53, 0x48, 0x44,
+ 0x49, 0x45, 0x4e, 0x50, 0x5d, 0x4a, 0x45, 0x4c, 0x45, 0x55, 0x43, 0x2e,
+ 0x59, 0x60, 0x4e, 0x4d, 0x32, 0x53, 0x3e, 0x3f, 0x40, 0x63, 0x41, 0x48,
+ 0x38, 0x73, 0x38, 0x46, 0x50, 0x3e, 0x3c, 0x23, 0x48, 0x61, 0x45, 0x3c,
+ 0x41, 0x41, 0x36, 0x3b, 0x58, 0x56, 0x4a, 0x40, 0x4f, 0x44, 0x45, 0x4c,
+ 0x5a, 0x56, 0x47, 0x3f, 0x4d, 0x4b, 0x46, 0x5d, 0x52, 0x47, 0x45, 0x4c,
+ 0x4a, 0x52, 0x4f, 0x4f, 0x4f, 0x43, 0x4f, 0x47, 0x43, 0x46, 0x3c, 0x4c,
+ 0x46, 0x55, 0x40, 0x53, 0x43, 0x3e, 0x42, 0x35, 0x51, 0x41, 0x42, 0x3f,
+ 0x45, 0x3d, 0x41, 0x31, 0x4e, 0x47, 0x48, 0x42, 0x41, 0x45, 0x43, 0x38,
+ 0x42, 0x40, 0x4a, 0x47, 0x4e, 0x43, 0x40, 0x43, 0x48, 0x49, 0x45, 0x4f,
+ 0x44, 0x42, 0x4d, 0x42, 0x42, 0x3f, 0x46, 0x52, 0x3c, 0x3c, 0x47, 0x43,
+ 0x46, 0x47, 0x45, 0x40, 0x4c, 0x44, 0x43, 0x4a, 0x4b, 0x4d, 0x4e, 0x46,
+ 0x51, 0x45, 0x47, 0x4b, 0x45, 0x50, 0x40, 0x42, 0x4c, 0x4c, 0x4c, 0x4f,
+ 0x44, 0x3c, 0x49, 0x3c, 0x3f, 0x45, 0x3f, 0x5c, 0x42, 0x3e, 0x4b, 0x4e,
+ 0x50, 0x45, 0x42, 0x5c, 0x4c, 0x48, 0x50, 0x52, 0x50, 0x47, 0x4b, 0x44,
+ 0x3d, 0x50, 0x55, 0x4c, 0x48, 0x3f, 0x4b, 0x44, 0x4a, 0x51, 0x42, 0x4c,
+ 0x60, 0x51, 0x41, 0x4b, 0x46, 0x5c, 0x42, 0x2c, 0x55, 0x61, 0x50, 0x52,
+ 0x37, 0x5a, 0x3f, 0x43, 0x43, 0x58, 0x3a, 0x4d, 0x3e, 0x72, 0x35, 0x3f,
+ 0x58, 0x41, 0x40, 0x1f, 0x55, 0x63, 0x3f, 0x49, 0x41, 0x3e, 0x35, 0x41,
+ 0x65, 0x54, 0x42, 0x45, 0x45, 0x3c, 0x44, 0x45, 0x59, 0x5a, 0x4d, 0x41,
+ 0x51, 0x46, 0x49, 0x59, 0x4c, 0x41, 0x42, 0x44, 0x4a, 0x45, 0x3f, 0x4a,
+ 0x4a, 0x44, 0x48, 0x48, 0x52, 0x40, 0x4a, 0x4a, 0x4d, 0x54, 0x44, 0x48,
+ 0x54, 0x46, 0x49, 0x3b, 0x42, 0x4a, 0x4e, 0x46, 0x4a, 0x45, 0x4f, 0x30,
+ 0x46, 0x41, 0x47, 0x46, 0x4b, 0x47, 0x46, 0x38, 0x4c, 0x3a, 0x4b, 0x46,
+ 0x52, 0x48, 0x4f, 0x3e, 0x48, 0x4a, 0x48, 0x4b, 0x44, 0x45, 0x4a, 0x46,
+ 0x3f, 0x4f, 0x40, 0x44, 0x43, 0x43, 0x4b, 0x39, 0x46, 0x43, 0x49, 0x49,
+ 0x49, 0x4a, 0x44, 0x48, 0x4c, 0x41, 0x4d, 0x52, 0x4c, 0x4a, 0x46, 0x3d,
+ 0x41, 0x4b, 0x41, 0x48, 0x45, 0x3b, 0x51, 0x54, 0x4a, 0x39, 0x4d, 0x41,
+ 0x54, 0x46, 0x4c, 0x53, 0x48, 0x3e, 0x4a, 0x3d, 0x41, 0x52, 0x54, 0x63,
+ 0x44, 0x4d, 0x4a, 0x43, 0x52, 0x4b, 0x52, 0x52, 0x4e, 0x41, 0x48, 0x42,
+ 0x48, 0x4d, 0x49, 0x45, 0x51, 0x48, 0x3e, 0x47, 0x5a, 0x52, 0x4a, 0x4e,
+ 0x3e, 0x59, 0x3c, 0x2e, 0x5c, 0x5b, 0x4c, 0x56, 0x30, 0x59, 0x3a, 0x48,
+ 0x3d, 0x5c, 0x44, 0x49, 0x40, 0x7c, 0x3a, 0x48, 0x54, 0x40, 0x41, 0x28,
+ 0x4d, 0x64, 0x46, 0x47, 0x49, 0x40, 0x30, 0x3a, 0x5f, 0x5b, 0x42, 0x37,
+ 0x49, 0x45, 0x40, 0x43, 0x5b, 0x54, 0x48, 0x4d, 0x4a, 0x47, 0x51, 0x58,
+ 0x4b, 0x3c, 0x4d, 0x46, 0x4b, 0x52, 0x4c, 0x58, 0x53, 0x46, 0x42, 0x45,
+ 0x4c, 0x4a, 0x4d, 0x4e, 0x52, 0x4d, 0x46, 0x44, 0x46, 0x3f, 0x46, 0x34,
+ 0x4f, 0x42, 0x44, 0x46, 0x44, 0x50, 0x47, 0x30, 0x44, 0x3c, 0x42, 0x46,
+ 0x4f, 0x4a, 0x52, 0x30, 0x55, 0x4f, 0x45, 0x4a, 0x48, 0x4c, 0x4e, 0x35,
+ 0x4e, 0x3c, 0x45, 0x4a, 0x45, 0x4a, 0x44, 0x3c, 0x4e, 0x4a, 0x51, 0x44,
+ 0x49, 0x40, 0x4a, 0x40, 0x41, 0x44, 0x4f, 0x4c, 0x43, 0x45, 0x4b, 0x43,
+ 0x3e, 0x3e, 0x4c, 0x44, 0x48, 0x48, 0x42, 0x42, 0x4d, 0x43, 0x50, 0x4d,
+ 0x49, 0x3c, 0x45, 0x4f, 0x4c, 0x46, 0x4b, 0x48, 0x4d, 0x4d, 0x49, 0x55,
+ 0x49, 0x3b, 0x40, 0x44, 0x4a, 0x4b, 0x4e, 0x5e, 0x43, 0x47, 0x45, 0x43,
+ 0x4d, 0x4d, 0x49, 0x46, 0x4a, 0x44, 0x4e, 0x3e, 0x52, 0x41, 0x47, 0x47,
+ 0x4a, 0x50, 0x48, 0x43, 0x5d, 0x4f, 0x49, 0x48, 0x43, 0x4f, 0x45, 0x3e,
+ 0x5a, 0x69, 0x4d, 0x5a, 0x3a, 0x5d, 0x3a, 0x48, 0x42, 0x55, 0x3e, 0x48,
+ 0x48, 0x7b, 0x37, 0x40, 0x57, 0x45, 0x48, 0x24, 0x50, 0x61, 0x4c, 0x4a,
+ 0x44, 0x41, 0x34, 0x38, 0x65, 0x5b, 0x4f, 0x3c, 0x4d, 0x3a, 0x4a, 0x4c,
+ 0x66, 0x55, 0x50, 0x47, 0x4d, 0x46, 0x47, 0x58, 0x4c, 0x48, 0x48, 0x48,
+ 0x4e, 0x59, 0x4f, 0x4b, 0x45, 0x45, 0x4b, 0x54, 0x46, 0x51, 0x4f, 0x44,
+ 0x42, 0x55, 0x48, 0x44, 0x48, 0x41, 0x53, 0x2e, 0x4d, 0x45, 0x44, 0x54,
+ 0x4a, 0x44, 0x53, 0x34, 0x4c, 0x46, 0x47, 0x3f, 0x4c, 0x4b, 0x47, 0x36,
+ 0x47, 0x41, 0x43, 0x40, 0x51, 0x46, 0x45, 0x33, 0x46, 0x3e, 0x47, 0x50,
+ 0x3f, 0x48, 0x48, 0x37, 0x41, 0x41, 0x42, 0x3e, 0x45, 0x3d, 0x49, 0x3e,
+ 0x4f, 0x42, 0x49, 0x4a, 0x46, 0x46, 0x48, 0x44, 0x49, 0x45, 0x46, 0x4a,
+ 0x4a, 0x47, 0x48, 0x43, 0x44, 0x45, 0x3f, 0x4c, 0x4c, 0x49, 0x4d, 0x51,
+ 0x4a, 0x4a, 0x49, 0x4c, 0x42, 0x4d, 0x4b, 0x4b, 0x4a, 0x42, 0x47, 0x4d,
+ 0x3e, 0x4b, 0x47, 0x5c, 0x49, 0x3d, 0x4e, 0x41, 0x44, 0x49, 0x3e, 0x3e,
+ 0x4b, 0x47, 0x4e, 0x45, 0x44, 0x4a, 0x4d, 0x4a, 0x4f, 0x46, 0x45, 0x52,
+ 0x60, 0x53, 0x49, 0x50, 0x3d, 0x4f, 0x43, 0x3d, 0x52, 0x64, 0x52, 0x58,
+ 0x39, 0x5f, 0x36, 0x4c, 0x45, 0x57, 0x42, 0x4b, 0x3f, 0x80, 0x34, 0x47,
+ 0x58, 0x41, 0x45, 0x1b, 0x4b, 0x5e, 0x4c, 0x40, 0x44, 0x42, 0x39, 0x3a,
+ 0x5e, 0x5b, 0x4b, 0x3a, 0x4b, 0x3f, 0x45, 0x3e, 0x69, 0x57, 0x4b, 0x45,
+ 0x4b, 0x3f, 0x45, 0x55, 0x49, 0x49, 0x48, 0x47, 0x41, 0x4f, 0x42, 0x53,
+ 0x49, 0x40, 0x42, 0x3e, 0x49, 0x47, 0x53, 0x47, 0x45, 0x51, 0x4a, 0x44,
+ 0x44, 0x45, 0x4e, 0x2a, 0x45, 0x42, 0x4a, 0x4b, 0x46, 0x4d, 0x41, 0x30,
+ 0x3d, 0x43, 0x3f, 0x48, 0x49, 0x44, 0x4d, 0x2e, 0x48, 0x4a, 0x4c, 0x51,
+ 0x50, 0x46, 0x3e, 0x2c, 0x4d, 0x3f, 0x47, 0x46, 0x3c, 0x40, 0x4c, 0x38,
+ 0x4f, 0x46, 0x47, 0x53, 0x3b, 0x3c, 0x4e, 0x3e, 0x49, 0x40, 0x43, 0x4c,
+ 0x4d, 0x48, 0x45, 0x3c, 0x4d, 0x4c, 0x4d, 0x45, 0x3f, 0x49, 0x4a, 0x43,
+ 0x4d, 0x41, 0x4b, 0x50, 0x4e, 0x46, 0x50, 0x44, 0x49, 0x44, 0x4e, 0x42,
+ 0x4a, 0x43, 0x4c, 0x4c, 0x49, 0x49, 0x44, 0x4e, 0x4b, 0x3f, 0x4b, 0x5d,
+ 0x41, 0x49, 0x4b, 0x46, 0x4e, 0x48, 0x45, 0x51, 0x4d, 0x45, 0x46, 0x45,
+ 0x4b, 0x4e, 0x3c, 0x4d, 0x3d, 0x41, 0x47, 0x47, 0x64, 0x54, 0x41, 0x55,
+ 0x47, 0x56, 0x44, 0x3b, 0x53, 0x66, 0x4f, 0x5e, 0x40, 0x5d, 0x38, 0x4a,
+ 0x41, 0x59, 0x42, 0x48, 0x47, 0xff, 0x36, 0x49, 0x59, 0x41, 0x43, 0x1d,
+ 0x4d, 0x5e, 0x44, 0x44, 0x50, 0x3f, 0x39, 0x40, 0x68, 0x5e, 0x4a, 0x41,
+ 0x52, 0x41, 0x43, 0x41, 0x68, 0x51, 0x45, 0x48, 0x4c, 0x46, 0x4a, 0x5e,
+ 0x4e, 0x40, 0x4d, 0x41, 0x41, 0x5c, 0x3f, 0x4e, 0x4c, 0x37, 0x48, 0x40,
+ 0x46, 0x47, 0x4f, 0x43, 0x53, 0x52, 0x3d, 0x44, 0x47, 0x44, 0x3d, 0x34,
+ 0x44, 0x42, 0x4a, 0x43, 0x4d, 0x3f, 0x53, 0x2e, 0x42, 0x47, 0x43, 0x4d,
+ 0x45, 0x45, 0x47, 0x31, 0x4d, 0x39, 0x41, 0x4a, 0x4a, 0x4d, 0x4b, 0x35,
+ 0x47, 0x4e, 0x4c, 0x40, 0x4a, 0x44, 0x44, 0x36, 0x3e, 0x49, 0x3f, 0x45,
+ 0x46, 0x43, 0x4e, 0x3c, 0x4d, 0x47, 0x4c, 0x48, 0x4a, 0x4b, 0x48, 0x39,
+ 0x46, 0x50, 0x4a, 0x4f, 0x46, 0x41, 0x44, 0x4a, 0x41, 0x4f, 0x4c, 0x4e,
+ 0x55, 0x46, 0x43, 0x46, 0x4a, 0x48, 0x4e, 0x46, 0x42, 0x40, 0x4f, 0x56,
+ 0x4c, 0x45, 0x4b, 0x46, 0x4a, 0x47, 0x42, 0x5e, 0x49, 0x4e, 0x46, 0x43,
+ 0x4e, 0x42, 0x45, 0x48, 0x47, 0x48, 0x4f, 0x45, 0x47, 0x51, 0x4b, 0x4c,
+ 0x51, 0x39, 0x4d, 0x48, 0x60, 0x57, 0x49, 0x52, 0x3d, 0x57, 0x46, 0x3d,
+ 0x53, 0x68, 0x4b, 0x60, 0x40, 0x5a, 0x41, 0x4b, 0x46, 0x56, 0x46, 0x4c,
+ 0x49, 0x7e, 0x2f, 0x48, 0x51, 0x42, 0x40, 0x20, 0x4b, 0x62, 0x4d, 0x41,
+ 0x4f, 0x43, 0x3d, 0x35, 0x63, 0x63, 0x46, 0x3e, 0x4e, 0x47, 0x40, 0x40,
+ 0x60, 0x52, 0x4c, 0x46, 0x49, 0x48, 0x4f, 0x56, 0x51, 0x47, 0x52, 0x4e,
+ 0x4b, 0x59, 0x55, 0x4f, 0x48, 0x3d, 0x48, 0x4a, 0x4d, 0x50, 0x47, 0x47,
+ 0x51, 0x52, 0x4d, 0x51, 0x45, 0x45, 0x47, 0x2d, 0x4d, 0x41, 0x43, 0x49,
+ 0x4d, 0x40, 0x4a, 0x2f, 0x4f, 0x43, 0x46, 0x4a, 0x3e, 0x4a, 0x4a, 0x2b,
+ 0x49, 0x4c, 0x4c, 0x3e, 0x41, 0x4c, 0x4a, 0x2b, 0x40, 0x44, 0x46, 0x4a,
+ 0x40, 0x44, 0x42, 0x38, 0x52, 0x42, 0x46, 0x51, 0x53, 0x4e, 0x45, 0x31,
+ 0x45, 0x47, 0x4f, 0x46, 0x49, 0x43, 0x45, 0x3b, 0x4b, 0x4b, 0x4b, 0x4c,
+ 0x43, 0x4a, 0x4c, 0x43, 0x4e, 0x40, 0x52, 0x44, 0x48, 0x49, 0x47, 0x4b,
+ 0x4e, 0x3d, 0x4e, 0x44, 0x48, 0x4d, 0x4f, 0x4f, 0x50, 0x36, 0x47, 0x41,
+ 0x4a, 0x44, 0x45, 0x56, 0x4f, 0x4c, 0x50, 0x4b, 0x45, 0x3e, 0x45, 0x4e,
+ 0x45, 0x45, 0x43, 0x40, 0x47, 0x4e, 0x45, 0x3e, 0x4a, 0x3f, 0x49, 0x50,
+ 0x62, 0x55, 0x48, 0x56, 0x3e, 0x57, 0x4f, 0x3b, 0x55, 0x6c, 0x50, 0x5c,
+ 0x3d, 0x54, 0x3d, 0x46, 0x43, 0x59, 0x3e, 0x51, 0x4d, 0x7b, 0x33, 0x47,
+ 0x52, 0x43, 0x3f, 0x25, 0x4a, 0x6f, 0x49, 0x3e, 0x50, 0x40, 0x41, 0x30,
+ 0x5e, 0x5c, 0x4a, 0x43, 0x4d, 0x42, 0x46, 0x3b, 0x63, 0x53, 0x4f, 0x43,
+ 0x58, 0x48, 0x4b, 0x59, 0x50, 0x4e, 0x4b, 0x51, 0x4a, 0x55, 0x44, 0x46,
+ 0x4c, 0x3d, 0x4c, 0x52, 0x44, 0x52, 0x4c, 0x41, 0x4f, 0x44, 0x4a, 0x47,
+ 0x4e, 0x48, 0x49, 0x2e, 0x3e, 0x45, 0x4c, 0x48, 0x41, 0x47, 0x4d, 0x2e,
+ 0x40, 0x4b, 0x4c, 0x42, 0x4d, 0x40, 0x4e, 0x2e, 0x43, 0x45, 0x4b, 0x43,
+ 0x3e, 0x49, 0x55, 0x35, 0x43, 0x42, 0x42, 0x40, 0x4e, 0x46, 0x44, 0x37,
+ 0x49, 0x41, 0x3f, 0x52, 0x47, 0x4b, 0x43, 0x33, 0x4b, 0x47, 0x4b, 0x4c,
+ 0x4d, 0x4b, 0x3f, 0x42, 0x44, 0x40, 0x49, 0x41, 0x42, 0x49, 0x4b, 0x46,
+ 0x4e, 0x4e, 0x47, 0x4e, 0x48, 0x48, 0x4b, 0x46, 0x51, 0x4b, 0x46, 0x4d,
+ 0x47, 0x4f, 0x3e, 0x51, 0x46, 0x4e, 0x46, 0x4b, 0x47, 0x48, 0x4e, 0x55,
+ 0x4c, 0x3d, 0x47, 0x51, 0x42, 0x45, 0x4f, 0x42, 0x52, 0x50, 0x44, 0x4c,
+ 0x44, 0x44, 0x43, 0x4d, 0x40, 0x42, 0x4d, 0x4b, 0x5d, 0x4e, 0x47, 0x54,
+ 0x47, 0x51, 0x43, 0x39, 0x58, 0x66, 0x4e, 0x5a, 0x41, 0x52, 0x36, 0x47,
+ 0x45, 0x5f, 0x34, 0x50, 0x46, 0x79, 0x30, 0x48, 0x50, 0x45, 0x32, 0x22,
+ 0x54, 0x64, 0x49, 0x46, 0x45, 0x3c, 0x42, 0x36, 0x65, 0x5c, 0x48, 0x3a,
+ 0x4d, 0x4b, 0x47, 0x3e, 0x63, 0x56, 0x4a, 0x48, 0x51, 0x42, 0x4f, 0x5e,
+ 0x4c, 0x44, 0x4b, 0x4c, 0x3d, 0x5a, 0x43, 0x4d, 0x42, 0x40, 0x4f, 0x4d,
+ 0x3f, 0x3e, 0x46, 0x40, 0x49, 0x42, 0x49, 0x40, 0x49, 0x4c, 0x4a, 0x2e,
+ 0x4b, 0x3f, 0x53, 0x4b, 0x48, 0x49, 0x3e, 0x34, 0x47, 0x4a, 0x4b, 0x46,
+ 0x3b, 0x49, 0x46, 0x34, 0x4b, 0x48, 0x4c, 0x49, 0x49, 0x43, 0x4f, 0x2e,
+ 0x44, 0x46, 0x48, 0x50, 0x46, 0x4e, 0x4a, 0x37, 0x4b, 0x4c, 0x4a, 0x50,
+ 0x45, 0x4a, 0x48, 0x3b, 0x48, 0x44, 0x48, 0x4a, 0x41, 0x44, 0x52, 0x3f,
+ 0x4c, 0x46, 0x4a, 0x45, 0x46, 0x49, 0x49, 0x36, 0x53, 0x3e, 0x48, 0x47,
+ 0x3f, 0x42, 0x41, 0x4c, 0x42, 0x4a, 0x52, 0x46, 0x49, 0x3f, 0x48, 0x5a,
+ 0x43, 0x42, 0x3d, 0x43, 0x4f, 0x44, 0x43, 0x65, 0x41, 0x41, 0x44, 0x4b,
+ 0x50, 0x44, 0x53, 0x49, 0x41, 0x45, 0x4a, 0x4d, 0x40, 0x45, 0x4a, 0x4e,
+ 0x50, 0x40, 0x51, 0x40, 0x5e, 0x50, 0x43, 0x5c, 0x47, 0x5a, 0x44, 0x4c,
+ 0x54, 0x64, 0x4f, 0x63, 0x39, 0x58, 0x3c, 0x4a, 0x42, 0x5e, 0x3c, 0x4a,
+ 0x48, 0x7b, 0x34, 0x4c, 0x4f, 0x44, 0x30, 0x24, 0x50, 0x65, 0x47, 0x39,
+ 0x46, 0x3e, 0x3f, 0x33, 0x65, 0x5a, 0x44, 0x38, 0x50, 0x47, 0x4b, 0x3e,
+ 0x5b, 0x53, 0x4a, 0x4d, 0x51, 0x40, 0x47, 0x59, 0x51, 0x42, 0x4f, 0x50,
+ 0x45, 0x57, 0x46, 0x50, 0x3f, 0x3c, 0x4c, 0x4f, 0x46, 0x41, 0x4a, 0x3e,
+ 0x4d, 0x45, 0x51, 0x48, 0x4e, 0x44, 0x4e, 0x35, 0x44, 0x3f, 0x44, 0x48,
+ 0x3c, 0x4c, 0x49, 0x2c, 0x4a, 0x46, 0x48, 0x44, 0x4b, 0x42, 0x4b, 0x2f,
+ 0x4e, 0x50, 0x4c, 0x4d, 0x44, 0x46, 0x3f, 0x39, 0x4d, 0x47, 0x45, 0x41,
+ 0x42, 0x47, 0x4a, 0x3a, 0x40, 0x3e, 0x4a, 0x51, 0x3f, 0x47, 0x44, 0x37,
+ 0x47, 0x4e, 0x47, 0x52, 0x45, 0x42, 0x4a, 0x3d, 0x43, 0x4d, 0x4d, 0x47,
+ 0x48, 0x43, 0x44, 0x44, 0x47, 0x4e, 0x52, 0x4b, 0x4e, 0x50, 0x42, 0x47,
+ 0x4b, 0x4b, 0x4e, 0x4c, 0x4e, 0x47, 0x50, 0x56, 0x46, 0x47, 0x4d, 0x49,
+ 0x4d, 0x46, 0x49, 0x5f, 0x49, 0x42, 0x4d, 0x44, 0x40, 0x4b, 0x52, 0x45,
+ 0x46, 0x4a, 0x4b, 0x49, 0x47, 0x4b, 0x42, 0x45, 0x42, 0x44, 0x46, 0x4c,
+ 0x62, 0x4a, 0x44, 0x53, 0x43, 0x5a, 0x48, 0x49, 0x59, 0x68, 0x46, 0x61,
+ 0x40, 0x5a, 0x3a, 0x4d, 0x45, 0x5e, 0x33, 0x4f, 0x4e, 0x74, 0x3e, 0x3e,
+ 0x5a, 0x4b, 0x34, 0x31, 0x52, 0x6c, 0x44, 0x39, 0x4c, 0x3b, 0x39, 0x3a,
+ 0x63, 0x65, 0x4b, 0x40, 0x50, 0x4d, 0x53, 0x4a, 0x69, 0x56, 0x54, 0x45,
+ 0x4c, 0x4c, 0x50, 0x5b, 0x4d, 0x4f, 0x3d, 0x4b, 0x44, 0x47, 0x43, 0x47,
+ 0x49, 0x3c, 0x49, 0x41, 0x41, 0x3f, 0x47, 0x43, 0x48, 0x47, 0x4c, 0x43,
+ 0x4a, 0x40, 0x4d, 0x32, 0x4b, 0x4d, 0x44, 0x48, 0x46, 0x44, 0x50, 0x2f,
+ 0x4e, 0x49, 0x53, 0x4b, 0x52, 0x47, 0x4b, 0x2b, 0x48, 0x4b, 0x4a, 0x4c,
+ 0x4d, 0x4c, 0x43, 0x37, 0x48, 0x3c, 0x4b, 0x42, 0x51, 0x3f, 0x45, 0x3c,
+ 0x49, 0x40, 0x42, 0x43, 0x4d, 0x4c, 0x3f, 0x3f, 0x4d, 0x43, 0x45, 0x42,
+ 0x48, 0x42, 0x48, 0x39, 0x51, 0x4e, 0x46, 0x4f, 0x3e, 0x4c, 0x45, 0x3e,
+ 0x3f, 0x3f, 0x43, 0x41, 0x4b, 0x4b, 0x43, 0x4d, 0x44, 0x3b, 0x48, 0x45,
+ 0x3c, 0x4a, 0x48, 0x5b, 0x3c, 0x4b, 0x4c, 0x44, 0x46, 0x3e, 0x45, 0x57,
+ 0x43, 0x42, 0x51, 0x4a, 0x46, 0x47, 0x43, 0x49, 0x42, 0x43, 0x50, 0x4e,
+ 0x4e, 0x44, 0x41, 0x4e, 0x4e, 0x41, 0x48, 0x47, 0x5c, 0x53, 0x44, 0x54,
+ 0x44, 0x5b, 0x45, 0x46, 0x55, 0x67, 0x4d, 0x5d, 0x40, 0x5a, 0x43, 0x4b,
+ 0x43, 0x60, 0x3c, 0x4b, 0x41, 0x79, 0x41, 0x41, 0x58, 0x48, 0x40, 0x3b,
+ 0x4f, 0x6c, 0x46, 0x3f, 0x53, 0x3a, 0x3d, 0x36, 0x5a, 0x57, 0x44, 0x41,
+ 0x4c, 0x47, 0x4e, 0x48, 0x62, 0x60, 0x4a, 0x46, 0x51, 0x3e, 0x52, 0x5f,
+ 0x4b, 0x46, 0x48, 0x4c, 0x4c, 0x55, 0x43, 0x46, 0x49, 0x3e, 0x41, 0x40,
+ 0x4d, 0x47, 0x46, 0x3b, 0x51, 0x3a, 0x4a, 0x45, 0x50, 0x47, 0x51, 0x38,
+ 0x44, 0x41, 0x40, 0x4b, 0x4d, 0x44, 0x4d, 0x28, 0x47, 0x3e, 0x44, 0x40,
+ 0x49, 0x49, 0x40, 0x3c, 0x44, 0x4c, 0x48, 0x51, 0x46, 0x3e, 0x47, 0x2a,
+ 0x41, 0x44, 0x49, 0x4c, 0x4e, 0x4e, 0x42, 0x3c, 0x49, 0x42, 0x43, 0x45,
+ 0x4e, 0x4d, 0x50, 0x39, 0x42, 0x43, 0x48, 0x41, 0x3f, 0x40, 0x4e, 0x3a,
+ 0x44, 0x3d, 0x49, 0x4d, 0x47, 0x45, 0x4b, 0x42, 0x4c, 0x4d, 0x3f, 0x3f,
+ 0x4e, 0x4d, 0x4d, 0x4d, 0x4d, 0x45, 0x47, 0x43, 0x4c, 0x46, 0x47, 0x57,
+ 0x4b, 0x42, 0x4d, 0x46, 0x4b, 0x4b, 0x43, 0x58, 0x48, 0x49, 0x4d, 0x47,
+ 0x43, 0x49, 0x4b, 0x48, 0x46, 0x4f, 0x4f, 0x42, 0x4a, 0x43, 0x49, 0x4e,
+ 0x4a, 0x47, 0x4c, 0x48, 0x5a, 0x57, 0x4a, 0x58, 0x49, 0x4f, 0x45, 0x47,
+ 0x63, 0x66, 0x4d, 0x5e, 0x4b, 0x51, 0x45, 0x4a, 0x43, 0x5d, 0x33, 0x4b,
+ 0x4e, 0x70, 0x42, 0x39, 0x57, 0x4a, 0x40, 0x3a, 0x51, 0x68, 0x45, 0x45,
+ 0x4c, 0x44, 0x3a, 0x3a, 0x4f, 0x62, 0x49, 0x45, 0x53, 0x4c, 0x4e, 0x41,
+ 0x63, 0x5e, 0x44, 0x44, 0x47, 0x43, 0x47, 0x59, 0x4c, 0x4b, 0x4c, 0x49,
+ 0x3e, 0x43, 0x4c, 0x46, 0x4c, 0x38, 0x47, 0x46, 0x46, 0x47, 0x40, 0x44,
+ 0x51, 0x3e, 0x40, 0x47, 0x3f, 0x45, 0x48, 0x2a, 0x42, 0x3e, 0x43, 0x46,
+ 0x50, 0x4c, 0x4a, 0x2c, 0x49, 0x4b, 0x48, 0x48, 0x40, 0x4a, 0x4a, 0x37,
+ 0x4e, 0x42, 0x4f, 0x4c, 0x41, 0x43, 0x45, 0x38, 0x4e, 0x3d, 0x41, 0x47,
+ 0x42, 0x42, 0x43, 0x3b, 0x4a, 0x40, 0x48, 0x4a, 0x53, 0x44, 0x4d, 0x35,
+ 0x51, 0x3c, 0x4e, 0x4e, 0x3e, 0x3f, 0x4b, 0x3c, 0x3e, 0x47, 0x41, 0x48,
+ 0x40, 0x46, 0x4e, 0x44, 0x49, 0x42, 0x49, 0x44, 0x4b, 0x46, 0x46, 0x43,
+ 0x4c, 0x4b, 0x49, 0x4d, 0x3d, 0x47, 0x43, 0x5c, 0x4a, 0x42, 0x47, 0x4e,
+ 0x47, 0x40, 0x4c, 0x55, 0x3f, 0x45, 0x46, 0x49, 0x46, 0x48, 0x49, 0x4d,
+ 0x4c, 0x41, 0x49, 0x40, 0x4a, 0x44, 0x42, 0x49, 0x52, 0x41, 0x49, 0x4a,
+ 0x5c, 0x53, 0x47, 0x58, 0x49, 0x55, 0x4a, 0x4a, 0x62, 0x61, 0x4b, 0x57,
+ 0x3c, 0x50, 0x42, 0x4c, 0x49, 0x5f, 0x3f, 0x4a, 0x42, 0x70, 0x40, 0x40,
+ 0x4f, 0x46, 0x43, 0x43, 0x4d, 0x6c, 0x41, 0x3e, 0x4e, 0x49, 0x43, 0x38,
+ 0x50, 0x57, 0x43, 0x39, 0x4a, 0x4f, 0x51, 0x3e, 0x5c, 0x57, 0x46, 0x49,
+ 0x41, 0x40, 0x42, 0x4f, 0x4c, 0x45, 0x46, 0x4a, 0x4c, 0x4b, 0x43, 0x42,
+ 0x4c, 0x3c, 0x47, 0x47, 0x4f, 0x44, 0x45, 0x3a, 0x4d, 0x3d, 0x4d, 0x3f,
+ 0x46, 0x4f, 0x41, 0x37, 0x46, 0x45, 0x54, 0x47, 0x4e, 0x46, 0x47, 0x23,
+ 0x48, 0x4e, 0x4a, 0x47, 0x45, 0x45, 0x4e, 0x33, 0x49, 0x4a, 0x4d, 0x4e,
+ 0x49, 0x46, 0x49, 0x36, 0x48, 0x44, 0x53, 0x44, 0x4a, 0x45, 0x4a, 0x37,
+ 0x45, 0x36, 0x4b, 0x4e, 0x50, 0x3f, 0x49, 0x38, 0x40, 0x43, 0x46, 0x4c,
+ 0x43, 0x46, 0x4a, 0x3f, 0x45, 0x3d, 0x44, 0x47, 0x44, 0x42, 0x4a, 0x45,
+ 0x47, 0x43, 0x4d, 0x4d, 0x44, 0x44, 0x4f, 0x4a, 0x4a, 0x41, 0x50, 0x50,
+ 0x4b, 0x44, 0x54, 0x5c, 0x4b, 0x3a, 0x46, 0x4a, 0x4a, 0x43, 0x48, 0x5c,
+ 0x4b, 0x43, 0x47, 0x3d, 0x3e, 0x54, 0x42, 0x47, 0x42, 0x4f, 0x4b, 0x4b,
+ 0x46, 0x46, 0x46, 0x42, 0x42, 0x4b, 0x48, 0x45, 0x51, 0x4e, 0x49, 0x4d,
+ 0x43, 0x56, 0x45, 0x40, 0x5a, 0x58, 0x4c, 0x55, 0x40, 0x4b, 0x4c, 0x51,
+ 0x42, 0x59, 0x43, 0x46, 0x46, 0x69, 0x43, 0x3c, 0x54, 0x47, 0x3d, 0x41,
+ 0x52, 0x64, 0x44, 0x38, 0x4f, 0x49, 0x3a, 0x3a, 0x55, 0x54, 0x45, 0x3e,
+ 0x49, 0x44, 0x4e, 0x3f, 0x57, 0x50, 0x47, 0x43, 0x45, 0x48, 0x53, 0x5b,
+ 0x53, 0x4d, 0x48, 0x4e, 0x48, 0x3a, 0x3e, 0x46, 0x42, 0x36, 0x50, 0x4d,
+ 0x49, 0x4b, 0x4b, 0x45, 0x4c, 0x44, 0x50, 0x47, 0x3e, 0x49, 0x50, 0x37,
+ 0x4c, 0x4b, 0x4a, 0x54, 0x4e, 0x43, 0x40, 0x25, 0x46, 0x42, 0x52, 0x3d,
+ 0x44, 0x45, 0x51, 0x2e, 0x4a, 0x3d, 0x46, 0x46, 0x4c, 0x42, 0x48, 0x34,
+ 0x44, 0x44, 0x44, 0x4c, 0x4f, 0x4b, 0x42, 0x3d, 0x45, 0x40, 0x47, 0x49,
+ 0x43, 0x41, 0x3e, 0x39, 0x47, 0x4b, 0x50, 0x4a, 0x46, 0x47, 0x4e, 0x3b,
+ 0x4e, 0x3e, 0x49, 0x4a, 0x50, 0x40, 0x43, 0x49, 0x48, 0x3c, 0x4f, 0x45,
+ 0x4a, 0x41, 0x42, 0x48, 0x4b, 0x46, 0x4a, 0x50, 0x40, 0x49, 0x44, 0x54,
+ 0x45, 0x45, 0x4a, 0x4b, 0x51, 0x51, 0x48, 0x53, 0x50, 0x3f, 0x50, 0x46,
+ 0x44, 0x45, 0x51, 0x43, 0x4f, 0x3e, 0x41, 0x41, 0x46, 0x45, 0x45, 0x4c,
+ 0x54, 0x3c, 0x4a, 0x4c, 0x5a, 0x4f, 0x46, 0x4b, 0x47, 0x4a, 0x43, 0x4c,
+ 0x56, 0x5a, 0x4a, 0x53, 0x4c, 0x49, 0x46, 0x4c, 0x45, 0x59, 0x40, 0x4b,
+ 0x48, 0x60, 0x3d, 0x42, 0x52, 0x3f, 0x42, 0x3d, 0x52, 0x5f, 0x46, 0x42,
+ 0x4b, 0x4e, 0x4a, 0x3d, 0x52, 0x55, 0x53, 0x37, 0x47, 0x3e, 0x4a, 0x42,
+ 0x51, 0x54, 0x48, 0x48, 0x4b, 0x48, 0x3e, 0x52, 0x41, 0x4e, 0x4c, 0x4f,
+ 0x43, 0x3b, 0x4b, 0x4b, 0x4c, 0x40, 0x48, 0x49, 0x4d, 0x3a, 0x45, 0x3c,
+ 0x53, 0x44, 0x48, 0x4d, 0x4b, 0x49, 0x46, 0x3c, 0x4d, 0x40, 0x51, 0x3f,
+ 0x4c, 0x45, 0x44, 0x2f, 0x49, 0x51, 0x3f, 0x4d, 0x3e, 0x4e, 0x3c, 0x30,
+ 0x3d, 0x48, 0x4f, 0x3f, 0x45, 0x45, 0x46, 0x3b, 0x4c, 0x46, 0x4d, 0x50,
+ 0x4c, 0x3d, 0x41, 0x37, 0x3e, 0x3e, 0x4f, 0x4b, 0x4d, 0x4f, 0x45, 0x45,
+ 0x4a, 0x47, 0x4a, 0x44, 0x43, 0x46, 0x51, 0x41, 0x4e, 0x39, 0x44, 0x4a,
+ 0x4e, 0x49, 0x4a, 0x42, 0x49, 0x4b, 0x4e, 0x48, 0x49, 0x4a, 0x45, 0x4a,
+ 0x45, 0x41, 0x4a, 0x4b, 0x42, 0x41, 0x48, 0x4a, 0x44, 0x3a, 0x46, 0x49,
+ 0x54, 0x45, 0x44, 0x60, 0x4a, 0x4e, 0x45, 0x4a, 0x4a, 0x45, 0x4b, 0x49,
+ 0x42, 0x44, 0x46, 0x50, 0x4b, 0x4b, 0x4e, 0x45, 0x48, 0x3e, 0x55, 0x42,
+ 0x51, 0x49, 0x49, 0x44, 0x4e, 0x54, 0x53, 0x49, 0x4c, 0x63, 0x48, 0x5a,
+ 0x50, 0x4b, 0x45, 0x49, 0x43, 0x57, 0x4c, 0x3f, 0x4d, 0x67, 0x3f, 0x47,
+ 0x53, 0x49, 0x43, 0x44, 0x49, 0x61, 0x50, 0x47, 0x49, 0x49, 0x4a, 0x42,
+ 0x4a, 0x51, 0x46, 0x43, 0x3f, 0x34, 0x40, 0x3a, 0x45, 0x54, 0x4c, 0x55,
+ 0x40, 0x3c, 0x4a, 0x4d, 0x3e, 0x4d, 0x48, 0x51, 0x4c, 0x3e, 0x4c, 0x4f,
+ 0x50, 0x47, 0x4d, 0x49, 0x4d, 0x4e, 0x45, 0x43, 0x41, 0x41, 0x40, 0x47,
+ 0x43, 0x4a, 0x4a, 0x3c, 0x4c, 0x3d, 0x4e, 0x43, 0x41, 0x42, 0x4a, 0x30,
+ 0x45, 0x4c, 0x45, 0x55, 0x46, 0x39, 0x43, 0x39, 0x45, 0x47, 0x48, 0x53,
+ 0x4a, 0x48, 0x43, 0x38, 0x4f, 0x51, 0x4d, 0x4c, 0x41, 0x46, 0x40, 0x3d,
+ 0x43, 0x4b, 0x40, 0x46, 0x47, 0x50, 0x4a, 0x43, 0x50, 0x4e, 0x45, 0x4f,
+ 0x4d, 0x44, 0x4d, 0x3f, 0x4e, 0x48, 0x4a, 0x49, 0x44, 0x3d, 0x4a, 0x44,
+ 0x40, 0x45, 0x49, 0x40, 0x4a, 0x44, 0x4f, 0x4a, 0x43, 0x4a, 0x4e, 0x52,
+ 0x4d, 0x50, 0x48, 0x4c, 0x43, 0x45, 0x4d, 0x54, 0x4a, 0x49, 0x4c, 0x58,
+ 0x4c, 0x48, 0x4c, 0x44, 0x4b, 0x4e, 0x52, 0x44, 0x49, 0x44, 0x47, 0x4e,
+ 0x4b, 0x45, 0x49, 0x3e, 0x4c, 0x3b, 0x53, 0x3f, 0x51, 0x41, 0x3f, 0x44,
+ 0x43, 0x4a, 0x4b, 0x43, 0x53, 0x57, 0x50, 0x53, 0x4f, 0x4b, 0x48, 0x51,
+ 0x47, 0x49, 0x46, 0x4d, 0x4d, 0x5e, 0x44, 0x46, 0x56, 0x3d, 0x3c, 0x3e,
+ 0x47, 0x55, 0x54, 0x46, 0x42, 0x49, 0x4f, 0x43, 0x48, 0x54, 0x51, 0x40,
+ 0x44, 0x44, 0x47, 0x45, 0x4b, 0x59, 0x4d, 0x47, 0x40, 0x39, 0x48, 0x54,
+ 0x43, 0x45, 0x44, 0x42, 0x4c, 0x3c, 0x4d, 0x42, 0x4b, 0x45, 0x42, 0x48,
+ 0x51, 0x44, 0x45, 0x3f, 0x3d, 0x49, 0x4b, 0x4a, 0x41, 0x43, 0x4f, 0x3f,
+ 0x51, 0x4b, 0x44, 0x46, 0x46, 0x44, 0x53, 0x3d, 0x47, 0x47, 0x43, 0x4b,
+ 0x41, 0x43, 0x3c, 0x3b, 0x49, 0x47, 0x47, 0x49, 0x4b, 0x3d, 0x43, 0x43,
+ 0x4b, 0x47, 0x45, 0x4e, 0x42, 0x4a, 0x4c, 0x3e, 0x51, 0x3e, 0x46, 0x44,
+ 0x46, 0x43, 0x42, 0x42, 0x47, 0x4d, 0x51, 0x4b, 0x49, 0x44, 0x4d, 0x40,
+ 0x50, 0x43, 0x41, 0x4c, 0x42, 0x49, 0x49, 0x4c, 0x42, 0x50, 0x48, 0x3f,
+ 0x46, 0x42, 0x48, 0x57, 0x49, 0x4d, 0x47, 0x4e, 0x48, 0x4b, 0x46, 0x50,
+ 0x47, 0x45, 0x52, 0x45, 0x4b, 0x48, 0x40, 0x5b, 0x4e, 0x43, 0x51, 0x48,
+ 0x48, 0x4a, 0x4a, 0x4a, 0x52, 0x51, 0x4c, 0x4b, 0x42, 0x55, 0x4d, 0x46,
+ 0x50, 0x40, 0x4a, 0x50, 0x51, 0x3e, 0x42, 0x4c, 0x43, 0x46, 0x4d, 0x46,
+ 0x46, 0x4d, 0x4d, 0x52, 0x4e, 0x44, 0x45, 0x47, 0x49, 0x4c, 0x41, 0x44,
+ 0x4d, 0x54, 0x4c, 0x4a, 0x54, 0x3e, 0x44, 0x43, 0x53, 0x55, 0x4b, 0x4a,
+ 0x47, 0x47, 0x4f, 0x46, 0x4f, 0x4b, 0x51, 0x3f, 0x41, 0x4c, 0x43, 0x46,
+ 0x55, 0x51, 0x40, 0x4b, 0x4f, 0x40, 0x47, 0x50, 0x4e, 0x4a, 0x46, 0x4e,
+ 0x42, 0x4d, 0x48, 0x49, 0x48, 0x4a, 0x4a, 0x43, 0x49, 0x48, 0x44, 0x3b,
+ 0x51, 0x46, 0x3d, 0x43, 0x47, 0x4a, 0x4f, 0x42, 0x4a, 0x50, 0x4f, 0x41,
+ 0x45, 0x45, 0x43, 0x3c, 0x4c, 0x4c, 0x46, 0x4b, 0x3e, 0x44, 0x4b, 0x3a,
+ 0x45, 0x50, 0x42, 0x48, 0x46, 0x47, 0x44, 0x3a, 0x53, 0x46, 0x4e, 0x4f,
+ 0x43, 0x40, 0x46, 0x48, 0x4e, 0x45, 0x3f, 0x47, 0x48, 0x3f, 0x44, 0x4f,
+ 0x44, 0x47, 0x4e, 0x47, 0x47, 0x49, 0x42, 0x43, 0x3f, 0x49, 0x4a, 0x53,
+ 0x53, 0x4a, 0x4e, 0x4a, 0x49, 0x4d, 0x49, 0x41, 0x48, 0x4d, 0x4d, 0x4e,
+ 0x4b, 0x45, 0x4d, 0x4a, 0x46, 0x4a, 0x46, 0x51, 0x4b, 0x47, 0x49, 0x45,
+ 0x49, 0x49, 0x4b, 0x5c, 0x48, 0x42, 0x51, 0x4c, 0x41, 0x3f, 0x4c, 0x42,
+ 0x4f, 0x45, 0x4b, 0x4a, 0x52, 0x48, 0x53, 0x4f, 0x40, 0x47, 0x41, 0x47,
+ 0x68, 0xfb, 0xff, 0xff, 0x4c, 0xfc, 0xff, 0xff, 0x20, 0x00, 0x00, 0x00,
+ 0x14, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0xe8, 0x03, 0x00, 0x00,
+ 0x01, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x03, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x90, 0x00, 0x00, 0x00,
+ 0x58, 0x01, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00, 0xd8, 0x00, 0x00, 0x00,
+ 0x38, 0x02, 0x00, 0x00, 0x9c, 0x02, 0x00, 0x00, 0xa0, 0x01, 0x00, 0x00,
+ 0x14, 0x03, 0x00, 0x00, 0xfe, 0xfc, 0xff, 0xff, 0x00, 0x00, 0x00, 0x03,
+ 0x10, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00,
+ 0x24, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x19, 0x00, 0x00, 0x00, 0x16, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00,
+ 0x04, 0x00, 0x00, 0x00, 0x52, 0x65, 0x6c, 0x75, 0x00, 0x00, 0x00, 0x00,
+ 0xcc, 0xfc, 0xff, 0xff, 0x2c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00,
+ 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x17, 0xbf, 0xd2, 0x3f, 0x01, 0x00, 0x00, 0x00, 0x58, 0xec, 0xd1, 0x43,
+ 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x6e, 0xfd, 0xff, 0xff,
+ 0x00, 0x00, 0x00, 0x02, 0x10, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00,
+ 0x10, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x08, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x43, 0x6f, 0x6e, 0x76,
+ 0x32, 0x44, 0x5f, 0x62, 0x69, 0x61, 0x73, 0x00, 0x34, 0xff, 0xff, 0xff,
+ 0x18, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x01, 0x00, 0x00, 0x00, 0xf5, 0xf7, 0x84, 0x3a, 0xc2, 0xfd, 0xff, 0xff,
+ 0x00, 0x00, 0x00, 0x03, 0x10, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00,
+ 0x1c, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
+ 0x01, 0x00, 0x00, 0x00, 0x31, 0x00, 0x00, 0x00, 0x2b, 0x00, 0x00, 0x00,
+ 0x01, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00, 0x52, 0x65, 0x73, 0x68,
+ 0x61, 0x70, 0x65, 0x5f, 0x31, 0x00, 0x00, 0x00, 0x94, 0xfd, 0xff, 0xff,
+ 0x30, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x18, 0x00, 0x00, 0x00,
+ 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x80, 0x3f, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7f, 0x43,
+ 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x3a, 0xfe, 0xff, 0xff,
+ 0x00, 0x00, 0x00, 0x02, 0x10, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00,
+ 0x10, 0x00, 0x00, 0x00, 0x28, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x04, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x00, 0x00, 0x4d, 0x61, 0x74, 0x4d,
+ 0x75, 0x6c, 0x5f, 0x62, 0x69, 0x61, 0x73, 0x00, 0x0c, 0x00, 0x0c, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x00, 0x00,
+ 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0xc5, 0x01, 0x2a, 0x3b, 0x96, 0xfe, 0xff, 0xff, 0x00, 0x00, 0x00, 0x03,
+ 0x10, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00,
+ 0x44, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x0a, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00,
+ 0x25, 0x00, 0x00, 0x00, 0x77, 0x65, 0x69, 0x67, 0x68, 0x74, 0x73, 0x5f,
+ 0x71, 0x75, 0x61, 0x6e, 0x74, 0x2f, 0x46, 0x61, 0x6b, 0x65, 0x51, 0x75,
+ 0x61, 0x6e, 0x74, 0x57, 0x69, 0x74, 0x68, 0x4d, 0x69, 0x6e, 0x4d, 0x61,
+ 0x78, 0x56, 0x61, 0x72, 0x73, 0x00, 0x00, 0x00, 0x84, 0xfe, 0xff, 0xff,
+ 0x2c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00,
+ 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0xab, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0xf5, 0xf7, 0x84, 0x3a,
+ 0x01, 0x00, 0x00, 0x00, 0x6e, 0x88, 0xae, 0x3d, 0x01, 0x00, 0x00, 0x00,
+ 0xd4, 0x97, 0x30, 0xbe, 0x26, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x03,
+ 0x10, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00,
+ 0x1c, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x04, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, 0x61, 0x64, 0x64, 0x5f,
+ 0x31, 0x00, 0x00, 0x00, 0xec, 0xfe, 0xff, 0xff, 0x2c, 0x00, 0x00, 0x00,
+ 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
+ 0x01, 0x00, 0x00, 0x00, 0x77, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x01, 0x00, 0x00, 0x00, 0x2f, 0xad, 0x18, 0x40, 0x01, 0x00, 0x00, 0x00,
+ 0x02, 0x38, 0xa2, 0x43, 0x01, 0x00, 0x00, 0x00, 0x02, 0xf1, 0x8d, 0xc3,
+ 0x8e, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x03, 0x10, 0x00, 0x00, 0x00,
+ 0x02, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00,
+ 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
+ 0x0e, 0x00, 0x00, 0x00, 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x73, 0x5f, 0x73,
+ 0x6f, 0x66, 0x74, 0x6d, 0x61, 0x78, 0x00, 0x00, 0x5c, 0xff, 0xff, 0xff,
+ 0x2c, 0x00, 0x00, 0x00, 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00,
+ 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3b,
+ 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x7f, 0x3f, 0x01, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x18, 0x00, 0x08, 0x00,
+ 0x07, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x14, 0x00, 0x0e, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x03, 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x14, 0x00, 0x00, 0x00, 0x54, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
+ 0x04, 0x00, 0x00, 0x00, 0x30, 0x11, 0x00, 0x00, 0x31, 0x00, 0x00, 0x00,
+ 0x77, 0x65, 0x69, 0x67, 0x68, 0x74, 0x73, 0x5f, 0x71, 0x75, 0x61, 0x6e,
+ 0x74, 0x5f, 0x31, 0x2f, 0x46, 0x61, 0x6b, 0x65, 0x51, 0x75, 0x61, 0x6e,
+ 0x74, 0x57, 0x69, 0x74, 0x68, 0x4d, 0x69, 0x6e, 0x4d, 0x61, 0x78, 0x56,
+ 0x61, 0x72, 0x73, 0x2f, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x73,
+ 0x65, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x14, 0x00, 0x04, 0x00, 0x08, 0x00,
+ 0x0c, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00, 0x2c, 0x00, 0x00, 0x00,
+ 0x20, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
+ 0x01, 0x00, 0x00, 0x00, 0x49, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x01, 0x00, 0x00, 0x00, 0x31, 0x83, 0xce, 0x3a, 0x01, 0x00, 0x00, 0x00,
+ 0x4d, 0x97, 0x92, 0x3e, 0x01, 0x00, 0x00, 0x00, 0x84, 0x75, 0xec, 0xbd,
+ 0x03, 0x00, 0x00, 0x00, 0xb4, 0x00, 0x00, 0x00, 0x5c, 0x00, 0x00, 0x00,
+ 0x04, 0x00, 0x00, 0x00, 0xc0, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x09,
+ 0x02, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00,
+ 0x24, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x04, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x08, 0x00, 0x04, 0x00,
+ 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x14, 0x00, 0x1c, 0x00,
+ 0x08, 0x00, 0x0c, 0x00, 0x10, 0x00, 0x07, 0x00, 0x14, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x18, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08,
+ 0x01, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00,
+ 0x28, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00,
+ 0x02, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x04, 0x00, 0x04, 0x00, 0x04, 0x00, 0x00, 0x00, 0x14, 0x00, 0x18, 0x00,
+ 0x00, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x07, 0x00, 0x10, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x14, 0x00, 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02,
+ 0x10, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00, 0x30, 0x00, 0x00, 0x00,
+ 0x1c, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00,
+ 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00,
+ 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x10, 0x00,
+ 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x0c, 0x00, 0x0c, 0x00, 0x00, 0x00,
+ 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x08, 0x00, 0x00, 0x00,
+ 0x03, 0x00, 0x00, 0x00, 0x24, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00,
+ 0x04, 0x00, 0x00, 0x00, 0xfa, 0xff, 0xff, 0xff, 0x00, 0x19, 0x06, 0x00,
+ 0x06, 0x00, 0x05, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x09, 0x06, 0x00,
+ 0x08, 0x00, 0x07, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04};
+const int g_tiny_conv_model_data_len = 19800;
diff --git a/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h
new file mode 100644
index 0000000000..2953cc852d
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.h
@@ -0,0 +1,27 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This is a standard TensorFlow Lite model file that has been converted into a
+// C data array, so it can be easily compiled into a binary for devices that
+// don't have a file system. It was created using the command:
+// xxd -i tiny_conv.tflite > tiny_conv_model_data.cc
+
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_TINY_CONV_MODEL_DATA_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_TINY_CONV_MODEL_DATA_H_
+
+extern const unsigned char g_tiny_conv_model_data[];
+extern const int g_tiny_conv_model_data_len;
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_EXAMPLES_MICRO_SPEECH_TINY_CONV_MODEL_DATA_H_
diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/BUILD b/tensorflow/contrib/lite/experimental/micro/kernels/BUILD
new file mode 100644
index 0000000000..a012f950e6
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/kernels/BUILD
@@ -0,0 +1,107 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
+load(
+ "//tensorflow/contrib/lite/experimental/micro/testing:micro_test.bzl",
+ "tflite_micro_cc_test",
+)
+
+cc_library(
+ name = "micro_ops",
+ srcs = [
+ "depthwise_conv.cc",
+ "fully_connected.cc",
+ "softmax.cc",
+ ],
+ hdrs = [
+ ],
+ copts = tflite_copts(),
+ deps = [
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/experimental/micro:micro_framework",
+ "//tensorflow/contrib/lite/kernels:kernel_util",
+ "//tensorflow/contrib/lite/kernels:op_macros",
+ "//tensorflow/contrib/lite/kernels:padding",
+ "//tensorflow/contrib/lite/kernels/internal:quantization_util",
+ "//tensorflow/contrib/lite/kernels/internal:reference_base",
+ "//tensorflow/contrib/lite/kernels/internal:tensor",
+ ],
+)
+
+cc_library(
+ name = "all_ops_resolver",
+ srcs = [
+ "all_ops_resolver.cc",
+ ],
+ hdrs = [
+ "all_ops_resolver.h",
+ ],
+ copts = tflite_copts(),
+ deps = [
+ ":micro_ops",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/experimental/micro:micro_framework",
+ ],
+)
+
+cc_library(
+ name = "test_utils",
+ srcs = [
+ ],
+ hdrs = [
+ "test_utils.h",
+ ],
+ copts = tflite_copts(),
+ deps = [
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/core/api",
+ "//tensorflow/contrib/lite/experimental/micro:micro_framework",
+ "//tensorflow/contrib/lite/experimental/micro/testing:micro_test",
+ ],
+)
+
+tflite_micro_cc_test(
+ name = "depthwise_conv_test",
+ srcs = [
+ "depthwise_conv_test.cc",
+ ],
+ deps = [
+ ":all_ops_resolver",
+ ":test_utils",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/experimental/micro:micro_framework",
+ "//tensorflow/contrib/lite/experimental/micro/testing:micro_test",
+ ],
+)
+
+tflite_micro_cc_test(
+ name = "fully_connected_test",
+ srcs = [
+ "fully_connected_test.cc",
+ ],
+ deps = [
+ ":all_ops_resolver",
+ ":test_utils",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/experimental/micro:micro_framework",
+ "//tensorflow/contrib/lite/experimental/micro/testing:micro_test",
+ ],
+)
+
+tflite_micro_cc_test(
+ name = "softmax_test",
+ srcs = [
+ "softmax_test.cc",
+ ],
+ deps = [
+ ":all_ops_resolver",
+ ":test_utils",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/experimental/micro:micro_framework",
+ "//tensorflow/contrib/lite/experimental/micro/testing:micro_test",
+ ],
+)
diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.cc b/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.cc
new file mode 100644
index 0000000000..bd0a37badb
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.cc
@@ -0,0 +1,43 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h"
+
+namespace tflite {
+namespace ops {
+namespace micro {
+
+TfLiteRegistration* Register_DEPTHWISE_CONV_2D();
+TfLiteRegistration* Micro_Register_DEPTHWISE_CONV_2D() {
+ return Register_DEPTHWISE_CONV_2D();
+}
+
+TfLiteRegistration* Register_FULLY_CONNECTED();
+TfLiteRegistration* Micro_Register_FULLY_CONNECTED() {
+ return Register_FULLY_CONNECTED();
+}
+
+TfLiteRegistration* Register_SOFTMAX();
+TfLiteRegistration* Micro_Register_SOFTMAX() { return Register_SOFTMAX(); }
+
+AllOpsResolver::AllOpsResolver() {
+ AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D,
+ Micro_Register_DEPTHWISE_CONV_2D());
+ AddBuiltin(BuiltinOperator_FULLY_CONNECTED, Micro_Register_FULLY_CONNECTED(),
+ /* min_version */ 1,
+ /* max_version */ 2);
+ AddBuiltin(BuiltinOperator_SOFTMAX, Micro_Register_SOFTMAX());
+}
+
+} // namespace micro
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h b/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h
new file mode 100644
index 0000000000..f836064a3f
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h
@@ -0,0 +1,34 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_ALL_OPS_RESOLVER_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_ALL_OPS_RESOLVER_H_
+
+#include "tensorflow/contrib/lite/experimental/micro/compatibility.h"
+#include "tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h"
+
+namespace tflite {
+namespace ops {
+namespace micro {
+
+class AllOpsResolver : public MicroMutableOpResolver {
+ public:
+ AllOpsResolver();
+
+ private:
+ TF_LITE_REMOVE_VIRTUAL_DELETE
+};
+
+} // namespace micro
+} // namespace ops
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_ALL_OPS_RESOLVER_H_
diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv.cc b/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv.cc
new file mode 100644
index 0000000000..4f17263181
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv.cc
@@ -0,0 +1,208 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/padding.h"
+
+#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h"
+
+namespace tflite {
+namespace ops {
+namespace micro {
+namespace depthwise_conv {
+namespace {
+
+constexpr int kInputTensor = 0;
+constexpr int kFilterTensor = 1;
+constexpr int kBiasTensor = 2;
+constexpr int kOutputTensor = 0;
+
+struct OpData {
+ TfLitePaddingValues padding;
+ // The scaling factor from input to output (aka the 'real multiplier') can
+ // be represented as a fixed point multiplier plus a left shift.
+ int32_t output_multiplier;
+ int output_shift;
+ // The range of the fused activation layer. For example for kNone and
+ // uint8_t these would be 0 and 255.
+ int32_t output_activation_min;
+ int32_t output_activation_max;
+};
+
+TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node,
+ TfLiteDepthwiseConvParams* params, int width,
+ int height, int filter_width, int filter_height,
+ int out_width, int out_height,
+ const TfLiteType data_type, OpData* data) {
+ data->padding.height = ComputePadding(params->stride_height, 1, height,
+ filter_height, out_height);
+ data->padding.width =
+ ComputePadding(params->stride_width, 1, width, filter_width, out_width);
+
+ // Note that quantized inference requires that all tensors have their
+ // parameters set. This is usually done during quantized training.
+ if (data_type != kTfLiteFloat32) {
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
+ const TfLiteTensor* bias =
+ GetOptionalInputTensor(context, node, kBiasTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ double real_multiplier = 0.0;
+ TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
+ context, input, filter, bias, output, &real_multiplier));
+ int exponent;
+ QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent);
+ data->output_shift = -exponent;
+ CalculateActivationRangeUint8(params->activation, output,
+ &data->output_activation_min,
+ &data->output_activation_max);
+ }
+ return kTfLiteOk;
+}
+
+} // namespace
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ return nullptr;
+}
+
+void Free(TfLiteContext* context, void* buffer) {}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ return kTfLiteOk;
+}
+
+void EvalFloat(TfLiteContext* context, TfLiteNode* node,
+ TfLiteDepthwiseConvParams* params, OpData* data,
+ const TfLiteTensor* input, const TfLiteTensor* filter,
+ const TfLiteTensor* bias, TfLiteTensor* output) {
+ float output_activation_min, output_activation_max;
+ CalculateActivationRange(params->activation, &output_activation_min,
+ &output_activation_max);
+
+ tflite::DepthwiseParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = data->padding.width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.stride_width = params->stride_width;
+ op_params.stride_height = params->stride_height;
+ op_params.dilation_width_factor = 1;
+ op_params.dilation_height_factor = 1;
+ op_params.depth_multiplier = params->depth_multiplier;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+
+ tflite::reference_ops::DepthwiseConv(
+ op_params, GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(filter), GetTensorData<float>(filter),
+ GetTensorShape(bias), GetTensorData<float>(bias), GetTensorShape(output),
+ GetTensorData<float>(output));
+}
+
+void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLiteDepthwiseConvParams* params, OpData* data,
+ const TfLiteTensor* input, const TfLiteTensor* filter,
+ const TfLiteTensor* bias, TfLiteTensor* output) {
+ const int32_t input_offset = -input->params.zero_point;
+ const int32_t filter_offset = -filter->params.zero_point;
+ const int32_t output_offset = output->params.zero_point;
+
+ tflite::DepthwiseParams op_params;
+ // Padding type is ignored, but still set.
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = data->padding.width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.stride_width = params->stride_width;
+ op_params.stride_height = params->stride_height;
+ op_params.dilation_width_factor = 1;
+ op_params.dilation_height_factor = 1;
+ op_params.depth_multiplier = params->depth_multiplier;
+ op_params.quantized_activation_min = data->output_activation_min;
+ op_params.quantized_activation_max = data->output_activation_max;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = data->output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = -data->output_shift;
+
+ tflite::reference_ops::DepthwiseConv(
+ op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
+ GetTensorShape(filter), GetTensorData<uint8_t>(filter),
+ GetTensorShape(bias), GetTensorData<int32_t>(bias),
+ GetTensorShape(output), GetTensorData<uint8_t>(output));
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteDepthwiseConvParams*>(node->builtin_data);
+
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* filter = GetInput(context, node, kFilterTensor);
+ const TfLiteTensor* bias =
+ (NumInputs(node) == 3) ? GetInput(context, node, kBiasTensor) : nullptr;
+
+ const TfLiteType data_type = input->type;
+ int width = SizeOfDimension(input, 2);
+ int height = SizeOfDimension(input, 1);
+ int filter_width = SizeOfDimension(filter, 2);
+ int filter_height = SizeOfDimension(filter, 1);
+ int out_width = ComputeOutSize(params->padding, width, filter_width,
+ params->stride_width);
+ int out_height = ComputeOutSize(params->padding, height, filter_height,
+ params->stride_height);
+ OpData local_data_object;
+ OpData* data = &local_data_object;
+ TF_LITE_ENSURE_STATUS(CalculateOpData(context, node, params, width, height,
+ filter_width, filter_height, out_width,
+ out_height, data_type, data));
+
+ // TODO(aselle): Consider whether float conv and quantized conv should be
+ // separate ops to avoid dispatch overhead here.
+ switch (input->type) { // Already know in/out types are same.
+ case kTfLiteFloat32:
+ EvalFloat(context, node, params, data, input, filter, bias, output);
+ break;
+ case kTfLiteUInt8:
+ EvalQuantized(context, node, params, data, input, filter, bias, output);
+ break;
+ default:
+ context->ReportError(context, "Type %d not currently supported.",
+ input->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+} // namespace depthwise_conv
+
+TfLiteRegistration* Register_DEPTHWISE_CONV_2D() {
+ static TfLiteRegistration r = {depthwise_conv::Init, depthwise_conv::Free,
+ depthwise_conv::Prepare, depthwise_conv::Eval};
+ return &r;
+}
+
+} // namespace micro
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test.cc b/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test.cc
new file mode 100644
index 0000000000..169899c471
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/kernels/depthwise_conv_test.cc
@@ -0,0 +1,406 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h"
+#include "tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h"
+#include "tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h"
+#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h"
+
+namespace tflite {
+namespace testing {
+namespace {
+
+void TestDepthwiseConvFloat(std::initializer_list<int> input_dims_data,
+ std::initializer_list<float> input_data,
+ std::initializer_list<int> filter_dims_data,
+ std::initializer_list<float> filter_data,
+ std::initializer_list<int> bias_dims_data,
+ std::initializer_list<float> bias_data,
+ std::initializer_list<float> expected_output_data,
+ std::initializer_list<int> output_dims_data,
+ TfLiteFusedActivation activation,
+ float* output_data) {
+ TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
+ TfLiteIntArray* filter_dims = IntArrayFromInitializer(filter_dims_data);
+ TfLiteIntArray* bias_dims = IntArrayFromInitializer(bias_dims_data);
+ TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
+ const int output_dims_count = ElementCount(*output_dims);
+
+ constexpr int inputs_size = 3;
+ constexpr int outputs_size = 1;
+ constexpr int tensors_size = inputs_size + outputs_size;
+ TfLiteTensor tensors[tensors_size] = {
+ CreateFloatTensor(input_data, input_dims, "input_tensor"),
+ CreateFloatTensor(filter_data, filter_dims, "filter_tensor"),
+ CreateFloatTensor(bias_data, bias_dims, "bias_tensor"),
+ CreateFloatTensor(output_data, output_dims, "output_tensor"),
+ };
+
+ TfLiteContext context;
+ PopulateContext(tensors, tensors_size, &context);
+
+ ::tflite::ops::micro::AllOpsResolver resolver;
+ const TfLiteRegistration* registration =
+ resolver.FindOp(tflite::BuiltinOperator_DEPTHWISE_CONV_2D, 1);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
+
+ int input_depth = input_dims->data[3];
+ int output_depth = filter_dims->data[3];
+ int depth_mul = output_depth / input_depth;
+ TfLiteDepthwiseConvParams builtin_data = {
+ kTfLitePaddingValid, 1, 1, depth_mul, activation,
+ };
+ const char* init_data = reinterpret_cast<const char*>(&builtin_data);
+ size_t init_data_size = 0;
+ void* user_data = nullptr;
+ if (registration->init) {
+ user_data = registration->init(&context, init_data, init_data_size);
+ }
+ int inputs_array_data[] = {3, 0, 1, 2};
+ TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
+ int outputs_array_data[] = {1, 3};
+ TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
+ int temporaries_array_data[] = {0};
+ TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
+
+ TfLiteNode node;
+ node.inputs = inputs_array;
+ node.outputs = outputs_array;
+ node.temporaries = temporaries_array;
+ node.user_data = user_data;
+ node.builtin_data = reinterpret_cast<void*>(&builtin_data);
+ node.custom_initial_data = nullptr;
+ node.custom_initial_data_size = 0;
+ node.delegate = nullptr;
+ if (registration->prepare) {
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
+ }
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
+ if (registration->free) {
+ registration->free(&context, user_data);
+ }
+ for (int i = 0; i < output_dims_count; ++i) {
+ TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i], output_data[i],
+ 1e-5f);
+ }
+}
+
+void TestDepthwiseConvQuantized(
+ std::initializer_list<int> input_dims_data,
+ std::initializer_list<uint8_t> input_data, float input_min, float input_max,
+ std::initializer_list<int> filter_dims_data,
+ std::initializer_list<uint8_t> filter_data, float filter_min,
+ float filter_max, std::initializer_list<int> bias_dims_data,
+ std::initializer_list<int32_t> bias_data, float bias_min, float bias_max,
+ std::initializer_list<uint8_t> expected_output_data,
+ std::initializer_list<int> output_dims_data, float output_min,
+ float output_max, TfLiteFusedActivation activation, uint8_t* output_data) {
+ TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
+ TfLiteIntArray* filter_dims = IntArrayFromInitializer(filter_dims_data);
+ TfLiteIntArray* bias_dims = IntArrayFromInitializer(bias_dims_data);
+ TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
+ const int output_dims_count = ElementCount(*output_dims);
+
+ constexpr int inputs_size = 3;
+ constexpr int outputs_size = 1;
+ constexpr int tensors_size = inputs_size + outputs_size;
+ TfLiteTensor tensors[tensors_size] = {
+ CreateQuantizedTensor(input_data, input_dims, "input_tensor", input_min,
+ input_max),
+ CreateQuantizedTensor(filter_data, filter_dims, "filter_tensor",
+ filter_min, filter_max),
+ CreateQuantized32Tensor(bias_data, bias_dims, "bias_tensor", bias_min,
+ bias_max),
+ CreateQuantizedTensor(output_data, output_dims, "output_tensor",
+ output_min, output_max),
+ };
+
+ TfLiteContext context;
+ PopulateContext(tensors, tensors_size, &context);
+
+ ::tflite::ops::micro::AllOpsResolver resolver;
+ const TfLiteRegistration* registration =
+ resolver.FindOp(tflite::BuiltinOperator_DEPTHWISE_CONV_2D, 1);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
+
+ int input_depth = input_dims->data[3];
+ int output_depth = filter_dims->data[3];
+ int depth_mul = output_depth / input_depth;
+ TfLiteDepthwiseConvParams builtin_data = {
+ kTfLitePaddingValid, 1, 1, depth_mul, activation,
+ };
+ const char* init_data = reinterpret_cast<const char*>(&builtin_data);
+ size_t init_data_size = 0;
+ void* user_data = nullptr;
+ if (registration->init) {
+ user_data = registration->init(&context, init_data, init_data_size);
+ }
+
+ int inputs_array_data[] = {3, 0, 1, 2};
+ TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
+ int outputs_array_data[] = {1, 3};
+ TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
+ int temporaries_array_data[] = {0};
+ TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
+
+ TfLiteNode node;
+ node.inputs = inputs_array;
+ node.outputs = outputs_array;
+ node.temporaries = temporaries_array;
+ node.user_data = user_data;
+ node.builtin_data = reinterpret_cast<void*>(&builtin_data);
+ node.custom_initial_data = nullptr;
+ node.custom_initial_data_size = 0;
+ node.delegate = nullptr;
+
+ if (registration->prepare) {
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
+ }
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
+ if (registration->free) {
+ registration->free(&context, user_data);
+ }
+ for (int i = 0; i < output_dims_count; ++i) {
+ TF_LITE_MICRO_EXPECT_EQ(expected_output_data.begin()[i], output_data[i]);
+ }
+}
+
+} // namespace
+} // namespace testing
+} // namespace tflite
+
+TF_LITE_MICRO_TESTS_BEGIN
+
+TF_LITE_MICRO_TEST(SimpleTest) {
+ const int output_dims_count = 8;
+ float output_data[output_dims_count];
+ tflite::testing::TestDepthwiseConvFloat( //
+ {4, 1, 3, 2, 2}, // Input shape.
+ {
+ 1, 2, 7, 8, // Input values.
+ 3, 4, 9, 10, //
+ 5, 6, 11, 12, //
+ },
+ {4, 1, 2, 2, 4}, // Filters shape.
+ {
+ 1, 2, 3, 4, // Filters values.
+ -9, 10, -11, 12, //
+ 5, 6, 7, 8, //
+ 13, -14, 15, -16, //
+ },
+ {1, 4}, // Bias shape.
+ {
+ 1, 2, 3, 4, // Bias values.
+ },
+ {
+ 71, -34, 99, -20, // Expected results.
+ 91, -26, 127, -4, //
+ },
+ {4, 1, 2, 1, 4}, // Output shape.
+ kTfLiteActNone, output_data);
+}
+
+TF_LITE_MICRO_TEST(SimpleTestQuantized) {
+ using tflite::testing::F2Q;
+ using tflite::testing::F2Q32;
+
+ const float input_min = -63.5f;
+ const float input_max = 64.0f;
+ const float filter_min = -63.5f;
+ const float filter_max = 64.0f;
+ const float bias_min = 0.0f;
+ const float bias_max = 64.0f * (1 << 24);
+ const float output_min = -127.0f;
+ const float output_max = 128.0f;
+ const int output_dims_count = 8;
+ uint8_t output_data[output_dims_count];
+
+ tflite::testing::TestDepthwiseConvQuantized( //
+ {4, 1, 3, 2, 2}, // Input shape.
+ {
+ // Input values.
+ F2Q(1, input_min, input_max),
+ F2Q(2, input_min, input_max),
+ F2Q(7, input_min, input_max),
+ F2Q(8, input_min, input_max),
+ F2Q(3, input_min, input_max),
+ F2Q(4, input_min, input_max),
+ F2Q(9, input_min, input_max),
+ F2Q(10, input_min, input_max),
+ F2Q(5, input_min, input_max),
+ F2Q(6, input_min, input_max),
+ F2Q(11, input_min, input_max),
+ F2Q(12, input_min, input_max),
+ },
+ input_min, input_max, // Input quantization range.
+ {4, 1, 2, 2, 4}, // Filter shape.
+ {
+ // Filter values.
+ F2Q(1, filter_min, filter_max),
+ F2Q(2, filter_min, filter_max),
+ F2Q(3, filter_min, filter_max),
+ F2Q(4, filter_min, filter_max),
+ F2Q(-9, filter_min, filter_max),
+ F2Q(10, filter_min, filter_max),
+ F2Q(-11, filter_min, filter_max),
+ F2Q(12, filter_min, filter_max),
+ F2Q(5, filter_min, filter_max),
+ F2Q(6, filter_min, filter_max),
+ F2Q(7, filter_min, filter_max),
+ F2Q(8, filter_min, filter_max),
+ F2Q(13, filter_min, filter_max),
+ F2Q(-14, filter_min, filter_max),
+ F2Q(15, filter_min, filter_max),
+ F2Q(-16, filter_min, filter_max),
+ },
+ filter_min, filter_max, // Filter quantization range.
+ {1, 4}, // Bias shape.
+ {
+ // Bias values.
+ F2Q32(1, bias_min, bias_max),
+ F2Q32(2, bias_min, bias_max),
+ F2Q32(3, bias_min, bias_max),
+ F2Q32(4, bias_min, bias_max),
+ },
+ bias_min, bias_max, // Bias quantization range.
+ {
+ // Expected results.
+ F2Q(71, output_min, output_max),
+ F2Q(-34, output_min, output_max),
+ F2Q(99, output_min, output_max),
+ F2Q(-20, output_min, output_max),
+ F2Q(91, output_min, output_max),
+ F2Q(-26, output_min, output_max),
+ F2Q(127, output_min, output_max),
+ F2Q(-4, output_min, output_max),
+ },
+ {4, 1, 2, 1, 4}, // Output shape.
+ output_min, output_max, // Output quantization range.
+ kTfLiteActNone, output_data);
+}
+
+TF_LITE_MICRO_TEST(SimpleTestRelu) {
+ const int output_dims_count = 8;
+ float output_data[output_dims_count];
+ tflite::testing::TestDepthwiseConvFloat( //
+ {4, 1, 3, 2, 2}, // Input shape.
+ {
+ 1, 2, 7, 8, // Input values.
+ 3, 4, 9, 10, //
+ 5, 6, 11, 12, //
+ },
+ {4, 1, 2, 2, 4}, // Filters shape.
+ {
+ 1, 2, 3, 4, // Filters values.
+ -9, 10, -11, 12, //
+ 5, 6, 7, 8, //
+ 13, -14, 15, -16, //
+ },
+ {1, 4}, // Bias shape.
+ {
+ 1, 2, 3, 4, // Bias values.
+ },
+ {
+ 71, 0, 99, 0, // Expected results.
+ 91, 0, 127, 0, //
+ },
+ {4, 1, 2, 1, 4}, // Output shape.
+ kTfLiteActRelu, output_data);
+}
+
+TF_LITE_MICRO_TEST(SimpleTestReluQuantized) {
+ using tflite::testing::F2Q;
+ using tflite::testing::F2Q32;
+
+ const float input_min = -63.5f;
+ const float input_max = 64.0f;
+ const float filter_min = -63.5f;
+ const float filter_max = 64.0f;
+ const float bias_min = 0.0f;
+ const float bias_max = 64.0f * (1 << 24);
+ const float output_min = -127.0f;
+ const float output_max = 128.0f;
+ const int output_dims_count = 8;
+ uint8_t output_data[output_dims_count];
+
+ tflite::testing::TestDepthwiseConvQuantized( //
+ {4, 1, 3, 2, 2}, // Input shape.
+ {
+ // Input values.
+ F2Q(1, input_min, input_max),
+ F2Q(2, input_min, input_max),
+ F2Q(7, input_min, input_max),
+ F2Q(8, input_min, input_max),
+ F2Q(3, input_min, input_max),
+ F2Q(4, input_min, input_max),
+ F2Q(9, input_min, input_max),
+ F2Q(10, input_min, input_max),
+ F2Q(5, input_min, input_max),
+ F2Q(6, input_min, input_max),
+ F2Q(11, input_min, input_max),
+ F2Q(12, input_min, input_max),
+ },
+ input_min, input_max, // Input quantization range.
+ {4, 1, 2, 2, 4}, // Filter shape.
+ {
+ // Filter values.
+ F2Q(1, filter_min, filter_max),
+ F2Q(2, filter_min, filter_max),
+ F2Q(3, filter_min, filter_max),
+ F2Q(4, filter_min, filter_max),
+ F2Q(-9, filter_min, filter_max),
+ F2Q(10, filter_min, filter_max),
+ F2Q(-11, filter_min, filter_max),
+ F2Q(12, filter_min, filter_max),
+ F2Q(5, filter_min, filter_max),
+ F2Q(6, filter_min, filter_max),
+ F2Q(7, filter_min, filter_max),
+ F2Q(8, filter_min, filter_max),
+ F2Q(13, filter_min, filter_max),
+ F2Q(-14, filter_min, filter_max),
+ F2Q(15, filter_min, filter_max),
+ F2Q(-16, filter_min, filter_max),
+ },
+ filter_min, filter_max, // Filter quantization range.
+ {1, 4}, // Bias shape.
+ {
+ // Bias values.
+ F2Q32(1, bias_min, bias_max),
+ F2Q32(2, bias_min, bias_max),
+ F2Q32(3, bias_min, bias_max),
+ F2Q32(4, bias_min, bias_max),
+ },
+ bias_min, bias_max, // Bias quantization range.
+ {
+ // Expected results.
+ F2Q(71, output_min, output_max),
+ F2Q(0, output_min, output_max),
+ F2Q(99, output_min, output_max),
+ F2Q(0, output_min, output_max),
+ F2Q(91, output_min, output_max),
+ F2Q(0, output_min, output_max),
+ F2Q(127, output_min, output_max),
+ F2Q(0, output_min, output_max),
+ },
+ {4, 1, 2, 1, 4}, // Output shape.
+ output_min, output_max, // Output quantization range.
+ kTfLiteActRelu, output_data);
+}
+
+TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected.cc b/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected.cc
new file mode 100644
index 0000000000..1e9e54cafb
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected.cc
@@ -0,0 +1,184 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/kernels/internal/reference/fully_connected.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace micro {
+namespace fully_connected {
+namespace {
+
+struct OpData {
+ // The scaling factor from input to output (aka the 'real multiplier') can
+ // be represented as a fixed point multiplier plus a left shift.
+ int32_t output_multiplier;
+ int output_shift;
+ // The range of the fused activation layer. For example for kNone and
+ // uint8_t these would be 0 and 255.
+ int32_t output_activation_min;
+ int32_t output_activation_max;
+ // The index of the temporary tensor where the quantized inputs are cached.
+ int input_quantized_index;
+};
+
+constexpr int kInputTensor = 0;
+constexpr int kWeightsTensor = 1;
+constexpr int kBiasTensor = 2;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus CalculateOpData(TfLiteContext* context,
+ TfLiteFullyConnectedParams* params,
+ TfLiteType data_type, const TfLiteTensor* input,
+ const TfLiteTensor* filter,
+ const TfLiteTensor* bias, TfLiteTensor* output,
+ OpData* data) {
+ TfLiteStatus status = kTfLiteOk;
+ if (data_type != kTfLiteFloat32) {
+ double real_multiplier = 0.0;
+ TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
+ context, input, filter, bias, output, &real_multiplier));
+ int exponent;
+ QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent);
+ data->output_shift = -exponent;
+ TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
+ context, params->activation, output, &data->output_activation_min,
+ &data->output_activation_max));
+ }
+ return status;
+}
+
+} // namespace
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ return nullptr;
+}
+
+void Free(TfLiteContext* context, void* buffer) {}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLiteFullyConnectedParams* params, OpData* data,
+ const TfLiteTensor* input,
+ const TfLiteTensor* filter, const TfLiteTensor* bias,
+ TfLiteTensor* output) {
+ const int32_t input_offset = -input->params.zero_point;
+ const int32_t filter_offset = -filter->params.zero_point;
+ const int32_t output_offset = output->params.zero_point;
+
+ tflite::FullyConnectedParams op_params;
+ op_params.input_offset = input_offset;
+ op_params.weights_offset = filter_offset;
+ op_params.output_offset = output_offset;
+ op_params.output_multiplier = data->output_multiplier;
+ // Legacy ops used mixed left and right shifts. Now all are +ve-means-left.
+ op_params.output_shift = -data->output_shift;
+ op_params.quantized_activation_min = data->output_activation_min;
+ op_params.quantized_activation_max = data->output_activation_max;
+
+#define TF_LITE_FULLY_CONNECTED(output_data_type) \
+ reference_ops::FullyConnected( \
+ op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
+ GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
+ GetTensorShape(bias), GetTensorData<int32_t>(bias), \
+ GetTensorShape(output), GetTensorData<output_data_type>(output), \
+ nullptr)
+ switch (output->type) {
+ case kTfLiteUInt8:
+ TF_LITE_FULLY_CONNECTED(uint8_t);
+ break;
+ case kTfLiteInt16:
+ TF_LITE_FULLY_CONNECTED(int16_t);
+ break;
+ default:
+ context->ReportError(
+ context,
+ "Quantized FullyConnected expects output data type uint8 or int16");
+ return kTfLiteError;
+ }
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
+ TfLiteFullyConnectedParams* params, OpData* data,
+ const TfLiteTensor* input, const TfLiteTensor* filter,
+ const TfLiteTensor* bias, TfLiteTensor* output) {
+ float output_activation_min, output_activation_max;
+ CalculateActivationRange(params->activation, &output_activation_min,
+ &output_activation_max);
+ tflite::FullyConnectedParams op_params;
+ op_params.float_activation_min = output_activation_min;
+ op_params.float_activation_max = output_activation_max;
+ tflite::reference_ops::FullyConnected(
+ op_params, GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(filter), GetTensorData<float>(filter),
+ GetTensorShape(bias), GetTensorData<float>(bias), GetTensorShape(output),
+ GetTensorData<float>(output));
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
+ const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ TfLiteType data_type = input->type;
+ OpData local_data_object;
+ OpData* data = &local_data_object;
+ TF_LITE_ENSURE_STATUS(CalculateOpData(context, params, data_type, input,
+ filter, bias, output, data));
+
+ switch (filter->type) { // Already know in/out types are same.
+ case kTfLiteFloat32:
+ return EvalFloat(context, node, params, data, input, filter, bias,
+ output);
+ case kTfLiteUInt8:
+ return EvalQuantized(context, node, params, data, input, filter, bias,
+ output);
+
+ default:
+ context->ReportError(context, "Type %d not currently supported.",
+ filter->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+} // namespace fully_connected
+
+TfLiteRegistration* Register_FULLY_CONNECTED() {
+ static TfLiteRegistration r = {fully_connected::Init, fully_connected::Free,
+ fully_connected::Prepare,
+ fully_connected::Eval};
+ return &r;
+}
+
+} // namespace micro
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected_test.cc b/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected_test.cc
new file mode 100644
index 0000000000..b42bf4c3bc
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/kernels/fully_connected_test.cc
@@ -0,0 +1,643 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h"
+#include "tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h"
+#include "tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h"
+#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h"
+
+namespace tflite {
+namespace testing {
+namespace {
+
+void TestFullyConnectedFloat(std::initializer_list<int> input_dims_data,
+ std::initializer_list<float> input_data,
+ std::initializer_list<int> weights_dims_data,
+ std::initializer_list<float> weights_data,
+ std::initializer_list<int> bias_dims_data,
+ std::initializer_list<float> bias_data,
+ std::initializer_list<float> expected_output_data,
+ std::initializer_list<int> output_dims_data,
+ TfLiteFusedActivation activation,
+ float* output_data) {
+ TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
+ TfLiteIntArray* weights_dims = IntArrayFromInitializer(weights_dims_data);
+ TfLiteIntArray* bias_dims = IntArrayFromInitializer(bias_dims_data);
+ TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
+ const int output_dims_count = ElementCount(*output_dims);
+
+ constexpr int inputs_size = 3;
+ constexpr int outputs_size = 1;
+ constexpr int tensors_size = inputs_size + outputs_size;
+ TfLiteTensor tensors[tensors_size] = {
+ CreateFloatTensor(input_data, input_dims, "input_tensor"),
+ CreateFloatTensor(weights_data, weights_dims, "weights_tensor"),
+ CreateFloatTensor(bias_data, bias_dims, "bias_tensor"),
+ CreateFloatTensor(output_data, output_dims, "output_tensor"),
+ };
+
+ TfLiteContext context;
+ PopulateContext(tensors, tensors_size, &context);
+
+ ::tflite::ops::micro::AllOpsResolver resolver;
+ const TfLiteRegistration* registration =
+ resolver.FindOp(tflite::BuiltinOperator_FULLY_CONNECTED, 1);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
+
+ TfLiteFullyConnectedParams builtin_data = {
+ activation,
+ kTfLiteFullyConnectedWeightsFormatDefault,
+ };
+ const char* init_data = reinterpret_cast<const char*>(&builtin_data);
+ size_t init_data_size = 0;
+ void* user_data = nullptr;
+ if (registration->init) {
+ user_data = registration->init(&context, init_data, init_data_size);
+ }
+ int inputs_array_data[] = {3, 0, 1, 2};
+ TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
+ int outputs_array_data[] = {1, 3};
+ TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
+ int temporaries_array_data[] = {0};
+ TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
+
+ TfLiteNode node;
+ node.inputs = inputs_array;
+ node.outputs = outputs_array;
+ node.temporaries = temporaries_array;
+ node.user_data = user_data;
+ node.builtin_data = reinterpret_cast<void*>(&builtin_data);
+ node.custom_initial_data = nullptr;
+ node.custom_initial_data_size = 0;
+ node.delegate = nullptr;
+ if (registration->prepare) {
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
+ }
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
+ if (registration->free) {
+ registration->free(&context, user_data);
+ }
+ for (int i = 0; i < output_dims_count; ++i) {
+ TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i], output_data[i],
+ 1e-5f);
+ }
+}
+
+void TestFullyConnectedQuantized(
+ std::initializer_list<int> input_dims_data,
+ std::initializer_list<uint8_t> input_data, float input_min, float input_max,
+ std::initializer_list<int> weights_dims_data,
+ std::initializer_list<uint8_t> weights_data, float weights_min,
+ float weights_max, std::initializer_list<int> bias_dims_data,
+ std::initializer_list<int32_t> bias_data, float bias_min, float bias_max,
+ std::initializer_list<uint8_t> expected_output_data,
+ std::initializer_list<int> output_dims_data, float output_min,
+ float output_max, TfLiteFusedActivation activation, uint8_t* output_data) {
+ TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
+ TfLiteIntArray* weights_dims = IntArrayFromInitializer(weights_dims_data);
+ TfLiteIntArray* bias_dims = IntArrayFromInitializer(bias_dims_data);
+ TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
+ const int output_dims_count = ElementCount(*output_dims);
+
+ constexpr int inputs_size = 3;
+ constexpr int outputs_size = 1;
+ constexpr int tensors_size = inputs_size + outputs_size;
+ TfLiteTensor tensors[tensors_size] = {
+ CreateQuantizedTensor(input_data, input_dims, "input_tensor", input_min,
+ input_max),
+ CreateQuantizedTensor(weights_data, weights_dims, "weights_tensor",
+ weights_min, weights_max),
+ CreateQuantized32Tensor(bias_data, bias_dims, "bias_tensor", bias_min,
+ bias_max),
+ CreateQuantizedTensor(output_data, output_dims, "output_tensor",
+ output_min, output_max),
+ };
+
+ TfLiteContext context;
+ PopulateContext(tensors, tensors_size, &context);
+
+ ::tflite::ops::micro::AllOpsResolver resolver;
+ const TfLiteRegistration* registration =
+ resolver.FindOp(tflite::BuiltinOperator_FULLY_CONNECTED, 1);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
+
+ TfLiteFullyConnectedParams builtin_data = {
+ activation,
+ kTfLiteFullyConnectedWeightsFormatDefault,
+ };
+ const char* init_data = reinterpret_cast<const char*>(&builtin_data);
+ size_t init_data_size = 0;
+ void* user_data = nullptr;
+ if (registration->init) {
+ user_data = registration->init(&context, init_data, init_data_size);
+ }
+
+ int inputs_array_data[] = {3, 0, 1, 2};
+ TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
+ int outputs_array_data[] = {1, 3};
+ TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
+ int temporaries_array_data[] = {0};
+ TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
+
+ TfLiteNode node;
+ node.inputs = inputs_array;
+ node.outputs = outputs_array;
+ node.temporaries = temporaries_array;
+ node.user_data = user_data;
+ node.builtin_data = reinterpret_cast<void*>(&builtin_data);
+ node.custom_initial_data = nullptr;
+ node.custom_initial_data_size = 0;
+ node.delegate = nullptr;
+
+ if (registration->prepare) {
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
+ }
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
+ if (registration->free) {
+ registration->free(&context, user_data);
+ }
+ for (int i = 0; i < output_dims_count; ++i) {
+ TF_LITE_MICRO_EXPECT_EQ(expected_output_data.begin()[i], output_data[i]);
+ }
+}
+
+} // namespace
+} // namespace testing
+} // namespace tflite
+
+TF_LITE_MICRO_TESTS_BEGIN
+
+TF_LITE_MICRO_TEST(SimpleTest) {
+ const int output_dims_count = 6;
+ float output_data[output_dims_count];
+ tflite::testing::TestFullyConnectedFloat( //
+ {2, 2, 10}, // Input shape.
+ {
+ 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
+ 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
+ },
+ {2, 3, 10}, // Weights shape.
+ {
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
+ },
+ {1, 3}, // Bias shape.
+ {
+ 1, 2, 3, // Bias values.
+ },
+ {
+ 24, 25, 26, 58, 59, 60, // Expected results.
+ },
+ {2, 2, 3}, // Output shape.
+ kTfLiteActNone, output_data);
+}
+
+TF_LITE_MICRO_TEST(SimpleTest2) {
+ const int output_dims_count = 6;
+ float output_data[output_dims_count];
+ tflite::testing::TestFullyConnectedFloat( //
+ {2, 2, 2}, // Input shape.
+ {
+ 1, 2, // b = 0
+ 2, 1, // b = 1
+ },
+ {2, 1, 2}, // Weights shape.
+ {
+ 2, 4, // u = 0
+ },
+ {1, 1}, // Bias shape.
+ {
+ 1, // Bias values.
+ },
+ {
+ 11, 9, // Expected results.
+ },
+ {2, 2, 1}, // Output shape.
+ kTfLiteActNone, output_data);
+}
+
+TF_LITE_MICRO_TEST(SimpleTestRelu) {
+ const int output_dims_count = 6;
+ float output_data[output_dims_count];
+ tflite::testing::TestFullyConnectedFloat( //
+ {2, 2, 10}, // Input shape.
+ {
+ 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
+ 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
+ },
+ {2, 3, 10}, // Weights shape.
+ {
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
+ -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, // u = 1
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
+ },
+ {1, 3}, // Bias shape.
+ {
+ 1, -2, 3, // Bias values.
+ },
+ {
+ 24, 0, 26, 58, 0, 60, // Expected results.
+ },
+ {2, 2, 3}, // Output shape.
+ kTfLiteActRelu, output_data);
+}
+
+TF_LITE_MICRO_TEST(SimpleTestQuantized) {
+ using tflite::testing::F2Q;
+ using tflite::testing::F2Q32;
+
+ const float input_min = -63.5f;
+ const float input_max = 64.0f;
+ const float weights_min = -63.5f;
+ const float weights_max = 64.0f;
+ const float bias_min = 0.0f;
+ const float bias_max = 64.0f * (1 << 24);
+ const float output_min = -127.0f;
+ const float output_max = 128.0f;
+ const int output_dims_count = 6;
+ uint8_t output_data[output_dims_count];
+ tflite::testing::TestFullyConnectedQuantized( //
+ {2, 2, 10}, // Input shape.
+ {
+ // Input values.
+ F2Q(1, input_min, input_max), F2Q(2, input_min, input_max),
+ F2Q(3, input_min, input_max), F2Q(4, input_min, input_max),
+ F2Q(5, input_min, input_max), F2Q(6, input_min, input_max),
+ F2Q(7, input_min, input_max), F2Q(8, input_min, input_max),
+ F2Q(-9, input_min, input_max), F2Q(-10, input_min, input_max),
+ F2Q(1, input_min, input_max), F2Q(2, input_min, input_max),
+ F2Q(3, input_min, input_max), F2Q(4, input_min, input_max),
+ F2Q(5, input_min, input_max), F2Q(6, input_min, input_max),
+ F2Q(7, input_min, input_max), F2Q(-8, input_min, input_max),
+ F2Q(9, input_min, input_max), F2Q(-10, input_min, input_max),
+ },
+ input_min, input_max, // Input quantization range.
+ {2, 3, 10}, // Weights shape.
+ {
+ // Weight values.
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ },
+ weights_min, weights_max, // Weights quantization range.
+ {1, 3}, // Bias shape.
+ {
+ F2Q32(1, bias_min, bias_max),
+ F2Q32(2, bias_min, bias_max),
+ F2Q32(3, bias_min, bias_max),
+ },
+ bias_min, bias_max, // Bias quantization range.
+ {
+ // Expected results.
+ F2Q(24, output_min, output_max),
+ F2Q(25, output_min, output_max),
+ F2Q(26, output_min, output_max),
+ F2Q(58, output_min, output_max),
+ F2Q(59, output_min, output_max),
+ F2Q(60, output_min, output_max),
+ },
+ {2, 2, 3}, // Output shape.
+ output_min, output_max, // Output quantization range.
+ kTfLiteActNone, output_data);
+}
+
+TF_LITE_MICRO_TEST(SimpleTestQuantizedRelu) {
+ using tflite::testing::F2Q;
+ using tflite::testing::F2Q32;
+
+ const float input_min = -63.5f;
+ const float input_max = 64.0f;
+ const float weights_min = -63.5f;
+ const float weights_max = 64.0f;
+ const float bias_min = 0.0f;
+ const float bias_max = 64.0f * (1 << 24);
+ const float output_min = -127.0f;
+ const float output_max = 128.0f;
+ const int output_dims_count = 6;
+ uint8_t output_data[output_dims_count];
+ tflite::testing::TestFullyConnectedQuantized( //
+ {2, 2, 10}, // Input shape.
+ {
+ // Input values.
+ F2Q(1, input_min, input_max), F2Q(2, input_min, input_max),
+ F2Q(3, input_min, input_max), F2Q(4, input_min, input_max),
+ F2Q(5, input_min, input_max), F2Q(6, input_min, input_max),
+ F2Q(7, input_min, input_max), F2Q(8, input_min, input_max),
+ F2Q(-9, input_min, input_max), F2Q(-10, input_min, input_max),
+ F2Q(1, input_min, input_max), F2Q(2, input_min, input_max),
+ F2Q(3, input_min, input_max), F2Q(4, input_min, input_max),
+ F2Q(5, input_min, input_max), F2Q(6, input_min, input_max),
+ F2Q(7, input_min, input_max), F2Q(-8, input_min, input_max),
+ F2Q(9, input_min, input_max), F2Q(-10, input_min, input_max),
+ },
+ input_min, input_max, // Input quantization range.
+ {2, 3, 10}, // Weights shape.
+ {
+ // Weight values.
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ F2Q(-1, weights_min, weights_max), F2Q(-2, weights_min, weights_max),
+ F2Q(-3, weights_min, weights_max), F2Q(-4, weights_min, weights_max),
+ F2Q(-5, weights_min, weights_max), F2Q(-6, weights_min, weights_max),
+ F2Q(-7, weights_min, weights_max), F2Q(-8, weights_min, weights_max),
+ F2Q(-9, weights_min, weights_max), F2Q(-10, weights_min, weights_max),
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ },
+ weights_min, weights_max, // Weights quantization range.
+ {1, 3}, // Bias shape.
+ {
+ F2Q32(1, bias_min, bias_max),
+ F2Q32(0, bias_min, bias_max),
+ F2Q32(3, bias_min, bias_max),
+ },
+ bias_min, bias_max, // Bias quantization range.
+ {
+ // Expected results.
+ F2Q(24, output_min, output_max),
+ F2Q(0, output_min, output_max),
+ F2Q(26, output_min, output_max),
+ F2Q(58, output_min, output_max),
+ F2Q(0, output_min, output_max),
+ F2Q(60, output_min, output_max),
+ },
+ {2, 2, 3}, // Output shape.
+ output_min, output_max, // Output quantization range.
+ kTfLiteActRelu, output_data);
+}
+
+TF_LITE_MICRO_TEST(SimpleTestQuantizedOutputMultiplierGreaterThan1) {
+ using tflite::testing::F2Q;
+ using tflite::testing::F2Q32;
+
+ const float input_min = -127.0f;
+ const float input_max = 128.0f;
+ const float weights_min = -127.0f;
+ const float weights_max = 128.0f;
+ const float bias_min = 0.0f;
+ const float bias_max = 256.0f * (1 << 24);
+ const float output_min = -63.5f;
+ const float output_max = 64.0f;
+ const int output_dims_count = 6;
+ uint8_t output_data[output_dims_count];
+ tflite::testing::TestFullyConnectedQuantized( //
+ {2, 2, 10}, // Input shape.
+ {
+ // Input values.
+ F2Q(1, input_min, input_max), F2Q(2, input_min, input_max),
+ F2Q(3, input_min, input_max), F2Q(4, input_min, input_max),
+ F2Q(5, input_min, input_max), F2Q(6, input_min, input_max),
+ F2Q(7, input_min, input_max), F2Q(8, input_min, input_max),
+ F2Q(-9, input_min, input_max), F2Q(-10, input_min, input_max),
+ F2Q(1, input_min, input_max), F2Q(2, input_min, input_max),
+ F2Q(3, input_min, input_max), F2Q(4, input_min, input_max),
+ F2Q(5, input_min, input_max), F2Q(6, input_min, input_max),
+ F2Q(7, input_min, input_max), F2Q(-8, input_min, input_max),
+ F2Q(9, input_min, input_max), F2Q(-10, input_min, input_max),
+ },
+ input_min, input_max, // Input quantization range.
+ {2, 3, 10}, // Weights shape.
+ {
+ // Weight values.
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ },
+ weights_min, weights_max, // Weights quantization range.
+ {1, 3}, // Bias shape.
+ {
+ F2Q32(1, bias_min, bias_max),
+ F2Q32(2, bias_min, bias_max),
+ F2Q32(3, bias_min, bias_max),
+ },
+ bias_min, bias_max, // Bias quantization range.
+ {
+ // Expected results.
+ F2Q(24, output_min, output_max),
+ F2Q(25, output_min, output_max),
+ F2Q(26, output_min, output_max),
+ F2Q(58, output_min, output_max),
+ F2Q(59, output_min, output_max),
+ F2Q(60, output_min, output_max),
+ },
+ {2, 2, 3}, // Output shape.
+ output_min, output_max, // Output quantization range.
+ kTfLiteActNone, output_data);
+}
+
+TF_LITE_MICRO_TEST(SimpleTest4DInput) {
+ const int output_dims_count = 6;
+ float output_data[output_dims_count];
+ tflite::testing::TestFullyConnectedFloat( //
+ {4, 1, 1, 5, 1}, // Input shape.
+ {
+ 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
+ 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
+ },
+ {2, 3, 10}, // Weights shape.
+ {
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 2
+ },
+ {1, 3}, // Bias shape.
+ {
+ 1, 2, 3, // Bias values.
+ },
+ {
+ 24, 25, 26, 58, 59, 60, // Expected results.
+ },
+ {2, 2, 3}, // Output shape.
+ kTfLiteActNone, output_data);
+}
+
+TF_LITE_MICRO_TEST(SimpleTest4DInputQuantized) {
+ using tflite::testing::F2Q;
+ using tflite::testing::F2Q32;
+
+ const float input_min = -63.5f;
+ const float input_max = 64.0f;
+ const float weights_min = -63.5f;
+ const float weights_max = 64.0f;
+ const float bias_min = 0.0f;
+ const float bias_max = 64.0f * (1 << 24);
+ const float output_min = -127.0f;
+ const float output_max = 128.0f;
+ const int output_dims_count = 6;
+ uint8_t output_data[output_dims_count];
+ tflite::testing::TestFullyConnectedQuantized( //
+ {4, 1, 1, 5, 1}, // Input shape.
+ {
+ // Input values.
+ F2Q(1, input_min, input_max), F2Q(2, input_min, input_max),
+ F2Q(3, input_min, input_max), F2Q(4, input_min, input_max),
+ F2Q(5, input_min, input_max), F2Q(6, input_min, input_max),
+ F2Q(7, input_min, input_max), F2Q(8, input_min, input_max),
+ F2Q(-9, input_min, input_max), F2Q(-10, input_min, input_max),
+ F2Q(1, input_min, input_max), F2Q(2, input_min, input_max),
+ F2Q(3, input_min, input_max), F2Q(4, input_min, input_max),
+ F2Q(5, input_min, input_max), F2Q(6, input_min, input_max),
+ F2Q(7, input_min, input_max), F2Q(-8, input_min, input_max),
+ F2Q(9, input_min, input_max), F2Q(-10, input_min, input_max),
+ },
+ input_min, input_max, // Input quantization range.
+ {2, 3, 10}, // Weights shape.
+ {
+ // Weight values.
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ },
+ weights_min, weights_max, // Weights quantization range.
+ {1, 3}, // Bias shape.
+ {
+ F2Q32(1, bias_min, bias_max),
+ F2Q32(2, bias_min, bias_max),
+ F2Q32(3, bias_min, bias_max),
+ },
+ bias_min, bias_max, // Bias quantization range.
+ {
+ // Expected results.
+ F2Q(24, output_min, output_max),
+ F2Q(25, output_min, output_max),
+ F2Q(26, output_min, output_max),
+ F2Q(58, output_min, output_max),
+ F2Q(59, output_min, output_max),
+ F2Q(60, output_min, output_max),
+ },
+ {2, 2, 3}, // Output shape.
+ output_min, output_max, // Output quantization range.
+ kTfLiteActNone, output_data);
+}
+
+TF_LITE_MICRO_TEST(SimpleTest4DInputQuantizedOutputMultiplierGreaterThan1) {
+ using tflite::testing::F2Q;
+ using tflite::testing::F2Q32;
+
+ const float input_min = -127.0f;
+ const float input_max = 128.0f;
+ const float weights_min = -127.0f;
+ const float weights_max = 128.0f;
+ const float bias_min = 0.0f;
+ const float bias_max = 256.0f * (1 << 24);
+ const float output_min = -63.5f;
+ const float output_max = 64.0f;
+ const int output_dims_count = 6;
+ uint8_t output_data[output_dims_count];
+ tflite::testing::TestFullyConnectedQuantized( //
+ {4, 1, 1, 5, 1}, // Input shape.
+ {
+ // Input values.
+ F2Q(1, input_min, input_max), F2Q(2, input_min, input_max),
+ F2Q(3, input_min, input_max), F2Q(4, input_min, input_max),
+ F2Q(5, input_min, input_max), F2Q(6, input_min, input_max),
+ F2Q(7, input_min, input_max), F2Q(8, input_min, input_max),
+ F2Q(-9, input_min, input_max), F2Q(-10, input_min, input_max),
+ F2Q(1, input_min, input_max), F2Q(2, input_min, input_max),
+ F2Q(3, input_min, input_max), F2Q(4, input_min, input_max),
+ F2Q(5, input_min, input_max), F2Q(6, input_min, input_max),
+ F2Q(7, input_min, input_max), F2Q(-8, input_min, input_max),
+ F2Q(9, input_min, input_max), F2Q(-10, input_min, input_max),
+ },
+ input_min, input_max, // Input quantization range.
+ {2, 3, 10}, // Weights shape.
+ {
+ // Weight values.
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ F2Q(1, weights_min, weights_max), F2Q(2, weights_min, weights_max),
+ F2Q(3, weights_min, weights_max), F2Q(4, weights_min, weights_max),
+ F2Q(5, weights_min, weights_max), F2Q(6, weights_min, weights_max),
+ F2Q(7, weights_min, weights_max), F2Q(8, weights_min, weights_max),
+ F2Q(9, weights_min, weights_max), F2Q(10, weights_min, weights_max),
+ },
+ weights_min, weights_max, // Weights quantization range.
+ {1, 3}, // Bias shape.
+ {
+ F2Q32(1, bias_min, bias_max),
+ F2Q32(2, bias_min, bias_max),
+ F2Q32(3, bias_min, bias_max),
+ },
+ bias_min, bias_max, // Bias quantization range.
+ {
+ // Expected results.
+ F2Q(24, output_min, output_max),
+ F2Q(25, output_min, output_max),
+ F2Q(26, output_min, output_max),
+ F2Q(58, output_min, output_max),
+ F2Q(59, output_min, output_max),
+ F2Q(60, output_min, output_max),
+ },
+ {2, 2, 3}, // Output shape.
+ output_min, output_max, // Output quantization range.
+ kTfLiteActNone, output_data);
+}
+
+TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/softmax.cc b/tensorflow/contrib/lite/experimental/micro/kernels/softmax.cc
new file mode 100644
index 0000000000..a4019a067c
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/kernels/softmax.cc
@@ -0,0 +1,213 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/kernels/internal/reference/softmax.h"
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace micro {
+namespace activations {
+namespace {
+
+struct OpData {
+ int32_t input_multiplier = 0;
+ int input_left_shift = 0;
+ int32_t input_range_radius = 0;
+ int diff_min = 0;
+};
+
+TfLiteStatus CalculateSoftmaxOpData(TfLiteContext* context,
+ const TfLiteTensor* input,
+ TfLiteTensor* output,
+ const TfLiteSoftmaxParams* params,
+ OpData* data) {
+ if (input->type == kTfLiteUInt8) {
+ TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
+ TF_LITE_ENSURE(context, output->params.scale == 1. / 256);
+
+ static const int kScaledDiffIntegerBits = 5;
+
+ tflite::PreprocessSoftmaxScaling(
+ params->beta, input->params.scale, kScaledDiffIntegerBits,
+ &data->input_multiplier, &data->input_left_shift);
+ data->diff_min = -1.0 * tflite::CalculateInputRadius(
+ kScaledDiffIntegerBits, data->input_left_shift);
+ }
+ return kTfLiteOk;
+}
+
+} // namespace
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ return nullptr;
+}
+
+void Free(TfLiteContext* context, void* buffer) {}
+
+TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
+ return kTfLiteOk;
+}
+
+// Takes a 1D tensor and performs softmax along it.
+void Softmax1DFloat(const TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params) {
+ const int input_size = input->dims->data[0];
+ tflite::reference_ops::Softmax(input->data.f, input_size, 1, params->beta,
+ output->data.f);
+}
+
+// Takes a 2D tensor and perform softmax along the last dimension.
+void Softmax2DFloat(const TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params) {
+ const int batch_size = input->dims->data[0];
+ const int input_size = input->dims->data[1];
+ tflite::reference_ops::Softmax(input->data.f, input_size, batch_size,
+ params->beta, output->data.f);
+}
+
+void Softmax1DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params, OpData* data) {
+ // TODO(ahentz): this is arguably a dirty trick. Since the implementation
+ // always traverses the last dimension of a 4D tensor, we will pretend our 1D
+ // tensor is 4D in a special way. We will convert a (Y) shape into a (1,
+ // 1, 1, Y) shape.
+ const int input_size = input->dims->data[0];
+ const int32_t shape_data[4] = {1, 1, 1, input_size};
+ RuntimeShape shape(4, shape_data);
+ SoftmaxParams op_params;
+ op_params.input_multiplier = data->input_multiplier;
+ op_params.input_left_shift = data->input_left_shift;
+ op_params.diff_min = data->diff_min;
+ tflite::reference_ops::Softmax(op_params, shape,
+ GetTensorData<uint8_t>(input), shape,
+ GetTensorData<uint8_t>(output));
+}
+
+void Softmax2DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params, OpData* data) {
+ // TODO(ahentz): this is arguably a dirty trick. Since the implementation
+ // always traverses the last dimension of a 4D tensor, we will pretend our 2D
+ // tensor is 4D in a special way. We will convert a (X, Y) shape into a (X,
+ // 1, 1, Y) shape.
+ const int batch_size = input->dims->data[0];
+ const int input_size = input->dims->data[1];
+ const int32_t shape_data[4] = {batch_size, 1, 1, input_size};
+ RuntimeShape shape(4, shape_data);
+ SoftmaxParams op_params;
+ op_params.input_multiplier = data->input_multiplier;
+ op_params.input_left_shift = data->input_left_shift;
+ op_params.diff_min = data->diff_min;
+ tflite::reference_ops::Softmax(op_params, shape,
+ GetTensorData<uint8_t>(input), shape,
+ GetTensorData<uint8_t>(output));
+}
+
+// Takes a 4D tensor and perform softmax along the forth dimension.
+void Softmax4DFloat(const TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params) {
+ SoftmaxParams op_params;
+ op_params.beta = params->beta;
+ tflite::reference_ops::Softmax(
+ op_params, GetTensorShape(input), GetTensorData<float>(input),
+ GetTensorShape(output), GetTensorData<float>(output));
+}
+
+void Softmax4DQuantized(const TfLiteTensor* input, TfLiteTensor* output,
+ TfLiteSoftmaxParams* params, OpData* data) {
+ SoftmaxParams op_params;
+ op_params.input_multiplier = data->input_multiplier;
+ op_params.input_left_shift = data->input_left_shift;
+ op_params.diff_min = data->diff_min;
+ tflite::reference_ops::Softmax(
+ op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
+ GetTensorShape(output), GetTensorData<uint8_t>(output));
+}
+
+TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
+
+ const TfLiteTensor* input = GetInput(context, node, 0);
+ TfLiteTensor* output = GetOutput(context, node, 0);
+
+ OpData local_data_object;
+ OpData* data = &local_data_object;
+ TF_LITE_ENSURE_STATUS(
+ CalculateSoftmaxOpData(context, input, output, params, data));
+
+ // TODO(ahentz): consider an implementation that works for many (all?)
+ // dimensions.
+ switch (input->type) {
+ case kTfLiteFloat32: {
+ if (NumDimensions(input) == 1) {
+ Softmax1DFloat(input, output, params);
+ return kTfLiteOk;
+ }
+ if (NumDimensions(input) == 2) {
+ Softmax2DFloat(input, output, params);
+ return kTfLiteOk;
+ }
+ if (NumDimensions(input) == 4) {
+ Softmax4DFloat(input, output, params);
+ return kTfLiteOk;
+ }
+ context->ReportError(
+ context, "Only 1D, 2D and 4D tensors supported currently, got %dD.",
+ NumDimensions(input));
+ return kTfLiteError;
+ }
+ case kTfLiteUInt8: {
+ if (NumDimensions(input) == 1) {
+ Softmax1DQuantized(input, output, params, data);
+ return kTfLiteOk;
+ }
+ if (NumDimensions(input) == 2) {
+ Softmax2DQuantized(input, output, params, data);
+ return kTfLiteOk;
+ }
+ if (NumDimensions(input) == 4) {
+ Softmax4DQuantized(input, output, params, data);
+ return kTfLiteOk;
+ }
+ context->ReportError(
+ context, "Only 2D and 4D tensors supported currently, got %dD.",
+ NumDimensions(input));
+ return kTfLiteError;
+ }
+ default:
+ context->ReportError(
+ context, "Only float32 and uint8_t supported currently, got %d.",
+ input->type);
+ return kTfLiteError;
+ }
+}
+} // namespace activations
+
+TfLiteRegistration* Register_SOFTMAX() {
+ static TfLiteRegistration r = {activations::Init, activations::Free,
+ activations::SoftmaxPrepare,
+ activations::SoftmaxEval};
+ return &r;
+}
+
+} // namespace micro
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/softmax_test.cc b/tensorflow/contrib/lite/experimental/micro/kernels/softmax_test.cc
new file mode 100644
index 0000000000..694456d8ac
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/kernels/softmax_test.cc
@@ -0,0 +1,220 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/experimental/micro/kernels/all_ops_resolver.h"
+#include "tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h"
+#include "tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h"
+#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h"
+
+namespace tflite {
+namespace testing {
+namespace {
+
+void TestSoftmaxFloat(std::initializer_list<int> input_dims_data,
+ std::initializer_list<float> input_data,
+ std::initializer_list<float> expected_output_data,
+ std::initializer_list<int> output_dims_data,
+ float* output_data) {
+ TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
+ TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
+ const int output_dims_count = ElementCount(*output_dims);
+
+ constexpr int inputs_size = 2;
+ constexpr int outputs_size = 1;
+ constexpr int tensors_size = inputs_size + outputs_size;
+ TfLiteTensor tensors[tensors_size] = {
+ CreateFloatTensor(input_data, input_dims, "input_tensor"),
+ CreateFloatTensor(output_data, output_dims, "output_tensor"),
+ };
+
+ TfLiteContext context;
+ PopulateContext(tensors, tensors_size, &context);
+
+ ::tflite::ops::micro::AllOpsResolver resolver;
+ const TfLiteRegistration* registration =
+ resolver.FindOp(tflite::BuiltinOperator_SOFTMAX, 1);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
+
+ TfLiteSoftmaxParams builtin_data = {1.0f};
+ const char* init_data = reinterpret_cast<const char*>(&builtin_data);
+ size_t init_data_size = 0;
+ void* user_data = nullptr;
+ if (registration->init) {
+ user_data = registration->init(&context, init_data, init_data_size);
+ }
+ int inputs_array_data[] = {1, 0};
+ TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
+ int outputs_array_data[] = {1, 1};
+ TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
+ int temporaries_array_data[] = {0};
+ TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
+
+ TfLiteNode node;
+ node.inputs = inputs_array;
+ node.outputs = outputs_array;
+ node.temporaries = temporaries_array;
+ node.user_data = user_data;
+ node.builtin_data = reinterpret_cast<void*>(&builtin_data);
+ node.custom_initial_data = nullptr;
+ node.custom_initial_data_size = 0;
+ node.delegate = nullptr;
+ if (registration->prepare) {
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
+ }
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
+ if (registration->free) {
+ registration->free(&context, user_data);
+ }
+ for (int i = 0; i < output_dims_count; ++i) {
+ TF_LITE_MICRO_EXPECT_NEAR(expected_output_data.begin()[i], output_data[i],
+ 1e-5f);
+ }
+}
+
+void TestSoftmaxQuantized(std::initializer_list<int> input_dims_data,
+ std::initializer_list<uint8_t> input_data,
+ float input_min, float input_max,
+ std::initializer_list<uint8_t> expected_output_data,
+ std::initializer_list<int> output_dims_data,
+ float output_min, float output_max,
+ uint8_t* output_data) {
+ TfLiteIntArray* input_dims = IntArrayFromInitializer(input_dims_data);
+ TfLiteIntArray* output_dims = IntArrayFromInitializer(output_dims_data);
+ const int output_dims_count = ElementCount(*output_dims);
+
+ constexpr int inputs_size = 1;
+ constexpr int outputs_size = 1;
+ constexpr int tensors_size = inputs_size + outputs_size;
+ TfLiteTensor tensors[tensors_size] = {
+ CreateQuantizedTensor(input_data, input_dims, "input_tensor", input_min,
+ input_max),
+ CreateQuantizedTensor(output_data, output_dims, "output_tensor",
+ output_min, output_max),
+ };
+
+ TfLiteContext context;
+ PopulateContext(tensors, tensors_size, &context);
+
+ ::tflite::ops::micro::AllOpsResolver resolver;
+ const TfLiteRegistration* registration =
+ resolver.FindOp(tflite::BuiltinOperator_SOFTMAX, 1);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
+
+ TfLiteSoftmaxParams builtin_data = {1.0f};
+ const char* init_data = reinterpret_cast<const char*>(&builtin_data);
+ size_t init_data_size = 0;
+ void* user_data = nullptr;
+ if (registration->init) {
+ user_data = registration->init(&context, init_data, init_data_size);
+ }
+
+ int inputs_array_data[] = {1, 0};
+ TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
+ int outputs_array_data[] = {1, 1};
+ TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
+ int temporaries_array_data[] = {0};
+ TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
+
+ TfLiteNode node;
+ node.inputs = inputs_array;
+ node.outputs = outputs_array;
+ node.temporaries = temporaries_array;
+ node.user_data = user_data;
+ node.builtin_data = reinterpret_cast<void*>(&builtin_data);
+ node.custom_initial_data = nullptr;
+ node.custom_initial_data_size = 0;
+ node.delegate = nullptr;
+
+ if (registration->prepare) {
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
+ }
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
+ if (registration->free) {
+ registration->free(&context, user_data);
+ }
+ for (int i = 0; i < output_dims_count; ++i) {
+ TF_LITE_MICRO_EXPECT_EQ(expected_output_data.begin()[i], output_data[i]);
+ }
+}
+
+} // namespace
+} // namespace testing
+} // namespace tflite
+
+TF_LITE_MICRO_TESTS_BEGIN
+
+TF_LITE_MICRO_TEST(SimpleTest) {
+ const int output_dims_count = 10;
+ float output_data[output_dims_count];
+ tflite::testing::TestSoftmaxFloat( //
+ {2, 2, 5}, // Input shape.
+ {
+ 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0
+ -1.0, -2.0, -3.0, -4.0, -5.0, // b = 0
+ },
+ {
+ // Expected results.
+ 0.011656231,
+ 0.031684921,
+ 0.086128544,
+ 0.234121657,
+ 0.636408647,
+ 0.636408647,
+ 0.234121657,
+ 0.086128544,
+ 0.031684921,
+ 0.011656231,
+ },
+ {2, 2, 5}, // Output shape.
+ output_data);
+}
+
+TF_LITE_MICRO_TEST(SimpleTestQuantized) {
+ using tflite::testing::F2Q;
+
+ const float input_min = -63.5f;
+ const float input_max = 64.0f;
+ const float output_min = 0.0f;
+ const float output_max = (255.0f / 256.0f);
+ const int output_dims_count = 5;
+ uint8_t output_data[output_dims_count];
+ tflite::testing::TestSoftmaxQuantized( //
+ {2, 1, 5}, // Input shape.
+ {
+ F2Q(1.0, input_min, input_max),
+ F2Q(2.0, input_min, input_max),
+ F2Q(3.0, input_min, input_max),
+ F2Q(4.0, input_min, input_max),
+ F2Q(5.0, input_min, input_max),
+ },
+ input_min, input_max, // Input quantized range.
+ {
+ // Expected results.
+ F2Q(0.011656231, output_min, output_max),
+ F2Q(0.031684921, output_min, output_max),
+ F2Q(0.086128544, output_min, output_max),
+ F2Q(0.234121657, output_min, output_max),
+ F2Q(0.636408647, output_min, output_max),
+ },
+ {2, 1, 5}, // Output shape.
+ output_min, output_max, // Output quantized range.
+ output_data);
+}
+
+TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h b/tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h
new file mode 100644
index 0000000000..789a48ece8
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h
@@ -0,0 +1,170 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_
+
+#include <cstdarg>
+#include <initializer_list>
+#include <limits>
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/experimental/micro/kernels/test_utils.h"
+#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h"
+
+namespace tflite {
+namespace testing {
+
+// How many elements are in the array with this shape.
+inline int ElementCount(const TfLiteIntArray& dims) {
+ int result = 1;
+ for (int i = 0; i < dims.size; ++i) {
+ result *= dims.data[i];
+ }
+ return result;
+}
+
+// Wrapper to forward kernel errors to the interpreter's error reporter.
+inline void ReportOpError(struct TfLiteContext* context, const char* format,
+ ...) {
+ ErrorReporter* error_reporter = static_cast<ErrorReporter*>(context->impl_);
+ va_list args;
+ va_start(args, format);
+ error_reporter->Report(format, args);
+ va_end(args);
+}
+
+// Derives the quantization scaling factor from a min and max range.
+template <typename T>
+inline float ScaleFromMinMax(const float min, const float max) {
+ return (max - min) / ((std::numeric_limits<T>::max() * 1.0) -
+ std::numeric_limits<T>::min());
+}
+
+// Derives the quantization zero point from a min and max range.
+template <typename T>
+inline int ZeroPointFromMinMax(const float min, const float max) {
+ return static_cast<int>((-min / ScaleFromMinMax<T>(min, max)) + 0.5f);
+}
+
+// Converts a float value into an unsigned eight-bit quantized value.
+inline uint8_t F2Q(const float value, const float min, const float max) {
+ int32_t result = ZeroPointFromMinMax<uint8_t>(min, max) +
+ (value / ScaleFromMinMax<uint8_t>(min, max)) + 0.5f;
+ if (result < 0) {
+ result = 0;
+ }
+ if (result > 256) {
+ result = 256;
+ }
+ return result;
+}
+
+// Converts a float value into a signed thirty-two-bit quantized value.
+inline uint8_t F2Q32(const float value, const float min, const float max) {
+ return static_cast<int32_t>((value - ZeroPointFromMinMax<int32_t>(min, max)) /
+ ScaleFromMinMax<int32_t>(min, max));
+}
+
+inline void PopulateContext(TfLiteTensor* tensors, int tensors_size,
+ TfLiteContext* context) {
+ context->tensors_size = tensors_size;
+ context->tensors = tensors;
+ context->impl_ = static_cast<void*>(micro_test::reporter);
+ context->GetExecutionPlan = nullptr;
+ context->ResizeTensor = nullptr;
+ context->ReportError = ReportOpError;
+ context->AddTensors = nullptr;
+ context->GetNodeAndRegistration = nullptr;
+ context->ReplaceSubgraphsWithDelegateKernels = nullptr;
+ context->recommended_num_threads = 1;
+ context->GetExternalContext = nullptr;
+ context->SetExternalContext = nullptr;
+}
+
+inline TfLiteIntArray* IntArrayFromInts(const int* int_array) {
+ return const_cast<TfLiteIntArray*>(
+ reinterpret_cast<const TfLiteIntArray*>(int_array));
+}
+
+inline TfLiteIntArray* IntArrayFromInitializer(
+ std::initializer_list<int> int_initializer) {
+ return IntArrayFromInts(int_initializer.begin());
+}
+
+inline TfLiteTensor CreateFloatTensor(const float* data, TfLiteIntArray* dims,
+ const char* name) {
+ const size_t bytes = ElementCount(*dims) * sizeof(float);
+ return {
+ kTfLiteFloat32, {const_cast<int*>(reinterpret_cast<const int*>(data))},
+ dims, {},
+ kTfLiteMemNone, bytes,
+ nullptr, name};
+}
+
+inline TfLiteTensor CreateFloatTensor(std::initializer_list<float> data,
+ TfLiteIntArray* dims, const char* name) {
+ return CreateFloatTensor(data.begin(), dims, name);
+}
+
+inline TfLiteTensor CreateQuantizedTensor(const uint8_t* data,
+ TfLiteIntArray* dims,
+ const char* name, float min,
+ float max) {
+ const size_t bytes = ElementCount(*dims) * sizeof(uint8_t);
+ const TfLiteQuantizationParams q_params = {
+ ScaleFromMinMax<uint8_t>(min, max),
+ ZeroPointFromMinMax<uint8_t>(min, max)};
+ return {
+ kTfLiteUInt8, {const_cast<int*>(reinterpret_cast<const int*>(data))},
+ dims, q_params,
+ kTfLiteMemNone, bytes,
+ nullptr, name};
+}
+
+inline TfLiteTensor CreateQuantizedTensor(std::initializer_list<uint8_t> data,
+ TfLiteIntArray* dims,
+ const char* name, float min,
+ float max) {
+ return CreateQuantizedTensor(data.begin(), dims, name, min, max);
+}
+
+inline TfLiteTensor CreateQuantized32Tensor(const int32_t* data,
+ TfLiteIntArray* dims,
+ const char* name, float min,
+ float max) {
+ const size_t bytes = ElementCount(*dims) * sizeof(int32_t);
+ const TfLiteQuantizationParams q_params = {
+ ScaleFromMinMax<int32_t>(min, max),
+ ZeroPointFromMinMax<int32_t>(min, max)};
+ return {
+ kTfLiteUInt8, {const_cast<int*>(reinterpret_cast<const int*>(data))},
+ dims, q_params,
+ kTfLiteMemNone, bytes,
+ nullptr, name};
+}
+
+inline TfLiteTensor CreateQuantized32Tensor(std::initializer_list<int32_t> data,
+ TfLiteIntArray* dims,
+ const char* name, float min,
+ float max) {
+ return CreateQuantized32Tensor(data.begin(), dims, name, min, max);
+}
+
+} // namespace testing
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_KERNELS_TEST_UTILS_H_
diff --git a/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.cc b/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.cc
new file mode 100644
index 0000000000..99dd883661
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.cc
@@ -0,0 +1,78 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h"
+
+#ifdef TF_LITE_MCU_DEBUG_LOG
+#include <debug_log.h>
+#else // TF_LITE_MCU_DEBUG_LOG
+#include <cstdint>
+#include <cstdio>
+void DebugLog(const char* s) { fprintf(stderr, "%s", s); }
+void DebugLogInt32(int32_t i) { fprintf(stderr, "%d", i); }
+void DebugLogUInt32(uint32_t i) { fprintf(stderr, "%d", i); }
+void DebugLogHex(uint32_t i) { fprintf(stderr, "0x%8x", i); }
+void DebugLogFloat(float i) { fprintf(stderr, "%f", i); }
+#endif // TF_LITE_MCU_DEBUG_LOG
+
+namespace tflite {
+namespace {
+void DebugLogPrintf(const char* format, va_list args) {
+ const int output_cache_size = 64;
+ char output_cache[output_cache_size + 1];
+ int output_cache_index = 0;
+ const char* current = format;
+ while (*current != 0) {
+ if (*current == '%') {
+ const char next = *(current + 1);
+ if ((next == 'd') || (next == 's')) {
+ current += 1;
+ if (output_cache_index > 0) {
+ output_cache[output_cache_index] = 0;
+ DebugLog(output_cache);
+ output_cache_index = 0;
+ }
+ if (next == 'd') {
+ DebugLogInt32(va_arg(args, int));
+ } else if (next == 's') {
+ DebugLog(va_arg(args, char*));
+ }
+ }
+ } else {
+ output_cache[output_cache_index] = *current;
+ output_cache_index += 1;
+ }
+ if (output_cache_index >= output_cache_size) {
+ output_cache[output_cache_index] = 0;
+ DebugLog(output_cache);
+ output_cache_index = 0;
+ }
+ current += 1;
+ }
+ if (output_cache_index > 0) {
+ output_cache[output_cache_index] = 0;
+ DebugLog(output_cache);
+ output_cache_index = 0;
+ }
+ DebugLog("\n");
+}
+} // namespace
+
+int MicroErrorReporter::Report(const char* format, va_list args) {
+ DebugLogPrintf(format, args);
+ return 0;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h b/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h
new file mode 100644
index 0000000000..33e54f7990
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h
@@ -0,0 +1,34 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_ERROR_REPORTER_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_ERROR_REPORTER_H_
+
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/experimental/micro/compatibility.h"
+
+namespace tflite {
+
+class MicroErrorReporter : public ErrorReporter {
+ public:
+ ~MicroErrorReporter() {}
+ int Report(const char* format, va_list args) override;
+
+ private:
+ TF_LITE_REMOVE_VIRTUAL_DELETE
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_ERROR_REPORTER_H_
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.h b/tensorflow/contrib/lite/experimental/micro/micro_error_reporter_test.cc
index 498d4a9495..ef3c32050c 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_options.h
+++ b/tensorflow/contrib/lite/experimental/micro/micro_error_reporter_test.cc
@@ -13,21 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_
-
-#include "tensorflow/compiler/xla/service/hlo_module_config.h"
-
-// Helper functions for querying options that are specific to the GPU backend.
-
-namespace xla {
-namespace gpu {
-
-// Returns true if we should use heuristics to assign convolution layouts, as
-// opposed to always assigning NCHW.
-bool ConvUseLayoutHeuristic(const HloModuleConfig& config);
-
-} // namespace gpu
-} // namespace xla
-
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_OPTIONS_H_
+#include "tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h"
+
+int main(int argc, char** argv) {
+ tflite::MicroErrorReporter micro_error_reporter;
+ tflite::ErrorReporter* error_reporter = &micro_error_reporter;
+ error_reporter->Report("Number: %d", 42);
+ error_reporter->Report("Badly-formed format string %");
+ error_reporter->Report("Another % badly-formed %% format string");
+ error_reporter->Report("~~~%s~~~", "ALL TESTS PASSED");
+}
diff --git a/tensorflow/contrib/lite/experimental/micro/micro_interpreter.cc b/tensorflow/contrib/lite/experimental/micro/micro_interpreter.cc
new file mode 100644
index 0000000000..0f38991bb0
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/micro_interpreter.cc
@@ -0,0 +1,310 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/experimental/micro/micro_interpreter.h"
+
+#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h"
+#include "tensorflow/contrib/lite/experimental/micro/compatibility.h"
+
+namespace tflite {
+namespace {
+const int kStackDataAllocatorSize = 128;
+class StackDataAllocator : public BuiltinDataAllocator {
+ public:
+ void* Allocate(size_t size) override {
+ if (size > kStackDataAllocatorSize) {
+ return nullptr;
+ } else {
+ return data_;
+ }
+ }
+ void Deallocate(void* data) override {
+ // Do nothing.
+ }
+
+ private:
+ uint8_t data_[kStackDataAllocatorSize];
+
+ TF_LITE_REMOVE_VIRTUAL_DELETE
+};
+
+const char* OpNameFromRegistration(const TfLiteRegistration* registration) {
+ if (registration->builtin_code == BuiltinOperator_CUSTOM) {
+ return registration->custom_name;
+ } else {
+ return EnumNameBuiltinOperator(BuiltinOperator(registration->builtin_code));
+ }
+}
+
+void ReportOpError(struct TfLiteContext* context, const char* format, ...) {
+ MicroInterpreter* interpreter =
+ static_cast<MicroInterpreter*>(context->impl_);
+ va_list args;
+ va_start(args, format);
+ interpreter->error_reporter()->Report(format, args);
+ va_end(args);
+}
+
+} // namespace
+
+MicroInterpreter::MicroInterpreter(const Model* model,
+ const OpResolver& op_resolver,
+ SimpleTensorAllocator* tensor_allocator,
+ ErrorReporter* error_reporter)
+ : model_(model),
+ op_resolver_(op_resolver),
+ tensor_allocator_(tensor_allocator),
+ error_reporter_(error_reporter),
+ initialization_status_(kTfLiteOk) {
+ const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers =
+ model->buffers();
+ auto* subgraphs = model->subgraphs();
+ if (subgraphs->size() != 1) {
+ error_reporter->Report("Only 1 subgraph is currently supported.\n");
+ initialization_status_ = kTfLiteError;
+ return;
+ }
+ subgraph_ = (*subgraphs)[0];
+ tensors_ = subgraph_->tensors();
+ operators_ = subgraph_->operators();
+
+ context_.tensors_size = tensors_->Length();
+ context_.tensors =
+ reinterpret_cast<TfLiteTensor*>(tensor_allocator_->AllocateMemory(
+ sizeof(TfLiteTensor) * context_.tensors_size));
+ for (int i = 0; i < subgraph_->inputs()->Length(); ++i) {
+ const int tensor_index = subgraph_->inputs()->Get(i);
+ const auto* tensor = tensors_->Get(tensor_index);
+ initialization_status_ = tensor_allocator_->AllocateTensor(
+ *tensor, 0, operators_->Length(), buffers, error_reporter,
+ &context_.tensors[tensor_index]);
+ if (initialization_status_ != kTfLiteOk) {
+ return;
+ }
+ }
+
+ int* first_created = reinterpret_cast<int*>(
+ tensor_allocator_->AllocateMemory(sizeof(int) * tensors_->Length()));
+ int* last_used = reinterpret_cast<int*>(
+ tensor_allocator_->AllocateMemory(sizeof(int) * tensors_->Length()));
+ for (int i = 0; i < tensors_->Length(); ++i) {
+ first_created[i] = -1;
+ last_used[i] = -1;
+ }
+
+ for (int i = (operators_->Length() - 1); i >= 0; --i) {
+ const auto* op = operators_->Get(i);
+ for (int n = 0; n < op->inputs()->Length(); ++n) {
+ const int tensor_index = op->inputs()->Get(n);
+ if ((last_used[tensor_index] == -1) || (last_used[tensor_index] < i)) {
+ last_used[tensor_index] = i;
+ }
+ }
+ for (int n = 0; n < op->outputs()->Length(); ++n) {
+ const int tensor_index = op->outputs()->Get(n);
+ const int create_before = i;
+ int destroy_after = last_used[tensor_index];
+ if (destroy_after == -1) {
+ destroy_after = operators_->Length();
+ }
+ const auto* tensor = tensors_->Get(tensor_index);
+ if (!tensor->is_variable()) {
+ initialization_status_ = tensor_allocator_->AllocateTensor(
+ *tensor, create_before, destroy_after, buffers, error_reporter,
+ &context_.tensors[tensor_index]);
+ if (initialization_status_ != kTfLiteOk) {
+ return;
+ }
+ first_created[tensor_index] = i;
+ }
+ }
+ }
+
+ for (int i = 0; i < tensors_->Length(); ++i) {
+ const auto* tensor = tensors_->Get(i);
+ const bool is_read_only = (first_created[i] == -1) && (last_used[i] != -1);
+ if (tensor->is_variable() || is_read_only) {
+ initialization_status_ = tensor_allocator_->AllocateTensor(
+ *tensor, 0, operators_->Length(), buffers, error_reporter,
+ &context_.tensors[i]);
+ if (initialization_status_ != kTfLiteOk) {
+ return;
+ }
+ }
+ }
+ context_.impl_ = static_cast<void*>(this);
+ context_.GetExecutionPlan = nullptr;
+ context_.ResizeTensor = nullptr;
+ context_.ReportError = ReportOpError;
+ context_.AddTensors = nullptr;
+ context_.GetNodeAndRegistration = nullptr;
+ context_.ReplaceSubgraphsWithDelegateKernels = nullptr;
+ context_.recommended_num_threads = 1;
+ context_.GetExternalContext = nullptr;
+ context_.SetExternalContext = nullptr;
+}
+
+TfLiteStatus MicroInterpreter::Invoke() {
+ if (initialization_status_ != kTfLiteOk) {
+ error_reporter_->Report("Invoke() called after initialization failed\n");
+ return kTfLiteError;
+ }
+ TfLiteStatus status = kTfLiteOk;
+ auto opcodes = model_->operator_codes();
+ for (int i = 0; i < operators_->Length(); ++i) {
+ const auto* op = operators_->Get(i);
+ int index = op->opcode_index();
+ if (index < 0 || index >= opcodes->size()) {
+ error_reporter_->Report("Missing registration for opcode_index %d\n",
+ index);
+ return kTfLiteError;
+ }
+ auto opcode = (*opcodes)[index];
+ const TfLiteRegistration* registration = nullptr;
+ status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_,
+ &registration);
+ if (status != kTfLiteOk) {
+ return status;
+ }
+ if (registration == nullptr) {
+ error_reporter_->Report("Skipping op for opcode_index %d\n", index);
+ return kTfLiteError;
+ }
+ BuiltinOperator op_type =
+ static_cast<BuiltinOperator>(registration->builtin_code);
+
+ if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) {
+ error_reporter_->Report(
+ "Found builtin operator %s with custom options.\n",
+ EnumNameBuiltinOperator(op_type));
+ }
+ StackDataAllocator stack_data_allocator;
+ const char* custom_data = nullptr;
+ size_t custom_data_size = 0;
+ unsigned char* builtin_data = nullptr;
+ if (op->custom_options()) {
+ custom_data = reinterpret_cast<const char*>(op->custom_options()->data());
+ custom_data_size = op->custom_options()->size();
+ } else {
+ TF_LITE_ENSURE_STATUS(ParseOpData(op, op_type, error_reporter_,
+ &stack_data_allocator,
+ (void**)(&builtin_data)));
+ }
+
+ const char* init_data;
+ size_t init_data_size;
+ if (registration->builtin_code == BuiltinOperator_CUSTOM) {
+ init_data = custom_data;
+ init_data_size = custom_data_size;
+ } else {
+ init_data = reinterpret_cast<const char*>(builtin_data);
+ init_data_size = 0;
+ }
+ void* user_data = nullptr;
+ if (registration->init) {
+ user_data = registration->init(&context_, init_data, init_data_size);
+ }
+
+ const int kMaxInputs = 16;
+ int inputs_data[kMaxInputs + 1];
+ TfLiteIntArray* inputs_array =
+ reinterpret_cast<TfLiteIntArray*>(inputs_data);
+ if (op->inputs()->Length() >= kMaxInputs) {
+ error_reporter_->Report("Too many inputs (%d)\n", op->inputs()->Length());
+ return kTfLiteError;
+ }
+ inputs_array->size = op->inputs()->Length();
+ for (int n = 0; n < op->inputs()->Length(); ++n) {
+ inputs_array->data[n] = op->inputs()->Get(n);
+ }
+
+ const int kMaxOutputs = 16;
+ int outputs_data[kMaxOutputs + 1];
+ TfLiteIntArray* outputs_array =
+ reinterpret_cast<TfLiteIntArray*>(outputs_data);
+ if (op->outputs()->Length() >= kMaxOutputs) {
+ error_reporter_->Report("Too many outputs (%d)\n",
+ op->outputs()->Length());
+ return kTfLiteError;
+ }
+ outputs_array->size = op->outputs()->Length();
+ for (int n = 0; n < op->outputs()->Length(); ++n) {
+ outputs_array->data[n] = op->outputs()->Get(n);
+ }
+
+ const int kMaxTemporaries = 16;
+ int temporaries_data[kMaxTemporaries + 1];
+ TfLiteIntArray* temporaries_array =
+ reinterpret_cast<TfLiteIntArray*>(temporaries_data);
+ temporaries_array->size = 0;
+
+ TfLiteNode node;
+ node.inputs = inputs_array;
+ node.outputs = outputs_array;
+ node.temporaries = temporaries_array;
+ node.user_data = user_data;
+ node.builtin_data = reinterpret_cast<void*>(builtin_data);
+ node.custom_initial_data = custom_data;
+ node.custom_initial_data_size = custom_data_size;
+ node.delegate = nullptr;
+ if (registration->prepare) {
+ TfLiteStatus prepare_status = registration->prepare(&context_, &node);
+ if (prepare_status != kTfLiteOk) {
+ error_reporter_->Report(
+ "Node %s (number %d) failed to prepare with status %d",
+ OpNameFromRegistration(registration), i, prepare_status);
+ return kTfLiteError;
+ }
+ }
+
+ if (registration->invoke) {
+ TfLiteStatus invoke_status = registration->invoke(&context_, &node);
+ if (invoke_status != kTfLiteOk) {
+ error_reporter_->Report(
+ "Node %s (number %d) failed to invoke with status %d",
+ OpNameFromRegistration(registration), i, invoke_status);
+ return kTfLiteError;
+ }
+ }
+
+ if (registration->free) {
+ registration->free(&context_, user_data);
+ }
+ }
+ return status;
+}
+
+TfLiteTensor* MicroInterpreter::input(int index) {
+ const flatbuffers::Vector<int32_t>* inputs = subgraph_->inputs();
+ const size_t length = inputs->Length();
+ if ((index < 0) || (index >= length)) {
+ error_reporter_->Report("Input index %d out of range (length is %d)", index,
+ length);
+ return nullptr;
+ }
+ return &(context_.tensors[inputs->Get(index)]);
+}
+
+TfLiteTensor* MicroInterpreter::output(int index) {
+ const flatbuffers::Vector<int32_t>* outputs = subgraph_->outputs();
+ const size_t length = outputs->Length();
+ if ((index < 0) || (index >= outputs->Length())) {
+ error_reporter_->Report("Output index %d out of range (length is %d)",
+ index, length);
+ return nullptr;
+ }
+ return &(context_.tensors[outputs->Get(index)]);
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/experimental/micro/micro_interpreter.h b/tensorflow/contrib/lite/experimental/micro/micro_interpreter.h
new file mode 100644
index 0000000000..a88514cde8
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/micro_interpreter.h
@@ -0,0 +1,71 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_INTERPRETER_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_INTERPRETER_H_
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+#include "tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+
+namespace tflite {
+
+class MicroInterpreter {
+ public:
+ // The lifetime of the model, op resolver, allocator, and error reporter must
+ // be at least as long as that of the interpreter object, since the
+ // interpreter may need to access them at any time. This means that you should
+ // usually create them with the same scope as each other, for example having
+ // them all allocated on the stack as local variables through a top-level
+ // function.
+ // The interpreter doesn't do any deallocation of any of the pointed-to
+ // objects, ownership remains with the caller.
+ MicroInterpreter(const Model* model, const OpResolver& op_resolver,
+ SimpleTensorAllocator* tensor_allocator,
+ ErrorReporter* error_reporter);
+
+ TfLiteStatus Invoke();
+
+ size_t tensors_size() const { return context_.tensors_size; }
+ TfLiteTensor* tensor(int tensor_index);
+
+ TfLiteTensor* input(int index);
+ size_t inputs_size() const { return subgraph_->inputs()->Length(); }
+
+ TfLiteTensor* output(int index);
+ size_t outputs_size() const { return subgraph_->outputs()->Length(); }
+
+ TfLiteStatus initialization_status() const { return initialization_status_; }
+
+ ErrorReporter* error_reporter() { return error_reporter_; }
+
+ private:
+ const Model* model_;
+ const OpResolver& op_resolver_;
+ SimpleTensorAllocator* tensor_allocator_;
+ ErrorReporter* error_reporter_;
+
+ TfLiteStatus initialization_status_;
+ const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors_;
+ const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators_;
+ TfLiteContext context_;
+
+ const SubGraph* subgraph_;
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_INTERPRETER_H_
diff --git a/tensorflow/contrib/lite/experimental/micro/micro_interpreter_test.cc b/tensorflow/contrib/lite/experimental/micro/micro_interpreter_test.cc
new file mode 100644
index 0000000000..251e5f7203
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/micro_interpreter_test.cc
@@ -0,0 +1,197 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/experimental/micro/micro_interpreter.h"
+
+#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h"
+
+namespace tflite {
+namespace {
+void* MockInit(TfLiteContext* context, const char* buffer, size_t length) {
+ // Do nothing.
+ return nullptr;
+}
+
+void MockFree(TfLiteContext* context, void* buffer) {
+ // Do nothing.
+}
+
+TfLiteStatus MockPrepare(TfLiteContext* context, TfLiteNode* node) {
+ return kTfLiteOk;
+}
+
+TfLiteStatus MockInvoke(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
+ const int32_t* input_data = input->data.i32;
+ const TfLiteTensor* weight = &context->tensors[node->inputs->data[1]];
+ const uint8_t* weight_data = weight->data.uint8;
+ TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
+ int32_t* output_data = output->data.i32;
+ output_data[0] = input_data[0] + weight_data[0];
+ return kTfLiteOk;
+}
+
+class MockOpResolver : public OpResolver {
+ public:
+ const TfLiteRegistration* FindOp(BuiltinOperator op,
+ int version) const override {
+ return nullptr;
+ }
+ const TfLiteRegistration* FindOp(const char* op, int version) const override {
+ if (strcmp(op, "mock_custom") == 0) {
+ static TfLiteRegistration r = {MockInit, MockFree, MockPrepare,
+ MockInvoke};
+ return &r;
+ } else {
+ return nullptr;
+ }
+ }
+};
+
+class StackAllocator : public flatbuffers::Allocator {
+ public:
+ StackAllocator() : data_(data_backing_), data_size_(0) {}
+
+ uint8_t* allocate(size_t size) override {
+ if ((data_size_ + size) > kStackAllocatorSize) {
+ // TODO(petewarden): Add error reporting beyond returning null!
+ return nullptr;
+ }
+ uint8_t* result = data_;
+ data_ += size;
+ data_size_ += size;
+ return result;
+ }
+
+ void deallocate(uint8_t* p, size_t) override {}
+
+ static StackAllocator& instance() {
+ // Avoid using true dynamic memory allocation to be portable to bare metal.
+ static char inst_memory[sizeof(StackAllocator)];
+ static StackAllocator* inst = new (inst_memory) StackAllocator;
+ return *inst;
+ }
+
+ static constexpr int kStackAllocatorSize = 4096;
+
+ private:
+ uint8_t data_backing_[kStackAllocatorSize];
+ uint8_t* data_;
+ int data_size_;
+};
+
+const Model* BuildMockModel() {
+ using flatbuffers::Offset;
+ flatbuffers::FlatBufferBuilder builder(StackAllocator::kStackAllocatorSize,
+ &StackAllocator::instance());
+ constexpr size_t buffer_data_size = 1;
+ const uint8_t buffer_data[buffer_data_size] = {21};
+ constexpr size_t buffers_size = 2;
+ const Offset<Buffer> buffers[buffers_size] = {
+ CreateBuffer(builder),
+ CreateBuffer(builder,
+ builder.CreateVector(buffer_data, buffer_data_size))};
+ constexpr size_t tensor_shape_size = 1;
+ const int32_t tensor_shape[tensor_shape_size] = {1};
+ constexpr size_t tensors_size = 3;
+ const Offset<Tensor> tensors[tensors_size] = {
+ CreateTensor(builder,
+ builder.CreateVector(tensor_shape, tensor_shape_size),
+ TensorType_INT32, 0,
+ builder.CreateString("test_input_tensor"), 0, false),
+ CreateTensor(builder,
+ builder.CreateVector(tensor_shape, tensor_shape_size),
+ TensorType_UINT8, 1,
+ builder.CreateString("test_weight_tensor"), 0, false),
+ CreateTensor(builder,
+ builder.CreateVector(tensor_shape, tensor_shape_size),
+ TensorType_INT32, 0,
+ builder.CreateString("test_output_tensor"), 0, false),
+ };
+ constexpr size_t inputs_size = 1;
+ const int32_t inputs[inputs_size] = {0};
+ constexpr size_t outputs_size = 1;
+ const int32_t outputs[outputs_size] = {2};
+ constexpr size_t operator_inputs_size = 2;
+ const int32_t operator_inputs[operator_inputs_size] = {0, 1};
+ constexpr size_t operator_outputs_size = 1;
+ const int32_t operator_outputs[operator_outputs_size] = {2};
+ constexpr size_t operators_size = 1;
+ const Offset<Operator> operators[operators_size] = {CreateOperator(
+ builder, 0, builder.CreateVector(operator_inputs, operator_inputs_size),
+ builder.CreateVector(operator_outputs, operator_outputs_size),
+ BuiltinOptions_NONE)};
+ constexpr size_t subgraphs_size = 1;
+ const Offset<SubGraph> subgraphs[subgraphs_size] = {
+ CreateSubGraph(builder, builder.CreateVector(tensors, tensors_size),
+ builder.CreateVector(inputs, inputs_size),
+ builder.CreateVector(outputs, outputs_size),
+ builder.CreateVector(operators, operators_size),
+ builder.CreateString("test_subgraph"))};
+ constexpr size_t operator_codes_size = 1;
+ const Offset<OperatorCode> operator_codes[operator_codes_size] = {
+ CreateOperatorCodeDirect(builder, BuiltinOperator_CUSTOM, "mock_custom",
+ 0)};
+ const Offset<Model> model_offset = CreateModel(
+ builder, 0, builder.CreateVector(operator_codes, operator_codes_size),
+ builder.CreateVector(subgraphs, subgraphs_size),
+ builder.CreateString("test_model"),
+ builder.CreateVector(buffers, buffers_size));
+ FinishModelBuffer(builder, model_offset);
+ void* model_pointer = builder.GetBufferPointer();
+ const Model* model = flatbuffers::GetRoot<Model>(model_pointer);
+ return model;
+}
+
+} // namespace
+} // namespace tflite
+
+TF_LITE_MICRO_TESTS_BEGIN
+
+TF_LITE_MICRO_TEST(TestInterpreter) {
+ const tflite::Model* model = tflite::BuildMockModel();
+ TF_LITE_MICRO_EXPECT_NE(nullptr, model);
+ tflite::MockOpResolver mock_resolver;
+ constexpr size_t allocator_buffer_size = 1024;
+ uint8_t allocator_buffer[allocator_buffer_size];
+ tflite::SimpleTensorAllocator simple_tensor_allocator(allocator_buffer,
+ allocator_buffer_size);
+ tflite::MicroInterpreter interpreter(
+ model, mock_resolver, &simple_tensor_allocator, micro_test::reporter);
+ TF_LITE_MICRO_EXPECT_EQ(1, interpreter.inputs_size());
+ TF_LITE_MICRO_EXPECT_EQ(1, interpreter.outputs_size());
+
+ TfLiteTensor* input = interpreter.input(0);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, input);
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt32, input->type);
+ TF_LITE_MICRO_EXPECT_EQ(1, input->dims->size);
+ TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
+ TF_LITE_MICRO_EXPECT_EQ(4, input->bytes);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, input->data.i32);
+ input->data.i32[0] = 21;
+
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, interpreter.Invoke());
+
+ TfLiteTensor* output = interpreter.output(0);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, output);
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt32, output->type);
+ TF_LITE_MICRO_EXPECT_EQ(1, output->dims->size);
+ TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]);
+ TF_LITE_MICRO_EXPECT_EQ(4, output->bytes);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, output->data.i32);
+ TF_LITE_MICRO_EXPECT_EQ(42, output->data.i32[0]);
+}
+
+TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.cc b/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.cc
new file mode 100644
index 0000000000..40c21c6448
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.cc
@@ -0,0 +1,80 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h"
+
+namespace tflite {
+
+const TfLiteRegistration* MicroMutableOpResolver::FindOp(
+ tflite::BuiltinOperator op, int version) const {
+ for (int i = 0; i < registrations_len_; ++i) {
+ const TfLiteRegistration& registration = registrations_[i];
+ if ((registration.builtin_code == op) &&
+ (registration.version == version)) {
+ return &registration;
+ }
+ }
+ return nullptr;
+}
+
+const TfLiteRegistration* MicroMutableOpResolver::FindOp(const char* op,
+ int version) const {
+ for (int i = 0; i < registrations_len_; ++i) {
+ const TfLiteRegistration& registration = registrations_[i];
+ if ((registration.builtin_code == -1) &&
+ (strcmp(registration.custom_name, op) == 0) &&
+ (registration.version == version)) {
+ return &registration;
+ }
+ }
+ return nullptr;
+}
+
+void MicroMutableOpResolver::AddBuiltin(tflite::BuiltinOperator op,
+ TfLiteRegistration* registration,
+ int min_version, int max_version) {
+ for (int version = min_version; version <= max_version; ++version) {
+ if (registrations_len_ >= TFLITE_REGISTRATIONS_MAX) {
+ // TODO(petewarden) - Add error reporting hooks so we can report this!
+ return;
+ }
+ TfLiteRegistration* new_registration = &registrations_[registrations_len_];
+ registrations_len_ += 1;
+
+ *new_registration = *registration;
+ new_registration->builtin_code = op;
+ new_registration->version = version;
+ }
+}
+
+void MicroMutableOpResolver::AddCustom(const char* name,
+ TfLiteRegistration* registration,
+ int min_version, int max_version) {
+ for (int version = min_version; version <= max_version; ++version) {
+ if (registrations_len_ >= TFLITE_REGISTRATIONS_MAX) {
+ // TODO(petewarden) - Add error reporting hooks so we can report this!
+ return;
+ }
+ TfLiteRegistration* new_registration = &registrations_[registrations_len_];
+ registrations_len_ += 1;
+
+ *new_registration = *registration;
+ new_registration->builtin_code = -1;
+ new_registration->custom_name = name;
+ new_registration->version = version;
+ }
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h b/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h
new file mode 100644
index 0000000000..f3750a2484
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
+
+#include "tensorflow/contrib/lite/core/api/op_resolver.h"
+#include "tensorflow/contrib/lite/experimental/micro/compatibility.h"
+
+#ifndef TFLITE_REGISTRATIONS_MAX
+#define TFLITE_REGISTRATIONS_MAX (128)
+#endif
+
+namespace tflite {
+
+class MicroMutableOpResolver : public OpResolver {
+ public:
+ const TfLiteRegistration* FindOp(tflite::BuiltinOperator op,
+ int version) const override;
+ const TfLiteRegistration* FindOp(const char* op, int version) const override;
+ void AddBuiltin(tflite::BuiltinOperator op, TfLiteRegistration* registration,
+ int min_version = 1, int max_version = 1);
+ void AddCustom(const char* name, TfLiteRegistration* registration,
+ int min_version = 1, int max_version = 1);
+
+ private:
+ TfLiteRegistration registrations_[TFLITE_REGISTRATIONS_MAX];
+ int registrations_len_ = 0;
+
+ TF_LITE_REMOVE_VIRTUAL_DELETE
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_MICRO_MUTABLE_OP_RESOLVER_H_
diff --git a/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver_test.cc b/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver_test.cc
new file mode 100644
index 0000000000..5420a33e87
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver_test.cc
@@ -0,0 +1,83 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/experimental/micro/micro_mutable_op_resolver.h"
+
+#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h"
+
+namespace tflite {
+namespace {
+void* MockInit(TfLiteContext* context, const char* buffer, size_t length) {
+ // Do nothing.
+ return nullptr;
+}
+
+void MockFree(TfLiteContext* context, void* buffer) {
+ // Do nothing.
+}
+
+TfLiteStatus MockPrepare(TfLiteContext* context, TfLiteNode* node) {
+ return kTfLiteOk;
+}
+
+TfLiteStatus MockInvoke(TfLiteContext* context, TfLiteNode* node) {
+ return kTfLiteOk;
+}
+} // namespace
+} // namespace tflite
+
+TF_LITE_MICRO_TESTS_BEGIN
+
+TF_LITE_MICRO_TEST(TestOperations) {
+ using tflite::BuiltinOperator_CONV_2D;
+ using tflite::BuiltinOperator_RELU;
+ using tflite::MicroMutableOpResolver;
+ using tflite::OpResolver;
+
+ static TfLiteRegistration r = {tflite::MockInit, tflite::MockFree,
+ tflite::MockPrepare, tflite::MockInvoke};
+
+ MicroMutableOpResolver micro_mutable_op_resolver;
+ micro_mutable_op_resolver.AddBuiltin(BuiltinOperator_CONV_2D, &r, 0, 2);
+ micro_mutable_op_resolver.AddCustom("mock_custom", &r, 0, 3);
+ OpResolver* resolver = &micro_mutable_op_resolver;
+
+ const TfLiteRegistration* registration =
+ resolver->FindOp(BuiltinOperator_CONV_2D, 0);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
+ TF_LITE_MICRO_EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
+
+ registration = resolver->FindOp(BuiltinOperator_CONV_2D, 10);
+ TF_LITE_MICRO_EXPECT_EQ(nullptr, registration);
+
+ registration = resolver->FindOp(BuiltinOperator_RELU, 0);
+ TF_LITE_MICRO_EXPECT_EQ(nullptr, registration);
+
+ registration = resolver->FindOp("mock_custom", 0);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
+ TF_LITE_MICRO_EXPECT_EQ(nullptr, registration->init(nullptr, nullptr, 0));
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(nullptr, nullptr));
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(nullptr, nullptr));
+
+ registration = resolver->FindOp("mock_custom", 10);
+ TF_LITE_MICRO_EXPECT_EQ(nullptr, registration);
+
+ registration = resolver->FindOp("nonexistent_custom", 0);
+ TF_LITE_MICRO_EXPECT_EQ(nullptr, registration);
+}
+
+TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.cc b/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.cc
new file mode 100644
index 0000000000..8c090a20a5
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.cc
@@ -0,0 +1,149 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h"
+
+#include "tensorflow/contrib/lite/core/api/flatbuffer_conversions.h"
+
+namespace tflite {
+namespace {
+
+TfLiteStatus TfLiteTypeSizeOf(TfLiteType type, size_t* size,
+ ErrorReporter* reporter) {
+ switch (type) {
+ case kTfLiteFloat32:
+ *size = sizeof(float);
+ break;
+ case kTfLiteInt16:
+ *size = sizeof(int16_t);
+ break;
+ case kTfLiteInt32:
+ *size = sizeof(int32_t);
+ break;
+ case kTfLiteUInt8:
+ *size = sizeof(uint8_t);
+ break;
+ case kTfLiteInt64:
+ *size = sizeof(int64_t);
+ break;
+ case kTfLiteBool:
+ *size = sizeof(bool);
+ break;
+ case kTfLiteComplex64:
+ *size = sizeof(float) * 2;
+ break;
+ default:
+ reporter->Report(
+ "Only float32, int16, int32, int64, uint8, bool, complex64 "
+ "supported currently.");
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus BytesRequired(const tflite::Tensor& flatbuffer_tensor,
+ size_t dims_size, size_t* bytes,
+ ErrorReporter* error_reporter) {
+ TfLiteType tf_lite_type;
+ TF_LITE_ENSURE_STATUS(ConvertTensorType(flatbuffer_tensor.type(),
+ &tf_lite_type, error_reporter));
+ size_t type_size;
+ TF_LITE_ENSURE_STATUS(
+ TfLiteTypeSizeOf(tf_lite_type, &type_size, error_reporter));
+ *bytes = dims_size * type_size;
+ return kTfLiteOk;
+}
+
+} // namespace
+
+TfLiteStatus SimpleTensorAllocator::AllocateTensor(
+ const tflite::Tensor& flatbuffer_tensor, int create_before,
+ int destroy_after,
+ const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
+ ErrorReporter* error_reporter, TfLiteTensor* result) {
+ TF_LITE_ENSURE_STATUS(ConvertTensorType(flatbuffer_tensor.type(),
+ &result->type, error_reporter));
+ result->is_variable = flatbuffer_tensor.is_variable();
+
+ result->data.raw = nullptr;
+ result->bytes = 0;
+ if (auto* buffer = (*buffers)[flatbuffer_tensor.buffer()]) {
+ if (auto* array = buffer->data()) {
+ if (size_t array_size = array->size()) {
+ result->data.raw =
+ const_cast<char*>(reinterpret_cast<const char*>(array->data()));
+ TF_LITE_ENSURE_STATUS(BytesRequired(flatbuffer_tensor, array_size,
+ &result->bytes, error_reporter));
+ }
+ }
+ }
+ if (result->data.raw) {
+ result->allocation_type = kTfLiteMmapRo;
+ } else {
+ int data_size = 1;
+ for (int n = 0; n < flatbuffer_tensor.shape()->Length(); ++n) {
+ data_size *= flatbuffer_tensor.shape()->Get(n);
+ }
+ TF_LITE_ENSURE_STATUS(BytesRequired(flatbuffer_tensor, data_size,
+ &result->bytes, error_reporter));
+ result->data.raw = reinterpret_cast<char*>(AllocateMemory(result->bytes));
+ if (result->data.raw == nullptr) {
+ const char* tensor_name = flatbuffer_tensor.name()->c_str();
+ if (tensor_name == nullptr) {
+ tensor_name = "<None>";
+ }
+ error_reporter->Report(
+ "Couldn't allocate memory for tensor '%s', wanted %d bytes but only "
+ "%d were available",
+ tensor_name, result->bytes, (data_size_max_ - data_size_));
+ return kTfLiteError;
+ }
+ result->allocation_type = kTfLiteArenaRw;
+ }
+ result->dims = reinterpret_cast<TfLiteIntArray*>(
+ AllocateMemory(sizeof(int) * (flatbuffer_tensor.shape()->Length() + 1)));
+ result->dims->size = flatbuffer_tensor.shape()->Length();
+ for (int n = 0; n < flatbuffer_tensor.shape()->Length(); ++n) {
+ result->dims->data[n] = flatbuffer_tensor.shape()->Get(n);
+ }
+ if (flatbuffer_tensor.quantization()) {
+ result->params.scale = flatbuffer_tensor.quantization()->scale()->Get(0);
+ result->params.zero_point =
+ flatbuffer_tensor.quantization()->zero_point()->Get(0);
+ }
+ result->allocation = nullptr;
+ if (flatbuffer_tensor.name()) {
+ result->name = flatbuffer_tensor.name()->c_str();
+ } else {
+ result->name = "<No name>";
+ }
+ result->delegate = nullptr;
+ result->buffer_handle = 0;
+ result->data_is_stale = false;
+ return kTfLiteOk;
+}
+
+uint8_t* SimpleTensorAllocator::AllocateMemory(size_t size) {
+ if ((data_size_ + size) > data_size_max_) {
+ // TODO(petewarden): Add error reporting beyond returning null!
+ return nullptr;
+ }
+ uint8_t* result = data_;
+ data_ += size;
+ data_size_ += size;
+ return result;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h b/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h
new file mode 100644
index 0000000000..4f16a9d0e5
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator.h
@@ -0,0 +1,51 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_SIMPLE_TENSOR_ALLOCATOR_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_SIMPLE_TENSOR_ALLOCATOR_H_
+
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/schema/schema_generated.h"
+
+namespace tflite {
+
+// TODO(petewarden): This allocator never frees up or reuses any memory, even
+// though we have enough information about lifetimes of the tensors to do so.
+// This makes it pretty wasteful, so we should use a more intelligent method.
+class SimpleTensorAllocator {
+ public:
+ SimpleTensorAllocator(uint8_t* buffer, int buffer_size)
+ : data_size_(0), data_size_max_(buffer_size), data_(buffer) {}
+
+ TfLiteStatus AllocateTensor(
+ const tflite::Tensor& flatbuffer_tensor, int create_before,
+ int destroy_after,
+ const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
+ ErrorReporter* error_reporter, TfLiteTensor* result);
+
+ uint8_t* AllocateMemory(size_t size);
+
+ int GetDataSize() const { return data_size_; }
+
+ private:
+ int data_size_;
+ int data_size_max_;
+ uint8_t* data_;
+};
+
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_SIMPLE_TENSOR_ALLOCATOR_H_
diff --git a/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator_test.cc b/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator_test.cc
new file mode 100644
index 0000000000..c835427243
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator_test.cc
@@ -0,0 +1,144 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/experimental/micro/micro_interpreter.h"
+
+#include "tensorflow/contrib/lite/experimental/micro/testing/micro_test.h"
+
+namespace tflite {
+namespace {
+class StackAllocator : public flatbuffers::Allocator {
+ public:
+ StackAllocator() : data_(data_backing_), data_size_(0) {}
+
+ uint8_t* allocate(size_t size) override {
+ if ((data_size_ + size) > kStackAllocatorSize) {
+ // TODO(petewarden): Add error reporting beyond returning null!
+ return nullptr;
+ }
+ uint8_t* result = data_;
+ data_ += size;
+ data_size_ += size;
+ return result;
+ }
+
+ void deallocate(uint8_t* p, size_t) override {}
+
+ static StackAllocator& instance() {
+ // Avoid using true dynamic memory allocation to be portable to bare metal.
+ static char inst_memory[sizeof(StackAllocator)];
+ static StackAllocator* inst = new (inst_memory) StackAllocator;
+ return *inst;
+ }
+
+ static constexpr int kStackAllocatorSize = 4096;
+
+ private:
+ uint8_t data_backing_[kStackAllocatorSize];
+ uint8_t* data_;
+ int data_size_;
+};
+
+flatbuffers::FlatBufferBuilder* BuilderInstance() {
+ static char inst_memory[sizeof(flatbuffers::FlatBufferBuilder)];
+ static flatbuffers::FlatBufferBuilder* inst =
+ new (inst_memory) flatbuffers::FlatBufferBuilder(
+ StackAllocator::kStackAllocatorSize, &StackAllocator::instance());
+ return inst;
+}
+
+const Tensor* Create1dTensor(int size) {
+ using flatbuffers::Offset;
+ flatbuffers::FlatBufferBuilder* builder = BuilderInstance();
+ constexpr size_t tensor_shape_size = 1;
+ const int32_t tensor_shape[tensor_shape_size] = {size};
+ const Offset<Tensor> tensor_offset = CreateTensor(
+ *builder, builder->CreateVector(tensor_shape, tensor_shape_size),
+ TensorType_INT32, 0, builder->CreateString("test_tensor"), 0, false);
+ builder->Finish(tensor_offset);
+ void* tensor_pointer = builder->GetBufferPointer();
+ const Tensor* tensor = flatbuffers::GetRoot<Tensor>(tensor_pointer);
+ return tensor;
+}
+
+const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* CreateBuffers() {
+ using flatbuffers::Offset;
+ flatbuffers::FlatBufferBuilder* builder = BuilderInstance();
+ constexpr size_t buffers_size = 1;
+ const Offset<Buffer> buffers[buffers_size] = {
+ CreateBuffer(*builder),
+ };
+ const flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>
+ buffers_offset = builder->CreateVector(buffers, buffers_size);
+ builder->Finish(buffers_offset);
+ void* buffers_pointer = builder->GetBufferPointer();
+ const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* result =
+ flatbuffers::GetRoot<flatbuffers::Vector<flatbuffers::Offset<Buffer>>>(
+ buffers_pointer);
+ return result;
+}
+
+} // namespace
+} // namespace tflite
+
+TF_LITE_MICRO_TESTS_BEGIN
+
+TF_LITE_MICRO_TEST(TestAllocateTensor) {
+ constexpr size_t arena_size = 1024;
+ uint8_t arena[arena_size];
+ tflite::SimpleTensorAllocator allocator(arena, arena_size);
+
+ const tflite::Tensor* tensor = tflite::Create1dTensor(100);
+ const flatbuffers::Vector<flatbuffers::Offset<tflite::Buffer>>* buffers =
+ tflite::CreateBuffers();
+
+ TfLiteTensor allocated_tensor;
+ TF_LITE_MICRO_EXPECT_EQ(
+ kTfLiteOk,
+ allocator.AllocateTensor(*tensor, 0, 1, buffers, micro_test::reporter,
+ &allocated_tensor));
+ TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt32, allocated_tensor.type);
+ TF_LITE_MICRO_EXPECT_EQ(1, allocated_tensor.dims->size);
+ TF_LITE_MICRO_EXPECT_EQ(100, allocated_tensor.dims->data[0]);
+ TF_LITE_MICRO_EXPECT_EQ(400, allocated_tensor.bytes);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, allocated_tensor.data.i32);
+}
+
+TF_LITE_MICRO_TEST(TestTooLarge) {
+ constexpr size_t arena_size = 1024;
+ uint8_t arena[arena_size];
+ tflite::SimpleTensorAllocator allocator(arena, arena_size);
+
+ const tflite::Tensor* tensor = tflite::Create1dTensor(10000);
+ const flatbuffers::Vector<flatbuffers::Offset<tflite::Buffer>>* buffers =
+ tflite::CreateBuffers();
+
+ TfLiteTensor allocated_tensor;
+ TF_LITE_MICRO_EXPECT_NE(
+ kTfLiteOk,
+ allocator.AllocateTensor(*tensor, 0, 1, buffers, micro_test::reporter,
+ &allocated_tensor));
+}
+
+TF_LITE_MICRO_TEST(TestJustFits) {
+ constexpr size_t arena_size = 1024;
+ uint8_t arena[arena_size];
+ tflite::SimpleTensorAllocator allocator(arena, arena_size);
+
+ uint8_t* result = allocator.AllocateMemory(arena_size);
+ TF_LITE_MICRO_EXPECT_NE(nullptr, result);
+}
+
+TF_LITE_MICRO_TESTS_END
diff --git a/tensorflow/contrib/lite/experimental/micro/testing/BUILD b/tensorflow/contrib/lite/experimental/micro/testing/BUILD
new file mode 100644
index 0000000000..0d23be5712
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/testing/BUILD
@@ -0,0 +1,17 @@
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["test_linux_binary.sh"])
+
+cc_library(
+ name = "micro_test",
+ hdrs = [
+ "micro_test.h",
+ ],
+ deps = [
+ "//tensorflow/contrib/lite/experimental/micro:micro_framework",
+ ],
+)
diff --git a/tensorflow/contrib/lite/experimental/micro/testing/Dockerfile.bluepill b/tensorflow/contrib/lite/experimental/micro/testing/Dockerfile.bluepill
new file mode 100644
index 0000000000..7d6d81af0f
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/testing/Dockerfile.bluepill
@@ -0,0 +1,21 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# This docker configuration file lets you emulate a Blue Pill board
+# on an x86 desktop or laptop, which can be useful for debugging and
+# automated testing.
+FROM antmicro/renode:latest
+
+LABEL maintainer="Pete Warden <petewarden@google.com>" \ No newline at end of file
diff --git a/tensorflow/contrib/lite/experimental/micro/testing/bluepill.resc b/tensorflow/contrib/lite/experimental/micro/testing/bluepill.resc
new file mode 100644
index 0000000000..9333dc42bf
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/testing/bluepill.resc
@@ -0,0 +1,36 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+using sysbus
+
+mach create
+machine LoadPlatformDescription @platforms/cpus/stm32f103.repl
+
+# These lines are needed to show the results of DebugLog calls in the output.
+machine LoadPlatformDescriptionFromString "uartSemihosting: UART.SemihostingUart @ cpu"
+showAnalyzer cpu.uartSemihosting Antmicro.Renode.Analyzers.LoggingUartAnalyzer
+
+logFile @/tmp/renode_bluepill_log.txt
+
+macro reset
+"""
+ sysbus LoadELF $bin
+"""
+
+runMacro $reset
+
+emulation RunFor @1
+
+quit \ No newline at end of file
diff --git a/tensorflow/contrib/lite/experimental/micro/testing/micro_test.bzl b/tensorflow/contrib/lite/experimental/micro/testing/micro_test.bzl
new file mode 100644
index 0000000000..916e3eeac3
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/testing/micro_test.bzl
@@ -0,0 +1,67 @@
+"""Rules for simple testing without dependencies by parsing output logs."""
+
+def tflite_micro_cc_test(
+ name,
+ expected_in_logs = "~~~ALL TESTS PASSED~~~",
+ srcs = [],
+ includes = [],
+ defines = [],
+ copts = [],
+ nocopts = "",
+ linkopts = [],
+ deps = [],
+ tags = [],
+ visibility = None):
+ """Tests a C/C++ binary without testing framework dependencies`.
+
+ Runs a C++ binary, and tests that the output logs contain the
+ expected value. This is a deliberately spartan way of testing, to match
+ what's available when testing microcontroller binaries.
+
+ Args:
+ name: a unique name for this rule.
+ expected_in_logs: A regular expression that is required to be
+ present in the binary's logs for the test to pass.
+ srcs: sources to compile (C, C++, ld scripts).
+ includes: include paths to add to this rule and its dependents.
+ defines: list of `VAR` or `VAR=VAL` to pass to CPP for this rule and
+ its dependents.
+ copts: gcc compilation flags for this rule only.
+ nocopts: list of gcc compilation flags to remove for this rule
+ only. No regexp like for `cc_library`.
+ linkopts: `gcc` flags to add to the linking phase. For "pure" ld flags,
+ prefix them with the `-Wl,` prefix here.
+ deps: dependencies. only `tflite_bare_metal_cc_library()` dependencies
+ allowed.
+ visibility: visibility.
+ """
+ native.cc_binary(
+ name = name + "_binary",
+ srcs = srcs,
+ includes = includes,
+ defines = defines,
+ copts = copts,
+ nocopts = nocopts,
+ linkopts = linkopts,
+ deps = deps,
+ tags = tags,
+ visibility = visibility,
+ )
+ native.sh_test(
+ name = name,
+ size = "medium",
+ srcs = [
+ "//tensorflow/contrib/lite/experimental/micro/testing:test_linux_binary.sh",
+ ],
+ args = [
+ native.package_name() + "/" + name + "_binary",
+ "'" + expected_in_logs + "'",
+ ],
+ data = [
+ name + "_binary",
+ # Internal test dependency placeholder
+ ],
+ deps = [
+ ],
+ tags = tags,
+ )
diff --git a/tensorflow/contrib/lite/experimental/micro/testing/micro_test.h b/tensorflow/contrib/lite/experimental/micro/testing/micro_test.h
new file mode 100644
index 0000000000..104509c9dc
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/testing/micro_test.h
@@ -0,0 +1,138 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// An ultra-lightweight testing framework designed for use with microcontroller
+// applications. Its only dependency is on TensorFlow Lite's ErrorReporter
+// interface, where log messages are output. This is designed to be usable even
+// when no standard C or C++ libraries are available, and without any dynamic
+// memory allocation or reliance on global constructors.
+//
+// To build a test, you use syntax similar to gunit, but with some extra
+// decoration to create a hidden 'main' function containing each of the tests to
+// be run. Your code should look something like:
+// ----------------------------------------------------------------------------
+// #include "path/to/this/header"
+//
+// TF_LITE_MICRO_TESTS_BEGIN
+//
+// TF_LITE_MICRO_TEST(SomeTest) {
+// TF_LITE_LOG_EXPECT_EQ(true, true);
+// }
+//
+// TF_LITE_MICRO_TESTS_END
+// ----------------------------------------------------------------------------
+// If you compile this for your platform, you'll get a normal binary that you
+// should be able to run. Executing it will output logging information like this
+// to stderr (or whatever equivalent is available and written to by
+// ErrorReporter):
+// ----------------------------------------------------------------------------
+// Testing SomeTest
+// 1/1 tests passed
+// ~~~ALL TESTS PASSED~~~
+// ----------------------------------------------------------------------------
+// This is designed to be human-readable, so you can just run tests manually,
+// but the string "~~~ALL TESTS PASSED~~~" should only appear if all of the
+// tests do pass. This makes it possible to integrate with automated test
+// systems by scanning the output logs and looking for that magic value.
+//
+// This framework is intended to be a rudimentary alternative to no testing at
+// all on systems that struggle to run more conventional approaches, so use with
+// caution!
+
+#ifndef TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_TESTING_MICRO_TEST_H_
+#define TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_TESTING_MICRO_TEST_H_
+
+#include "tensorflow/contrib/lite/experimental/micro/micro_error_reporter.h"
+
+namespace micro_test {
+extern int tests_passed;
+extern int tests_failed;
+extern bool is_test_complete;
+extern bool did_test_fail;
+extern tflite::ErrorReporter* reporter;
+} // namespace micro_test
+
+#define TF_LITE_MICRO_TESTS_BEGIN \
+ namespace micro_test { \
+ int tests_passed; \
+ int tests_failed; \
+ bool is_test_complete; \
+ bool did_test_fail; \
+ tflite::ErrorReporter* reporter; \
+ } \
+ \
+ int main(int argc, char** argv) { \
+ micro_test::tests_passed = 0; \
+ micro_test::tests_failed = 0; \
+ tflite::MicroErrorReporter error_reporter; \
+ micro_test::reporter = &error_reporter;
+
+#define TF_LITE_MICRO_TESTS_END \
+ micro_test::reporter->Report( \
+ "%d/%d tests passed", micro_test::tests_passed, \
+ (micro_test::tests_failed + micro_test::tests_passed)); \
+ if (micro_test::tests_failed == 0) { \
+ micro_test::reporter->Report("~~~ALL TESTS PASSED~~~\n"); \
+ } else { \
+ micro_test::reporter->Report("~~~SOME TESTS FAILED~~~\n"); \
+ } \
+ }
+
+// TODO(petewarden): I'm going to hell for what I'm doing to this poor for loop.
+#define TF_LITE_MICRO_TEST(name) \
+ micro_test::reporter->Report("Testing %s", #name); \
+ for (micro_test::is_test_complete = false, \
+ micro_test::did_test_fail = false; \
+ !micro_test::is_test_complete; micro_test::is_test_complete = true, \
+ micro_test::tests_passed += (micro_test::did_test_fail) ? 0 : 1, \
+ micro_test::tests_failed += (micro_test::did_test_fail) ? 1 : 0)
+
+#define TF_LITE_MICRO_EXPECT(x) \
+ do { \
+ if (!(x)) { \
+ micro_test::reporter->Report(#x " failed at %s:%d", __FILE__, __LINE__); \
+ micro_test::did_test_fail = true; \
+ } \
+ } while (false)
+
+#define TF_LITE_MICRO_EXPECT_EQ(x, y) \
+ do { \
+ if ((x) != (y)) { \
+ micro_test::reporter->Report(#x " == " #y " failed at %s:%d", __FILE__, \
+ __LINE__); \
+ micro_test::did_test_fail = true; \
+ } \
+ } while (false)
+
+#define TF_LITE_MICRO_EXPECT_NE(x, y) \
+ do { \
+ if ((x) == (y)) { \
+ micro_test::reporter->Report(#x " != " #y " failed at %s:%d", __FILE__, \
+ __LINE__); \
+ micro_test::did_test_fail = true; \
+ } \
+ } while (false)
+
+#define TF_LITE_MICRO_EXPECT_NEAR(x, y, epsilon) \
+ do { \
+ auto delta = ((x) > (y)) ? ((x) - (y)) : ((y) - (x)); \
+ if (delta > epsilon) { \
+ micro_test::reporter->Report(#x " near " #y " failed at %s:%d", \
+ __FILE__, __LINE__); \
+ micro_test::did_test_fail = true; \
+ } \
+ } while (false)
+
+#endif // TENSORFLOW_CONTRIB_LITE_EXPERIMENTAL_MICRO_TESTING_MICRO_TEST_H_
diff --git a/tensorflow/contrib/lite/experimental/micro/testing/test_bluepill_binary.sh b/tensorflow/contrib/lite/experimental/micro/testing/test_bluepill_binary.sh
new file mode 100755
index 0000000000..07742a8262
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/testing/test_bluepill_binary.sh
@@ -0,0 +1,54 @@
+#!/bin/bash -e
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Tests a 'bluepill' STM32F103 ELF by parsing the log output of Renode emulation.
+#
+# First argument is the ELF location.
+# Second argument is a regular expression that's required to be in the output logs
+# for the test to pass.
+
+declare -r ROOT_DIR=`pwd`
+declare -r TEST_TMPDIR=/tmp/test_bluepill_binary/
+declare -r MICRO_LOG_PATH=${TEST_TMPDIR}
+declare -r MICRO_LOG_FILENAME=${MICRO_LOG_PATH}/logs.txt
+mkdir -p ${MICRO_LOG_PATH}
+
+docker build -t renode_bluepill \
+ -f ${ROOT_DIR}/tensorflow/contrib/lite/experimental/micro/testing/Dockerfile.bluepill \
+ ${ROOT_DIR}/tensorflow/contrib/lite/experimental/micro/testing/
+
+docker run \
+ --log-driver=none -a stdout -a stderr \
+ -v ${ROOT_DIR}:/workspace \
+ -v /tmp:/tmp \
+ -it renode_bluepill \
+ /bin/bash -c "renode -P 5000 --disable-xwt -e '
+\$bin?=@/workspace/$1
+s @/workspace/tensorflow/contrib/lite/experimental/micro/testing/bluepill.resc
+' 2>&1 >${MICRO_LOG_FILENAME}"
+
+echo "LOGS:"
+cat ${MICRO_LOG_FILENAME}
+
+if grep -q "$2" ${MICRO_LOG_FILENAME}
+then
+ echo "$1: PASS"
+ exit 0
+else
+ echo "$1: FAIL - '$2' not found in logs."
+ exit 1
+fi
+
diff --git a/tensorflow/contrib/lite/experimental/micro/testing/test_linux_binary.sh b/tensorflow/contrib/lite/experimental/micro/testing/test_linux_binary.sh
new file mode 100755
index 0000000000..24131a6d2d
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/testing/test_linux_binary.sh
@@ -0,0 +1,39 @@
+#!/bin/bash -e
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Tests a Linux binary by parsing the log output.
+#
+# First argument is the binary location.
+# Second argument is a regular expression that's required to be in the output logs
+# for the test to pass.
+
+declare -r ROOT_DIR=`pwd`
+declare -r TEST_TMPDIR=/tmp/test_bluepill_binary/
+declare -r MICRO_LOG_PATH=${TEST_TMPDIR}/$1
+declare -r MICRO_LOG_FILENAME=${MICRO_LOG_PATH}/logs.txt
+mkdir -p ${MICRO_LOG_PATH}
+
+$1 2>&1 | tee ${MICRO_LOG_FILENAME}
+
+if grep -q "$2" ${MICRO_LOG_FILENAME}
+then
+ echo "$1: PASS"
+ exit 0
+else
+ echo "$1: FAIL - '$2' not found in logs."
+ exit 1
+fi
+
diff --git a/tensorflow/contrib/lite/experimental/micro/tools/make/Makefile b/tensorflow/contrib/lite/experimental/micro/tools/make/Makefile
new file mode 100644
index 0000000000..880bb4763c
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/tools/make/Makefile
@@ -0,0 +1,166 @@
+MAKEFILE_DIR := tensorflow/contrib/lite/experimental/micro/tools/make
+
+# Try to figure out the host system
+HOST_OS :=
+ifeq ($(OS),Windows_NT)
+ HOST_OS = windows
+else
+ UNAME_S := $(shell uname -s)
+ ifeq ($(UNAME_S),Linux)
+ HOST_OS := linux
+ endif
+ ifeq ($(UNAME_S),Darwin)
+ HOST_OS := osx
+ endif
+endif
+
+HOST_ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32; else echo $(shell uname -m); fi)
+
+# Override these on the make command line to target a specific architecture. For example:
+# make -f tensorflow/contrib/lite/Makefile TARGET=rpi TARGET_ARCH=armv7l
+TARGET := $(HOST_OS)
+TARGET_ARCH := $(HOST_ARCH)
+
+INCLUDES := \
+-I. \
+-I$(MAKEFILE_DIR)/../../../../../ \
+-I$(MAKEFILE_DIR)/../../../../../../ \
+-I$(MAKEFILE_DIR)/downloads/ \
+-I$(MAKEFILE_DIR)/downloads/gemmlowp \
+-I$(MAKEFILE_DIR)/downloads/flatbuffers/include \
+-I$(OBJDIR)
+# This is at the end so any globally-installed frameworks like protobuf don't
+# override local versions in the source tree.
+INCLUDES += -I/usr/local/include
+
+TEST_SCRIPT := tensorflow/contrib/lite/experimental/micro/testing/test_linux_binary.sh
+
+MICROLITE_LIBS := -lm
+
+# There are no rules for compiling objects for the host system (since we don't
+# generate things like the protobuf compiler that require that), so all of
+# these settings are for the target compiler.
+CXXFLAGS := -O3 -DNDEBUG
+CXXFLAGS += --std=c++11 -g -DTF_LITE_STATIC_MEMORY
+CCFLAGS := -DNDEBUG -g -DTF_LITE_STATIC_MEMORY
+LDOPTS := -L/usr/local/lib
+ARFLAGS := -r
+TARGET_TOOLCHAIN_PREFIX :=
+CC_PREFIX :=
+
+# This library is the main target for this makefile. It will contain a minimal
+# runtime that can be linked in to other programs.
+MICROLITE_LIB_NAME := libtensorflow-microlite.a
+
+# Test binary for the microcontroller speech model.
+MICRO_SPEECH_TEST_SRCS := \
+tensorflow/contrib/lite/experimental/micro/examples/micro_speech/micro_speech_test.cc \
+tensorflow/contrib/lite/experimental/micro/examples/micro_speech/tiny_conv_model_data.cc
+
+MICROLITE_TEST_SRCS := \
+$(wildcard tensorflow/contrib/lite/experimental/micro/*test.cc) \
+$(wildcard tensorflow/contrib/lite/experimental/micro/kernels/*test.cc)
+
+MICROLITE_CC_BASE_SRCS := \
+$(wildcard tensorflow/contrib/lite/experimental/micro/*.cc) \
+$(wildcard tensorflow/contrib/lite/experimental/micro/kernels/*.cc) \
+tensorflow/contrib/lite/c/c_api_internal.c \
+tensorflow/contrib/lite/core/api/error_reporter.cc \
+tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc \
+tensorflow/contrib/lite/core/api/op_resolver.cc \
+tensorflow/contrib/lite/kernels/kernel_util.cc \
+tensorflow/contrib/lite/kernels/internal/quantization_util.cc
+MICROLITE_CC_SRCS := $(filter-out $(MICROLITE_TEST_SRCS), $(MICROLITE_CC_BASE_SRCS))
+
+# These target-specific makefiles should modify or replace options like
+# CXXFLAGS or LIBS to work for a specific targetted architecture. All logic
+# based on platforms or architectures should happen within these files, to
+# keep this main makefile focused on the sources and dependencies.
+include $(wildcard $(MAKEFILE_DIR)/targets/*_makefile.inc)
+
+ALL_SRCS := \
+ $(MICRO_SPEECH_TEST_SRCS) \
+ $(MICROLITE_CC_SRCS) \
+ $(MICROLITE_TEST_SRCS)
+
+# Where compiled objects are stored.
+GENDIR := $(MAKEFILE_DIR)/gen/$(TARGET)_$(TARGET_ARCH)/
+OBJDIR := $(GENDIR)obj/
+BINDIR := $(GENDIR)bin/
+LIBDIR := $(GENDIR)lib/
+
+MICROLITE_LIB_PATH := $(LIBDIR)$(MICROLITE_LIB_NAME)
+
+MICRO_SPEECH_TEST_BINARY := $(BINDIR)micro_speech_test
+
+CXX := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}g++
+CC := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}gcc
+AR := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}ar
+
+MICRO_SPEECH_TEST_OBJS := $(addprefix $(OBJDIR), \
+$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MICRO_SPEECH_TEST_SRCS))))
+
+MICROLITE_LIB_OBJS := $(addprefix $(OBJDIR), \
+$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MICROLITE_CC_SRCS))))
+
+MICROLITE_TEST_TARGETS := $(addprefix $(BINDIR), \
+$(patsubst %_test.cc,%.test_target,$(MICROLITE_TEST_SRCS)))
+
+# For normal manually-created TensorFlow C++ source files.
+$(OBJDIR)%.o: %.cc
+ @mkdir -p $(dir $@)
+ $(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@
+
+# For normal manually-created TensorFlow C source files.
+$(OBJDIR)%.o: %.c
+ @mkdir -p $(dir $@)
+ $(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@
+
+# The target that's compiled if there's no command-line arguments.
+all: $(MICROLITE_LIB_PATH) $(MICRO_SPEECH_TEST_BINARY)
+
+microlite: $(MICROLITE_LIB_PATH)
+
+# Hack for generating schema file bypassing flatbuffer parsing
+tensorflow/contrib/lite/schema/schema_generated.h:
+ @cp -u tensorflow/contrib/lite/schema/schema_generated.h.OPENSOURCE tensorflow/contrib/lite/schema/schema_generated.h
+
+# Gathers together all the objects we've compiled into a single '.a' archive.
+$(MICROLITE_LIB_PATH): tensorflow/contrib/lite/schema/schema_generated.h $(MICROLITE_LIB_OBJS)
+ @mkdir -p $(dir $@)
+ $(AR) $(ARFLAGS) $(MICROLITE_LIB_PATH) $(MICROLITE_LIB_OBJS)
+
+$(MICRO_SPEECH_TEST_BINARY): $(MICRO_SPEECH_TEST_OBJS) $(MICROLITE_LIB_PATH)
+ @mkdir -p $(dir $@)
+ $(CXX) $(CXXFLAGS) $(INCLUDES) \
+ -o $(MICRO_SPEECH_TEST_BINARY) $(MICRO_SPEECH_TEST_OBJS) \
+ $(LIBFLAGS) $(MICROLITE_LIB_PATH) $(LDFLAGS) $(MICROLITE_LIBS)
+
+micro_speech_test: $(MICRO_SPEECH_TEST_BINARY)
+micro_speech_test_bin: $(MICRO_SPEECH_TEST_BINARY).bin
+
+test_micro_speech: $(MICRO_SPEECH_TEST_BINARY)
+ $(TEST_SCRIPT) $(MICRO_SPEECH_TEST_BINARY) '~~~ALL TESTS PASSED~~~'
+
+$(BINDIR)%_test : $(OBJDIR)%_test.o $(MICROLITE_LIB_PATH)
+ @mkdir -p $(dir $@)
+ $(CXX) $(CXXFLAGS) $(INCLUDES) \
+ -o $@ $< \
+ $(LIBFLAGS) $(MICROLITE_LIB_PATH) $(LDFLAGS) $(MICROLITE_LIBS)
+
+$(BINDIR)%.test_target: $(BINDIR)%_test
+ $(TEST_SCRIPT) $< '~~~ALL TESTS PASSED~~~'
+
+$(info $(MICROLITE_TEST_TARGETS))
+
+test: test_micro_speech $(MICROLITE_TEST_TARGETS)
+
+# Gets rid of all generated files.
+clean:
+ rm -rf $(MAKEFILE_DIR)/gen
+
+$(DEPDIR)/%.d: ;
+.PRECIOUS: $(DEPDIR)/%.d
+.PRECIOUS: $(BINDIR)%_test
+
+-include $(patsubst %,$(DEPDIR)/%.d,$(basename $(ALL_SRCS)))
diff --git a/tensorflow/contrib/lite/experimental/micro/tools/make/download_dependencies.sh b/tensorflow/contrib/lite/experimental/micro/tools/make/download_dependencies.sh
new file mode 100755
index 0000000000..4c2ff8545d
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/tools/make/download_dependencies.sh
@@ -0,0 +1,73 @@
+#!/bin/bash
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+set -e
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+cd "$SCRIPT_DIR/../../../../../../.."
+
+DOWNLOADS_DIR=tensorflow/contrib/lite/experimental/micro/tools/make/downloads
+BZL_FILE_PATH=tensorflow/workspace.bzl
+
+# Ensure it is being run from repo root
+if [ ! -f $BZL_FILE_PATH ]; then
+ echo "Could not find ${BZL_FILE_PATH}":
+ echo "Likely you are not running this from the root directory of the repository.";
+ exit 1;
+fi
+
+GEMMLOWP_URL="https://github.com/google/gemmlowp/archive/719139ce755a0f31cbf1c37f7f98adcc7fc9f425.zip"
+FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/1f5eae5d6a135ff6811724f6c57f911d1f46bb15.tar.gz"
+CMSIS_URL="https://github.com/ARM-software/CMSIS_5/archive/5.4.0.zip"
+STM32_BARE_LIB_URL="https://github.com/google/stm32_bare_lib/archive/50e0da307a2821bb54af1f57b969e6b76cb89d32.zip"
+
+download_and_extract() {
+ local usage="Usage: download_and_extract URL DIR"
+ local url="${1:?${usage}}"
+ local dir="${2:?${usage}}"
+ echo "downloading ${url}" >&2
+ mkdir -p "${dir}"
+ if [[ "${url}" == *gz ]]; then
+ curl -Ls "${url}" | tar -C "${dir}" --strip-components=1 -xz
+ elif [[ "${url}" == *zip ]]; then
+ tempdir=$(mktemp -d)
+ tempdir2=$(mktemp -d)
+
+ curl -L ${url} > ${tempdir}/zipped.zip
+ unzip ${tempdir}/zipped.zip -d ${tempdir2}
+
+ # If the zip file contains nested directories, extract the files from the
+ # inner directory.
+ if ls ${tempdir2}/*/* 1> /dev/null 2>&1; then
+ # unzip has no strip components, so unzip to a temp dir, and move the
+ # files we want from the tempdir to destination.
+ cp -R ${tempdir2}/*/* ${dir}/
+ else
+ cp -R ${tempdir2}/* ${dir}/
+ fi
+ rm -rf ${tempdir2} ${tempdir}
+ fi
+
+ # Delete any potential BUILD files, which would interfere with Bazel builds.
+ find "${dir}" -type f -name '*BUILD' -delete
+}
+
+download_and_extract "${GEMMLOWP_URL}" "${DOWNLOADS_DIR}/gemmlowp"
+download_and_extract "${FLATBUFFERS_URL}" "${DOWNLOADS_DIR}/flatbuffers"
+download_and_extract "${CMSIS_URL}" "${DOWNLOADS_DIR}/cmsis"
+download_and_extract "${STM32_BARE_LIB_URL}" "${DOWNLOADS_DIR}/stm32_bare_lib"
+
+echo "download_dependencies.sh completed successfully." >&2
diff --git a/tensorflow/contrib/lite/experimental/micro/tools/make/targets/bluepill_makefile.inc b/tensorflow/contrib/lite/experimental/micro/tools/make/targets/bluepill_makefile.inc
new file mode 100644
index 0000000000..022a8422dc
--- /dev/null
+++ b/tensorflow/contrib/lite/experimental/micro/tools/make/targets/bluepill_makefile.inc
@@ -0,0 +1,65 @@
+# Settings for Blue Pill platforms.
+ifeq ($(TARGET), bluepill)
+ TARGET_ARCH := cortex-m3
+ TARGET_TOOLCHAIN_PREFIX := arm-none-eabi-
+
+ PLATFORM_FLAGS = \
+ -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \
+ -DTF_LITE_STATIC_MEMORY \
+ -DTF_LITE_MCU_DEBUG_LOG \
+ -fno-rtti \
+ -fmessage-length=0 \
+ -fno-exceptions \
+ -fno-unwind-tables \
+ -fno-builtin \
+ -ffunction-sections \
+ -fdata-sections \
+ -funsigned-char \
+ -MMD \
+ -mcpu=cortex-m3 \
+ -mthumb \
+ -std=gnu++11 \
+ -Wvla \
+ -Wall \
+ -Wextra \
+ -Wno-unused-parameter \
+ -Wno-missing-field-initializers \
+ -Wno-write-strings \
+ -Wno-sign-compare \
+ -fno-delete-null-pointer-checks \
+ -fomit-frame-pointer \
+ -fpermissive \
+ -nostdlib \
+ -g \
+ -Os
+ CXXFLAGS += $(PLATFORM_FLAGS)
+ CCFLAGS += $(PLATFORM_FLAGS)
+ LDFLAGS += \
+ -T $(MAKEFILE_DIR)/downloads/stm32_bare_lib/stm32_linker_layout.lds \
+ -Wl,-Map=$(MAKEFILE_DIR)/gen/$(TARGET).map,--cref \
+ -Wl,--gc-sections
+ BUILD_TYPE := micro
+ MICROLITE_LIBS := \
+ -lm
+ INCLUDES += \
+ -isystem$(MAKEFILE_DIR)/downloads/cmsis/CMSIS/Core/Include/ \
+ -I$(MAKEFILE_DIR)/downloads/stm32_bare_lib/include
+ MICROLITE_CC_SRCS += \
+ $(wildcard $(MAKEFILE_DIR)/downloads/stm32_bare_lib/source/*.c) \
+ $(wildcard $(MAKEFILE_DIR)/downloads/stm32_bare_lib/source/*.cc)
+ TEST_SCRIPT := tensorflow/contrib/lite/experimental/micro/testing/test_bluepill_binary.sh
+ # These are tests that don't currently work on the blue pill.
+ EXCLUDED_TESTS := \
+ tensorflow/contrib/lite/experimental/micro/micro_interpreter_test.cc \
+ tensorflow/contrib/lite/experimental/micro/simple_tensor_allocator_test.cc
+ MICROLITE_TEST_SRCS := $(filter-out $(EXCLUDED_TESTS), $(MICROLITE_TEST_SRCS))
+
+# These are microcontroller-specific rules for converting the ELF output
+# of the linker into a binary image that can be loaded directly.
+OBJCOPY := $(TARGET_TOOLCHAIN_PREFIX)objcopy
+
+$(BINDIR)/%.bin: $(BINDIR)/%
+ @mkdir -p $(dir $@)
+ $(OBJCOPY) $< $@ -O binary
+
+endif \ No newline at end of file
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
index b0f32a8d6c..2eb776d10c 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/android_build.md
@@ -1,6 +1,22 @@
-
# Building TensorFlow on Android
+Warning: We expect to deprecate TensorFlow Mobile in early 2019
+
+<div class="caution">
+ <p>
+ <a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
+ working hard to close the feature gap between TensorFlow Mobile and
+ TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
+ will give ample notice to our users when we get to that point and will
+ provide help and support to ensure easy migrations.
+ </p>
+ <p>
+ In the meantime, please use TensorFlow Lite. If you have a feature request,
+ such as a missing op, please post to our <a
+ href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
+ </p>
+</div>
+
To get you started working with TensorFlow on Android, we'll walk through two
ways to build our TensorFlow mobile demos and deploying them on an Android
device. The first is Android Studio, which lets you build and deploy in an
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/index.md b/tensorflow/contrib/lite/g3doc/tfmobile/index.md
index 49ad35d4e6..15f0fd3961 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/index.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/index.md
@@ -1,6 +1,22 @@
-
# Overview
+Warning: We expect to deprecate TensorFlow Mobile in early 2019
+
+<div class="caution">
+ <p>
+ <a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
+ working hard to close the feature gap between TensorFlow Mobile and
+ TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
+ will give ample notice to our users when we get to that point and will
+ provide help and support to ensure easy migrations.
+ </p>
+ <p>
+ In the meantime, please use TensorFlow Lite. If you have a feature request,
+ such as a missing op, please post to our <a
+ href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
+ </p>
+</div>
+
TensorFlow was designed to be a good deep learning solution for mobile
platforms. Currently we have two solutions for deploying machine learning
applications on mobile and embedded devices: TensorFlow for Mobile and
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md b/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md
index be8b4100c8..d922907cdc 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/ios_build.md
@@ -1,6 +1,22 @@
-
# Building TensorFlow on iOS
+Warning: We expect to deprecate TensorFlow Mobile in early 2019
+
+<div class="caution">
+ <p>
+ <a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
+ working hard to close the feature gap between TensorFlow Mobile and
+ TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
+ will give ample notice to our users when we get to that point and will
+ provide help and support to ensure easy migrations.
+ </p>
+ <p>
+ In the meantime, please use TensorFlow Lite. If you have a feature request,
+ such as a missing op, please post to our <a
+ href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
+ </p>
+</div>
+
## Using CocoaPods
The simplest way to get started with TensorFlow on iOS is using the CocoaPods
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md b/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md
index 4d4bb3bc08..fd0e322c93 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/linking_libs.md
@@ -1,6 +1,22 @@
-
# Integrating TensorFlow libraries
+Warning: We expect to deprecate TensorFlow Mobile in early 2019
+
+<div class="caution">
+ <p>
+ <a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
+ working hard to close the feature gap between TensorFlow Mobile and
+ TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
+ will give ample notice to our users when we get to that point and will
+ provide help and support to ensure easy migrations.
+ </p>
+ <p>
+ In the meantime, please use TensorFlow Lite. If you have a feature request,
+ such as a missing op, please post to our <a
+ href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
+ </p>
+</div>
+
Once you have made some progress on a model that addresses the problem you’re
trying to solve, it’s important to test it out inside your application
immediately. There are often unexpected differences between your training data
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md b/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md
index 7436594fd8..59ff8e774c 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/optimizing.md
@@ -1,6 +1,22 @@
-
# Optimizing for mobile
+Warning: We expect to deprecate TensorFlow Mobile in early 2019
+
+<div class="caution">
+ <p>
+ <a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
+ working hard to close the feature gap between TensorFlow Mobile and
+ TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
+ will give ample notice to our users when we get to that point and will
+ provide help and support to ensure easy migrations.
+ </p>
+ <p>
+ In the meantime, please use TensorFlow Lite. If you have a feature request,
+ such as a missing op, please post to our <a
+ href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
+ </p>
+</div>
+
There are some special issues that you have to deal with when you’re trying to
ship on mobile or embedded devices, and you’ll need to think about these as
you’re developing your model.
diff --git a/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md b/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md
index d1c67d4c61..1d373251dd 100644
--- a/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md
+++ b/tensorflow/contrib/lite/g3doc/tfmobile/prepare_models.md
@@ -1,6 +1,22 @@
-
# Preparing models for mobile deployment
+Warning: We expect to deprecate TensorFlow Mobile in early 2019
+
+<div class="caution">
+ <p>
+ <a href="../">TensorFlow Lite</a> is our main mobile and embedded offering. We are
+ working hard to close the feature gap between TensorFlow Mobile and
+ TensorFlow Lite. We expect to deprecate TensorFlow Mobile in early 2019. We
+ will give ample notice to our users when we get to that point and will
+ provide help and support to ensure easy migrations.
+ </p>
+ <p>
+ In the meantime, please use TensorFlow Lite. If you have a feature request,
+ such as a missing op, please post to our <a
+ href="https://github.com/tensorflow/tensorflow/issues">GitHub</a>.
+ </p>
+</div>
+
The requirements for storing model information during training are very
different from when you want to release it as part of a mobile app. This section
covers the tools involved in converting from a training model to something
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index 7ef736d01b..651a97e9dc 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -349,6 +349,10 @@ class Interpreter {
return context_.allow_fp32_relax_to_fp16;
}
+ // Owning handle to a TfLiteDelegate instance.
+ using TfLiteDelegatePtr =
+ std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
+
// Allow a delegate to look at the graph and modify the graph to handle
// parts of the graph themselves. After this is called, the graph may
// contain new nodes that replace 1 more nodes.
@@ -574,19 +578,11 @@ class Interpreter {
TfLiteExternalContextType type,
TfLiteExternalContext* ctx);
- using TfLiteDelegatePtr =
- std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
-
// Variant of the public ModifyGraphWithDelegate method that additionally
// Assumes ownership of the provided delegate.
// WARNING: This is an experimental API and subject to change.
- template <typename Delegate>
- TfLiteStatus ModifyGraphWithDelegate(std::unique_ptr<Delegate> typed_delegate,
+ TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegatePtr delegate,
bool allow_dynamic_tensors = false) {
- TfLiteDelegatePtr delegate(typed_delegate.release(),
- [](TfLiteDelegate* delegate) {
- delete static_cast<Delegate*>(delegate);
- });
// Note that we retain ownership of the delegate even if graph modification
// fails, as delegate use will be in an indeterminate state at that point.
owned_delegates_.push_back(std::move(delegate));
@@ -676,6 +672,7 @@ class Interpreter {
// List of delegates that have been installed and are owned by this
// interpreter instance. Useful if client delegate ownership is burdensome.
// WARNING: This is an experimental API and subject to change.
+ // TODO(b/116667551): Use TfLiteExternalContext for storing state.
std::vector<TfLiteDelegatePtr> owned_delegates_;
std::unique_ptr<MemoryPlanner> memory_planner_;
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
index cdede430e2..6c71d5a8d7 100644
--- a/tensorflow/contrib/lite/interpreter_test.cc
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -30,7 +30,11 @@ class InterpreterTest : public ::testing::Test {
template <typename Delegate>
static TfLiteStatus ModifyGraphWithDelegate(
Interpreter* interpreter, std::unique_ptr<Delegate> delegate) {
- return interpreter->ModifyGraphWithDelegate(std::move(delegate));
+ Interpreter::TfLiteDelegatePtr tflite_delegate(
+ delegate.release(), [](TfLiteDelegate* delegate) {
+ delete reinterpret_cast<Delegate*>(delegate);
+ });
+ return interpreter->ModifyGraphWithDelegate(std::move(tflite_delegate));
}
protected:
diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD
index 098ba7e773..e68cd26f81 100644
--- a/tensorflow/contrib/lite/java/BUILD
+++ b/tensorflow/contrib/lite/java/BUILD
@@ -11,6 +11,10 @@ load("//tensorflow/java:build_defs.bzl", "JAVACOPTS")
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_jni_binary")
load("//tensorflow/contrib/lite/java:aar_with_jni.bzl", "aar_with_jni")
+JAVA_SRCS = glob([
+ "src/main/java/org/tensorflow/lite/*.java",
+])
+
# Building tensorflow-lite.aar including 4 variants of .so
# To build an aar for release, run below command:
# bazel build --cxxopt='--std=c++11' -c opt --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \
@@ -20,28 +24,38 @@ aar_with_jni(
android_library = ":tensorflowlite",
)
+# EXPERIMENTAL: AAR target that supports TensorFlow op execution with TFLite.
+aar_with_jni(
+ name = "tensorflow-lite-flex",
+ android_library = ":tensorflowlite_flex",
+)
+
android_library(
name = "tensorflowlite",
- srcs = glob(
- [
- "src/main/java/org/tensorflow/lite/*.java",
- ],
- ),
+ srcs = JAVA_SRCS,
+ manifest = "AndroidManifest.xml",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":tensorflowlite_native",
+ "@org_checkerframework_qual",
+ ],
+)
+
+# EXPERIMENTAL: Android target that supports TensorFlow op execution with TFLite.
+android_library(
+ name = "tensorflowlite_flex",
+ srcs = JAVA_SRCS,
manifest = "AndroidManifest.xml",
visibility = ["//visibility:public"],
deps = [
- ":tflite_runtime",
+ ":tensorflowlite_native_flex",
"@org_checkerframework_qual",
],
)
android_library(
name = "tensorflowlite_java",
- srcs = glob(
- [
- "src/main/java/org/tensorflow/lite/*.java",
- ],
- ),
+ srcs = JAVA_SRCS,
visibility = ["//visibility:public"],
deps = [
"@org_checkerframework_qual",
@@ -50,16 +64,23 @@ android_library(
java_library(
name = "tensorflowlitelib",
- srcs = glob(
- [
- "src/main/java/org/tensorflow/lite/*.java",
- ],
- ),
+ srcs = JAVA_SRCS,
javacopts = JAVACOPTS,
visibility = ["//visibility:public"],
deps = [
":libtensorflowlite_jni.so",
- "//tensorflow/contrib/lite/java/src/main/native",
+ "@org_checkerframework_qual",
+ ],
+)
+
+# EXPERIMENTAL: Java target that supports TensorFlow op execution with TFLite.
+java_library(
+ name = "tensorflowlitelib_flex",
+ srcs = JAVA_SRCS,
+ javacopts = JAVACOPTS,
+ visibility = ["//visibility:public"],
+ deps = [
+ ":libtensorflowlite_flex_jni.so",
"@org_checkerframework_qual",
],
)
@@ -72,7 +93,6 @@ java_test(
tags = ["no_oss"],
test_class = "org.tensorflow.lite.TensorFlowLiteTest",
deps = [
- ":libtensorflowlite_jni.so",
":tensorflowlitelib",
"@com_google_truth",
"@junit",
@@ -87,7 +107,6 @@ java_test(
tags = ["no_oss"],
test_class = "org.tensorflow.lite.DataTypeTest",
deps = [
- ":libtensorflowlite_jni.so",
":tensorflowlitelib",
"@com_google_truth",
"@junit",
@@ -110,7 +129,6 @@ java_test(
tags = ["no_oss"],
test_class = "org.tensorflow.lite.NativeInterpreterWrapperTest",
deps = [
- ":libtensorflowlite_jni.so",
":tensorflowlitelib",
"@com_google_truth",
"@junit",
@@ -125,13 +143,13 @@ java_test(
data = [
"src/testdata/add.bin",
"src/testdata/mobilenet.tflite.bin",
+ "//tensorflow/contrib/lite:testdata/multi_add_flex.bin",
],
javacopts = JAVACOPTS,
tags = ["no_oss"],
test_class = "org.tensorflow.lite.InterpreterTest",
visibility = ["//visibility:private"],
deps = [
- ":libtensorflowlite_jni.so",
":tensorflowlitelib",
"@com_google_truth",
"@junit",
@@ -139,6 +157,24 @@ java_test(
)
java_test(
+ name = "InterpreterFlexTest",
+ size = "small",
+ srcs = ["src/test/java/org/tensorflow/lite/InterpreterFlexTest.java"],
+ data = [
+ "//tensorflow/contrib/lite:testdata/multi_add_flex.bin",
+ ],
+ javacopts = JAVACOPTS,
+ tags = ["no_oss"],
+ test_class = "org.tensorflow.lite.InterpreterFlexTest",
+ visibility = ["//visibility:private"],
+ deps = [
+ ":tensorflowlitelib_flex",
+ "@com_google_truth",
+ "@junit",
+ ],
+)
+
+java_test(
name = "TensorTest",
size = "small",
srcs = ["src/test/java/org/tensorflow/lite/TensorTest.java"],
@@ -164,14 +200,29 @@ filegroup(
)
cc_library(
- name = "tflite_runtime",
+ name = "tensorflowlite_native",
srcs = ["libtensorflowlite_jni.so"],
visibility = ["//visibility:public"],
)
+cc_library(
+ name = "tensorflowlite_native_flex",
+ srcs = ["libtensorflowlite_flex_jni.so"],
+ visibility = ["//visibility:public"],
+)
+
tflite_jni_binary(
name = "libtensorflowlite_jni.so",
deps = [
"//tensorflow/contrib/lite/java/src/main/native",
],
)
+
+# EXPERIMENTAL: Native target that supports TensorFlow op execution with TFLite.
+tflite_jni_binary(
+ name = "libtensorflowlite_flex_jni.so",
+ deps = [
+ "//tensorflow/contrib/lite/delegates/flex:delegate",
+ "//tensorflow/contrib/lite/java/src/main/native",
+ ],
+)
diff --git a/tensorflow/contrib/lite/java/aar_with_jni.bzl b/tensorflow/contrib/lite/java/aar_with_jni.bzl
index 9d2aead266..360d622b1b 100644
--- a/tensorflow/contrib/lite/java/aar_with_jni.bzl
+++ b/tensorflow/contrib/lite/java/aar_with_jni.bzl
@@ -30,7 +30,10 @@ EOF
# In some platforms we don't have an Android SDK/NDK and this target
# can't be built. We need to prevent the build system from trying to
# use the target in that case.
- tags = ["manual"],
+ tags = [
+ "manual",
+ "no_cuda_on_cpu_tap",
+ ],
)
native.genrule(
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java
index 711638a9f9..d5447b3bf8 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/TensorFlowLite.java
@@ -18,7 +18,8 @@ package org.tensorflow.lite;
/** Static utility methods loading the TensorFlowLite runtime. */
public final class TensorFlowLite {
- private static final String LIBNAME = "tensorflowlite_jni";
+ private static final String PRIMARY_LIBNAME = "tensorflowlite_jni";
+ private static final String FALLBACK_LIBNAME = "tensorflowlite_flex_jni";
private TensorFlowLite() {}
@@ -29,13 +30,24 @@ public final class TensorFlowLite {
* Load the TensorFlowLite runtime C library.
*/
static boolean init() {
+ Throwable primaryLibException;
try {
- System.loadLibrary(LIBNAME);
+ System.loadLibrary(PRIMARY_LIBNAME);
return true;
} catch (UnsatisfiedLinkError e) {
- System.err.println("TensorFlowLite: failed to load native library: " + e.getMessage());
- return false;
+ primaryLibException = e;
}
+
+ try {
+ System.loadLibrary(FALLBACK_LIBNAME);
+ return true;
+ } catch (UnsatisfiedLinkError e) {
+ // If the fallback fails, log the error for the primary load instead.
+ System.err.println(
+ "TensorFlowLite: failed to load native library: " + primaryLibException.getMessage());
+ }
+
+ return false;
}
static {
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java
new file mode 100644
index 0000000000..2791c3864b
--- /dev/null
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterFlexTest.java
@@ -0,0 +1,46 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.lite;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import java.io.File;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Unit tests for {@link org.tensorflow.lite.Interpreter} that validate execution with models that
+ * have TensorFlow ops.
+ */
+@RunWith(JUnit4.class)
+public final class InterpreterFlexTest {
+
+ private static final File FLEX_MODEL_FILE =
+ new File("tensorflow/contrib/lite/testdata/multi_add_flex.bin");
+
+ /** Smoke test validating that flex model loading works when the flex delegate is linked. */
+ @Test
+ public void testFlexModel() throws Exception {
+ try (Interpreter interpreter = new Interpreter(FLEX_MODEL_FILE)) {
+ assertThat(interpreter.getInputTensorCount()).isEqualTo(4);
+ assertThat(interpreter.getInputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
+ assertThat(interpreter.getOutputTensorCount()).isEqualTo(4);
+ assertThat(interpreter.getOutputTensor(0).dataType()).isEqualTo(DataType.FLOAT32);
+ interpreter.run(new float[1], new float[1]);
+ }
+ }
+}
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
index a98fca0132..f8b73c7cf3 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/InterpreterTest.java
@@ -43,6 +43,9 @@ public final class InterpreterTest {
private static final File MOBILENET_MODEL_FILE =
new File("tensorflow/contrib/lite/java/src/testdata/mobilenet.tflite.bin");
+ private static final File FLEX_MODEL_FILE =
+ new File("tensorflow/contrib/lite/testdata/multi_add_flex.bin");
+
@Test
public void testInterpreter() throws Exception {
Interpreter interpreter = new Interpreter(MODEL_FILE);
@@ -345,4 +348,15 @@ public final class InterpreterTest {
interpreter.close();
interpreter.close();
}
+
+ /** Smoke test validating that flex model loading fails when the flex delegate is not linked. */
+ @Test
+ public void testFlexModel() throws Exception {
+ try {
+ new Interpreter(FLEX_MODEL_FILE);
+ fail();
+ } catch (IllegalStateException e) {
+ // Expected failure.
+ }
+ }
}
diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD
index daaf6714cc..d2d8073abd 100644
--- a/tensorflow/contrib/lite/kernels/BUILD
+++ b/tensorflow/contrib/lite/kernels/BUILD
@@ -210,6 +210,7 @@ cc_library(
"slice.cc",
"space_to_batch_nd.cc",
"space_to_depth.cc",
+ "sparse_output_fully_connected.cc",
"sparse_to_dense.cc",
"split.cc",
"squeeze.cc",
@@ -233,11 +234,11 @@ cc_library(
":activation_functor",
":eigen_support",
":kernel_util",
+ ":lstm_eval",
":op_macros",
":padding",
"//tensorflow/contrib/lite:framework",
"//tensorflow/contrib/lite:string_util",
- "//tensorflow/contrib/lite:util",
"//tensorflow/contrib/lite/c:c_api_internal",
"//tensorflow/contrib/lite/kernels:gemm_support",
"//tensorflow/contrib/lite/kernels/internal:audio_utils",
@@ -254,6 +255,18 @@ cc_library(
)
cc_library(
+ name = "lstm_eval",
+ srcs = ["lstm_eval.cc"],
+ hdrs = ["lstm_eval.h"],
+ deps = [
+ ":op_macros",
+ "//tensorflow/contrib/lite/c:c_api_internal",
+ "//tensorflow/contrib/lite/kernels/internal:kernel_utils",
+ "//tensorflow/contrib/lite/kernels/internal:tensor_utils",
+ ],
+)
+
+cc_library(
name = "builtin_ops",
srcs = ["register.cc"],
hdrs = ["register.h"],
@@ -334,6 +347,23 @@ tf_cc_test(
)
tf_cc_test(
+ name = "sparse_output_fully_connected_test",
+ size = "small",
+ srcs = ["sparse_output_fully_connected_test.cc"],
+ tags = [
+ "no_oss",
+ "tflite_not_portable_ios",
+ ],
+ deps = [
+ ":builtin_ops",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ "@flatbuffers",
+ ],
+)
+
+tf_cc_test(
name = "activations_test",
size = "small",
srcs = ["activations_test.cc"],
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
index 0532528f52..a326827b1e 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/lstm_eval.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
@@ -694,330 +695,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
-TfLiteStatus EvalFloat(
- const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
- const TfLiteTensor* input_to_forget_weights,
- const TfLiteTensor* input_to_cell_weights,
- const TfLiteTensor* input_to_output_weights,
- const TfLiteTensor* recurrent_to_input_weights,
- const TfLiteTensor* recurrent_to_forget_weights,
- const TfLiteTensor* recurrent_to_cell_weights,
- const TfLiteTensor* recurrent_to_output_weights,
- const TfLiteTensor* cell_to_input_weights,
- const TfLiteTensor* cell_to_forget_weights,
- const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
- const TfLiteTensor* aux_input_to_input_weights,
- const TfLiteTensor* aux_input_to_forget_weights,
- const TfLiteTensor* aux_input_to_cell_weights,
- const TfLiteTensor* aux_input_to_output_weights,
- const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
- const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
- const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
- const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
- TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
- TfLiteTensor* cell_state, TfLiteTensor* output) {
- const int max_time = input->dims->data[0];
- const int n_batch = input->dims->data[1];
- const int n_input = input->dims->data[2];
- const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
-
- // n_cell and n_output will be the same size when there is no projection.
- const int n_cell = input_to_output_weights->dims->data[0];
- const int n_output = recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existense of only one to the get the condition.
- const bool use_cifg = (input_to_input_weights == nullptr);
- const bool use_peephole = (cell_to_output_weights != nullptr);
-
- // Index the scratch buffers pointers to the global scratch buffer.
- float* input_gate_scratch = nullptr;
- float* cell_scratch = nullptr;
- float* forget_gate_scratch = nullptr;
- float* output_gate_scratch = nullptr;
- if (use_cifg) {
- cell_scratch = scratch_buffer->data.f;
- forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- } else {
- input_gate_scratch = scratch_buffer->data.f;
- cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
- forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- const float* input_to_input_weights_ptr =
- (use_cifg) ? nullptr : input_to_input_weights->data.f;
- const float* recurrent_to_input_weights_ptr =
- (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
- const float* input_gate_bias_ptr =
- (use_cifg) ? nullptr : input_gate_bias->data.f;
- const float* cell_to_input_weights_ptr =
- (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
- const float* cell_to_forget_weights_ptr =
- (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
- const float* cell_to_output_weights_ptr =
- (use_peephole) ? cell_to_output_weights->data.f : nullptr;
- const float* projection_weights_ptr =
- (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
- const float* projection_bias_ptr =
- (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
-
- float* aux_input_ptr = nullptr;
- float* aux_input_to_input_weights_ptr = nullptr;
- float* aux_input_to_forget_weights_ptr = nullptr;
- float* aux_input_to_cell_weights_ptr = nullptr;
- float* aux_input_to_output_weights_ptr = nullptr;
- if (aux_input_size > 0) {
- aux_input_ptr = aux_input->data.f;
- aux_input_to_input_weights_ptr = aux_input_to_input_weights->data.f;
- aux_input_to_forget_weights_ptr = aux_input_to_forget_weights->data.f;
- aux_input_to_cell_weights_ptr = aux_input_to_cell_weights->data.f;
- aux_input_to_output_weights_ptr = aux_input_to_output_weights->data.f;
- }
-
- // Loop through the sequence.
- const int input_step = n_batch * n_input;
- const int output_step = n_batch * output->dims->data[2];
- for (int t = 0; t < max_time; t++) {
- // If this is the forward_sequence, step forward, otherwise step backwards.
- const int t_rel = forward_sequence ? t : max_time - t - 1;
- const float* input_ptr = input->data.f + t_rel * input_step;
- float* output_ptr_time =
- output->data.f + t_rel * output_step + output_offset;
-
- kernel_utils::LstmStepWithAuxInput(
- input_ptr, input_to_input_weights_ptr, input_to_forget_weights->data.f,
- input_to_cell_weights->data.f, input_to_output_weights->data.f,
- aux_input_ptr, aux_input_to_input_weights_ptr,
- aux_input_to_forget_weights_ptr, aux_input_to_cell_weights_ptr,
- aux_input_to_output_weights_ptr, recurrent_to_input_weights_ptr,
- recurrent_to_forget_weights->data.f, recurrent_to_cell_weights->data.f,
- recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
- cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
- input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
- output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
- params, n_batch, n_cell, n_input, aux_input_size, n_output,
- activation_state->data.f, cell_state->data.f, input_gate_scratch,
- forget_gate_scratch, cell_scratch, output_gate_scratch,
- output_ptr_time);
- }
- return kTfLiteOk;
-}
-
-TfLiteStatus EvalHybrid(
- const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
- const TfLiteTensor* input_to_forget_weights,
- const TfLiteTensor* input_to_cell_weights,
- const TfLiteTensor* input_to_output_weights,
- const TfLiteTensor* recurrent_to_input_weights,
- const TfLiteTensor* recurrent_to_forget_weights,
- const TfLiteTensor* recurrent_to_cell_weights,
- const TfLiteTensor* recurrent_to_output_weights,
- const TfLiteTensor* cell_to_input_weights,
- const TfLiteTensor* cell_to_forget_weights,
- const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
- const TfLiteTensor* aux_input_to_input_weights,
- const TfLiteTensor* aux_input_to_forget_weights,
- const TfLiteTensor* aux_input_to_cell_weights,
- const TfLiteTensor* aux_input_to_output_weights,
- const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
- const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
- const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
- const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
- TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
- TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
- TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
- TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
- TfLiteTensor* output_state, TfLiteTensor* cell_state,
- TfLiteTensor* output) {
- const int max_time = input->dims->data[0];
- const int n_batch = input->dims->data[1];
- const int n_input = input->dims->data[2];
- const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
- // n_cell and n_output will be the same size when there is no projection.
- const int n_cell = input_to_output_weights->dims->data[0];
- const int n_output = recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existence of only one to get the condition.
- const bool use_cifg = (input_to_input_weights == nullptr);
- const bool use_peephole = (cell_to_output_weights != nullptr);
-
- float* input_gate_scratch = nullptr;
- float* cell_scratch = nullptr;
- float* forget_gate_scratch = nullptr;
- float* output_gate_scratch = nullptr;
- if (use_cifg) {
- cell_scratch = scratch_buffer->data.f;
- forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- } else {
- input_gate_scratch = scratch_buffer->data.f;
- cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
- forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- int8_t* input_to_input_weights_ptr = nullptr;
- float input_to_input_weights_scale = 1.0f;
- int8_t* recurrent_to_input_weights_ptr = nullptr;
- float recurrent_to_input_weights_scale = 1.0f;
- float* input_gate_bias_ptr = nullptr;
- if (!use_cifg) {
- input_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
- recurrent_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
- input_gate_bias_ptr = input_gate_bias->data.f;
- input_to_input_weights_scale = input_to_input_weights->params.scale;
- recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
- }
-
- int8_t* cell_to_input_weights_ptr = nullptr;
- int8_t* cell_to_forget_weights_ptr = nullptr;
- int8_t* cell_to_output_weights_ptr = nullptr;
- float cell_to_input_weights_scale = 1.0f;
- float cell_to_forget_weights_scale = 1.0f;
- float cell_to_output_weights_scale = 1.0f;
- if (use_peephole) {
- if (!use_cifg) {
- cell_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
- cell_to_input_weights_scale = cell_to_input_weights->params.scale;
- }
- cell_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
- cell_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
- cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
- cell_to_output_weights_scale = cell_to_output_weights->params.scale;
- }
-
- const int8_t* projection_weights_ptr =
- (projection_weights == nullptr)
- ? nullptr
- : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
- const float projection_weights_scale =
- (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
- const float* projection_bias_ptr =
- (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
-
- // Required tensors, pointers are non-null.
- const int8_t* input_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
- const float input_to_forget_weights_scale =
- input_to_forget_weights->params.scale;
- const int8_t* input_to_cell_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
- const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
- const int8_t* input_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
- const float input_to_output_weights_scale =
- input_to_output_weights->params.scale;
- const int8_t* recurrent_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
- const float recurrent_to_forget_weights_scale =
- recurrent_to_forget_weights->params.scale;
- const int8_t* recurrent_to_cell_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
- const float recurrent_to_cell_weights_scale =
- recurrent_to_cell_weights->params.scale;
- const int8_t* recurrent_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
- const float recurrent_to_output_weights_scale =
- recurrent_to_output_weights->params.scale;
- const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
- const float* cell_bias_ptr = cell_bias->data.f;
- const float* output_gate_bias_ptr = output_gate_bias->data.f;
-
- float* output_state_ptr = output_state->data.f;
- float* cell_state_ptr = cell_state->data.f;
-
- // Temporary storage for quantized values and scaling factors.
- int8_t* quantized_input_ptr =
- reinterpret_cast<int8_t*>(input_quantized->data.uint8);
- int8_t* quantized_aux_input_ptr =
- (aux_input_quantized == nullptr)
- ? nullptr
- : reinterpret_cast<int8_t*>(aux_input_quantized->data.uint8);
- int8_t* quantized_output_state_ptr =
- reinterpret_cast<int8_t*>(output_state_quantized->data.uint8);
- int8_t* quantized_cell_state_ptr =
- reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
- float* scaling_factors_ptr = scaling_factors->data.f;
- float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
- float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
-
- // Auxiliary input and weights.
- float* aux_input_ptr = nullptr;
- int8_t* aux_input_to_input_weights_ptr = nullptr;
- int8_t* aux_input_to_forget_weights_ptr = nullptr;
- int8_t* aux_input_to_cell_weights_ptr = nullptr;
- int8_t* aux_input_to_output_weights_ptr = nullptr;
- float aux_input_to_input_weights_scale = 0.0f;
- float aux_input_to_forget_weights_scale = 0.0f;
- float aux_input_to_cell_weights_scale = 0.0f;
- float aux_input_to_output_weights_scale = 0.0f;
- if (aux_input_size > 0) {
- aux_input_ptr = aux_input->data.f;
- aux_input_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(aux_input_to_input_weights->data.uint8);
- aux_input_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(aux_input_to_forget_weights->data.uint8);
- aux_input_to_cell_weights_ptr =
- reinterpret_cast<int8_t*>(aux_input_to_cell_weights->data.uint8);
- aux_input_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(aux_input_to_output_weights->data.uint8);
- aux_input_to_input_weights_scale = aux_input_to_input_weights->params.scale;
- aux_input_to_forget_weights_scale =
- aux_input_to_forget_weights->params.scale;
- aux_input_to_cell_weights_scale = aux_input_to_cell_weights->params.scale;
- aux_input_to_output_weights_scale =
- aux_input_to_output_weights->params.scale;
- }
-
- // Feed the sequence into the LSTM step-by-step.
- const int input_step = n_batch * n_input;
- const int output_step = n_batch * output->dims->data[2];
- for (int t = 0; t < max_time; t++) {
- // If this is the forward_sequence, step forward, otherwise step backwards.
- const int t_rel = forward_sequence ? t : max_time - t - 1;
- const float* input_ptr = input->data.f + t_rel * input_step;
- float* output_ptr = output->data.f + t_rel * output_step + output_offset;
-
- kernel_utils::LstmStepWithAuxInput(
- input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
- input_to_forget_weights_ptr, input_to_forget_weights_scale,
- input_to_cell_weights_ptr, input_to_cell_weights_scale,
- input_to_output_weights_ptr, input_to_output_weights_scale,
- aux_input_ptr, aux_input_to_input_weights_ptr,
- aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
- aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
- aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
- aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
- recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
- recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
- recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
- recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
- cell_to_input_weights_scale, cell_to_forget_weights_ptr,
- cell_to_forget_weights_scale, cell_to_output_weights_ptr,
- cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr,
- cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
- projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell,
- n_input, aux_input_size, n_output, input_gate_scratch,
- forget_gate_scratch, cell_scratch, output_gate_scratch,
- scaling_factors_ptr, prod_scaling_factors_ptr,
- recovered_cell_weights_ptr, quantized_input_ptr,
- quantized_aux_input_ptr, quantized_output_state_ptr,
- quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, output_ptr);
- }
-
- return kTfLiteOk;
-}
-
// The LSTM Op engine.
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
@@ -1157,7 +834,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
switch (fw_input_to_output_weights->type) {
case kTfLiteFloat32: {
- TfLiteStatus fw_pass_status = EvalFloat(
+ TfLiteStatus fw_pass_status = lstm_eval::EvalFloat(
input, fw_input_to_input_weights, fw_input_to_forget_weights,
fw_input_to_cell_weights, fw_input_to_output_weights,
fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
@@ -1172,7 +849,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
fw_activation_state, fw_cell_state, fw_output);
TF_LITE_ENSURE_OK(context, fw_pass_status);
- TfLiteStatus bw_pass_status = EvalFloat(
+ TfLiteStatus bw_pass_status = lstm_eval::EvalFloat(
input, bw_input_to_input_weights, bw_input_to_forget_weights,
bw_input_to_cell_weights, bw_input_to_output_weights,
bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
@@ -1208,7 +885,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* recovered_cell_weights =
GetTemporary(context, node, kRecoveredCellWeights);
- TfLiteStatus fw_pass_status = EvalHybrid(
+ TfLiteStatus fw_pass_status = lstm_eval::EvalHybrid(
input, fw_input_to_input_weights, fw_input_to_forget_weights,
fw_input_to_cell_weights, fw_input_to_output_weights,
fw_recurrent_to_input_weights, fw_recurrent_to_forget_weights,
@@ -1226,7 +903,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
fw_output);
TF_LITE_ENSURE_OK(context, fw_pass_status);
- TfLiteStatus bw_pass_status = EvalHybrid(
+ TfLiteStatus bw_pass_status = lstm_eval::EvalHybrid(
input, bw_input_to_input_weights, bw_input_to_forget_weights,
bw_input_to_cell_weights, bw_input_to_output_weights,
bw_recurrent_to_input_weights, bw_recurrent_to_forget_weights,
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
index 9f62ac3f2c..c22a457a71 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
@@ -113,6 +113,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// input configuration.
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
const int batch_size = input->dims->data[0];
const int max_time = input->dims->data[1];
const int fw_num_units = fw_input_weights->dims->data[0];
diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc
index f765235e04..3926af5b97 100644
--- a/tensorflow/contrib/lite/kernels/comparisons.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons.cc
@@ -66,31 +66,25 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
if (input1->type == kTfLiteUInt8) { \
auto input1_offset = -input1->params.zero_point; \
auto input2_offset = -input2->params.zero_point; \
- const int left_shift = 20; \
- const double twice_max_input_scale = \
- 2 * std::max(input1->params.scale, input2->params.scale); \
- const double real_input1_multiplier = \
- input1->params.scale / twice_max_input_scale; \
- const double real_input2_multiplier = \
- input2->params.scale / twice_max_input_scale; \
+ const int left_shift = 8; \
\
int32 input1_multiplier; \
int input1_shift; \
- QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, \
+ QuantizeMultiplierSmallerThanOneExp(input1->params.scale, \
&input1_multiplier, &input1_shift); \
int32 input2_multiplier; \
int input2_shift; \
- QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, \
+ QuantizeMultiplierSmallerThanOneExp(input2->params.scale, \
&input2_multiplier, &input2_shift); \
\
ComparisonParams op_params; \
op_params.left_shift = left_shift; \
op_params.input1_offset = input1_offset; \
op_params.input1_multiplier = input1_multiplier; \
- op_params.input1_shift = -input1_shift; \
+ op_params.input1_shift = input1_shift; \
op_params.input2_offset = input2_offset; \
op_params.input2_multiplier = input2_multiplier; \
- op_params.input2_shift = -input2_shift; \
+ op_params.input2_shift = input2_shift; \
if (requires_broadcast) { \
reference_ops::Broadcast4DSlow##opname##WithScaling( \
op_params, GetTensorShape(input1), GetTensorData<uint8_t>(input1), \
diff --git a/tensorflow/contrib/lite/kernels/comparisons_test.cc b/tensorflow/contrib/lite/kernels/comparisons_test.cc
index 67a91c17fd..04c8bf2e30 100644
--- a/tensorflow/contrib/lite/kernels/comparisons_test.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons_test.cc
@@ -402,6 +402,17 @@ TEST(ComparisonsTest, GreaterQuantized) {
EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
}
+TEST(ComparisonsTest, GreaterQuantizedSmallRange) {
+ ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, 0.0, 1.0},
+ {TensorType_UINT8, {1, 2, 2, 1}, 0.0, 2.0},
+ TensorType_UINT8, BuiltinOperator_GREATER);
+ model.QuantizeAndPopulate<uint8_t>(model.input1(), {1.0, 0.5, 0.35, 0.1});
+ model.QuantizeAndPopulate<uint8_t>(model.input2(), {1.01, 0.25, 0.3, 0.4});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
+}
+
TEST(ComparisonsTest, GreaterEqualQuantized) {
const float kMin = -1.f;
const float kMax = 128.f;
diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD
index afb5ec05df..5c9ca6e910 100644
--- a/tensorflow/contrib/lite/kernels/internal/BUILD
+++ b/tensorflow/contrib/lite/kernels/internal/BUILD
@@ -49,6 +49,20 @@ cc_library(
],
)
+cc_library(
+ name = "legacy_types",
+ srcs = [],
+ hdrs = [
+ "compatibility.h",
+ "legacy_types.h",
+ "types.h",
+ ],
+ deps = [
+ "//tensorflow/contrib/lite/kernels:op_macros",
+ "@com_google_absl//absl/base:core_headers",
+ ],
+)
+
config_setting(
name = "arm",
values = {
@@ -198,6 +212,7 @@ cc_library(
":strided_slice_logic",
":tensor_utils",
":types",
+ ":legacy_types",
":legacy_reference_base",
":round",
"//third_party/eigen3",
@@ -336,6 +351,7 @@ cc_library(
":quantization_util",
":round",
":strided_slice_logic",
+ ":legacy_types",
":types",
"@gemmlowp",
"//tensorflow/contrib/lite/c:c_api_internal",
diff --git a/tensorflow/contrib/lite/kernels/internal/compatibility.h b/tensorflow/contrib/lite/kernels/internal/compatibility.h
index b87cf2b60d..7c176e0fa1 100644
--- a/tensorflow/contrib/lite/kernels/internal/compatibility.h
+++ b/tensorflow/contrib/lite/kernels/internal/compatibility.h
@@ -84,4 +84,27 @@ using uint16 = std::uint16_t;
using int32 = std::int32_t;
using uint32 = std::uint32_t;
+// TFLITE_DEPRECATED()
+//
+// Duplicated from absl/base/macros.h to avoid pulling in that library.
+// Marks a deprecated class, struct, enum, function, method and variable
+// declarations. The macro argument is used as a custom diagnostic message (e.g.
+// suggestion of a better alternative).
+//
+// Example:
+//
+// class TFLITE_DEPRECATED("Use Bar instead") Foo {...};
+// TFLITE_DEPRECATED("Use Baz instead") void Bar() {...}
+//
+// Every usage of a deprecated entity will trigger a warning when compiled with
+// clang's `-Wdeprecated-declarations` option. This option is turned off by
+// default, but the warnings will be reported by clang-tidy.
+#if defined(__clang__) && __cplusplus >= 201103L
+#define TFLITE_DEPRECATED(message) __attribute__((deprecated(message)))
+#endif
+
+#ifndef TFLITE_DEPRECATED
+#define TFLITE_DEPRECATED(message)
+#endif
+
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
index 56e9367878..083e5839bd 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc
@@ -169,603 +169,5 @@ void RnnBatchStep(
hidden_state_ptr_batch);
}
-void LstmStep(
- const float* input_ptr_batch, const float* input_to_input_weights_ptr,
- const float* input_to_forget_weights_ptr,
- const float* input_to_cell_weights_ptr,
- const float* input_to_output_weights_ptr,
- const float* recurrent_to_input_weights_ptr,
- const float* recurrent_to_forget_weights_ptr,
- const float* recurrent_to_cell_weights_ptr,
- const float* recurrent_to_output_weights_ptr,
- const float* cell_to_input_weights_ptr,
- const float* cell_to_forget_weights_ptr,
- const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const float* projection_weights_ptr,
- const float* projection_bias_ptr, const TfLiteLSTMParams* params,
- int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr,
- float* cell_state_ptr, float* input_gate_scratch,
- float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
- float* output_ptr_batch) {
- LstmStepWithAuxInput(
- input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr,
- input_to_cell_weights_ptr, input_to_output_weights_ptr,
- /*aux_input_ptr_batch=*/nullptr,
- /*aux_input_to_input_weights_ptr=*/nullptr,
- /*aux_input_to_forget_weights_ptr=*/nullptr,
- /*aux_input_to_cell_weights_ptr=*/nullptr,
- /*aux_input_to_output_weights_ptr=*/nullptr,
- recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
- recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
- cell_to_input_weights_ptr, cell_to_forget_weights_ptr,
- cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr,
- cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
- projection_bias_ptr, params, n_batch, n_cell, n_input, /*n_aux_input=*/0,
- n_output, output_state_ptr, cell_state_ptr, input_gate_scratch,
- forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr_batch);
-}
-
-void LstmStepWithAuxInput(
- const float* input_ptr_batch, const float* input_to_input_weights_ptr,
- const float* input_to_forget_weights_ptr,
- const float* input_to_cell_weights_ptr,
- const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch,
- const float* aux_input_to_input_weights_ptr,
- const float* aux_input_to_forget_weights_ptr,
- const float* aux_input_to_cell_weights_ptr,
- const float* aux_input_to_output_weights_ptr,
- const float* recurrent_to_input_weights_ptr,
- const float* recurrent_to_forget_weights_ptr,
- const float* recurrent_to_cell_weights_ptr,
- const float* recurrent_to_output_weights_ptr,
- const float* cell_to_input_weights_ptr,
- const float* cell_to_forget_weights_ptr,
- const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const float* projection_weights_ptr,
- const float* projection_bias_ptr, const TfLiteLSTMParams* params,
- int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
- float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
- float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
- float* output_ptr_batch) {
- // Since we have already checked that weights are all there or none, we can
- // check the existense of only one to the get the condition.
- const bool use_cifg = (input_to_input_weights_ptr == nullptr);
- const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
- // Initialize scratch buffers with bias.
- if (!use_cifg) {
- tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch,
- input_gate_scratch);
- }
- tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
- forget_gate_scratch);
- tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
- cell_scratch);
- tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
- output_gate_scratch);
-
- // For each batch and cell: compute input_weight * input.
- if (!use_cifg) {
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- input_gate_scratch, /*result_stride=*/1);
- }
-
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- forget_gate_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- cell_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- output_gate_scratch, /*result_stride=*/1);
-
- // If auxiliary input is available then compute aux_input_weight * aux_input
- if (aux_input_ptr_batch != nullptr) {
- if (!use_cifg) {
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_input_weights_ptr, n_cell, n_aux_input,
- aux_input_ptr_batch, n_batch, input_gate_scratch,
- /*result_stride=*/1);
- }
-
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
- aux_input_ptr_batch, n_batch, forget_gate_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr_batch,
- n_batch, cell_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_output_weights_ptr, n_cell, n_aux_input,
- aux_input_ptr_batch, n_batch, output_gate_scratch, /*result_stride=*/1);
- }
-
- // For each batch and cell: compute recurrent_weight * output_state.
- if (!use_cifg) {
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, input_gate_scratch, /*result_stride=*/1);
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, forget_gate_scratch,
- /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, cell_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, output_gate_scratch,
- /*result_stride=*/1);
-
- // For each batch and cell: update input gate.
- if (!use_cifg) {
- if (use_peephole) {
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
- input_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
- input_gate_scratch);
- }
-
- // For each batch and cell: update forget gate.
- if (use_peephole) {
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
- forget_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
- forget_gate_scratch);
-
- // For each batch and cell: update the cell.
- tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
- n_batch * n_cell, cell_state_ptr);
- tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
- params->activation, cell_scratch);
- if (use_cifg) {
- tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
- forget_gate_scratch);
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
- } else {
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
- }
- if (params->cell_clip > 0.0) {
- tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
- params->cell_clip, cell_state_ptr);
- }
-
- // For each batch and cell: update the output gate.
- if (use_peephole) {
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
- output_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
- output_gate_scratch);
- tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
- params->activation, cell_scratch);
- tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
- n_batch * n_cell, output_gate_scratch);
-
- // For each batch: update the projection and output_state.
- const bool use_projection_weight = (projection_weights_ptr != nullptr);
- const bool use_projection_bias = (projection_bias_ptr != nullptr);
- if (use_projection_weight) {
- if (use_projection_bias) {
- tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
- n_batch, output_ptr_batch);
- } else {
- tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch,
- output_ptr_batch, /*result_stride=*/1);
- if (params->proj_clip > 0.0) {
- tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
- params->proj_clip, output_ptr_batch);
- }
- } else {
- tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
- output_ptr_batch);
- }
- tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
- output_state_ptr);
-}
-
-void LstmStep(
- const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
- float input_to_input_weights_scale,
- const int8_t* input_to_forget_weights_ptr,
- float input_to_forget_weights_scale,
- const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
- const int8_t* input_to_output_weights_ptr,
- float input_to_output_weights_scale,
- const int8_t* recurrent_to_input_weights_ptr,
- float recurrent_to_input_weights_scale,
- const int8_t* recurrent_to_forget_weights_ptr,
- float recurrent_to_forget_weights_scale,
- const int8_t* recurrent_to_cell_weights_ptr,
- float recurrent_to_cell_weights_scale,
- const int8_t* recurrent_to_output_weights_ptr,
- float recurrent_to_output_weights_scale,
- const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
- const int8_t* cell_to_forget_weights_ptr,
- float cell_to_forget_weights_scale,
- const int8_t* cell_to_output_weights_ptr,
- float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
- float projection_weights_scale, const float* projection_bias_ptr,
- const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
- int n_output, float* input_gate_scratch, float* forget_gate_scratch,
- float* cell_scratch, float* output_gate_scratch, float* scaling_factors,
- float* product_scaling_factors, float* recovered_cell_weights,
- int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr,
- int8_t* quantized_cell_state_ptr, float* output_state_ptr,
- float* cell_state_ptr, float* output_ptr_batch) {
- LstmStepWithAuxInput(
- input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale,
- input_to_forget_weights_ptr, input_to_forget_weights_scale,
- input_to_cell_weights_ptr, input_to_cell_weights_scale,
- input_to_output_weights_ptr, input_to_output_weights_scale,
- /*aux_input_ptr_batch=*/nullptr,
- /*aux_input_to_input_weights_ptr=*/nullptr,
- /*aux_input_to_input_weights_scale=*/0.0f,
- /*aux_input_to_forget_weights_ptr=*/nullptr,
- /*aux_input_to_forget_weights_scale=*/0.0f,
- /*aux_input_to_cell_weights_ptr=*/nullptr,
- /*aux_input_to_cell_weights_scale=*/0.0f,
- /*aux_input_to_output_weights_ptr=*/nullptr,
- /*aux_input_to_output_weights_scale=*/0.0f,
- recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale,
- recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale,
- recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale,
- recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale,
- cell_to_input_weights_ptr, cell_to_input_weights_scale,
- cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
- cell_to_output_weights_ptr, cell_to_output_weights_scale,
- input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
- output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale,
- projection_bias_ptr, params, n_batch, n_cell, n_input,
- /*n_aux_input=*/0, n_output, input_gate_scratch, forget_gate_scratch,
- cell_scratch, output_gate_scratch, scaling_factors,
- product_scaling_factors, recovered_cell_weights,
- quantized_input_ptr_batch,
- /*quantized_aux_input_ptr_batch=*/nullptr, quantized_output_state_ptr,
- quantized_cell_state_ptr, output_state_ptr, cell_state_ptr,
- output_ptr_batch);
- }
-
- void LstmStepWithAuxInput(
- const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
- float input_to_input_weights_scale,
- const int8_t* input_to_forget_weights_ptr,
- float input_to_forget_weights_scale,
- const int8_t* input_to_cell_weights_ptr,
- float input_to_cell_weights_scale,
- const int8_t* input_to_output_weights_ptr,
- float input_to_output_weights_scale, const float* aux_input_ptr_batch,
- const int8_t* aux_input_to_input_weights_ptr,
- float aux_input_to_input_weights_scale,
- const int8_t* aux_input_to_forget_weights_ptr,
- float aux_input_to_forget_weights_scale,
- const int8_t* aux_input_to_cell_weights_ptr,
- float aux_input_to_cell_weights_scale,
- const int8_t* aux_input_to_output_weights_ptr,
- float aux_input_to_output_weights_scale,
- const int8_t* recurrent_to_input_weights_ptr,
- float recurrent_to_input_weights_scale,
- const int8_t* recurrent_to_forget_weights_ptr,
- float recurrent_to_forget_weights_scale,
- const int8_t* recurrent_to_cell_weights_ptr,
- float recurrent_to_cell_weights_scale,
- const int8_t* recurrent_to_output_weights_ptr,
- float recurrent_to_output_weights_scale,
- const int8_t* cell_to_input_weights_ptr,
- float cell_to_input_weights_scale,
- const int8_t* cell_to_forget_weights_ptr,
- float cell_to_forget_weights_scale,
- const int8_t* cell_to_output_weights_ptr,
- float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
- float projection_weights_scale, const float* projection_bias_ptr,
- const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
- int n_aux_input, int n_output, float* input_gate_scratch,
- float* forget_gate_scratch, float* cell_scratch,
- float* output_gate_scratch, float* scaling_factors,
- float* product_scaling_factors, float* recovered_cell_weights,
- int8_t* quantized_input_ptr_batch,
- int8_t* quantized_aux_input_ptr_batch,
- int8_t* quantized_output_state_ptr, int8_t* quantized_cell_state_ptr,
- float* output_state_ptr, float* cell_state_ptr,
- float* output_ptr_batch) {
- // Since we have already checked that weights are all there or none, we
- // can check the existense of only one to the get the condition.
- const bool use_cifg = (input_to_input_weights_ptr == nullptr);
- const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
- // Initialize scratch buffers with bias.
- if (!use_cifg) {
- tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell,
- n_batch, input_gate_scratch);
- }
- tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell,
- n_batch, forget_gate_scratch);
- tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
- cell_scratch);
- tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell,
- n_batch, output_gate_scratch);
-
- if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
- // Save quantization and matmul computation for all zero input.
- float unused_min, unused_max;
- for (int b = 0; b < n_batch; ++b) {
- const int offset = b * n_input;
- tensor_utils::SymmetricQuantizeFloats(
- input_ptr_batch + offset, n_input,
- quantized_input_ptr_batch + offset, &unused_min, &unused_max,
- &scaling_factors[b]);
- }
- // For each batch and cell: compute input_weight * input.
- if (!use_cifg) {
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_input_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_input_weights_ptr, n_cell, n_input,
- quantized_input_ptr_batch, product_scaling_factors, n_batch,
- input_gate_scratch, /*result_stride=*/1);
- }
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_forget_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_forget_weights_ptr, n_cell, n_input,
- quantized_input_ptr_batch, product_scaling_factors, n_batch,
- forget_gate_scratch,
- /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_cell_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_cell_weights_ptr, n_cell, n_input,
- quantized_input_ptr_batch, product_scaling_factors, n_batch,
- cell_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_output_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_output_weights_ptr, n_cell, n_input,
- quantized_input_ptr_batch, product_scaling_factors, n_batch,
- output_gate_scratch,
- /*result_stride=*/1);
- }
-
- if (aux_input_ptr_batch != nullptr &&
- !tensor_utils::IsZeroVector(aux_input_ptr_batch, n_batch * n_input)) {
- // Save quantization and matmul computation for all zero input.
- float unused_min, unused_max;
- for (int b = 0; b < n_batch; ++b) {
- const int offset = b * n_input;
- tensor_utils::SymmetricQuantizeFloats(
- aux_input_ptr_batch + offset, n_input,
- quantized_aux_input_ptr_batch + offset, &unused_min, &unused_max,
- &scaling_factors[b]);
- }
- // For each batch and cell: compute input_weight * input.
- if (!use_cifg) {
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * aux_input_to_input_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_input_weights_ptr, n_cell, n_input,
- quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
- input_gate_scratch, /*result_stride=*/1);
- }
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * aux_input_to_forget_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_forget_weights_ptr, n_cell, n_input,
- quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
- forget_gate_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * aux_input_to_cell_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_cell_weights_ptr, n_cell, n_input,
- quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
- cell_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * aux_input_to_output_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- aux_input_to_output_weights_ptr, n_cell, n_input,
- quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
- output_gate_scratch, /*result_stride=*/1);
- }
-
- if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
- // Save quantization and matmul computation for all zero input.
- float unused_min, unused_max;
- for (int b = 0; b < n_batch; ++b) {
- const int offset = b * n_output;
- tensor_utils::SymmetricQuantizeFloats(
- output_state_ptr + offset, n_output,
- quantized_output_state_ptr + offset, &unused_min, &unused_max,
- &scaling_factors[b]);
- }
- // For each batch and cell: compute recurrent_weight * output_state.
- if (!use_cifg) {
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_input_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_input_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- input_gate_scratch, /*result_stride=*/1);
- }
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_forget_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_forget_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- forget_gate_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_cell_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_cell_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- cell_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_output_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_output_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- output_gate_scratch, /*result_stride=*/1);
- }
-
- // Save quantization and matmul computation for all zero input.
- bool is_cell_state_all_zeros =
- tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
-
- // For each batch and cell: update input gate.
- if (!use_cifg) {
- if (use_peephole && !is_cell_state_all_zeros) {
- tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell,
- cell_to_input_weights_scale,
- recovered_cell_weights);
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
- input_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
- input_gate_scratch);
- }
-
- // For each batch and cell: update forget gate.
- if (use_peephole && !is_cell_state_all_zeros) {
- tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell,
- cell_to_forget_weights_scale,
- recovered_cell_weights);
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
- forget_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
- forget_gate_scratch);
-
- // For each batch and cell: update the cell.
- tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch,
- cell_state_ptr, n_batch * n_cell,
- cell_state_ptr);
- tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
- params->activation, cell_scratch);
- if (use_cifg) {
- tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
- forget_gate_scratch);
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, forget_gate_scratch, n_batch * n_cell,
- cell_state_ptr);
- } else {
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
- }
- if (params->cell_clip > 0.0) {
- tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
- params->cell_clip, cell_state_ptr);
- }
-
- is_cell_state_all_zeros =
- tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
- // For each batch and cell: update the output gate.
- if (use_peephole && !is_cell_state_all_zeros) {
- tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell,
- cell_to_output_weights_scale,
- recovered_cell_weights);
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
- output_gate_scratch);
- }
- tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
- output_gate_scratch);
- tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
- params->activation, cell_scratch);
- tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
- n_batch * n_cell,
- output_gate_scratch);
-
- // For each batch: update the projection and output_state.
- const bool use_projection_weight = (projection_weights_ptr != nullptr);
- const bool use_projection_bias = (projection_bias_ptr != nullptr);
- if (use_projection_weight) {
- if (use_projection_bias) {
- tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
- n_batch, output_ptr_batch);
- } else {
- tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
- }
- if (!tensor_utils::IsZeroVector(output_gate_scratch,
- n_batch * n_cell)) {
- // Save quantization and matmul computation for all zero input.
- float unused_min, unused_max;
- for (int b = 0; b < n_batch; ++b) {
- const int offset = b * n_cell;
- tensor_utils::SymmetricQuantizeFloats(
- output_gate_scratch + offset, n_cell,
- quantized_cell_state_ptr + offset, &unused_min, &unused_max,
- &scaling_factors[b]);
- }
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * projection_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- projection_weights_ptr, n_output, n_cell,
- quantized_cell_state_ptr, product_scaling_factors, n_batch,
- output_ptr_batch,
- /*result_stride=*/1);
- }
- if (params->proj_clip > 0.0) {
- tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
- params->proj_clip, output_ptr_batch);
- }
- } else {
- tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
- output_ptr_batch);
- }
- tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
- output_state_ptr);
- }
-
} // namespace kernel_utils
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
index b5558cce55..74e0a4a53d 100644
--- a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
+++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h
@@ -76,190 +76,6 @@ void RnnBatchStep(
int8_t* quantized_hidden_state_ptr_batch, float* scaling_factors,
float* hidden_state_ptr_batch, float* output_ptr_batch);
-// Performs an LSTM batch inference step for input specified by input_ptr_batch.
-// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
-// biases (*_bias_ptr), and buffers (*_scratch), along with additional
-// parameters:
-// - params: various LSTM params including activation, clipping, etc.,
-// - n_batch: size of batch,
-// - n_cell: number of cells (or units),
-// - n_input: the input size,
-// - n_output: the output size.
-//
-// The pointers to the cell and output state and the output are updated.
-//
-// The pointers with the suffix "_batch" point to data aligned in batch_major
-// order, and each step processes batch_size many inputs from input_ptr_batch,
-// and updates batch_size many cell and output states.
-void LstmStep(
- const float* input_ptr_batch, const float* input_to_input_weights_ptr,
- const float* input_to_forget_weights_ptr,
- const float* input_to_cell_weights_ptr,
- const float* input_to_output_weights_ptr,
- const float* recurrent_to_input_weights_ptr,
- const float* recurrent_to_forget_weights_ptr,
- const float* recurrent_to_cell_weights_ptr,
- const float* recurrent_to_output_weights_ptr,
- const float* cell_to_input_weights_ptr,
- const float* cell_to_forget_weights_ptr,
- const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const float* projection_weights_ptr,
- const float* projection_bias_ptr, const TfLiteLSTMParams* params,
- int n_batch, int n_cell, int n_input, int n_output, float* output_state_ptr,
- float* cell_state_ptr, float* input_gate_scratch,
- float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
- float* output_ptr_batch);
-
-// Same as above but includes an auxiliary input with the corresponding weights.
-void LstmStepWithAuxInput(
- const float* input_ptr_batch, const float* input_to_input_weights_ptr,
- const float* input_to_forget_weights_ptr,
- const float* input_to_cell_weights_ptr,
- const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch,
- const float* aux_input_to_input_weights_ptr,
- const float* aux_input_to_forget_weights_ptr,
- const float* aux_input_to_cell_weights_ptr,
- const float* aux_input_to_output_weights_ptr,
- const float* recurrent_to_input_weights_ptr,
- const float* recurrent_to_forget_weights_ptr,
- const float* recurrent_to_cell_weights_ptr,
- const float* recurrent_to_output_weights_ptr,
- const float* cell_to_input_weights_ptr,
- const float* cell_to_forget_weights_ptr,
- const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const float* projection_weights_ptr,
- const float* projection_bias_ptr, const TfLiteLSTMParams* params,
- int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
- float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
- float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
- float* output_ptr_batch);
-
-// Same as above but with quantized weight matrices. In detail:
-// Input of size 'n_batch * n_input':
-// input_ptr_batch
-//
-// LSTM weights:
-// Quantized input weights of size 'n_cell * n_input':
-// input_to_input_weights - optional (can be nullptr)
-// input_to_forget_weights
-// input_to_cell_weights
-// input_to_input_weights
-// Quantized recurrent weights of size 'n_cell * n_output':
-// recurrent_to_input_weights - optional
-// recurrent_to_forget_weights
-// recurrent_to_cell_weights
-// recurrent_to_input_weights
-// Quantized peephole weights of size 'n_cell', representing diagonal matrices.
-// cell_to_input_weights - optional
-// cell_to_cell_weights - optional
-// cell_to_output_weights - optional
-// Quantized projection weights of size 'n_output * n_cell'
-// projection_weights_ptr - optional
-// Weight scales (scalars) for each of the weights above.
-// input_to_input_weights_scale - optional
-// input_to_forget_weights_scale
-// input_to_cell_weights_scale
-// input_to_output_weights_scale
-// recurrent_to_input_weights_scale - optional
-// recurrent_to_forget_weights_scale
-// recurrent_to_cell_weights_scale
-// recurrent_to_output_weights_scale
-// cell_to_input_weights_scale,
-// cell_to_forget_weights_scale,
-// cell_to_output_weights_scale,
-// projection_weights_scale - optional
-// Gate biases of size 'n_cell':
-// input_gate_bias_ptr - optional
-// forget_gate_bias_ptr
-// cell_gate_bias_ptr
-// output_gate_bias_ptr
-//
-// Temporary pre-allocated storage for quantized values:
-// quantized_input_ptr_batch (same size as input_ptr_batch)
-// quantized_output_state_ptr (same size as output_state_ptr)
-// quantized_cell_state_ptr (same size as cell_state_ptr)
-// Temporary pre-allocated storage for recovered values:
-// recovered_cell_weights (same size as cell_to_*_weights)
-//
-// Outputs:
-// output_state_ptr - size 'n_batch * n_output'
-// cell_state_ptr - size 'n_batch * n_cell'
-// output_ptr_batch - size 'n_batch * n_output'
-void LstmStep(
- const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
- float input_to_input_weights_scale,
- const int8_t* input_to_forget_weights_ptr,
- float input_to_forget_weights_scale,
- const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
- const int8_t* input_to_output_weights_ptr,
- float input_to_output_weights_scale,
- const int8_t* recurrent_to_input_weights_ptr,
- float recurrent_to_input_weights_scale,
- const int8_t* recurrent_to_forget_weights_ptr,
- float recurrent_to_forget_weights_scale,
- const int8_t* recurrent_to_cell_weights_ptr,
- float recurrent_to_cell_weights_scale,
- const int8_t* recurrent_to_output_weights_ptr,
- float recurrent_to_output_weights_scale,
- const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
- const int8_t* cell_to_forget_weights_ptr,
- float cell_to_forget_weights_scale,
- const int8_t* cell_to_output_weights_ptr,
- float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
- float projection_weights_scale, const float* projection_bias_ptr,
- const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
- int n_output, float* input_gate_scratch, float* forget_gate_scratch,
- float* cell_scratch, float* output_gate_scratch, float* scaling_factors,
- float* product_scaling_factors, float* recovered_cell_weights,
- int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr,
- int8_t* quantized_cell_state_ptr, float* output_state_ptr,
- float* cell_state_ptr, float* output_ptr_batch);
-
-void LstmStepWithAuxInput(
- const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
- float input_to_input_weights_scale,
- const int8_t* input_to_forget_weights_ptr,
- float input_to_forget_weights_scale,
- const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
- const int8_t* input_to_output_weights_ptr,
- float input_to_output_weights_scale, const float* aux_input_ptr_batch,
- const int8_t* aux_input_to_input_weights_ptr,
- float aux_input_to_input_weights_scale,
- const int8_t* aux_input_to_forget_weights_ptr,
- float aux_input_to_forget_weights_scale,
- const int8_t* aux_input_to_cell_weights_ptr,
- float aux_input_to_cell_weights_scale,
- const int8_t* aux_input_to_output_weights_ptr,
- float aux_input_to_output_weights_scale,
- const int8_t* recurrent_to_input_weights_ptr,
- float recurrent_to_input_weights_scale,
- const int8_t* recurrent_to_forget_weights_ptr,
- float recurrent_to_forget_weights_scale,
- const int8_t* recurrent_to_cell_weights_ptr,
- float recurrent_to_cell_weights_scale,
- const int8_t* recurrent_to_output_weights_ptr,
- float recurrent_to_output_weights_scale,
- const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
- const int8_t* cell_to_forget_weights_ptr,
- float cell_to_forget_weights_scale,
- const int8_t* cell_to_output_weights_ptr,
- float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
- float projection_weights_scale, const float* projection_bias_ptr,
- const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
- int n_aux_input, int n_output, float* input_gate_scratch,
- float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
- float* scaling_factors, float* product_scaling_factors,
- float* recovered_cell_weights, int8_t* quantized_input_ptr_batch,
- int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr,
- int8_t* quantized_cell_state_ptr, float* output_state_ptr,
- float* cell_state_ptr, float* output_ptr_batch);
-
} // namespace kernel_utils
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_options.cc b/tensorflow/contrib/lite/kernels/internal/legacy_types.h
index 35b4b4e20b..2e4d3137f5 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_options.cc
+++ b/tensorflow/contrib/lite/kernels/internal/legacy_types.h
@@ -12,17 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_LEGACY_TYPES_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_LEGACY_TYPES_H_
-#include "tensorflow/compiler/xla/service/gpu/gpu_options.h"
-#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/contrib/lite/kernels/internal/types.h"
-namespace xla {
-namespace gpu {
+namespace tflite {
-bool ConvUseLayoutHeuristic(const HloModuleConfig& config) {
- return !config.debug_options().xla_backend_extra_options().count(
- "xla_gpu_experimental_conv_disable_layout_heuristic");
-}
+// TODO(b/116772710): Insert legacy Dims<> code in here.
-} // namespace gpu
-} // namespace xla
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_LEGACY_TYPES_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
index 14281f25c6..25ea72b886 100644
--- a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
+++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc
@@ -259,7 +259,7 @@ TEST(QuantizationUtilTest, IntegerFrExpVersusDouble) {
EXPECT_EQ(double_shift, 1);
result = IntegerFrExp(123.45, &shift);
- EXPECT_NEAR(result, (0.964453 * (1L << 31)), 1000);
+ EXPECT_NEAR(result, (0.964453 * (1LL << 31)), 1000);
EXPECT_EQ(shift, 7);
double_result = std::frexp(123.45, &double_shift);
EXPECT_NEAR(double_result, 0.964453, 1e-5);
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
index be99240b1f..c8b64cfd96 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/legacy_reference_ops.h
@@ -19,10 +19,10 @@ limitations under the License.
#include <sys/types.h>
#include "tensorflow/contrib/lite/kernels/internal/common.h"
+#include "tensorflow/contrib/lite/kernels/internal/legacy_types.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
-#include "tensorflow/contrib/lite/kernels/internal/types.h"
namespace tflite {
@@ -30,6 +30,11 @@ namespace reference_ops {
static constexpr int kDepthwiseReverseShift = -1;
+inline void ShapeFromDims(const tflite::Dims<4>& dims, RuntimeShape* shape) {
+ shape->BuildFrom(
+ {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
+}
+
inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
const float* filter_data, const Dims<4>& filter_dims,
const float* bias_data, const Dims<4>& bias_dims,
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 59f17ae854..19d23fa80b 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -100,11 +100,6 @@ gemmlowp::FixedPoint<tRawType, tIntegerBits> SaturatingSub(
namespace reference_ops {
-inline void ShapeFromDims(const tflite::Dims<4>& dims, RuntimeShape* shape) {
- shape->BuildFrom(
- {dims.sizes[3], dims.sizes[2], dims.sizes[1], dims.sizes[0]});
-}
-
template <typename T>
int CountLeadingZeros(T integer_input) {
static_assert(std::is_unsigned<T>::value,
diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h
index b39347758a..c6bc6074d4 100644
--- a/tensorflow/contrib/lite/kernels/internal/types.h
+++ b/tensorflow/contrib/lite/kernels/internal/types.h
@@ -18,7 +18,6 @@ limitations under the License.
#include <algorithm>
#include <cstring>
-#include "absl/base/macros.h"
#include "tensorflow/contrib/lite/kernels/internal/compatibility.h"
namespace tflite {
@@ -269,8 +268,9 @@ class RuntimeShape {
// This creates a shape padded to the desired size with the specified value.
RuntimeShape(int new_shape_size, const RuntimeShape& shape, int pad_value)
: size_(0) {
+ // If the following check fails, it is likely because a 4D-only kernel is
+ // being used with an array of larger dimension count.
TFLITE_CHECK_GE(new_shape_size, shape.DimensionsCount());
- TFLITE_CHECK_LE(new_shape_size, kMaxSmallSize);
Resize(new_shape_size);
const int size_increase = new_shape_size - shape.DimensionsCount();
for (int i = 0; i < size_increase; ++i) {
@@ -441,7 +441,7 @@ inline int FlatSize(const Dims<N>& dims) {
return flat_size;
}
-ABSL_DEPRECATED("Prefer FlatSize.")
+TFLITE_DEPRECATED("Prefer FlatSize.")
inline int RequiredBufferSizeForDims(const Dims<4>& dims) {
return FlatSize(dims);
}
diff --git a/tensorflow/contrib/lite/kernels/lstm.cc b/tensorflow/contrib/lite/kernels/lstm.cc
index 5b996d00bc..16d67a1a93 100644
--- a/tensorflow/contrib/lite/kernels/lstm.cc
+++ b/tensorflow/contrib/lite/kernels/lstm.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/lstm_eval.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
@@ -424,263 +425,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
-// The LSTM Op engine.
-TfLiteStatus EvalFloat(
- const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
- const TfLiteTensor* input_to_forget_weights,
- const TfLiteTensor* input_to_cell_weights,
- const TfLiteTensor* input_to_output_weights,
- const TfLiteTensor* recurrent_to_input_weights,
- const TfLiteTensor* recurrent_to_forget_weights,
- const TfLiteTensor* recurrent_to_cell_weights,
- const TfLiteTensor* recurrent_to_output_weights,
- const TfLiteTensor* cell_to_input_weights,
- const TfLiteTensor* cell_to_forget_weights,
- const TfLiteTensor* cell_to_output_weights,
- const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
- const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
- const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
- const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
- TfLiteTensor* activation_state, TfLiteTensor* cell_state,
- TfLiteTensor* output) {
- const int n_batch = input->dims->data[0];
- const int n_input = input->dims->data[1];
- // n_cell and n_output will be the same size when there is no projection.
- const int n_cell = input_to_output_weights->dims->data[0];
- const int n_output = recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existence of only one to get the condition.
- const bool use_cifg = (input_to_input_weights == nullptr);
- const bool use_peephole = (cell_to_output_weights != nullptr);
-
- float* input_gate_scratch = nullptr;
- float* cell_scratch = nullptr;
- float* forget_gate_scratch = nullptr;
- float* output_gate_scratch = nullptr;
- if (use_cifg) {
- cell_scratch = scratch_buffer->data.f;
- forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- } else {
- input_gate_scratch = scratch_buffer->data.f;
- cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
- forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- const float* input_to_input_weights_ptr =
- (use_cifg) ? nullptr : input_to_input_weights->data.f;
- const float* recurrent_to_input_weights_ptr =
- (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
- const float* input_gate_bias_ptr =
- (use_cifg) ? nullptr : input_gate_bias->data.f;
- const float* cell_to_input_weights_ptr =
- (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
- const float* cell_to_forget_weights_ptr =
- (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
- const float* cell_to_output_weights_ptr =
- (use_peephole) ? cell_to_output_weights->data.f : nullptr;
- const float* projection_weights_ptr =
- (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
- const float* projection_bias_ptr =
- (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
-
- // Required tensors, pointers are non-null.
- const float* input_ptr_batch = input->data.f;
- const float* input_to_forget_weights_ptr = input_to_forget_weights->data.f;
- const float* input_to_cell_weights_ptr = input_to_cell_weights->data.f;
- const float* input_to_output_weights_ptr = input_to_output_weights->data.f;
- const float* recurrent_to_forget_weights_ptr =
- recurrent_to_forget_weights->data.f;
- const float* recurrent_to_cell_weights_ptr =
- recurrent_to_cell_weights->data.f;
- const float* recurrent_to_output_weights_ptr =
- recurrent_to_output_weights->data.f;
- const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
- const float* cell_bias_ptr = cell_bias->data.f;
- const float* output_gate_bias_ptr = output_gate_bias->data.f;
-
- float* activation_state_ptr = activation_state->data.f;
- float* cell_state_ptr = cell_state->data.f;
- float* output_ptr_batch = output->data.f;
-
- kernel_utils::LstmStep(
- input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr,
- input_to_cell_weights_ptr, input_to_output_weights_ptr,
- recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
- recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
- cell_to_input_weights_ptr, cell_to_forget_weights_ptr,
- cell_to_output_weights_ptr, input_gate_bias_ptr, forget_gate_bias_ptr,
- cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
- projection_bias_ptr, params, n_batch, n_cell, n_input, n_output,
- activation_state_ptr, cell_state_ptr, input_gate_scratch,
- forget_gate_scratch, cell_scratch, output_gate_scratch, output_ptr_batch);
-
- return kTfLiteOk;
-}
-
-TfLiteStatus EvalHybrid(
- const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
- const TfLiteTensor* input_to_forget_weights,
- const TfLiteTensor* input_to_cell_weights,
- const TfLiteTensor* input_to_output_weights,
- const TfLiteTensor* recurrent_to_input_weights,
- const TfLiteTensor* recurrent_to_forget_weights,
- const TfLiteTensor* recurrent_to_cell_weights,
- const TfLiteTensor* recurrent_to_output_weights,
- const TfLiteTensor* cell_to_input_weights,
- const TfLiteTensor* cell_to_forget_weights,
- const TfLiteTensor* cell_to_output_weights,
- const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
- const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
- const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
- const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
- TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
- TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
- TfLiteTensor* activation_state_quantized,
- TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state,
- TfLiteTensor* cell_state, TfLiteTensor* output) {
- const int n_batch = input->dims->data[0];
- const int n_input = input->dims->data[1];
- // n_cell and n_output will be the same size when there is no projection.
- const int n_cell = input_to_output_weights->dims->data[0];
- const int n_output = recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existence of only one to get the condition.
- const bool use_cifg = (input_to_input_weights == nullptr);
- const bool use_peephole = (cell_to_output_weights != nullptr);
-
- float* input_gate_scratch = nullptr;
- float* cell_scratch = nullptr;
- float* forget_gate_scratch = nullptr;
- float* output_gate_scratch = nullptr;
- if (use_cifg) {
- cell_scratch = scratch_buffer->data.f;
- forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- } else {
- input_gate_scratch = scratch_buffer->data.f;
- cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
- forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- int8_t* input_to_input_weights_ptr = nullptr;
- float input_to_input_weights_scale = 1.0f;
- int8_t* recurrent_to_input_weights_ptr = nullptr;
- float recurrent_to_input_weights_scale = 1.0f;
- float* input_gate_bias_ptr = nullptr;
- if (!use_cifg) {
- input_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
- recurrent_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
- input_gate_bias_ptr = input_gate_bias->data.f;
- input_to_input_weights_scale = input_to_input_weights->params.scale;
- recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
- }
-
- int8_t* cell_to_input_weights_ptr = nullptr;
- int8_t* cell_to_forget_weights_ptr = nullptr;
- int8_t* cell_to_output_weights_ptr = nullptr;
- float cell_to_input_weights_scale = 1.0f;
- float cell_to_forget_weights_scale = 1.0f;
- float cell_to_output_weights_scale = 1.0f;
- if (use_peephole) {
- if (!use_cifg) {
- cell_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
- cell_to_input_weights_scale = cell_to_input_weights->params.scale;
- }
- cell_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
- cell_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
- cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
- cell_to_output_weights_scale = cell_to_output_weights->params.scale;
- }
-
- const int8_t* projection_weights_ptr =
- (projection_weights == nullptr)
- ? nullptr
- : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
- const float projection_weights_scale =
- (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
- const float* projection_bias_ptr =
- (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
-
- // Required tensors, pointers are non-null.
- const float* input_ptr_batch = input->data.f;
- const int8_t* input_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
- const float input_to_forget_weights_scale =
- input_to_forget_weights->params.scale;
- const int8_t* input_to_cell_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
- const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
- const int8_t* input_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
- const float input_to_output_weights_scale =
- input_to_output_weights->params.scale;
- const int8_t* recurrent_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
- const float recurrent_to_forget_weights_scale =
- recurrent_to_forget_weights->params.scale;
- const int8_t* recurrent_to_cell_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
- const float recurrent_to_cell_weights_scale =
- recurrent_to_cell_weights->params.scale;
- const int8_t* recurrent_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
- const float recurrent_to_output_weights_scale =
- recurrent_to_output_weights->params.scale;
- const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
- const float* cell_bias_ptr = cell_bias->data.f;
- const float* output_gate_bias_ptr = output_gate_bias->data.f;
-
- float* activation_state_ptr = activation_state->data.f;
- float* cell_state_ptr = cell_state->data.f;
- float* output_ptr_batch = output->data.f;
-
- // Temporary storage for quantized values and scaling factors.
- int8_t* quantized_input_ptr =
- reinterpret_cast<int8_t*>(input_quantized->data.uint8);
- int8_t* quantized_activation_state_ptr =
- reinterpret_cast<int8_t*>(activation_state_quantized->data.uint8);
- int8_t* quantized_cell_state_ptr =
- reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
- float* scaling_factors_ptr = scaling_factors->data.f;
- float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
- float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
-
- kernel_utils::LstmStep(
- input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale,
- input_to_forget_weights_ptr, input_to_forget_weights_scale,
- input_to_cell_weights_ptr, input_to_cell_weights_scale,
- input_to_output_weights_ptr, input_to_output_weights_scale,
- recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale,
- recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale,
- recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale,
- recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale,
- cell_to_input_weights_ptr, cell_to_input_weights_scale,
- cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
- cell_to_output_weights_ptr, cell_to_output_weights_scale,
- input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
- output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale,
- projection_bias_ptr, params, n_batch, n_cell, n_input, n_output,
- input_gate_scratch, forget_gate_scratch, cell_scratch,
- output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
- recovered_cell_weights_ptr, quantized_input_ptr,
- quantized_activation_state_ptr, quantized_cell_state_ptr,
- activation_state_ptr, cell_state_ptr, output_ptr_batch);
-
- return kTfLiteOk;
-}
-
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
@@ -738,15 +482,21 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// TODO(mirkov): add a check that weights are all uint8s or all floats.
switch (input_to_output_weights->type) {
case kTfLiteFloat32: {
- return EvalFloat(input, input_to_input_weights, input_to_forget_weights,
- input_to_cell_weights, input_to_output_weights,
- recurrent_to_input_weights, recurrent_to_forget_weights,
- recurrent_to_cell_weights, recurrent_to_output_weights,
- cell_to_input_weights, cell_to_forget_weights,
- cell_to_output_weights, input_gate_bias,
- forget_gate_bias, cell_bias, output_gate_bias,
- projection_weights, projection_bias, params,
- scratch_buffer, activation_state, cell_state, output);
+ return lstm_eval::EvalFloat(
+ input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights,
+ recurrent_to_input_weights, recurrent_to_forget_weights,
+ recurrent_to_cell_weights, recurrent_to_output_weights,
+ cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
+ /*aux_input=*/nullptr,
+ /*aux_input_to_input_weights=*/nullptr,
+ /*aux_input_to_forget_weights=*/nullptr,
+ /*aux_input_to_cell_weights=*/nullptr,
+ /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
+ projection_bias, params, /*forward_sequence=*/true,
+ /*output_offset=*/0, scratch_buffer, activation_state, cell_state,
+ output);
}
case kTfLiteUInt8: {
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
@@ -759,17 +509,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTemporary(context, node, /*index=*/5);
TfLiteTensor* recovered_cell_weights =
GetTemporary(context, node, /*index=*/6);
- return EvalHybrid(
+ return lstm_eval::EvalHybrid(
input, input_to_input_weights, input_to_forget_weights,
input_to_cell_weights, input_to_output_weights,
recurrent_to_input_weights, recurrent_to_forget_weights,
recurrent_to_cell_weights, recurrent_to_output_weights,
cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
- input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias,
- projection_weights, projection_bias, params, scratch_buffer,
- scaling_factors, prod_scaling_factors, recovered_cell_weights,
- input_quantized, activation_state_quantized, cell_state_quantized,
- activation_state, cell_state, output);
+ /*aux_input=*/nullptr,
+ /*aux_input_to_input_weights=*/nullptr,
+ /*aux_input_to_forget_weights=*/nullptr,
+ /*aux_input_to_cell_weights=*/nullptr,
+ /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
+ projection_bias, params, /*forward_sequence=*/true,
+ /*output_offset=*/0, scratch_buffer, scaling_factors,
+ prod_scaling_factors, recovered_cell_weights, input_quantized,
+ /*aux_input_quantized=*/nullptr, activation_state_quantized,
+ cell_state_quantized, activation_state, cell_state, output);
}
default:
context->ReportError(context, "Type %d is not currently supported.",
diff --git a/tensorflow/contrib/lite/kernels/lstm_eval.cc b/tensorflow/contrib/lite/kernels/lstm_eval.cc
new file mode 100644
index 0000000000..20a4e30009
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/lstm_eval.cc
@@ -0,0 +1,912 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/kernels/lstm_eval.h"
+
+#include <stdint.h>
+
+#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace lstm_eval {
+
+namespace {
+
+// Performs an LSTM batch inference step for input specified by input_ptr_batch.
+// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and
+// biases (*_bias_ptr), and buffers (*_scratch), along with additional
+// parameters:
+// - params: various LSTM params including activation, clipping, etc.,
+// - n_batch: size of batch,
+// - n_cell: number of cells (or units),
+// - n_input: the input size,
+// - n_output: the output size.
+//
+// The pointers to the cell and output state and the output are updated.
+//
+// The pointers with the suffix "_batch" point to data aligned in batch_major
+// order, and each step processes batch_size many inputs from input_ptr_batch,
+// and updates batch_size many cell and output states.
+inline void LstmStepWithAuxInput(
+ const float* input_ptr_batch, const float* input_to_input_weights_ptr,
+ const float* input_to_forget_weights_ptr,
+ const float* input_to_cell_weights_ptr,
+ const float* input_to_output_weights_ptr, const float* aux_input_ptr_batch,
+ const float* aux_input_to_input_weights_ptr,
+ const float* aux_input_to_forget_weights_ptr,
+ const float* aux_input_to_cell_weights_ptr,
+ const float* aux_input_to_output_weights_ptr,
+ const float* recurrent_to_input_weights_ptr,
+ const float* recurrent_to_forget_weights_ptr,
+ const float* recurrent_to_cell_weights_ptr,
+ const float* recurrent_to_output_weights_ptr,
+ const float* cell_to_input_weights_ptr,
+ const float* cell_to_forget_weights_ptr,
+ const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr,
+ const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+ const float* output_gate_bias_ptr, const float* projection_weights_ptr,
+ const float* projection_bias_ptr, const TfLiteLSTMParams* params,
+ int n_batch, int n_cell, int n_input, int n_aux_input, int n_output,
+ float* output_state_ptr, float* cell_state_ptr, float* input_gate_scratch,
+ float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
+ float* output_ptr_batch) {
+ // Since we have already checked that weights are all there or none, we can
+ // check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights_ptr == nullptr);
+ const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
+ // Initialize scratch buffers with bias.
+ if (!use_cifg) {
+ tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
+ forget_gate_scratch);
+ tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
+ cell_scratch);
+ tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
+ output_gate_scratch);
+
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ forget_gate_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
+ output_gate_scratch, /*result_stride=*/1);
+
+ // If auxiliary input is available then compute aux_input_weight * aux_input
+ if (aux_input_ptr_batch != nullptr) {
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_input_weights_ptr, n_cell, n_aux_input,
+ aux_input_ptr_batch, n_batch, input_gate_scratch,
+ /*result_stride=*/1);
+ }
+
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_forget_weights_ptr, n_cell, n_aux_input,
+ aux_input_ptr_batch, n_batch, forget_gate_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_cell_weights_ptr, n_cell, n_aux_input, aux_input_ptr_batch,
+ n_batch, cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_output_weights_ptr, n_cell, n_aux_input,
+ aux_input_ptr_batch, n_batch, output_gate_scratch, /*result_stride=*/1);
+ }
+
+ // For each batch and cell: compute recurrent_weight * output_state.
+ if (!use_cifg) {
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, input_gate_scratch, /*result_stride=*/1);
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, forget_gate_scratch,
+ /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, cell_scratch, /*result_stride=*/1);
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
+ n_batch, output_gate_scratch,
+ /*result_stride=*/1);
+
+ // For each batch and cell: update input gate.
+ if (!use_cifg) {
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
+ input_gate_scratch);
+ }
+
+ // For each batch and cell: update forget gate.
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
+ forget_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
+ forget_gate_scratch);
+
+ // For each batch and cell: update the cell.
+ tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
+ n_batch * n_cell, cell_state_ptr);
+ tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
+ params->activation, cell_scratch);
+ if (use_cifg) {
+ tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
+ forget_gate_scratch);
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ } else {
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ }
+ if (params->cell_clip > 0.0) {
+ tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
+ params->cell_clip, cell_state_ptr);
+ }
+
+ // For each batch and cell: update the output gate.
+ if (use_peephole) {
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
+ output_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
+ output_gate_scratch);
+ tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
+ params->activation, cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
+ n_batch * n_cell, output_gate_scratch);
+
+ // For each batch: update the projection and output_state.
+ const bool use_projection_weight = (projection_weights_ptr != nullptr);
+ const bool use_projection_bias = (projection_bias_ptr != nullptr);
+ if (use_projection_weight) {
+ if (use_projection_bias) {
+ tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
+ n_batch, output_ptr_batch);
+ } else {
+ tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch,
+ output_ptr_batch, /*result_stride=*/1);
+ if (params->proj_clip > 0.0) {
+ tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
+ params->proj_clip, output_ptr_batch);
+ }
+ } else {
+ tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
+ output_ptr_batch);
+ }
+ tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
+ output_state_ptr);
+}
+
+// Same as above but with quantized weight matrices. In detail:
+// Input of size 'n_batch * n_input':
+// input_ptr_batch
+//
+// LSTM weights:
+// Quantized input weights of size 'n_cell * n_input':
+// input_to_input_weights - optional (can be nullptr)
+// input_to_forget_weights
+// input_to_cell_weights
+// input_to_input_weights
+// Quantized recurrent weights of size 'n_cell * n_output':
+// recurrent_to_input_weights - optional
+// recurrent_to_forget_weights
+// recurrent_to_cell_weights
+// recurrent_to_input_weights
+// Quantized peephole weights of size 'n_cell', representing diagonal matrices.
+// cell_to_input_weights - optional
+// cell_to_cell_weights - optional
+// cell_to_output_weights - optional
+// Quantized projection weights of size 'n_output * n_cell'
+// projection_weights_ptr - optional
+// Weight scales (scalars) for each of the weights above.
+// input_to_input_weights_scale - optional
+// input_to_forget_weights_scale
+// input_to_cell_weights_scale
+// input_to_output_weights_scale
+// recurrent_to_input_weights_scale - optional
+// recurrent_to_forget_weights_scale
+// recurrent_to_cell_weights_scale
+// recurrent_to_output_weights_scale
+// cell_to_input_weights_scale,
+// cell_to_forget_weights_scale,
+// cell_to_output_weights_scale,
+// projection_weights_scale - optional
+// Gate biases of size 'n_cell':
+// input_gate_bias_ptr - optional
+// forget_gate_bias_ptr
+// cell_gate_bias_ptr
+// output_gate_bias_ptr
+//
+// Temporary pre-allocated storage for quantized values:
+// quantized_input_ptr_batch (same size as input_ptr_batch)
+// quantized_output_state_ptr (same size as output_state_ptr)
+// quantized_cell_state_ptr (same size as cell_state_ptr)
+// Temporary pre-allocated storage for recovered values:
+// recovered_cell_weights (same size as cell_to_*_weights)
+//
+// Outputs:
+// output_state_ptr - size 'n_batch * n_output'
+// cell_state_ptr - size 'n_batch * n_cell'
+// output_ptr_batch - size 'n_batch * n_output'
+inline void LstmStepWithAuxInput(
+ const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
+ float input_to_input_weights_scale,
+ const int8_t* input_to_forget_weights_ptr,
+ float input_to_forget_weights_scale,
+ const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
+ const int8_t* input_to_output_weights_ptr,
+ float input_to_output_weights_scale, const float* aux_input_ptr_batch,
+ const int8_t* aux_input_to_input_weights_ptr,
+ float aux_input_to_input_weights_scale,
+ const int8_t* aux_input_to_forget_weights_ptr,
+ float aux_input_to_forget_weights_scale,
+ const int8_t* aux_input_to_cell_weights_ptr,
+ float aux_input_to_cell_weights_scale,
+ const int8_t* aux_input_to_output_weights_ptr,
+ float aux_input_to_output_weights_scale,
+ const int8_t* recurrent_to_input_weights_ptr,
+ float recurrent_to_input_weights_scale,
+ const int8_t* recurrent_to_forget_weights_ptr,
+ float recurrent_to_forget_weights_scale,
+ const int8_t* recurrent_to_cell_weights_ptr,
+ float recurrent_to_cell_weights_scale,
+ const int8_t* recurrent_to_output_weights_ptr,
+ float recurrent_to_output_weights_scale,
+ const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
+ const int8_t* cell_to_forget_weights_ptr,
+ float cell_to_forget_weights_scale,
+ const int8_t* cell_to_output_weights_ptr,
+ float cell_to_output_weights_scale, const float* input_gate_bias_ptr,
+ const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
+ const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
+ float projection_weights_scale, const float* projection_bias_ptr,
+ const TfLiteLSTMParams* params, int n_batch, int n_cell, int n_input,
+ int n_aux_input, int n_output, float* input_gate_scratch,
+ float* forget_gate_scratch, float* cell_scratch, float* output_gate_scratch,
+ float* scaling_factors, float* product_scaling_factors,
+ float* recovered_cell_weights, int8_t* quantized_input_ptr_batch,
+ int8_t* quantized_aux_input_ptr_batch, int8_t* quantized_output_state_ptr,
+ int8_t* quantized_cell_state_ptr, float* output_state_ptr,
+ float* cell_state_ptr, float* output_ptr_batch) {
+ // Since we have already checked that weights are all there or none, we
+ // can check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights_ptr == nullptr);
+ const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
+ // Initialize scratch buffers with bias.
+ if (!use_cifg) {
+ tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch,
+ forget_gate_scratch);
+ tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch,
+ cell_scratch);
+ tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch,
+ output_gate_scratch);
+
+ if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_input;
+ tensor_utils::SymmetricQuantizeFloats(
+ input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset,
+ &unused_min, &unused_max, &scaling_factors[b]);
+ }
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_input_weights_ptr, n_cell, n_input,
+ quantized_input_ptr_batch, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+ product_scaling_factors, n_batch, forget_gate_scratch,
+ /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+ product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * input_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
+ product_scaling_factors, n_batch, output_gate_scratch,
+ /*result_stride=*/1);
+ }
+
+ if (aux_input_ptr_batch != nullptr &&
+ !tensor_utils::IsZeroVector(aux_input_ptr_batch, n_batch * n_input)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_input;
+ tensor_utils::SymmetricQuantizeFloats(
+ aux_input_ptr_batch + offset, n_input,
+ quantized_aux_input_ptr_batch + offset, &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ // For each batch and cell: compute input_weight * input.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * aux_input_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_input_weights_ptr, n_cell, n_input,
+ quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * aux_input_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_forget_weights_ptr, n_cell, n_input,
+ quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
+ forget_gate_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * aux_input_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_cell_weights_ptr, n_cell, n_input,
+ quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
+ cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * aux_input_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ aux_input_to_output_weights_ptr, n_cell, n_input,
+ quantized_aux_input_ptr_batch, product_scaling_factors, n_batch,
+ output_gate_scratch, /*result_stride=*/1);
+ }
+
+ if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_output;
+ tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output,
+ quantized_output_state_ptr + offset,
+ &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ // For each batch and cell: compute recurrent_weight * output_state.
+ if (!use_cifg) {
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_input_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_input_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ input_gate_scratch, /*result_stride=*/1);
+ }
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_forget_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_forget_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ forget_gate_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_cell_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_cell_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ cell_scratch, /*result_stride=*/1);
+
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * recurrent_to_output_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ recurrent_to_output_weights_ptr, n_cell, n_output,
+ quantized_output_state_ptr, product_scaling_factors, n_batch,
+ output_gate_scratch, /*result_stride=*/1);
+ }
+
+ // Save quantization and matmul computation for all zero input.
+ bool is_cell_state_all_zeros =
+ tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
+
+ // For each batch and cell: update input gate.
+ if (!use_cifg) {
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell,
+ cell_to_input_weights_scale,
+ recovered_cell_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
+ input_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
+ input_gate_scratch);
+ }
+
+ // For each batch and cell: update forget gate.
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell,
+ cell_to_forget_weights_scale,
+ recovered_cell_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
+ forget_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
+ forget_gate_scratch);
+
+ // For each batch and cell: update the cell.
+ tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
+ n_batch * n_cell, cell_state_ptr);
+ tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
+ params->activation, cell_scratch);
+ if (use_cifg) {
+ tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
+ forget_gate_scratch);
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ } else {
+ tensor_utils::VectorVectorCwiseProductAccumulate(
+ cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
+ }
+ if (params->cell_clip > 0.0) {
+ tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell,
+ params->cell_clip, cell_state_ptr);
+ }
+
+ is_cell_state_all_zeros =
+ tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
+ // For each batch and cell: update the output gate.
+ if (use_peephole && !is_cell_state_all_zeros) {
+ tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell,
+ cell_to_output_weights_scale,
+ recovered_cell_weights);
+ tensor_utils::VectorBatchVectorCwiseProductAccumulate(
+ recovered_cell_weights, n_cell, cell_state_ptr, n_batch,
+ output_gate_scratch);
+ }
+ tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
+ output_gate_scratch);
+ tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
+ params->activation, cell_scratch);
+ tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
+ n_batch * n_cell, output_gate_scratch);
+
+ // For each batch: update the projection and output_state.
+ const bool use_projection_weight = (projection_weights_ptr != nullptr);
+ const bool use_projection_bias = (projection_bias_ptr != nullptr);
+ if (use_projection_weight) {
+ if (use_projection_bias) {
+ tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
+ n_batch, output_ptr_batch);
+ } else {
+ tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
+ }
+ if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) {
+ // Save quantization and matmul computation for all zero input.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_cell;
+ tensor_utils::SymmetricQuantizeFloats(
+ output_gate_scratch + offset, n_cell,
+ quantized_cell_state_ptr + offset, &unused_min, &unused_max,
+ &scaling_factors[b]);
+ }
+ for (int b = 0; b < n_batch; ++b) {
+ product_scaling_factors[b] =
+ scaling_factors[b] * projection_weights_scale;
+ }
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr,
+ product_scaling_factors, n_batch, output_ptr_batch,
+ /*result_stride=*/1);
+ }
+ if (params->proj_clip > 0.0) {
+ tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output,
+ params->proj_clip, output_ptr_batch);
+ }
+ } else {
+ tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
+ output_ptr_batch);
+ }
+ tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
+ output_state_ptr);
+}
+} // namespace
+
+TfLiteStatus EvalFloat(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
+ const TfLiteTensor* aux_input_to_input_weights,
+ const TfLiteTensor* aux_input_to_forget_weights,
+ const TfLiteTensor* aux_input_to_cell_weights,
+ const TfLiteTensor* aux_input_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
+ TfLiteTensor* cell_state, TfLiteTensor* output) {
+ TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
+ const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0];
+ const int n_batch = input->dims->data[input->dims->size - 2];
+ const int n_input = input->dims->data[input->dims->size - 1];
+ const int aux_input_size =
+ (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
+
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existense of only one to the get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ // Index the scratch buffers pointers to the global scratch buffer.
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Check optional tensors, the respective pointers can be null.
+ const float* input_to_input_weights_ptr =
+ (use_cifg) ? nullptr : input_to_input_weights->data.f;
+ const float* recurrent_to_input_weights_ptr =
+ (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
+ const float* input_gate_bias_ptr =
+ (use_cifg) ? nullptr : input_gate_bias->data.f;
+ const float* cell_to_input_weights_ptr =
+ (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
+ const float* cell_to_forget_weights_ptr =
+ (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
+ const float* cell_to_output_weights_ptr =
+ (use_peephole) ? cell_to_output_weights->data.f : nullptr;
+ const float* projection_weights_ptr =
+ (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
+ const float* projection_bias_ptr =
+ (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+ float* aux_input_ptr = nullptr;
+ float* aux_input_to_input_weights_ptr = nullptr;
+ float* aux_input_to_forget_weights_ptr = nullptr;
+ float* aux_input_to_cell_weights_ptr = nullptr;
+ float* aux_input_to_output_weights_ptr = nullptr;
+ if (aux_input_size > 0) {
+ aux_input_ptr = aux_input->data.f;
+ aux_input_to_input_weights_ptr = aux_input_to_input_weights->data.f;
+ aux_input_to_forget_weights_ptr = aux_input_to_forget_weights->data.f;
+ aux_input_to_cell_weights_ptr = aux_input_to_cell_weights->data.f;
+ aux_input_to_output_weights_ptr = aux_input_to_output_weights->data.f;
+ }
+
+ // Loop through the sequence.
+ const int input_step = n_batch * n_input;
+ const int output_step = n_batch * output->dims->data[output->dims->size - 1];
+ for (int t = 0; t < max_time; t++) {
+ // If this is the forward_sequence, step forward, otherwise step backwards.
+ const int t_rel = forward_sequence ? t : max_time - t - 1;
+ const float* input_ptr = input->data.f + t_rel * input_step;
+ float* output_ptr_time =
+ output->data.f + t_rel * output_step + output_offset;
+
+ LstmStepWithAuxInput(
+ input_ptr, input_to_input_weights_ptr, input_to_forget_weights->data.f,
+ input_to_cell_weights->data.f, input_to_output_weights->data.f,
+ aux_input_ptr, aux_input_to_input_weights_ptr,
+ aux_input_to_forget_weights_ptr, aux_input_to_cell_weights_ptr,
+ aux_input_to_output_weights_ptr, recurrent_to_input_weights_ptr,
+ recurrent_to_forget_weights->data.f, recurrent_to_cell_weights->data.f,
+ recurrent_to_output_weights->data.f, cell_to_input_weights_ptr,
+ cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
+ input_gate_bias_ptr, forget_gate_bias->data.f, cell_bias->data.f,
+ output_gate_bias->data.f, projection_weights_ptr, projection_bias_ptr,
+ params, n_batch, n_cell, n_input, aux_input_size, n_output,
+ activation_state->data.f, cell_state->data.f, input_gate_scratch,
+ forget_gate_scratch, cell_scratch, output_gate_scratch,
+ output_ptr_time);
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalHybrid(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
+ const TfLiteTensor* aux_input_to_input_weights,
+ const TfLiteTensor* aux_input_to_forget_weights,
+ const TfLiteTensor* aux_input_to_cell_weights,
+ const TfLiteTensor* aux_input_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
+ TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
+ TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
+ TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
+ TfLiteTensor* output_state, TfLiteTensor* cell_state,
+ TfLiteTensor* output) {
+ TF_LITE_ASSERT(input->dims->size >= 2 && input->dims->size <= 3);
+ const int max_time = (input->dims->size == 2) ? 1 : input->dims->data[0];
+ const int n_batch = input->dims->data[input->dims->size - 2];
+ const int n_input = input->dims->data[input->dims->size - 1];
+ const int aux_input_size =
+ (aux_input) ? aux_input->dims->data[aux_input->dims->size - 1] : 0;
+ // n_cell and n_output will be the same size when there is no projection.
+ const int n_cell = input_to_output_weights->dims->data[0];
+ const int n_output = recurrent_to_output_weights->dims->data[1];
+
+ // Since we have already checked that weights are all there or none, we can
+ // check the existence of only one to get the condition.
+ const bool use_cifg = (input_to_input_weights == nullptr);
+ const bool use_peephole = (cell_to_output_weights != nullptr);
+
+ float* input_gate_scratch = nullptr;
+ float* cell_scratch = nullptr;
+ float* forget_gate_scratch = nullptr;
+ float* output_gate_scratch = nullptr;
+ if (use_cifg) {
+ cell_scratch = scratch_buffer->data.f;
+ forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ } else {
+ input_gate_scratch = scratch_buffer->data.f;
+ cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
+ forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
+ output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
+ }
+
+ // Check optional tensors, the respective pointers can be null.
+ int8_t* input_to_input_weights_ptr = nullptr;
+ float input_to_input_weights_scale = 1.0f;
+ int8_t* recurrent_to_input_weights_ptr = nullptr;
+ float recurrent_to_input_weights_scale = 1.0f;
+ float* input_gate_bias_ptr = nullptr;
+ if (!use_cifg) {
+ input_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
+ recurrent_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
+ input_gate_bias_ptr = input_gate_bias->data.f;
+ input_to_input_weights_scale = input_to_input_weights->params.scale;
+ recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
+ }
+
+ int8_t* cell_to_input_weights_ptr = nullptr;
+ int8_t* cell_to_forget_weights_ptr = nullptr;
+ int8_t* cell_to_output_weights_ptr = nullptr;
+ float cell_to_input_weights_scale = 1.0f;
+ float cell_to_forget_weights_scale = 1.0f;
+ float cell_to_output_weights_scale = 1.0f;
+ if (use_peephole) {
+ if (!use_cifg) {
+ cell_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
+ cell_to_input_weights_scale = cell_to_input_weights->params.scale;
+ }
+ cell_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
+ cell_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
+ cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
+ cell_to_output_weights_scale = cell_to_output_weights->params.scale;
+ }
+
+ const int8_t* projection_weights_ptr =
+ (projection_weights == nullptr)
+ ? nullptr
+ : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
+ const float projection_weights_scale =
+ (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
+ const float* projection_bias_ptr =
+ (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
+
+ // Required tensors, pointers are non-null.
+ const int8_t* input_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
+ const float input_to_forget_weights_scale =
+ input_to_forget_weights->params.scale;
+ const int8_t* input_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
+ const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
+ const int8_t* input_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
+ const float input_to_output_weights_scale =
+ input_to_output_weights->params.scale;
+ const int8_t* recurrent_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
+ const float recurrent_to_forget_weights_scale =
+ recurrent_to_forget_weights->params.scale;
+ const int8_t* recurrent_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
+ const float recurrent_to_cell_weights_scale =
+ recurrent_to_cell_weights->params.scale;
+ const int8_t* recurrent_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
+ const float recurrent_to_output_weights_scale =
+ recurrent_to_output_weights->params.scale;
+ const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
+ const float* cell_bias_ptr = cell_bias->data.f;
+ const float* output_gate_bias_ptr = output_gate_bias->data.f;
+
+ float* output_state_ptr = output_state->data.f;
+ float* cell_state_ptr = cell_state->data.f;
+
+ // Temporary storage for quantized values and scaling factors.
+ int8_t* quantized_input_ptr =
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+ int8_t* quantized_aux_input_ptr =
+ (aux_input_quantized == nullptr)
+ ? nullptr
+ : reinterpret_cast<int8_t*>(aux_input_quantized->data.uint8);
+ int8_t* quantized_output_state_ptr =
+ reinterpret_cast<int8_t*>(output_state_quantized->data.uint8);
+ int8_t* quantized_cell_state_ptr =
+ reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
+ float* scaling_factors_ptr = scaling_factors->data.f;
+ float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
+ float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
+
+ // Auxiliary input and weights.
+ float* aux_input_ptr = nullptr;
+ int8_t* aux_input_to_input_weights_ptr = nullptr;
+ int8_t* aux_input_to_forget_weights_ptr = nullptr;
+ int8_t* aux_input_to_cell_weights_ptr = nullptr;
+ int8_t* aux_input_to_output_weights_ptr = nullptr;
+ float aux_input_to_input_weights_scale = 0.0f;
+ float aux_input_to_forget_weights_scale = 0.0f;
+ float aux_input_to_cell_weights_scale = 0.0f;
+ float aux_input_to_output_weights_scale = 0.0f;
+ if (aux_input_size > 0) {
+ aux_input_ptr = aux_input->data.f;
+ aux_input_to_input_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_input_to_input_weights->data.uint8);
+ aux_input_to_forget_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_input_to_forget_weights->data.uint8);
+ aux_input_to_cell_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_input_to_cell_weights->data.uint8);
+ aux_input_to_output_weights_ptr =
+ reinterpret_cast<int8_t*>(aux_input_to_output_weights->data.uint8);
+ aux_input_to_input_weights_scale = aux_input_to_input_weights->params.scale;
+ aux_input_to_forget_weights_scale =
+ aux_input_to_forget_weights->params.scale;
+ aux_input_to_cell_weights_scale = aux_input_to_cell_weights->params.scale;
+ aux_input_to_output_weights_scale =
+ aux_input_to_output_weights->params.scale;
+ }
+
+ // Feed the sequence into the LSTM step-by-step.
+ const int input_step = n_batch * n_input;
+ const int output_step = n_batch * output->dims->data[output->dims->size - 1];
+ for (int t = 0; t < max_time; t++) {
+ // If this is the forward_sequence, step forward, otherwise step backwards.
+ const int t_rel = forward_sequence ? t : max_time - t - 1;
+ const float* input_ptr = input->data.f + t_rel * input_step;
+ float* output_ptr = output->data.f + t_rel * output_step + output_offset;
+
+ LstmStepWithAuxInput(
+ input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
+ input_to_forget_weights_ptr, input_to_forget_weights_scale,
+ input_to_cell_weights_ptr, input_to_cell_weights_scale,
+ input_to_output_weights_ptr, input_to_output_weights_scale,
+ aux_input_ptr, aux_input_to_input_weights_ptr,
+ aux_input_to_input_weights_scale, aux_input_to_forget_weights_ptr,
+ aux_input_to_forget_weights_scale, aux_input_to_cell_weights_ptr,
+ aux_input_to_cell_weights_scale, aux_input_to_output_weights_ptr,
+ aux_input_to_output_weights_scale, recurrent_to_input_weights_ptr,
+ recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
+ recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
+ recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
+ recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
+ cell_to_input_weights_scale, cell_to_forget_weights_ptr,
+ cell_to_forget_weights_scale, cell_to_output_weights_ptr,
+ cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr,
+ cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
+ projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell,
+ n_input, aux_input_size, n_output, input_gate_scratch,
+ forget_gate_scratch, cell_scratch, output_gate_scratch,
+ scaling_factors_ptr, prod_scaling_factors_ptr,
+ recovered_cell_weights_ptr, quantized_input_ptr,
+ quantized_aux_input_ptr, quantized_output_state_ptr,
+ quantized_cell_state_ptr, output_state_ptr, cell_state_ptr, output_ptr);
+ }
+
+ return kTfLiteOk;
+}
+
+} // namespace lstm_eval
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/lstm_eval.h b/tensorflow/contrib/lite/kernels/lstm_eval.h
new file mode 100644
index 0000000000..adf8cf0f64
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/lstm_eval.h
@@ -0,0 +1,79 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_LSTM_EVAL_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_LSTM_EVAL_H_
+
+#include "tensorflow/contrib/lite/c/builtin_op_data.h"
+#include "tensorflow/contrib/lite/c/c_api_internal.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace lstm_eval {
+
+TfLiteStatus EvalFloat(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
+ const TfLiteTensor* aux_input_to_input_weights,
+ const TfLiteTensor* aux_input_to_forget_weights,
+ const TfLiteTensor* aux_input_to_cell_weights,
+ const TfLiteTensor* aux_input_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
+ TfLiteTensor* cell_state, TfLiteTensor* output);
+
+TfLiteStatus EvalHybrid(
+ const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
+ const TfLiteTensor* input_to_forget_weights,
+ const TfLiteTensor* input_to_cell_weights,
+ const TfLiteTensor* input_to_output_weights,
+ const TfLiteTensor* recurrent_to_input_weights,
+ const TfLiteTensor* recurrent_to_forget_weights,
+ const TfLiteTensor* recurrent_to_cell_weights,
+ const TfLiteTensor* recurrent_to_output_weights,
+ const TfLiteTensor* cell_to_input_weights,
+ const TfLiteTensor* cell_to_forget_weights,
+ const TfLiteTensor* cell_to_output_weights, const TfLiteTensor* aux_input,
+ const TfLiteTensor* aux_input_to_input_weights,
+ const TfLiteTensor* aux_input_to_forget_weights,
+ const TfLiteTensor* aux_input_to_cell_weights,
+ const TfLiteTensor* aux_input_to_output_weights,
+ const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
+ const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
+ const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
+ const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
+ TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
+ TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
+ TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
+ TfLiteTensor* output_state_quantized, TfLiteTensor* cell_state_quantized,
+ TfLiteTensor* output_state, TfLiteTensor* cell_state, TfLiteTensor* output);
+
+} // namespace lstm_eval
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_LSTM_EVAL_H_
diff --git a/tensorflow/contrib/lite/kernels/sparse_output_fully_connected.cc b/tensorflow/contrib/lite/kernels/sparse_output_fully_connected.cc
new file mode 100644
index 0000000000..843ed0768c
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/sparse_output_fully_connected.cc
@@ -0,0 +1,235 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// SparseOutputFullyConnected is a fully connected layer that uses a single
+// row in the weights and bias via a lookup.
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace sparse_output_fully_connected {
+
+// Input tensors of size {n_batch, n_input}
+constexpr int kInputTensor = 0;
+// Auxiliary input tensor of size { 1 }
+constexpr int kInputLookupTensor = 1;
+
+// Weights tensor of size { n_embeddings , n_input }
+constexpr int kWeightsTensor = 2;
+// Bias tensor of size { n_embeddings }
+constexpr int kBiasTensor = 3;
+
+// Output tensor.
+constexpr int kOutputTensor = 0;
+
+// Temporary tensors.
+enum TemporaryTensor {
+ kInputQuantized = 0,
+ kScalingFactors = 1,
+ kNumTemporaryTensors = 2
+};
+
+// Struct to hold op data.
+struct OpData {
+ int scratch_tensor_index;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* data = new OpData;
+ context->AddTensors(context, /*tensors_to_add=*/kNumTemporaryTensors,
+ &data->scratch_tensor_index);
+ return data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+
+ TF_LITE_ENSURE_EQ(context, node->inputs->size, 4);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(input), 2);
+ const int n_batch = SizeOfDimension(input, 0);
+ const int n_input = SizeOfDimension(input, 1);
+
+ const TfLiteTensor* lookup = GetInput(context, node, kInputLookupTensor);
+ TF_LITE_ENSURE_EQ(context, lookup->type, kTfLiteInt32);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(lookup), 1);
+ // Only support single lookup.
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(lookup, 0), 1);
+
+ const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, NumDimensions(weights), 2);
+ TF_LITE_ENSURE_EQ(context, SizeOfDimension(weights, 1), n_input);
+
+ const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
+ TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(weights, 0));
+
+ const bool is_hybrid_op =
+ (weights->type == kTfLiteUInt8 && input->type == kTfLiteFloat32);
+
+ if (is_hybrid_op) {
+ TfLiteIntArrayFree(node->temporaries);
+ node->temporaries = TfLiteIntArrayCreate(kNumTemporaryTensors);
+
+ // Allocate temporary tensors to store quantized values of input.
+ node->temporaries->data[kInputQuantized] = op_data->scratch_tensor_index;
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, /*index=*/kInputQuantized);
+ input_quantized->type = kTfLiteUInt8;
+ input_quantized->allocation_type = kTfLiteArenaRw;
+ if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
+ TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
+ input_quantized_size));
+ }
+
+ // Tell interpreter to allocate temporary tensors to store scaling factors.
+ node->temporaries->data[kScalingFactors] =
+ op_data->scratch_tensor_index + kScalingFactors;
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, /*index=*/kScalingFactors);
+ scaling_factors->type = kTfLiteFloat32;
+ scaling_factors->allocation_type = kTfLiteArenaRw;
+ TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
+ scaling_factors_size->data[0] = n_batch;
+ if (!TfLiteIntArrayEqual(scaling_factors->dims, scaling_factors_size)) {
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
+ scaling_factors_size));
+ }
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalFloat(const TfLiteTensor* input, const TfLiteTensor* lookup,
+ const TfLiteTensor* weights, const TfLiteTensor* bias,
+ TfLiteTensor* output) {
+ const int n_batch = SizeOfDimension(input, 0);
+ const int n_input = SizeOfDimension(input, 1);
+
+ const float* input_ptr_batch = input->data.f;
+
+ // Initialize pointer to right row according to lookup value.
+ int32 lookup_index = lookup->data.i32[0];
+ const float* weights_ptr = weights->data.f + lookup_index * n_input;
+
+ // Initialize output to bias.
+ if (bias) {
+ float* bias_ptr = bias->data.f + lookup_index;
+ tensor_utils::VectorBatchVectorAssign(bias_ptr, 1, n_batch, output->data.f);
+ } else {
+ tensor_utils::ZeroVector(output->data.f, n_batch * 1);
+ }
+
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ weights_ptr, /*m_rows=*/1, n_input, input_ptr_batch, n_batch,
+ output->data.f, /*result_stride=*/1);
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus EvalHybrid(const TfLiteTensor* input, const TfLiteTensor* lookup,
+ const TfLiteTensor* weights, const TfLiteTensor* bias,
+ TfLiteTensor* scaling_factors,
+ TfLiteTensor* input_quantized, TfLiteTensor* output) {
+ const int n_batch = SizeOfDimension(input, 0);
+ const int n_input = SizeOfDimension(input, 1);
+
+ const float* input_ptr_batch = input->data.f;
+ // Initialize the pointer to storage for quantized values and
+ // scaling factors.
+ int8_t* quantized_input_ptr_batch =
+ reinterpret_cast<int8_t*>(input_quantized->data.uint8);
+ float* scaling_factors_ptr = scaling_factors->data.f;
+
+ // Initialize pointer to right row according to lookup value.
+ int32 lookup_index = lookup->data.i32[0];
+ int8_t* weights_ptr =
+ reinterpret_cast<int8_t*>(weights->data.uint8) + lookup_index * n_input;
+
+ // Initialize output to bias.
+ if (bias) {
+ float* bias_ptr = bias->data.f + lookup_index;
+ tensor_utils::VectorBatchVectorAssign(bias_ptr, 1, n_batch, output->data.f);
+ } else {
+ tensor_utils::ZeroVector(output->data.f, n_batch * 1);
+ }
+
+ if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
+ // Quantize input from float to int8.
+ float unused_min, unused_max;
+ for (int b = 0; b < n_batch; ++b) {
+ const int offset = b * n_input;
+ tensor_utils::SymmetricQuantizeFloats(
+ input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset,
+ &unused_min, &unused_max, &scaling_factors_ptr[b]);
+ scaling_factors_ptr[b] *= weights->params.scale;
+ }
+
+ tensor_utils::MatrixBatchVectorMultiplyAccumulate(
+ weights_ptr, /*m_rows=*/1, n_input, quantized_input_ptr_batch,
+ scaling_factors_ptr, n_batch, output->data.f, /*result_stride=*/1);
+ }
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* input = GetInput(context, node, kInputTensor);
+ const TfLiteTensor* lookup = GetInput(context, node, kInputLookupTensor);
+ const TfLiteTensor* weights = GetInput(context, node, kWeightsTensor);
+ const TfLiteTensor* bias = GetInput(context, node, kBiasTensor);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+ switch (weights->type) {
+ case kTfLiteFloat32: {
+ return EvalFloat(input, lookup, weights, bias, output);
+ }
+ case kTfLiteUInt8: {
+ TfLiteTensor* input_quantized =
+ GetTemporary(context, node, /*index=*/kInputQuantized);
+ TfLiteTensor* scaling_factors =
+ GetTemporary(context, node, /*index=*/kScalingFactors);
+ return EvalHybrid(input, lookup, weights, bias, scaling_factors,
+ input_quantized, output);
+ }
+ default:
+ context->ReportError(context, "Type %d is not currently supported.",
+ weights->type);
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+} // namespace sparse_output_fully_connected
+
+TfLiteRegistration* Register_SPARSE_OUTPUT_FULLY_CONNECTED() {
+ static TfLiteRegistration r = {sparse_output_fully_connected::Init,
+ sparse_output_fully_connected::Free,
+ sparse_output_fully_connected::Prepare,
+ sparse_output_fully_connected::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/sparse_output_fully_connected_test.cc b/tensorflow/contrib/lite/kernels/sparse_output_fully_connected_test.cc
new file mode 100644
index 0000000000..365986a5c1
--- /dev/null
+++ b/tensorflow/contrib/lite/kernels/sparse_output_fully_connected_test.cc
@@ -0,0 +1,158 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// Unit test for TFLite sparse output fully connected op.
+#include <iomanip>
+#include <random>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+
+namespace tflite {
+
+namespace ops {
+namespace custom {
+
+TfLiteRegistration* Register_SPARSE_OUTPUT_FULLY_CONNECTED();
+
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BaseSparseOutputFullyConnectedOpModel : public SingleOpModel {
+ public:
+ BaseSparseOutputFullyConnectedOpModel(const TensorData& input,
+ const TensorData& weights,
+ const TensorData& output = {
+ TensorType_FLOAT32}) {
+ input_ = AddInput(input);
+ lookup_ = AddInput({TensorType_INT32, {1}});
+ weights_ = AddInput(weights);
+ int bias_size = GetShape(weights_)[0];
+ bias_ = AddInput({TensorType_FLOAT32, {bias_size}});
+ output_ = AddOutput(output);
+
+ // Create empty (required) options map.
+ flexbuffers::Builder fbb;
+ fbb.Map([&]() {});
+ fbb.Finish();
+
+ SetCustomOp("SPARSE_OUTPUT_FULLY_CONNECTED", fbb.GetBuffer(),
+ Register_SPARSE_OUTPUT_FULLY_CONNECTED);
+ BuildInterpreter({GetShape(input_), GetShape(lookup_), GetShape(weights_),
+ GetShape(bias_)});
+ }
+
+ void SetInput(const std::vector<float>& data) {
+ PopulateTensor(input_, data);
+ }
+
+ void SetLookup(const std::vector<int32>& f) { PopulateTensor(lookup_, f); }
+
+ void SetBias(const std::vector<float>& f) { PopulateTensor(bias_, f); }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ protected:
+ int input_;
+ int lookup_;
+ int weights_;
+ int bias_;
+ int output_;
+};
+
+class FloatSparseOutputFullyConnectedOpModel
+ : public BaseSparseOutputFullyConnectedOpModel {
+ public:
+ using BaseSparseOutputFullyConnectedOpModel::
+ BaseSparseOutputFullyConnectedOpModel;
+
+ void SetWeights(const std::vector<float>& f) { PopulateTensor(weights_, f); }
+};
+
+class HybridSparseOutputFullyConnectedOpModel
+ : public BaseSparseOutputFullyConnectedOpModel {
+ public:
+ using BaseSparseOutputFullyConnectedOpModel::
+ BaseSparseOutputFullyConnectedOpModel;
+
+ void SetWeights(const std::vector<float>& f) {
+ SymmetricQuantizeAndPopulate(weights_, f);
+ }
+};
+
+TEST(SparseOutputFullyConnectedOpTest, SimpleTestFloat) {
+ FloatSparseOutputFullyConnectedOpModel m({TensorType_FLOAT32, {1, 5}},
+ {TensorType_FLOAT32, {3, 5}},
+ {TensorType_FLOAT32, {}});
+
+ m.SetInput({-1.0, 0.0, 1.0, 2.0, 3.0});
+
+ m.SetLookup({2});
+
+ m.SetWeights({
+ -1.0, 0.0, 1.0, 2.0, 3.0, //
+ 0.0, 1.0, 2.0, 3.0, 4.0, //
+ 1.0, 2.0, 3.0, 4.0, 5.0, //
+ });
+
+ m.SetBias({1.0, 2.0, 3.0});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({28}));
+}
+
+TEST(SparseOutputFullyConnectedOpTest, SimpleTestHybrid) {
+ HybridSparseOutputFullyConnectedOpModel m({TensorType_FLOAT32, {1, 5}},
+ {TensorType_UINT8, {3, 5}},
+ {TensorType_FLOAT32, {}});
+
+ m.SetInput({-1.0, 0.0, 1.0, 2.0, 3.0});
+
+ m.SetLookup({2});
+
+ m.SetWeights({
+ -1.0, 0.0, 1.0, 2.0, 3.0, //
+ 0.0, 1.0, 2.0, 3.0, 4.0, //
+ 1.0, 2.0, 3.0, 4.0, 5.0, //
+ });
+
+ m.SetBias({1.0, 2.0, 3.0});
+
+ m.Invoke();
+
+ // We get 28.0552 instead of 28.
+ //
+ // Input -> -42, 0, 42, 85, 127 with scale factor of 127/3.
+ // Looked up weights -> 25, 51, 76, 102, 127 with scale factor of 127/5.
+ //
+ // (-42 * 25 + 0 * 51 + 42 * 76 + 85 * 102 + 127 * 127) * (3*5/127^2) + 3.0
+ // gives us the expected result.
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({28}, 0.0553)));
+}
+
+} // namespace
+} // namespace custom
+} // namespace ops
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
index 63817bd886..89d57e4599 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h"
#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/lstm_eval.h"
#include "tensorflow/contrib/lite/kernels/op_macros.h"
namespace tflite {
@@ -429,275 +430,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
return kTfLiteOk;
}
-// The LSTM Op engine.
-TfLiteStatus EvalFloat(
- const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
- const TfLiteTensor* input_to_forget_weights,
- const TfLiteTensor* input_to_cell_weights,
- const TfLiteTensor* input_to_output_weights,
- const TfLiteTensor* recurrent_to_input_weights,
- const TfLiteTensor* recurrent_to_forget_weights,
- const TfLiteTensor* recurrent_to_cell_weights,
- const TfLiteTensor* recurrent_to_output_weights,
- const TfLiteTensor* cell_to_input_weights,
- const TfLiteTensor* cell_to_forget_weights,
- const TfLiteTensor* cell_to_output_weights,
- const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
- const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
- const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
- const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
- TfLiteTensor* activation_state, TfLiteTensor* cell_state,
- TfLiteTensor* output) {
- const int max_time = input->dims->data[0];
- const int n_batch = input->dims->data[1];
- const int n_input = input->dims->data[2];
- // n_cell and n_output will be the same size when there is no projection.
- const int n_cell = input_to_output_weights->dims->data[0];
- const int n_output = recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existence of only one to get the condition.
- const bool use_cifg = (input_to_input_weights == nullptr);
- const bool use_peephole = (cell_to_output_weights != nullptr);
-
- float* input_gate_scratch = nullptr;
- float* cell_scratch = nullptr;
- float* forget_gate_scratch = nullptr;
- float* output_gate_scratch = nullptr;
- if (use_cifg) {
- cell_scratch = scratch_buffer->data.f;
- forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- } else {
- input_gate_scratch = scratch_buffer->data.f;
- cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
- forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- const float* input_to_input_weights_ptr =
- (use_cifg) ? nullptr : input_to_input_weights->data.f;
- const float* recurrent_to_input_weights_ptr =
- (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
- const float* input_gate_bias_ptr =
- (use_cifg) ? nullptr : input_gate_bias->data.f;
- const float* cell_to_input_weights_ptr =
- (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
- const float* cell_to_forget_weights_ptr =
- (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
- const float* cell_to_output_weights_ptr =
- (use_peephole) ? cell_to_output_weights->data.f : nullptr;
- const float* projection_weights_ptr =
- (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
- const float* projection_bias_ptr =
- (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
-
- // Required tensors, pointers are non-null.
- const float* input_to_forget_weights_ptr = input_to_forget_weights->data.f;
- const float* input_to_cell_weights_ptr = input_to_cell_weights->data.f;
- const float* input_to_output_weights_ptr = input_to_output_weights->data.f;
- const float* recurrent_to_forget_weights_ptr =
- recurrent_to_forget_weights->data.f;
- const float* recurrent_to_cell_weights_ptr =
- recurrent_to_cell_weights->data.f;
- const float* recurrent_to_output_weights_ptr =
- recurrent_to_output_weights->data.f;
- const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
- const float* cell_bias_ptr = cell_bias->data.f;
- const float* output_gate_bias_ptr = output_gate_bias->data.f;
-
- float* activation_state_ptr = activation_state->data.f;
- float* cell_state_ptr = cell_state->data.f;
-
- // Feed the sequence into the LSTM step-by-step.
- for (int t = 0; t < max_time; t++) {
- const float* input_ptr_batch = input->data.f + t * n_batch * n_input;
- float* output_ptr_batch = output->data.f + t * n_batch * n_output;
-
- kernel_utils::LstmStep(
- input_ptr_batch, input_to_input_weights_ptr,
- input_to_forget_weights_ptr, input_to_cell_weights_ptr,
- input_to_output_weights_ptr, recurrent_to_input_weights_ptr,
- recurrent_to_forget_weights_ptr, recurrent_to_cell_weights_ptr,
- recurrent_to_output_weights_ptr, cell_to_input_weights_ptr,
- cell_to_forget_weights_ptr, cell_to_output_weights_ptr,
- input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
- output_gate_bias_ptr, projection_weights_ptr, projection_bias_ptr,
- params, n_batch, n_cell, n_input, n_output, activation_state_ptr,
- cell_state_ptr, input_gate_scratch, forget_gate_scratch, cell_scratch,
- output_gate_scratch, output_ptr_batch);
- }
- return kTfLiteOk;
-}
-
-TfLiteStatus EvalHybrid(
- const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
- const TfLiteTensor* input_to_forget_weights,
- const TfLiteTensor* input_to_cell_weights,
- const TfLiteTensor* input_to_output_weights,
- const TfLiteTensor* recurrent_to_input_weights,
- const TfLiteTensor* recurrent_to_forget_weights,
- const TfLiteTensor* recurrent_to_cell_weights,
- const TfLiteTensor* recurrent_to_output_weights,
- const TfLiteTensor* cell_to_input_weights,
- const TfLiteTensor* cell_to_forget_weights,
- const TfLiteTensor* cell_to_output_weights,
- const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
- const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
- const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
- const TfLiteLSTMParams* params, TfLiteTensor* scratch_buffer,
- TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors,
- TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized,
- TfLiteTensor* activation_state_quantized,
- TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state,
- TfLiteTensor* cell_state, TfLiteTensor* output) {
- const int max_time = input->dims->data[0];
- const int n_batch = input->dims->data[1];
- const int n_input = input->dims->data[2];
- // n_cell and n_output will be the same size when there is no projection.
- const int n_cell = input_to_output_weights->dims->data[0];
- const int n_output = recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existence of only one to get the condition.
- const bool use_cifg = (input_to_input_weights == nullptr);
- const bool use_peephole = (cell_to_output_weights != nullptr);
-
- float* input_gate_scratch = nullptr;
- float* cell_scratch = nullptr;
- float* forget_gate_scratch = nullptr;
- float* output_gate_scratch = nullptr;
- if (use_cifg) {
- cell_scratch = scratch_buffer->data.f;
- forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- } else {
- input_gate_scratch = scratch_buffer->data.f;
- cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
- forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- int8_t* input_to_input_weights_ptr = nullptr;
- float input_to_input_weights_scale = 1.0f;
- int8_t* recurrent_to_input_weights_ptr = nullptr;
- float recurrent_to_input_weights_scale = 1.0f;
- float* input_gate_bias_ptr = nullptr;
- if (!use_cifg) {
- input_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
- recurrent_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
- input_gate_bias_ptr = input_gate_bias->data.f;
- input_to_input_weights_scale = input_to_input_weights->params.scale;
- recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
- }
-
- int8_t* cell_to_input_weights_ptr = nullptr;
- int8_t* cell_to_forget_weights_ptr = nullptr;
- int8_t* cell_to_output_weights_ptr = nullptr;
- float cell_to_input_weights_scale = 1.0f;
- float cell_to_forget_weights_scale = 1.0f;
- float cell_to_output_weights_scale = 1.0f;
- if (use_peephole) {
- if (!use_cifg) {
- cell_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
- cell_to_input_weights_scale = cell_to_input_weights->params.scale;
- }
- cell_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
- cell_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
- cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
- cell_to_output_weights_scale = cell_to_output_weights->params.scale;
- }
-
- const int8_t* projection_weights_ptr =
- (projection_weights == nullptr)
- ? nullptr
- : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
- float projection_weights_scale =
- (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
- const float* projection_bias_ptr =
- (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
-
- // Required tensors, pointers are non-null.
- const int8_t* input_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
- const float input_to_forget_weights_scale =
- input_to_forget_weights->params.scale;
- const int8_t* input_to_cell_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
- const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
- const int8_t* input_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
- const float input_to_output_weights_scale =
- input_to_output_weights->params.scale;
- const int8_t* recurrent_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
- const float recurrent_to_forget_weights_scale =
- recurrent_to_forget_weights->params.scale;
- const int8_t* recurrent_to_cell_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
- const float recurrent_to_cell_weights_scale =
- recurrent_to_cell_weights->params.scale;
- const int8_t* recurrent_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
- const float recurrent_to_output_weights_scale =
- recurrent_to_output_weights->params.scale;
- const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
- const float* cell_bias_ptr = cell_bias->data.f;
- const float* output_gate_bias_ptr = output_gate_bias->data.f;
-
- float* activation_state_ptr = activation_state->data.f;
- float* cell_state_ptr = cell_state->data.f;
-
- // Temporary storage for quantized values and scaling factors.
- int8_t* quantized_input_ptr =
- reinterpret_cast<int8_t*>(input_quantized->data.uint8);
- int8_t* quantized_activation_state_ptr =
- reinterpret_cast<int8_t*>(activation_state_quantized->data.uint8);
- int8_t* quantized_cell_state_ptr =
- reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
- float* scaling_factors_ptr = scaling_factors->data.f;
- float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
- float* recovered_cell_weights_ptr = recovered_cell_weights->data.f;
-
- // Feed the sequence into the LSTM step-by-step.
- for (int t = 0; t < max_time; t++) {
- const float* input_ptr_batch = input->data.f + t * n_batch * n_input;
- float* output_ptr_batch = output->data.f + t * n_batch * n_output;
-
- kernel_utils::LstmStep(
- input_ptr_batch, input_to_input_weights_ptr,
- input_to_input_weights_scale, input_to_forget_weights_ptr,
- input_to_forget_weights_scale, input_to_cell_weights_ptr,
- input_to_cell_weights_scale, input_to_output_weights_ptr,
- input_to_output_weights_scale, recurrent_to_input_weights_ptr,
- recurrent_to_input_weights_scale, recurrent_to_forget_weights_ptr,
- recurrent_to_forget_weights_scale, recurrent_to_cell_weights_ptr,
- recurrent_to_cell_weights_scale, recurrent_to_output_weights_ptr,
- recurrent_to_output_weights_scale, cell_to_input_weights_ptr,
- cell_to_input_weights_scale, cell_to_forget_weights_ptr,
- cell_to_forget_weights_scale, cell_to_output_weights_ptr,
- cell_to_output_weights_scale, input_gate_bias_ptr, forget_gate_bias_ptr,
- cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
- projection_weights_scale, projection_bias_ptr, params, n_batch, n_cell,
- n_input, n_output, input_gate_scratch, forget_gate_scratch,
- cell_scratch, output_gate_scratch, scaling_factors_ptr,
- prod_scaling_factors_ptr, recovered_cell_weights_ptr,
- quantized_input_ptr, quantized_activation_state_ptr,
- quantized_cell_state_ptr, activation_state_ptr, cell_state_ptr,
- output_ptr_batch);
- }
- return kTfLiteOk;
-}
-
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const auto* params =
+ reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
+ node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input_to_input_weights =
@@ -748,17 +484,29 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ // Copy out the LSTM specific params so they can be passed in the function.
+ TfLiteLSTMParams lstm_params;
+ lstm_params.activation = params->activation;
+ lstm_params.cell_clip = params->cell_clip;
+ lstm_params.proj_clip = params->proj_clip;
+
switch (input_to_output_weights->type) {
case kTfLiteFloat32: {
- return EvalFloat(input, input_to_input_weights, input_to_forget_weights,
- input_to_cell_weights, input_to_output_weights,
- recurrent_to_input_weights, recurrent_to_forget_weights,
- recurrent_to_cell_weights, recurrent_to_output_weights,
- cell_to_input_weights, cell_to_forget_weights,
- cell_to_output_weights, input_gate_bias,
- forget_gate_bias, cell_bias, output_gate_bias,
- projection_weights, projection_bias, params,
- scratch_buffer, activation_state, cell_state, output);
+ return lstm_eval::EvalFloat(
+ input, input_to_input_weights, input_to_forget_weights,
+ input_to_cell_weights, input_to_output_weights,
+ recurrent_to_input_weights, recurrent_to_forget_weights,
+ recurrent_to_cell_weights, recurrent_to_output_weights,
+ cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
+ /*aux_input=*/nullptr,
+ /*aux_input_to_input_weights=*/nullptr,
+ /*aux_input_to_forget_weights=*/nullptr,
+ /*aux_input_to_cell_weights=*/nullptr,
+ /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
+ projection_bias, &lstm_params, /*forward_sequence=*/true,
+ /*output_offset=*/0, scratch_buffer, activation_state, cell_state,
+ output);
}
case kTfLiteUInt8: {
TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
@@ -771,17 +519,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetTemporary(context, node, /*index=*/5);
TfLiteTensor* recovered_cell_weights =
GetTemporary(context, node, /*index=*/6);
- return EvalHybrid(
+ return lstm_eval::EvalHybrid(
input, input_to_input_weights, input_to_forget_weights,
input_to_cell_weights, input_to_output_weights,
recurrent_to_input_weights, recurrent_to_forget_weights,
recurrent_to_cell_weights, recurrent_to_output_weights,
cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
- input_gate_bias, forget_gate_bias, cell_bias, output_gate_bias,
- projection_weights, projection_bias, params, scratch_buffer,
- scaling_factors, prod_scaling_factors, recovered_cell_weights,
- input_quantized, activation_state_quantized, cell_state_quantized,
- activation_state, cell_state, output);
+ /*aux_input=*/nullptr,
+ /*aux_input_to_input_weights=*/nullptr,
+ /*aux_input_to_forget_weights=*/nullptr,
+ /*aux_input_to_cell_weights=*/nullptr,
+ /*aux_input_to_output_weights=*/nullptr, input_gate_bias,
+ forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
+ projection_bias, &lstm_params, /*forward_sequence=*/true,
+ /*output_offset=*/0, scratch_buffer, scaling_factors,
+ prod_scaling_factors, recovered_cell_weights, input_quantized,
+ /*aux_input_quantized=*/nullptr, activation_state_quantized,
+ cell_state_quantized, activation_state, cell_state, output);
}
default:
context->ReportError(context, "Type %d is not currently supported.",
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
index cd3aac0532..c97b0fdd61 100644
--- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc
@@ -110,11 +110,12 @@ class UnidirectionalLSTMOpModel : public SingleOpModel {
output_ = AddOutput(TensorType_FLOAT32);
- SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
- BuiltinOptions_LSTMOptions,
- CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
- cell_clip, proj_clip)
- .Union());
+ SetBuiltinOp(
+ BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
+ BuiltinOptions_UnidirectionalSequenceLSTMOptions,
+ CreateUnidirectionalSequenceLSTMOptions(
+ builder_, ActivationFunctionType_TANH, cell_clip, proj_clip)
+ .Union());
BuildInterpreter(input_shapes);
}
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index d50c345194..d7b109ac1a 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -27,9 +27,6 @@ limitations under the License.
#ifndef TFLITE_MCU
#include "tensorflow/contrib/lite/nnapi_delegate.h"
#endif
-#if defined(TFLITE_FLEX)
-#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
-#endif
#include "tensorflow/contrib/lite/version.h"
namespace tflite {
@@ -43,6 +40,25 @@ ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
const char* kEmptyTensorName = "";
+// Normally we'd use ABSL_HAVE_ATTRIBUTE_WEAK and ABSL_ATTRIBUTE_WEAK, but
+// we avoid the absl dependency for binary size reasons.
+#ifdef __has_attribute
+#define TFLITE_HAS_ATTRIBUTE(x) __has_attribute(x)
+#else
+#define TFLITE_HAS_ATTRIBUTE(x) 0
+#endif
+
+#if TFLITE_HAS_ATTRIBUTE(weak) || (defined(__GNUC__) && !defined(__clang__))
+// Using weak symbols for the flex delegate allows automatic injection of the
+// delegate simply by adding it as a dependency. See also the strong override in
+// lite/delegates/flex/delegate.cc.
+__attribute__((weak)) Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() {
+ return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
+}
+#else
+Interpreter::TfLiteDelegatePtr (*AcquireFlexDelegate)() = nullptr;
+#endif
+
#ifndef TFLITE_MCU
// Loads a model from `filename`. If `mmap_file` is true then use mmap,
// otherwise make a copy of the model in a buffer.
@@ -450,13 +466,14 @@ TfLiteStatus InterpreterBuilder::operator()(
}
(**interpreter).SetVariables(std::move(variables));
-#if defined(TFLITE_FLEX)
- if (auto delegate = FlexDelegate::Create()) {
- (**interpreter)
- .ModifyGraphWithDelegate(std::move(delegate),
- /*allow_dynamic_tensors=*/true);
+ // TODO(b/116667551): Only create the flex delegate if the model has flex ops.
+ if (AcquireFlexDelegate != nullptr) {
+ if (auto flex_delegate = AcquireFlexDelegate()) {
+ (**interpreter)
+ .ModifyGraphWithDelegate(std::move(flex_delegate),
+ /*allow_dynamic_tensors=*/true);
+ }
}
-#endif
return kTfLiteOk;
}
diff --git a/tensorflow/contrib/lite/model_flex_test.cc b/tensorflow/contrib/lite/model_flex_test.cc
new file mode 100644
index 0000000000..52e76bee49
--- /dev/null
+++ b/tensorflow/contrib/lite/model_flex_test.cc
@@ -0,0 +1,45 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/model.h"
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/testing/util.h"
+
+namespace tflite {
+
+// Ensures that a model with TensorFlow ops can be imported as long as the
+// appropriate delegate is linked into the client.
+TEST(FlexModel, WithFlexDelegate) {
+ auto model = FlatBufferModel::BuildFromFile(
+ "tensorflow/contrib/lite/testdata/multi_add_flex.bin");
+ ASSERT_TRUE(model);
+
+ std::unique_ptr<Interpreter> interpreter;
+ ASSERT_EQ(InterpreterBuilder(*model,
+ ops::builtin::BuiltinOpResolver{})(&interpreter),
+ kTfLiteOk);
+ ASSERT_TRUE(interpreter);
+
+ ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc
index ec7d46af7c..b969bea5dc 100644
--- a/tensorflow/contrib/lite/model_test.cc
+++ b/tensorflow/contrib/lite/model_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include <gtest/gtest.h>
#include "tensorflow/contrib/lite/core/api/error_reporter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/testing/util.h"
// Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object,
@@ -193,6 +194,27 @@ TEST(BasicFlatBufferModel, TestModelInInterpreter) {
}
}
+// Test that loading a model with TensorFlow ops fails when the flex delegate is
+// not linked into the target.
+TEST(FlexModel, FailureWithoutFlexDelegate) {
+ auto model = FlatBufferModel::BuildFromFile(
+ "tensorflow/contrib/lite/testdata/multi_add_flex.bin");
+ ASSERT_TRUE(model);
+
+ // Note that creation will succeed when using the BuiltinOpResolver, but
+ // unless the appropriate delegate is linked into the target or the client
+ // explicitly installs the delegate, execution will fail.
+ std::unique_ptr<Interpreter> interpreter;
+ ASSERT_EQ(InterpreterBuilder(*model,
+ ops::builtin::BuiltinOpResolver{})(&interpreter),
+ kTfLiteOk);
+ ASSERT_TRUE(interpreter);
+
+ // As the flex ops weren't resolved implicitly by the flex delegate, runtime
+ // allocation and execution will fail.
+ ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteError);
+}
+
// This tests on a flatbuffer that defines a shape of 2 to be a memory mapped
// buffer. But the buffer is provided to be only 1 element.
TEST(BasicFlatBufferModel, TestBrokenMmap) {
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index ff8430827c..cb7a282743 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -250,6 +250,7 @@ union BuiltinOptions {
FillOptions,
BidirectionalSequenceLSTMOptions,
BidirectionalSequenceRNNOptions,
+ UnidirectionalSequenceLSTMOptions,
}
enum Padding : byte { SAME, VALID }
@@ -394,6 +395,13 @@ table LSTMOptions {
kernel_type: LSTMKernelType = FULL;
}
+// An implementation of TensorFlow dynamic_rnn with LSTMCell.
+table UnidirectionalSequenceLSTMOptions {
+ fused_activation_function:ActivationFunctionType;
+ cell_clip: float; // Optional, 0.0 means no clipping
+ proj_clip: float; // Optional, 0.0 means no clipping
+}
+
table BidirectionalSequenceLSTMOptions {
fused_activation_function:ActivationFunctionType;
cell_clip: float; // Optional, 0.0 means no clipping
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index f3cb113c9c..e7b7a59def 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -79,6 +79,9 @@ struct LocalResponseNormalizationOptionsT;
struct LSTMOptions;
struct LSTMOptionsT;
+struct UnidirectionalSequenceLSTMOptions;
+struct UnidirectionalSequenceLSTMOptionsT;
+
struct BidirectionalSequenceLSTMOptions;
struct BidirectionalSequenceLSTMOptionsT;
@@ -681,11 +684,12 @@ enum BuiltinOptions {
BuiltinOptions_FillOptions = 68,
BuiltinOptions_BidirectionalSequenceLSTMOptions = 69,
BuiltinOptions_BidirectionalSequenceRNNOptions = 70,
+ BuiltinOptions_UnidirectionalSequenceLSTMOptions = 71,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_BidirectionalSequenceRNNOptions
+ BuiltinOptions_MAX = BuiltinOptions_UnidirectionalSequenceLSTMOptions
};
-inline const BuiltinOptions (&EnumValuesBuiltinOptions())[71] {
+inline const BuiltinOptions (&EnumValuesBuiltinOptions())[72] {
static const BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -757,7 +761,8 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[71] {
BuiltinOptions_ZerosLikeOptions,
BuiltinOptions_FillOptions,
BuiltinOptions_BidirectionalSequenceLSTMOptions,
- BuiltinOptions_BidirectionalSequenceRNNOptions
+ BuiltinOptions_BidirectionalSequenceRNNOptions,
+ BuiltinOptions_UnidirectionalSequenceLSTMOptions
};
return values;
}
@@ -835,6 +840,7 @@ inline const char * const *EnumNamesBuiltinOptions() {
"FillOptions",
"BidirectionalSequenceLSTMOptions",
"BidirectionalSequenceRNNOptions",
+ "UnidirectionalSequenceLSTMOptions",
nullptr
};
return names;
@@ -1129,6 +1135,10 @@ template<> struct BuiltinOptionsTraits<BidirectionalSequenceRNNOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_BidirectionalSequenceRNNOptions;
};
+template<> struct BuiltinOptionsTraits<UnidirectionalSequenceLSTMOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_UnidirectionalSequenceLSTMOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -1720,6 +1730,14 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_BidirectionalSequenceRNNOptions ?
reinterpret_cast<const BidirectionalSequenceRNNOptionsT *>(value) : nullptr;
}
+ UnidirectionalSequenceLSTMOptionsT *AsUnidirectionalSequenceLSTMOptions() {
+ return type == BuiltinOptions_UnidirectionalSequenceLSTMOptions ?
+ reinterpret_cast<UnidirectionalSequenceLSTMOptionsT *>(value) : nullptr;
+ }
+ const UnidirectionalSequenceLSTMOptionsT *AsUnidirectionalSequenceLSTMOptions() const {
+ return type == BuiltinOptions_UnidirectionalSequenceLSTMOptions ?
+ reinterpret_cast<const UnidirectionalSequenceLSTMOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -3469,6 +3487,84 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(
flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct UnidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable {
+ typedef UnidirectionalSequenceLSTMOptions TableType;
+ ActivationFunctionType fused_activation_function;
+ float cell_clip;
+ float proj_clip;
+ UnidirectionalSequenceLSTMOptionsT()
+ : fused_activation_function(ActivationFunctionType_NONE),
+ cell_clip(0.0f),
+ proj_clip(0.0f) {
+ }
+};
+
+struct UnidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef UnidirectionalSequenceLSTMOptionsT NativeTableType;
+ enum {
+ VT_FUSED_ACTIVATION_FUNCTION = 4,
+ VT_CELL_CLIP = 6,
+ VT_PROJ_CLIP = 8
+ };
+ ActivationFunctionType fused_activation_function() const {
+ return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
+ }
+ float cell_clip() const {
+ return GetField<float>(VT_CELL_CLIP, 0.0f);
+ }
+ float proj_clip() const {
+ return GetField<float>(VT_PROJ_CLIP, 0.0f);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
+ VerifyField<float>(verifier, VT_CELL_CLIP) &&
+ VerifyField<float>(verifier, VT_PROJ_CLIP) &&
+ verifier.EndTable();
+ }
+ UnidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(UnidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct UnidirectionalSequenceLSTMOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_fused_activation_function(ActivationFunctionType fused_activation_function) {
+ fbb_.AddElement<int8_t>(UnidirectionalSequenceLSTMOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
+ }
+ void add_cell_clip(float cell_clip) {
+ fbb_.AddElement<float>(UnidirectionalSequenceLSTMOptions::VT_CELL_CLIP, cell_clip, 0.0f);
+ }
+ void add_proj_clip(float proj_clip) {
+ fbb_.AddElement<float>(UnidirectionalSequenceLSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f);
+ }
+ explicit UnidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ UnidirectionalSequenceLSTMOptionsBuilder &operator=(const UnidirectionalSequenceLSTMOptionsBuilder &);
+ flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<UnidirectionalSequenceLSTMOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> CreateUnidirectionalSequenceLSTMOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE,
+ float cell_clip = 0.0f,
+ float proj_clip = 0.0f) {
+ UnidirectionalSequenceLSTMOptionsBuilder builder_(_fbb);
+ builder_.add_proj_clip(proj_clip);
+ builder_.add_cell_clip(cell_clip);
+ builder_.add_fused_activation_function(fused_activation_function);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> CreateUnidirectionalSequenceLSTMOptions(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct BidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable {
typedef BidirectionalSequenceLSTMOptions TableType;
ActivationFunctionType fused_activation_function;
@@ -6488,6 +6584,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const BidirectionalSequenceRNNOptions *builtin_options_as_BidirectionalSequenceRNNOptions() const {
return builtin_options_type() == BuiltinOptions_BidirectionalSequenceRNNOptions ? static_cast<const BidirectionalSequenceRNNOptions *>(builtin_options()) : nullptr;
}
+ const UnidirectionalSequenceLSTMOptions *builtin_options_as_UnidirectionalSequenceLSTMOptions() const {
+ return builtin_options_type() == BuiltinOptions_UnidirectionalSequenceLSTMOptions ? static_cast<const UnidirectionalSequenceLSTMOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -6799,6 +6898,10 @@ template<> inline const BidirectionalSequenceRNNOptions *Operator::builtin_optio
return builtin_options_as_BidirectionalSequenceRNNOptions();
}
+template<> inline const UnidirectionalSequenceLSTMOptions *Operator::builtin_options_as<UnidirectionalSequenceLSTMOptions>() const {
+ return builtin_options_as_UnidirectionalSequenceLSTMOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -7809,6 +7912,38 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(flatbuffers::FlatBuffe
_kernel_type);
}
+inline UnidirectionalSequenceLSTMOptionsT *UnidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new UnidirectionalSequenceLSTMOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void UnidirectionalSequenceLSTMOptions::UnPackTo(UnidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = fused_activation_function(); _o->fused_activation_function = _e; };
+ { auto _e = cell_clip(); _o->cell_clip = _e; };
+ { auto _e = proj_clip(); _o->proj_clip = _e; };
+}
+
+inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> UnidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateUnidirectionalSequenceLSTMOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<UnidirectionalSequenceLSTMOptions> CreateUnidirectionalSequenceLSTMOptions(flatbuffers::FlatBufferBuilder &_fbb, const UnidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const UnidirectionalSequenceLSTMOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _fused_activation_function = _o->fused_activation_function;
+ auto _cell_clip = _o->cell_clip;
+ auto _proj_clip = _o->proj_clip;
+ return tflite::CreateUnidirectionalSequenceLSTMOptions(
+ _fbb,
+ _fused_activation_function,
+ _cell_clip,
+ _proj_clip);
+}
+
inline BidirectionalSequenceLSTMOptionsT *BidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new BidirectionalSequenceLSTMOptionsT();
UnPackTo(_o, _resolver);
@@ -9620,6 +9755,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const BidirectionalSequenceRNNOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_UnidirectionalSequenceLSTMOptions: {
+ auto ptr = reinterpret_cast<const UnidirectionalSequenceLSTMOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -9918,6 +10057,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const BidirectionalSequenceRNNOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_UnidirectionalSequenceLSTMOptions: {
+ auto ptr = reinterpret_cast<const UnidirectionalSequenceLSTMOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -10204,6 +10347,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const BidirectionalSequenceRNNOptionsT *>(value);
return CreateBidirectionalSequenceRNNOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_UnidirectionalSequenceLSTMOptions: {
+ auto ptr = reinterpret_cast<const UnidirectionalSequenceLSTMOptionsT *>(value);
+ return CreateUnidirectionalSequenceLSTMOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -10490,6 +10637,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new BidirectionalSequenceRNNOptionsT(*reinterpret_cast<BidirectionalSequenceRNNOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_UnidirectionalSequenceLSTMOptions: {
+ value = new UnidirectionalSequenceLSTMOptionsT(*reinterpret_cast<UnidirectionalSequenceLSTMOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -10847,6 +10998,11 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_UnidirectionalSequenceLSTMOptions: {
+ auto ptr = reinterpret_cast<UnidirectionalSequenceLSTMOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
diff --git a/tensorflow/contrib/lite/testdata/multi_add_flex.bin b/tensorflow/contrib/lite/testdata/multi_add_flex.bin
new file mode 100644
index 0000000000..9aac2155fe
--- /dev/null
+++ b/tensorflow/contrib/lite/testdata/multi_add_flex.bin
Binary files differ
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
index 3e57d3f467..f7e5aa6609 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
@@ -191,14 +191,6 @@ void EvaluateBinaryOperatorOnConstantInputs(Model* model,
bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
const auto binary_it = model->operators.begin() + op_index;
const auto* binary_op = binary_it->get();
-
- // If the output of this op is a non-discardable array such as an input_array
- // or a state array of the model, then this is a job for RemoveUnusedOp, not
- // for constants-propagation.
- if (!IsDiscardableArray(*model, binary_op->outputs[0])) {
- return false;
- }
-
// Test for binary ops of types that we know how to resolve
if (binary_op->type != OperatorType::kAdd &&
binary_op->type != OperatorType::kMul &&
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc
index c6c5035a51..d916ae0ddf 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc
@@ -144,13 +144,6 @@ bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) {
const auto* concat_op =
static_cast<const ConcatenationOperator*>(concat_base_op);
- // If the output of this op is a non-discardable array such as an input_array
- // or a state array of the model, then this is a job for RemoveUnusedOp, not
- // for constants-propagation.
- if (!IsDiscardableArray(*model, concat_op->outputs[0])) {
- return false;
- }
-
for (const string& input_name : concat_op->inputs) {
// We only expect constant unquantized arrays as input, otherwise we return.
// We also make sure the shapes of the input arrays are known and they are
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
index 3d797533c9..f5f2f77460 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
@@ -69,13 +69,6 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
const auto* fakequant_op =
static_cast<const FakeQuantOperator*>(fakequant_base_op);
- // If the output of this op is a non-discardable array such as an input_array
- // or a state array of the model, then this is a job for RemoveUnusedOp, not
- // for constants-propagation.
- if (!IsDiscardableArray(*model, fakequant_op->outputs[0])) {
- return false;
- }
-
// Yield until the fakequant MinMax has been resolved.
if (!fakequant_op->minmax) {
return false;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc
index 2cb1e64f3a..f6f95481b5 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fill.cc
@@ -52,13 +52,6 @@ bool ResolveConstantFill::Run(Model* model, std::size_t op_index) {
CHECK_EQ(op->inputs.size(), 2);
CHECK_EQ(op->outputs.size(), 1);
- // If the output of this op is a non-discardable array such as an input_array
- // or a state array of the model, then this is a job for RemoveUnusedOp, not
- // for constants-propagation.
- if (!IsDiscardableArray(*model, op->outputs[0])) {
- return false;
- }
-
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc
index 4dfe203a25..36d7dad0ce 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_gather.cc
@@ -71,14 +71,6 @@ bool ResolveConstantGather::Run(Model* model, std::size_t op_index) {
CHECK_GE(op->inputs.size(), 2);
CHECK_EQ(op->outputs.size(), 1);
-
- // If the output of this op is a non-discardable array such as an input_array
- // or a state array of the model, then this is a job for RemoveUnusedOp, not
- // for constants-propagation.
- if (!IsDiscardableArray(*model, op->outputs[0])) {
- return false;
- }
-
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes.
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_pack.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_pack.cc
index 6f44025dd4..e86616574d 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_pack.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_pack.cc
@@ -59,14 +59,6 @@ bool ResolveConstantPack::Run(Model* model, std::size_t op_index) {
CHECK_GE(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
-
- // If the output of this op is a non-discardable array such as an input_array
- // or a state array of the model, then this is a job for RemoveUnusedOp, not
- // for constants-propagation.
- if (!IsDiscardableArray(*model, op->outputs[0])) {
- return false;
- }
-
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc
index c9f2b95d09..88d06d7dc7 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_random_uniform.cc
@@ -70,13 +70,6 @@ bool ResolveConstantRandomUniform::Run(Model* model, std::size_t op_index) {
CHECK_EQ(op->inputs.size(), 1);
CHECK_EQ(op->outputs.size(), 1);
- // If the output of this op is a non-discardable array such as an input_array
- // or a state array of the model, then this is a job for RemoveUnusedOp, not
- // for constants-propagation.
- if (!IsDiscardableArray(*model, op->outputs[0])) {
- return false;
- }
-
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc
index e347286dd4..1a0ba9e2bc 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_range.cc
@@ -28,14 +28,6 @@ bool ResolveConstantRange::Run(Model* model, std::size_t op_index) {
auto* op = static_cast<RangeOperator*>(base_op);
CHECK_EQ(op->inputs.size(), 3);
-
- // If the output of this op is a non-discardable array such as an input_array
- // or a state array of the model, then this is a job for RemoveUnusedOp, not
- // for constants-propagation.
- if (!IsDiscardableArray(*model, op->outputs[0])) {
- return false;
- }
-
const auto& start_array = model->GetArray(op->inputs[0]);
if (!start_array.has_shape()) {
// Yield until all input dims have been resolved.
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc
index bfdaa8aafd..a6f665b5f0 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc
@@ -33,13 +33,6 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
CHECK_EQ(op->inputs.size(), 2);
CHECK_EQ(op->outputs.size(), 1);
- // If the output of this op is a non-discardable array such as an input_array
- // or a state array of the model, then this is a job for RemoveUnusedOp, not
- // for constants-propagation.
- if (!IsDiscardableArray(*model, op->outputs[0])) {
- return false;
- }
-
// We require constant inputs.
if (!IsConstantParameterArray(*model, op->inputs[0]) ||
!IsConstantParameterArray(*model, op->inputs[1])) {
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc
index 3a95d39cd4..e880a3f44d 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_select.cc
@@ -37,14 +37,6 @@ bool ResolveConstantSelect::Run(Model* model, std::size_t op_index) {
CHECK_GE(op->inputs.size(), 3);
CHECK_EQ(op->outputs.size(), 1);
-
- // If the output of this op is a non-discardable array such as an input_array
- // or a state array of the model, then this is a job for RemoveUnusedOp, not
- // for constants-propagation.
- if (!IsDiscardableArray(*model, op->outputs[0])) {
- return false;
- }
-
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes.
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc
index 452bef1f16..8a0e3e8995 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_shape_or_rank.cc
@@ -27,14 +27,6 @@ bool ResolveConstantShapeOrRank::Run(Model* model, std::size_t op_index) {
}
CHECK_EQ(op->outputs.size(), 1);
-
- // If the output of this op is a non-discardable array such as an input_array
- // or a state array of the model, then this is a job for RemoveUnusedOp, not
- // for constants-propagation.
- if (!IsDiscardableArray(*model, op->outputs[0])) {
- return false;
- }
-
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been resolved
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc
index 58d6797e1c..b35c3e19c4 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc
@@ -96,14 +96,6 @@ bool ResolveConstantSlice::Run(Model* model, std::size_t op_index) {
const SliceOperator* op = static_cast<const SliceOperator*>(base_op);
CHECK_EQ(op->outputs.size(), 1);
-
- // If the output of this op is a non-discardable array such as an input_array
- // or a state array of the model, then this is a job for RemoveUnusedOp, not
- // for constants-propagation.
- if (!IsDiscardableArray(*model, op->outputs[0])) {
- return false;
- }
-
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes.
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
index e275447a0c..8853ed87e6 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
@@ -114,14 +114,6 @@ bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) {
static_cast<const StridedSliceOperator*>(base_op);
CHECK_EQ(op->outputs.size(), 1);
-
- // If the output of this op is a non-discardable array such as an input_array
- // or a state array of the model, then this is a job for RemoveUnusedOp, not
- // for constants-propagation.
- if (!IsDiscardableArray(*model, op->outputs[0])) {
- return false;
- }
-
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc
index 378a38f14b..5cfa1a5582 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tile.cc
@@ -105,13 +105,6 @@ bool ResolveConstantTile::Run(Model* model, std::size_t op_index) {
}
const auto* op = static_cast<const TensorFlowTileOperator*>(base_op);
- // If the output of this op is a non-discardable array such as an input_array
- // or a state array of the model, then this is a job for RemoveUnusedOp, not
- // for constants-propagation.
- if (!IsDiscardableArray(*model, op->outputs[0])) {
- return false;
- }
-
CHECK_GE(op->inputs.size(), 2);
CHECK_EQ(op->outputs.size(), 1);
auto& output_array = model->GetArray(op->outputs[0]);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc
index 5d3f4a6240..fe15dfa06f 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_transpose.cc
@@ -111,14 +111,6 @@ bool ResolveConstantTranspose::Run(Model* model, std::size_t op_index) {
CHECK_EQ(op->inputs.size(), 2);
CHECK_EQ(op->outputs.size(), 1);
-
- // If the output of this op is a non-discardable array such as an input_array
- // or a state array of the model, then this is a job for RemoveUnusedOp, not
- // for constants-propagation.
- if (!IsDiscardableArray(*model, op->outputs[0])) {
- return false;
- }
-
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes.
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
index e35ed0898b..5364eebbc9 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
@@ -27,6 +27,73 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
namespace toco {
+namespace {
+
+// Using the function reducer, reduce input along all axes in axes.
+// Put the reduced data in output, which should aleady be appropriately sized.
+// check_output_shape is set to what this code computes the final shape
+// to be, so it can be cross checked with the shape computation logic.
+void ReduceGeneric(bool keep_dims, const std::vector<int>& axes,
+ const Shape& input_shape, const std::vector<float>& input,
+ Shape* check_output_shape, std::vector<float>* output,
+ const std::function<float(float, float)>& reducer) {
+ if (!IsNonEmpty(input_shape)) {
+ // Zero-dimensions will break the NextIndices() logic, so just early out if
+ // we have an empty shape.
+ return;
+ }
+
+ // Set up output_shape to be the same length as input_shape, with
+ // appropriate dimensions squashed to 1. If keep_dims is false, we'll strip
+ // out the one dimensions at the end, but it's convenient to leave them for
+ // now. We recompute the shape because we need the output shape to have
+ // 1-dims in all the squashed dimensions; the shape from shape computation may
+ // remove those squashed dimensions, depending on the options used.
+ Shape output_shape = input_shape;
+
+ // Reduction mask will be elementwise multiplied against the input
+ // indices to figure out the output index for the element.
+ std::vector<int> reduction_mask(input_shape.dimensions_count(), 1);
+ for (int axis : axes) {
+ CHECK_GE(axis, 0);
+ CHECK_LT(axis, input_shape.dimensions_count());
+ reduction_mask[axis] = 0;
+ output_shape.mutable_dims()->at(axis) = 1;
+ }
+
+ std::vector<int> output_indices(input_shape.dimensions_count());
+ for (int input_offset = 0; input_offset < input.size(); ++input_offset) {
+ std::vector<int> input_indices = ReverseOffset(input_shape, input_offset);
+ // Calculate the output location by squashing input indices to 0
+ // in reduced axes.
+ for (int i = 0; i < input_shape.dimensions_count(); ++i) {
+ output_indices[i] = input_indices[i] * reduction_mask[i];
+ }
+ int output_offset = Offset(output_shape, output_indices);
+ if (input_indices == output_indices) {
+ // Base element for the reduced axes
+ output->at(output_offset) = input.at(input_offset);
+ } else {
+ // Reduce with existing element.
+ output->at(output_offset) =
+ reducer(output->at(output_offset), input.at(input_offset));
+ }
+ }
+
+ if (!keep_dims) {
+ // Strip out the dims from output_shape.
+ std::vector<int> new_dims;
+ for (int i = 0; i < output_shape.dimensions_count(); ++i) {
+ if (reduction_mask[i]) {
+ new_dims.push_back(output_shape.dims(i));
+ }
+ }
+ output_shape.mutable_dims()->swap(new_dims);
+ }
+ *check_output_shape = output_shape;
+}
+
+} // namespace
bool CopyMinMaxFromFirstInput(const Operator& op, Model* model) {
auto& output_array = model->GetArray(op.outputs[0]);
@@ -48,14 +115,6 @@ bool CopyMinMaxFromFirstInput(const Operator& op, Model* model) {
bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
const auto unary_it = model->operators.begin() + op_index;
const auto* unary_op = unary_it->get();
-
- // If the output of this op is a non-discardable array such as an input_array
- // or a state array of the model, then this is a job for RemoveUnusedOp, not
- // for constants-propagation.
- if (!IsDiscardableArray(*model, unary_op->outputs[0])) {
- return false;
- }
-
// Test for unary ops of types that we know how to resolve.
switch (unary_op->type) {
case OperatorType::kCast:
@@ -184,27 +243,19 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
}
auto& axis_array = model->GetArray(unary_op->inputs[1]);
CHECK(axis_array.data_type == ArrayDataType::kInt32);
- int axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
- CHECK_LT(axis, input_shape.dimensions_count()) << "Axis out of bounds";
- // We currently only handle reduction on axis 0.
- CHECK_EQ(axis, 0) << "Only reduction along axis 0 is supported";
- // We currently only handle 1-D and 2-D input tensors.
- CHECK_LE(input_shape.dimensions_count(), 2) << "Rank >2 not yet supported";
// We only support keep_dims=true; shape prop will need to change otherwise.
auto sum_op = static_cast<const TensorFlowSumOperator*>(unary_op);
- CHECK(sum_op->keep_dims) << "Only keep_dims=true is supported";
+ Shape check_output_shape;
- std::vector<int> indices(input_shape.dimensions_count());
- for (int i = 0; i < input_shape.dims(1); ++i) {
- indices[1] = i;
- float sum = 0.f;
- for (int j = 0; j < input_shape.dims(0); ++j) {
- indices[0] = j;
- sum += (*input_float_data)[Offset(input_shape, indices)];
- }
- output_float_data[i] = sum;
- }
+ ReduceGeneric(
+ sum_op->keep_dims, axis_array.GetBuffer<ArrayDataType::kInt32>().data,
+ input_shape, *input_float_data, &check_output_shape, &output_float_data,
+ [](float existing, float current) -> float {
+ return existing + current;
+ });
+ CHECK(check_output_shape == output_shape)
+ << "Shape propagation output shape doesn't match output shape from op";
} else if (unary_op->type == OperatorType::kReduceMin) {
// At the moment only full reduction across all dimensions is supported.
// TODO(starka): Output should not be padded.
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
index acf1e3ede5..6f1be298ca 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
@@ -30,3 +30,16 @@ tf_cc_test(
"@com_google_googletest//:gtest_main",
],
)
+
+tf_cc_test(
+ name = "resolve_constant_unary_test",
+ srcs = ["resolve_constant_unary_test.cc"],
+ tags = ["no_oss"],
+ deps = [
+ "//tensorflow/contrib/lite/toco:graph_transformations",
+ "//tensorflow/contrib/lite/toco:model",
+ "//tensorflow/contrib/lite/toco:tooling_util",
+ "@com_google_absl//absl/memory",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc
new file mode 100644
index 0000000000..a53abc9941
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_unary_test.cc
@@ -0,0 +1,140 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <tuple>
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+namespace toco {
+
+namespace {
+
+void RunResolveSum(const std::vector<float>& input,
+ const std::vector<int>& input_shape,
+ const std::vector<int>& axis,
+ const std::vector<int>& output_shape,
+ const std::vector<float>& expected_output) {
+ Model model;
+ Array& input0 = model.GetOrCreateArray("input0");
+ Array& input1 = model.GetOrCreateArray("input1");
+ Array& output = model.GetOrCreateArray("output");
+
+ *input0.mutable_shape()->mutable_dims() = input_shape;
+ input0.data_type = ArrayDataType::kFloat;
+ input0.GetMutableBuffer<ArrayDataType::kFloat>().data = input;
+
+ *input1.mutable_shape()->mutable_dims() = {static_cast<int>(axis.size())};
+ input1.GetMutableBuffer<ArrayDataType::kInt32>().data = axis;
+ input1.data_type = ArrayDataType::kInt32;
+
+ *output.mutable_shape()->mutable_dims() = output_shape;
+
+ auto sum_op = absl::make_unique<TensorFlowSumOperator>();
+ sum_op->keep_dims = true;
+ sum_op->inputs = {"input0", "input1"};
+ sum_op->outputs = {"output"};
+ model.operators.push_back(std::move(sum_op));
+ ResolveConstantUnaryOperator().Run(&model, 0);
+ EXPECT_EQ(model.GetArray("output").GetBuffer<ArrayDataType::kFloat>().data,
+ expected_output);
+ EXPECT_EQ(model.GetArray("output").shape().dims(), output_shape);
+}
+
+// Reduce a 2d array across axis 0
+TEST(ResolveConstantUnary, ResolveSumAxis0_2D) {
+ // clang-format off
+ RunResolveSum(
+ // Input data
+ {3, 1, 4, 1,
+ 5, 9, 2, 6,
+ 5, 3, 5, 8},
+
+ // Input shape
+ {3, 4},
+
+ // Axes
+ {0},
+
+ // Expected output shape,
+ {1, 4},
+
+ // Expected output
+ {13, 13, 11, 15});
+ // clang-format on
+}
+
+// Reduce a 2d array across axis 1
+TEST(ResolveConstantUnary, ResolveSumAxis1_2D) {
+ // clang-format off
+ RunResolveSum(
+ // Input data
+ {3, 1, 4, 1,
+ 5, 9, 2, 6,
+ 5, 3, 5, 8},
+
+ // Input shape
+ {3, 4},
+
+ // Axes
+ {1},
+
+ // Expected output shape,
+ {3, 1},
+
+ // Expected output
+ {9, 22, 21});
+ // clang-format on
+}
+
+// Reduce a 3d tensor across axes 0 and 2.
+TEST(ResolveConstantUnary, ResolveSumAxis0_2_3D) {
+ // clang-format off
+ RunResolveSum(
+ // Input data
+ { 0, 1, 2,
+ 3, 10, 11,
+ 12, 13, 20,
+ 21, 22, 23,
+
+ 100, 101, 102,
+ 103, 110, 111,
+ 112, 113, 120,
+ 121, 122, 123,
+
+ 200, 201, 202,
+ 203, 210, 211,
+ 212, 213, 220,
+ 221, 222, 223 },
+
+ // Input shape
+ {3, 4, 3},
+
+ // Axes
+ {0, 2},
+
+ // Expected output shape,
+ {1, 4, 1},
+
+ // Expected output, generated using octave.
+ { 909, 972, 1035, 1098});
+ // clang-format on
+}
+
+} // namespace
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 5eaf6e27fc..133ef79a34 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -477,6 +477,30 @@ string CreateConstArray(Model* model, string const& name,
return array_name;
}
+// Retain TensorFlow NodeDef in Toco Operator.
+//
+// If an op is supported by Toco but not supported by TFLite, TFLite exporter
+// will use the retained NodeDef to populate a Flex op when Flex mode is
+// enabled.
+//
+// This can't be easily applied to all operations, because a TensorFlow node
+// may become multiple Toco operators. Thus we need to call this function in
+// operator conversion functions one by one whenever feasible.
+//
+// This may cause problems if a graph transformation rule changes parameters
+// of the node. When calling this function, please check if any existing
+// graph transformation rule will change an existing operator with the same
+// type.
+//
+// This provides a route to handle Toco-supported & TFLite-unsupported ops
+// in Flex mode. However it's not a solid solution. Eventually we should
+// get rid of this.
+// TODO(b/117327937): Implement all Toco-supported ops in TFLite, and remove
+// this function.
+void RetainTensorFlowNodeDef(const NodeDef& node, Operator* op) {
+ node.SerializeToString(&op->tensorflow_node_def);
+}
+
tensorflow::Status ConvertConstOperator(
const NodeDef& node, const TensorFlowImportFlags& tf_import_flags,
Model* model) {
@@ -990,6 +1014,10 @@ tensorflow::Status ConvertBatchMatMulOperator(
auto* batch_matmul = new BatchMatMulOperator;
batch_matmul->inputs = {node.input(0), node.input(1)};
batch_matmul->outputs = {node.name()};
+
+ // For Flex mode. Please read the comments of the function.
+ RetainTensorFlowNodeDef(node, batch_matmul);
+
model->operators.emplace_back(batch_matmul);
return tensorflow::Status::OK();
}
@@ -1081,7 +1109,10 @@ tensorflow::Status ConvertUnsupportedOperator(
auto* op = new TensorFlowUnsupportedOperator;
op->tensorflow_op = node.op();
- node.SerializeToString(&op->tensorflow_node_def);
+
+ // For Flex mode. Please read the comments of the function.
+ RetainTensorFlowNodeDef(node, op);
+
model->operators.emplace_back(op);
// Parse inputs.
@@ -1605,6 +1636,10 @@ tensorflow::Status ConvertRangeOperator(
op->inputs.push_back(node.input(1));
op->inputs.push_back(node.input(2));
op->outputs.push_back(node.name());
+
+ // For Flex mode. Please read the comments of the function.
+ RetainTensorFlowNodeDef(node, op);
+
model->operators.emplace_back(op);
return tensorflow::Status::OK();
}
diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h
index 6e207fdf54..61f1f095e9 100644
--- a/tensorflow/contrib/lite/toco/model.h
+++ b/tensorflow/contrib/lite/toco/model.h
@@ -376,6 +376,13 @@ struct Operator {
// looks unused.
bool unresolved_outputs = false;
+ // A serialized tensorflow::NodeDef string.
+ // The field is filled only when importing from TensorFlow.
+ // It's guaranteed to be filled for `TensorFlowUnsupportedOperator`.
+ // It's not guaranteed to be filled for other ops. Ops created by graph
+ // transformations won't have TensorFlow NodeDef.
+ string tensorflow_node_def;
+
protected:
// Constructor used by subclasses for specific OperatorType's.
explicit Operator(OperatorType t)
@@ -1535,8 +1542,6 @@ struct TensorFlowUnsupportedOperator : Operator {
// The original TF operation type. Used for diagnostic purposes.
string tensorflow_op;
- // A serialized tensorflow::NodeDef string.
- string tensorflow_node_def;
// A boolean indicating if the unsupported op should be treated as quantized.
bool quantized = false;
// A boolean indicating if the unsupported op output should allow float values
diff --git a/tensorflow/contrib/lite/toco/tflite/export.cc b/tensorflow/contrib/lite/toco/tflite/export.cc
index 45ca7f7f0c..3b34cd6285 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export.cc
@@ -63,21 +63,21 @@ bool IsControlFlowOp(const string& tensorflow_op) {
return false;
}
-details::OperatorKey GetOperatorKey(
- const ::toco::Operator& op,
- const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
- bool allow_flex_ops) {
- string custom_code;
- if (op.type == OperatorType::kUnsupported) {
- const TensorFlowUnsupportedOperator& unsupported_op =
- static_cast<const TensorFlowUnsupportedOperator&>(op);
- custom_code = unsupported_op.tensorflow_op;
- }
- int version = 1;
- if (ops_by_type.count(op.type) != 0) {
- version = ops_by_type.at(op.type)->GetVersion(op);
+// Map from operator name to TF Lite enum value, for all builtins.
+const std::map<string, BuiltinOperator>& GetBuiltinOpsMap() {
+ static std::map<string, BuiltinOperator>* builtin_ops = nullptr;
+ if (builtin_ops == nullptr) {
+ builtin_ops = new std::map<string, BuiltinOperator>();
+
+ for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) {
+ BuiltinOperator op = static_cast<BuiltinOperator>(i);
+ string name = EnumNameBuiltinOperator(op);
+ if (op != BuiltinOperator_CUSTOM && !name.empty()) {
+ (*builtin_ops)[name] = op;
+ }
+ }
}
- return details::OperatorKey(op.type, custom_code, version, allow_flex_ops);
+ return *builtin_ops;
}
void WriteModelToString(const flatbuffers::FlatBufferBuilder& builder,
@@ -91,27 +91,70 @@ void WriteModelToString(const flatbuffers::FlatBufferBuilder& builder,
namespace details {
-OperatorKey::OperatorKey(OperatorType type, const std::string& custom_code,
- int version, bool allow_flex_ops) {
- this->type = type;
- this->custom_code = custom_code;
- this->version = version;
+OperatorKey GetOperatorKey(
+ const ::toco::Operator& op,
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
+ bool allow_flex_ops) {
+ // Get the op name (by Toco definition).
+ string name = HelpfulOperatorTypeName(op);
+
+ bool is_builtin = false;
+ OperatorKey key;
+
+ const auto& builtin_ops = GetBuiltinOpsMap();
+ if (ops_by_type.count(op.type) != 0) {
+ key.version = ops_by_type.at(op.type)->GetVersion(op);
+ name = ops_by_type.at(op.type)->name();
+ is_builtin = (builtin_ops.count(name) > 0);
+ }
+
+ if (is_builtin) {
+ // For TFLite supported builtin ops, find out its BuiltinOperator enum used
+ // in FlatBuffer.
+ key.type = builtin_ops.at(name);
+ return key;
+ }
+
+ // The logic below is all for custom ops.
+ key.is_custom_op = true;
+ key.type = BuiltinOperator_CUSTOM;
+
+ if (op.type == OperatorType::kUnsupported) {
+ const TensorFlowUnsupportedOperator& unsupported_op =
+ static_cast<const TensorFlowUnsupportedOperator&>(op);
+ const auto tensorflow_op = unsupported_op.tensorflow_op;
- if (type == OperatorType::kUnsupported) {
// TODO(b/113715895): When `allow_flex_ops` is on, for now there's no way
// to populate a regular custom op. We need to find a way to fix this.
if (allow_flex_ops) {
- // Memorize the original TensorFlow op name.
- this->flex_tensorflow_op = custom_code;
- // Prefix the custom code of the flex op.
- this->custom_code = string(::tflite::kFlexCustomCodePrefix) + custom_code;
- this->is_flex_op = true;
-
- if (IsControlFlowOp(this->flex_tensorflow_op)) {
- is_unsupported_flex_op = true;
- }
+ key.is_flex_op = true;
+ key.flex_tensorflow_op = tensorflow_op;
+ key.custom_code =
+ string(::tflite::kFlexCustomCodePrefix) + key.flex_tensorflow_op;
+ } else {
+ key.custom_code = tensorflow_op;
+ }
+ } else if (allow_flex_ops && !op.tensorflow_node_def.empty()) {
+ // For Toco-supported/TFLite-unsupported ops, if the TensorFlow NodeDef
+ // is retained in the Toco Operator, we produce a Flex op if Flex mode
+ // is enabled.
+ key.is_flex_op = true;
+ key.flex_tensorflow_op = name;
+ key.custom_code =
+ string(::tflite::kFlexCustomCodePrefix) + key.flex_tensorflow_op;
+ } else {
+ // If Flex is disabled or the original TensorFlow NodeDef isn't available,
+ // we produce a custom op. This gives developers a chance to implemenr
+ // custom ops.
+ key.custom_code = name;
+ }
+
+ if (key.is_flex_op) {
+ if (IsControlFlowOp(key.flex_tensorflow_op)) {
+ key.is_unsupported_flex_op = true;
}
}
+ return key;
}
void LoadTensorsMap(const Model& model, TensorsMap* tensors_map) {
@@ -145,6 +188,7 @@ void LoadOperatorsMap(
++index;
}
}
+
} // namespace details
Offset<Vector<Offset<Tensor>>> ExportTensors(
@@ -230,7 +274,7 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
const Model& model,
const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
const details::OperatorsMap& operators_map, FlatBufferBuilder* builder,
- std::set<string>* unsupported_ops, const ExportParams& params) {
+ const ExportParams& params) {
// Map from operator name to TF Lite enum value, for all builtins.
std::map<string, BuiltinOperator> builtin_ops;
for (int i = BuiltinOperator_MIN; i <= BuiltinOperator_MAX; ++i) {
@@ -247,37 +291,16 @@ Offset<Vector<Offset<OperatorCode>>> ExportOperatorCodes(
for (const auto& op : model.operators) {
const details::OperatorKey operator_key =
- GetOperatorKey(*op, ops_by_type, params.allow_flex_ops);
+ details::GetOperatorKey(*op, ops_by_type, params.allow_flex_ops);
int op_index = operators_map.at(operator_key);
- int op_version = operator_key.version;
- string name = HelpfulOperatorTypeName(*op);
- bool is_builtin = false;
- if (ops_by_type.count(op->type) != 0) {
- name = ops_by_type.at(op->type)->name();
- is_builtin = (builtin_ops.count(name) > 0);
+ flatbuffers::Offset<flatbuffers::String> custom_code = 0;
+ if (!operator_key.custom_code.empty()) {
+ custom_code = builder->CreateString(operator_key.custom_code);
}
- if (is_builtin) {
- ordered_opcodes[op_index] =
- CreateOperatorCode(*builder, builtin_ops[name], 0, op_version);
- } else {
- // This could be a kUnsupported, in which case we should be
- // able to retrieve the original Tensorflow name from the OperatorKey, or
- // this could be a proper TOCO operator that is completely unknown to TF
- // Lite.
- if (!operator_key.custom_code.empty()) {
- name = operator_key.custom_code;
- }
- // Either way, this is an operator that is not supported by TF Lite,
- // so we output it as a custom op and add it to the error summary.
- if (unsupported_ops) {
- unsupported_ops->insert(name);
- }
- ordered_opcodes[op_index] =
- CreateOperatorCode(*builder, BuiltinOperator_CUSTOM,
- builder->CreateString(name), op_version);
- }
+ ordered_opcodes[op_index] = CreateOperatorCode(
+ *builder, operator_key.type, custom_code, operator_key.version);
}
std::vector<Offset<OperatorCode>> opcode_vector;
@@ -311,8 +334,9 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
outputs.push_back(tensors_map.at(output));
}
- int op_index = operators_map.at(
- GetOperatorKey(*op, ops_by_type, params.allow_flex_ops));
+ const auto key =
+ details::GetOperatorKey(*op, ops_by_type, params.allow_flex_ops);
+ int op_index = operators_map.at(key);
auto tflite_op_it = ops_by_type.find(op->type);
BaseOperator* tflite_op = tflite_op_it == ops_by_type.end()
@@ -337,6 +361,11 @@ Offset<Vector<Offset<Operator>>> ExportOperators(
variable_tensor_indices->insert(variable_tensor_index);
}
}
+ } else if (key.is_flex_op && !op->tensorflow_node_def.empty()) {
+ auto fbb = WriteFlexOpOptions(op->tensorflow_node_def);
+ if (fbb) {
+ options = Options::Custom(builder->CreateVector(fbb->GetBuffer()));
+ }
}
// The only supported CustomOptionFormat is FLEXBUFFERS now.
op_vector.push_back(CreateOperator(
@@ -386,9 +415,8 @@ void Export(
Array empty_array;
buffers_to_write.push_back(&empty_array);
- std::set<string> unsupported_ops;
- auto op_codes = ExportOperatorCodes(model, ops_by_type, operators_map,
- &builder, &unsupported_ops, params);
+ auto op_codes =
+ ExportOperatorCodes(model, ops_by_type, operators_map, &builder, params);
for (const auto& op : model.operators) {
if (op->type == OperatorType::kFakeQuant) {
@@ -398,7 +426,20 @@ void Export(
"for --std_values and --mean_values.";
}
}
- if (!unsupported_ops.empty()) {
+
+ std::set<string> custom_ops;
+ std::set<string> unsupported_flex_ops;
+ for (const auto& it : operators_map) {
+ const details::OperatorKey& key = it.first;
+ if (key.is_custom_op) {
+ custom_ops.insert(key.custom_code);
+ }
+ if (key.is_unsupported_flex_op) {
+ unsupported_flex_ops.insert(key.flex_tensorflow_op);
+ }
+ }
+
+ if (!custom_ops.empty()) {
if (!params.allow_custom_ops) {
// Remove ExpandDims and ReorderAxes from unimplemented list unless they
// compose the list. Both ops are removed during graph transformations.
@@ -406,14 +447,14 @@ void Export(
// transformation is unable to run because the output shape is not
// defined. This causes unnecessary confusion during model conversion
// time.
- std::set<string> unsupported_ops_final;
- for (const auto& op_type : unsupported_ops) {
+ std::set<string> custom_ops_final;
+ for (const auto& op_type : custom_ops) {
if (op_type != "ReorderAxes" && op_type != "ExpandDims") {
- unsupported_ops_final.insert(op_type);
+ custom_ops_final.insert(op_type);
}
}
- if (unsupported_ops_final.empty()) {
- unsupported_ops_final = unsupported_ops;
+ if (custom_ops_final.empty()) {
+ custom_ops_final = custom_ops;
}
LOG(QFATAL)
@@ -423,13 +464,13 @@ void Export(
"--allow_custom_ops, or by setting allow_custom_ops=True "
"when calling tf.contrib.lite.TFLiteConverter(). Here is a list "
"of operators for which you will need custom implementations: "
- << absl::StrJoin(unsupported_ops_final, ", ") << ".";
+ << absl::StrJoin(custom_ops_final, ", ") << ".";
}
std::set<string> unsupported_control_flow_ops;
// Check if unsupported ops contains control flow ops. It's impossible
// to implement these ops as custom ops at the moment.
- for (const auto& op : unsupported_ops) {
+ for (const auto& op : custom_ops) {
if (IsControlFlowOp(op)) {
unsupported_control_flow_ops.insert(op);
}
@@ -441,14 +482,6 @@ void Export(
}
}
- std::set<string> unsupported_flex_ops;
- for (const auto& it : operators_map) {
- const details::OperatorKey& key = it.first;
- if (key.is_unsupported_flex_op) {
- unsupported_flex_ops.insert(key.custom_code);
- }
- }
-
if (!unsupported_flex_ops.empty()) {
LOG(QFATAL) << "Some of the operators in the model are not supported by "
"TensorFlow Flex runtime: "
diff --git a/tensorflow/contrib/lite/toco/tflite/export.h b/tensorflow/contrib/lite/toco/tflite/export.h
index 9efb282c6c..c627f48086 100644
--- a/tensorflow/contrib/lite/toco/tflite/export.h
+++ b/tensorflow/contrib/lite/toco/tflite/export.h
@@ -81,16 +81,20 @@ using TensorsMap = std::unordered_map<string, int>;
// Only when `type` is `kUnsupported`, `custom_code` is filled to
// identify which operation is used.
struct OperatorKey {
- OperatorKey(OperatorType type, const std::string& custom_code, int version,
- bool allow_flex_ops = false);
+ OperatorKey() {}
+ OperatorKey(::tflite::BuiltinOperator type, const std::string& custom_code,
+ int version)
+ : type(type), custom_code(custom_code), version(version) {}
// Only `type`, `custom_code` and `version` is used to compute hash and
// identity.
- OperatorType type;
+ ::tflite::BuiltinOperator type = ::tflite::BuiltinOperator_CUSTOM;
std::string custom_code;
- int version;
+ int version = 1;
- // THe fields below are not used to compute hash and identity.
+ // The fields below are not used to compute hash and identity.
+ // TODO(ycling): Consider to change these fields to accessor functions.
+ bool is_custom_op = false;
bool is_flex_op = false;
bool is_unsupported_flex_op = false;
// The original TensorFlow op name for the flex op. Filled only when
@@ -124,6 +128,11 @@ struct OperatorKey {
};
};
+OperatorKey GetOperatorKey(
+ const ::toco::Operator& op,
+ const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
+ bool allow_flex_ops);
+
// A maps from operator type to its final position in the TF Lite buffer.
using OperatorsMap = std::unordered_map<OperatorKey, int, OperatorKey::Hash>;
diff --git a/tensorflow/contrib/lite/toco/tflite/export_test.cc b/tensorflow/contrib/lite/toco/tflite/export_test.cc
index a71a64d56f..eda1aa78a3 100644
--- a/tensorflow/contrib/lite/toco/tflite/export_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/export_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h"
#include "tensorflow/contrib/lite/toco/tflite/operator.h"
#include "tensorflow/contrib/lite/toco/tflite/types.h"
+#include "tensorflow/core/framework/node_def.pb.h"
namespace toco {
namespace tflite {
@@ -105,13 +106,15 @@ TEST_F(ExportTest, LoadOperatorsMap) {
details::OperatorsMap operators;
const auto ops_by_type = BuildOperatorByTypeMap();
- // TODO(ycling): Add a test for allow_flex_ops.
details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
- EXPECT_EQ(0, operators[details::OperatorKey(OperatorType::kAdd, "", 1)]);
- EXPECT_EQ(1, operators[details::OperatorKey(OperatorType::kConv, "", 1)]);
- EXPECT_EQ(2, operators[details::OperatorKey(OperatorType::kSub, "", 1)]);
- EXPECT_EQ(3, operators[details::OperatorKey(OperatorType::kUnsupported,
+ EXPECT_EQ(
+ 0, operators[details::OperatorKey(::tflite::BuiltinOperator_ADD, "", 1)]);
+ EXPECT_EQ(1, operators[details::OperatorKey(::tflite::BuiltinOperator_CONV_2D,
+ "", 1)]);
+ EXPECT_EQ(2, operators[details::OperatorKey(::tflite::BuiltinOperator_CUSTOM,
"MyCrazyOp", 1)]);
+ EXPECT_EQ(
+ 3, operators[details::OperatorKey(::tflite::BuiltinOperator_SUB, "", 1)]);
}
TEST_F(ExportTest, Export) {
@@ -133,7 +136,7 @@ TEST_F(ExportTest, Export) {
}
EXPECT_THAT(names, ElementsAre("builtin:ADD", "builtin:CONV_2D",
- "builtin:SUB", "custom:MyCrazyOp"));
+ "custom:MyCrazyOp", "builtin:SUB"));
std::vector<uint32_t> indices;
auto operators = (*model->subgraphs())[0]->operators();
@@ -142,7 +145,7 @@ TEST_F(ExportTest, Export) {
indices.push_back(op->opcode_index());
}
- EXPECT_THAT(indices, ElementsAre(1, 0, 3, 2));
+ EXPECT_THAT(indices, ElementsAre(1, 0, 2, 3));
}
TEST_F(ExportTest, QuantizeWeights) {
@@ -257,7 +260,8 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV1) {
details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
EXPECT_EQ(1, operators.size());
- EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1)));
+ EXPECT_EQ(0, operators.at(details::OperatorKey(
+ ::tflite::BuiltinOperator_CONV_2D, "", 1)));
}
TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV2) {
@@ -268,7 +272,8 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithOpV2) {
details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
EXPECT_EQ(1, operators.size());
- EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 2)));
+ EXPECT_EQ(0, operators.at(details::OperatorKey(
+ ::tflite::BuiltinOperator_CONV_2D, "", 2)));
}
TEST_F(VersionedOpExportTest, LoadOperatorsMapWithBothVersions) {
@@ -280,8 +285,10 @@ TEST_F(VersionedOpExportTest, LoadOperatorsMapWithBothVersions) {
details::LoadOperatorsMap(input_model_, &operators, ops_by_type, false);
EXPECT_EQ(2, operators.size());
- EXPECT_EQ(0, operators.at(details::OperatorKey(OperatorType::kConv, "", 1)));
- EXPECT_EQ(1, operators.at(details::OperatorKey(OperatorType::kConv, "", 2)));
+ EXPECT_EQ(0, operators.at(details::OperatorKey(
+ ::tflite::BuiltinOperator_CONV_2D, "", 1)));
+ EXPECT_EQ(1, operators.at(details::OperatorKey(
+ ::tflite::BuiltinOperator_CONV_2D, "", 2)));
}
TEST_F(VersionedOpExportTest, Export) {
@@ -314,38 +321,61 @@ TEST_F(VersionedOpExportTest, Export) {
}
TEST(OperatorKeyTest, TestBuiltinOp) {
- details::OperatorKey key(OperatorType::kConv, "", 2);
- EXPECT_EQ(key.type, OperatorType::kConv);
+ auto op = absl::make_unique<ConvOperator>();
+
+ const auto ops_by_type = BuildOperatorByTypeMap();
+ const auto key = details::GetOperatorKey(*op, ops_by_type, false);
+
+ EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CONV_2D);
EXPECT_EQ(key.custom_code, "");
- EXPECT_EQ(key.version, 2);
+ EXPECT_EQ(key.version, 1);
+}
+
+TEST(OperatorKeyTest, TestCustomOp) {
+ auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
+ op->tensorflow_op = "MyCrazyCustomOp";
+
+ const auto ops_by_type = BuildOperatorByTypeMap();
+ const auto key = details::GetOperatorKey(*op, ops_by_type, false);
+
+ EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CUSTOM);
+ EXPECT_EQ(key.custom_code, "MyCrazyCustomOp");
+ EXPECT_EQ(key.version, 1);
}
TEST(OperatorKeyTest, TestFlexOp) {
+ auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
+ op->tensorflow_op = "BatchMatMul";
+
+ const auto ops_by_type = BuildOperatorByTypeMap();
{
- details::OperatorKey key(OperatorType::kUnsupported, "SomeUnsupportedOp", 1,
- false);
- EXPECT_EQ(key.type, OperatorType::kUnsupported);
+ const auto key = details::GetOperatorKey(*op, ops_by_type, false);
// It shouldn't be converted to Flex op if `allow_flex_op` is false.
- EXPECT_EQ(key.custom_code, "SomeUnsupportedOp");
+ EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CUSTOM);
+ EXPECT_EQ(key.custom_code, "BatchMatMul");
EXPECT_EQ(key.version, 1);
EXPECT_FALSE(key.is_flex_op);
}
{
- details::OperatorKey key(OperatorType::kUnsupported, "SomeUnsupportedOp", 1,
- true);
- EXPECT_EQ(key.type, OperatorType::kUnsupported);
// Verify that the custom op name is prefixed by "Flex" and `is_flex_op`
// is true.
- EXPECT_EQ(key.custom_code, "FlexSomeUnsupportedOp");
+ const auto key = details::GetOperatorKey(*op, ops_by_type, true);
+ EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CUSTOM);
+ EXPECT_EQ(key.custom_code, "FlexBatchMatMul");
EXPECT_EQ(key.version, 1);
EXPECT_TRUE(key.is_flex_op);
}
}
TEST(OperatorKeyTest, TestFlexWithControlFlowOp) {
- details::OperatorKey key(OperatorType::kUnsupported, "Merge", 1, true);
- EXPECT_EQ(key.type, OperatorType::kUnsupported);
+ auto op = absl::make_unique<TensorFlowUnsupportedOperator>();
+ op->tensorflow_op = "Merge";
+
+ const auto ops_by_type = BuildOperatorByTypeMap();
+ const auto key = details::GetOperatorKey(*op, ops_by_type, true);
+
+ EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CUSTOM);
EXPECT_EQ(key.custom_code, "FlexMerge");
EXPECT_EQ(key.version, 1);
EXPECT_TRUE(key.is_flex_op);
@@ -353,6 +383,39 @@ TEST(OperatorKeyTest, TestFlexWithControlFlowOp) {
EXPECT_TRUE(key.is_unsupported_flex_op);
}
+TEST(OperatorKeyTest, TestFlexWithPartiallySupportedOps) {
+ // Test Toco-supported/TFLite-unsupported operators.
+ // TODO(ycling): The test will be broken if Range is implemented in TFLite.
+ // Find a more robust way to test the fallback logic.
+ auto op = absl::make_unique<RangeOperator>();
+
+ const auto ops_by_type = BuildOperatorByTypeMap();
+
+ {
+ // If NodeDef isn't retained in the Toco op, a regular custom op
+ // will be exported.
+ const auto key = details::GetOperatorKey(*op, ops_by_type, true);
+ EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CUSTOM);
+ EXPECT_EQ(key.custom_code, "Range");
+ EXPECT_EQ(key.version, 1);
+ EXPECT_FALSE(key.is_flex_op);
+ }
+
+ ::tensorflow::NodeDef node_def;
+ node_def.set_name("Range");
+ node_def.set_op("Range");
+ node_def.SerializeToString(&op->tensorflow_node_def);
+
+ {
+ // If NodeDef is retained in the Toco op, a Flex op will be exported.
+ const auto key = details::GetOperatorKey(*op, ops_by_type, true);
+ EXPECT_EQ(key.type, ::tflite::BuiltinOperator_CUSTOM);
+ EXPECT_EQ(key.custom_code, "FlexRange");
+ EXPECT_EQ(key.version, 1);
+ EXPECT_TRUE(key.is_flex_op);
+ }
+}
+
// TODO(ahentz): tests for tensors, inputs, outputs, opcodes and operators.
} // namespace
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 9addbb81e7..ed37535fe0 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -1157,6 +1157,25 @@ class Unpack : public BuiltinOperator<UnpackOperator, ::tflite::UnpackOptions,
int GetVersion(const Operator& op) const override { return 1; }
};
+std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
+ const string& tensorflow_node_def) {
+ auto fbb = absl::make_unique<flexbuffers::Builder>();
+
+ ::tensorflow::NodeDef node_def;
+ if (!node_def.ParseFromString(tensorflow_node_def)) {
+ LOG(ERROR) << "Failed to parse TensorFlow NodeDef";
+ return {};
+ }
+
+ fbb->Vector([&]() {
+ fbb->String(node_def.op());
+ fbb->String(tensorflow_node_def);
+ });
+ fbb->Finish();
+ LOG(INFO) << "Writing flex op: " << node_def.op();
+ return std::unique_ptr<flexbuffers::Builder>(fbb.release());
+}
+
class TensorFlowUnsupported : public BaseOperator {
public:
TensorFlowUnsupported(const string& name, OperatorType type,
@@ -1192,6 +1211,9 @@ class TensorFlowUnsupported : public BaseOperator {
std::unique_ptr<flexbuffers::Builder> WriteOptions(
const TensorFlowUnsupportedOperator& op) const {
+ if (allow_flex_ops_) {
+ return WriteFlexOpOptions(op.tensorflow_node_def);
+ }
auto fbb = absl::make_unique<flexbuffers::Builder>();
::tensorflow::NodeDef node_def;
@@ -1200,16 +1222,6 @@ class TensorFlowUnsupported : public BaseOperator {
return std::unique_ptr<flexbuffers::Builder>();
}
- if (allow_flex_ops_) {
- fbb->Vector([&]() {
- fbb->String(node_def.op());
- fbb->String(op.tensorflow_node_def);
- });
- fbb->Finish();
- LOG(INFO) << "Writing flex op: " << node_def.op();
- return std::unique_ptr<flexbuffers::Builder>(fbb.release());
- }
-
bool has_valid_attr = false;
size_t map_start = fbb->StartMap();
for (const auto& pair : node_def.attr()) {
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.h b/tensorflow/contrib/lite/toco/tflite/operator.h
index 13d9f6c49a..6e4e0a16d1 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.h
+++ b/tensorflow/contrib/lite/toco/tflite/operator.h
@@ -16,6 +16,7 @@ limitations under the License.
#define TENSORFLOW_CONTRIB_LITE_TOCO_TFLITE_OPERATOR_H_
#include "flatbuffers/flatbuffers.h"
+#include "flatbuffers/flexbuffers.h"
#include "tensorflow/contrib/lite/schema/schema_generated.h"
#include "tensorflow/contrib/lite/toco/model.h"
@@ -36,6 +37,11 @@ std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
bool allow_flex_ops = false);
+// Write the custom option FlexBuffer with a serialized TensorFlow NodeDef
+// for a Flex op.
+std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
+ const string& tensorflow_node_def);
+
// These are the flatbuffer types for custom and builtin options.
using CustomOptions = flatbuffers::Vector<uint8_t>;
using BuiltinOptions = void;
diff --git a/tensorflow/contrib/lite/tools/benchmark/BUILD b/tensorflow/contrib/lite/tools/benchmark/BUILD
index 502e181139..71bf61657e 100644
--- a/tensorflow/contrib/lite/tools/benchmark/BUILD
+++ b/tensorflow/contrib/lite/tools/benchmark/BUILD
@@ -40,7 +40,7 @@ cc_binary(
srcs = [
"benchmark_main.cc",
],
- copts = common_copts + ["-DTFLITE_FLEX"],
+ copts = common_copts,
linkopts = tflite_linkopts() + select({
"//tensorflow:android": [
"-pie", # Android 5.0 and later supports only PIE
@@ -49,8 +49,9 @@ cc_binary(
"//conditions:default": [],
}),
deps = [
- ":benchmark_tflite_model_plus_flex_lib",
+ ":benchmark_tflite_model_lib",
":logging",
+ "//tensorflow/contrib/lite/delegates/flex:delegate",
],
)
@@ -111,25 +112,6 @@ cc_library(
)
cc_library(
- name = "benchmark_tflite_model_plus_flex_lib",
- srcs = [
- "benchmark_tflite_model.cc",
- "logging.h",
- ],
- hdrs = ["benchmark_tflite_model.h"],
- copts = common_copts + ["-DTFLITE_FLEX"],
- deps = [
- ":benchmark_model_lib",
- ":logging",
- "//tensorflow/contrib/lite:framework",
- "//tensorflow/contrib/lite:string_util",
- "//tensorflow/contrib/lite/delegates/flex:delegate",
- "//tensorflow/contrib/lite/kernels:builtin_ops",
- "//tensorflow/contrib/lite/profiling:profile_summarizer",
- ],
-)
-
-cc_library(
name = "benchmark_params",
srcs = [
"benchmark_params.cc",
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
index 463d5993f4..2a3df7f289 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.cc
@@ -23,9 +23,6 @@ limitations under the License.
#include <unordered_set>
#include <vector>
-#ifdef TFLITE_FLEX
-#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
-#endif // TFLITE_FLEX
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/op_resolver.h"
@@ -305,15 +302,6 @@ void BenchmarkTfLiteModel::Init() {
interpreter->UseNNAPI(use_nnapi);
-#ifdef TFLITE_FLEX
- TFLITE_LOG(INFO) << "Instantiating Flex Delegate";
- delegate_ = FlexDelegate::Create();
- if (delegate_) {
- interpreter->ModifyGraphWithDelegate(delegate_.get(),
- /*allow_dynamic_tensors=*/true);
- }
-#endif // TFLITE_FLEX
-
auto interpreter_inputs = interpreter->inputs();
if (!inputs.empty()) {
diff --git a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
index b091e18a29..25a302b2aa 100644
--- a/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
+++ b/tensorflow/contrib/lite/tools/benchmark/benchmark_tflite_model.h
@@ -20,9 +20,6 @@ limitations under the License.
#include <string>
#include <vector>
-#ifdef TFLITE_FLEX
-#include "tensorflow/contrib/lite/delegates/flex/delegate.h"
-#endif // TFLITE_FLEX
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/profiling/profile_summarizer.h"
#include "tensorflow/contrib/lite/tools/benchmark/benchmark_model.h"
@@ -73,9 +70,6 @@ class BenchmarkTfLiteModel : public BenchmarkModel {
void PrepareInputsAndOutputs() override;
private:
-#ifdef TFLITE_FLEX
- std::unique_ptr<FlexDelegate> delegate_;
-#endif // TFLITE_FLEX
std::unique_ptr<tflite::FlatBufferModel> model;
std::unique_ptr<tflite::Interpreter> interpreter;
std::vector<InputLayerInfo> inputs;
diff --git a/tensorflow/contrib/opt/python/training/shampoo.py b/tensorflow/contrib/opt/python/training/shampoo.py
index f161521b97..e542f46892 100644
--- a/tensorflow/contrib/opt/python/training/shampoo.py
+++ b/tensorflow/contrib/opt/python/training/shampoo.py
@@ -108,7 +108,8 @@ class ShampooOptimizer(optimizer.Optimizer):
precond_update_interval: We should update the preconditioners after
this many steps. Default = 1. Usually less than
svd_interval.
- epsilon: epsilon * I_n is added to each mat_gbar_j for stability
+ epsilon: epsilon * I_n is added to each mat_gbar_j for stability for
+ non-diagonal version of shampoo.
alpha: total power of the preconditioners.
use_iterative_root: should the optimizer use SVD (faster) or the
iterative root method (for TPU) for finding the
@@ -394,15 +395,20 @@ class ShampooOptimizer(optimizer.Optimizer):
assert self._mat_gbar_decay == 1.0
mat_g_updated = state_ops.scatter_add(mat_g, indices,
mat_gbar_weight_t * grad_outer)
- mat_h = math_ops.pow(
- array_ops.gather(mat_g_updated, indices) + self._epsilon,
- neg_alpha)
+ mat_g_updated_slice = array_ops.gather(mat_g_updated, indices)
+ mat_h = array_ops.where(
+ math_ops.greater(mat_g_updated_slice, 0),
+ math_ops.pow(mat_g_updated_slice, neg_alpha),
+ array_ops.zeros_like(mat_g_updated_slice))
else:
mat_g_updated = self._weighted_average(mat_g,
self._mat_gbar_decay,
mat_gbar_decay_t,
mat_gbar_weight_t * grad_outer)
- mat_h = math_ops.pow(mat_g_updated + self._epsilon, neg_alpha)
+ mat_h = array_ops.where(
+ math_ops.greater(mat_g_updated, 0),
+ math_ops.pow(mat_g_updated, neg_alpha),
+ array_ops.zeros_like(mat_g_updated))
# Need to do the transpose to ensure that the tensor becomes
# a d_{i+1} x ... x d_n x d_0 x ... d_i tensor as described above.
diff --git a/tensorflow/contrib/opt/python/training/shampoo_test.py b/tensorflow/contrib/opt/python/training/shampoo_test.py
index a2fd8fbd87..e88c8221a0 100644
--- a/tensorflow/contrib/opt/python/training/shampoo_test.py
+++ b/tensorflow/contrib/opt/python/training/shampoo_test.py
@@ -279,7 +279,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
# Update rule is var = var - lr * gg^{-0.5} * grad
# lr = 1
mat_g = (grad_np * grad_np)
- new_val_np = init_var_np - np.power(mat_g + RIDGE_EPSILON, -0.5) * grad_np
+ new_val_np = init_var_np - np.power(mat_g, -0.5) * grad_np
self.assertAllCloseAccordingToType(
new_val_np, new_val, atol=TOLERANCE, rtol=TOLERANCE)
@@ -288,7 +288,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
new_val = sess.run(var)
mat_g += (grad_np_2 * grad_np_2)
- new_val_np -= np.power(mat_g + RIDGE_EPSILON, -0.5) * grad_np_2
+ new_val_np -= np.power(mat_g, -0.5) * grad_np_2
self.assertAllCloseAccordingToType(
new_val_np, new_val, atol=TOLERANCE, rtol=TOLERANCE)
@@ -339,7 +339,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
mat_g1 = np.sum(
grad_np * grad_np, axis=1, keepdims=True) / grad_np.shape[0]
- mat_left = np.power(mat_g1 + RIDGE_EPSILON, -0.25)
+ mat_left = np.power(mat_g1, -0.25)
mat_g2 = np.dot(grad_np.transpose(), grad_np) / grad_np.shape[1]
mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
new_val_np = init_var_np - np.dot(grad_np * mat_left, mat_right)
@@ -353,7 +353,7 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
mat_g1 += np.sum(
grad_np_2 * grad_np_2, axis=1, keepdims=True) / grad_np_2.shape[0]
- mat_left = np.power(mat_g1 + RIDGE_EPSILON, -0.25)
+ mat_left = np.power(mat_g1, -0.25)
mat_g2 += np.dot(grad_np_2.transpose(), grad_np_2) / grad_np_2.shape[1]
mat_right = np_power(mat_g2 + RIDGE_EPSILON * np.eye(size[1]), -0.25)
new_val_np -= np.dot(grad_np_2 * mat_left, mat_right)
diff --git a/tensorflow/contrib/optimizer_v2/BUILD b/tensorflow/contrib/optimizer_v2/BUILD
index 3ba3ee29ec..2cf445a85e 100644
--- a/tensorflow/contrib/optimizer_v2/BUILD
+++ b/tensorflow/contrib/optimizer_v2/BUILD
@@ -47,15 +47,8 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:distribute",
- "//tensorflow/python:framework",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:resource_variable_ops",
- "//tensorflow/python:state_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variable_scope",
- "//tensorflow/python:variables",
+ "//tensorflow/python:util",
+ "//tensorflow/python/keras:optimizer_v2",
],
)
diff --git a/tensorflow/contrib/optimizer_v2/adadelta.py b/tensorflow/contrib/optimizer_v2/adadelta.py
index b206f9f61b..9d73bddd1c 100644
--- a/tensorflow/contrib/optimizer_v2/adadelta.py
+++ b/tensorflow/contrib/optimizer_v2/adadelta.py
@@ -18,17 +18,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.optimizer_v2 import optimizer_v2
-from tensorflow.python.training import training_ops
+from tensorflow.python.keras.optimizer_v2 import adadelta
+from tensorflow.python.util import deprecation
-class AdadeltaOptimizer(optimizer_v2.OptimizerV2):
+class AdadeltaOptimizer(adadelta.Adadelta):
"""Optimizer that implements the Adadelta algorithm.
See [M. D. Zeiler](http://arxiv.org/abs/1212.5701)
([pdf](http://arxiv.org/pdf/1212.5701v1.pdf))
"""
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self, learning_rate=0.001, rho=0.95, epsilon=1e-8,
use_locking=False, name="Adadelta"):
"""Construct a new Adadelta optimizer.
@@ -48,66 +52,5 @@ class AdadeltaOptimizer(optimizer_v2.OptimizerV2):
name: Optional name prefix for the operations created when applying
gradients. Defaults to "Adadelta".
"""
- super(AdadeltaOptimizer, self).__init__(use_locking, name)
- self._set_hyper("learning_rate", learning_rate)
- self._set_hyper("rho", rho)
- self._set_hyper("epsilon", epsilon)
-
- def _create_vars(self, var_list, state):
- for v in var_list:
- state.zeros_slot(v, "accum")
- state.zeros_slot(v, "accum_update")
-
- def _apply_dense(self, grad, var, state):
- accum = state.get_slot(var, "accum")
- accum_update = state.get_slot(var, "accum_update")
- return training_ops.apply_adadelta(
- var,
- accum,
- accum_update,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("rho", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
- grad,
- use_locking=self._use_locking)
-
- def _resource_apply_dense(self, grad, var, state):
- accum = state.get_slot(var, "accum")
- accum_update = state.get_slot(var, "accum_update")
- return training_ops.resource_apply_adadelta(
- var.handle,
- accum.handle,
- accum_update.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("rho", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
- grad,
- use_locking=self._use_locking)
-
- def _apply_sparse(self, grad, var, state):
- accum = state.get_slot(var, "accum")
- accum_update = state.get_slot(var, "accum_update")
- return training_ops.sparse_apply_adadelta(
- var,
- accum,
- accum_update,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("rho", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
- grad.values,
- grad.indices,
- use_locking=self._use_locking)
-
- def _resource_apply_sparse(self, grad, var, indices, state):
- accum = state.get_slot(var, "accum")
- accum_update = state.get_slot(var, "accum_update")
- return training_ops.resource_sparse_apply_adadelta(
- var.handle,
- accum.handle,
- accum_update.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("rho", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
- grad,
- indices,
- use_locking=self._use_locking)
+ super(AdadeltaOptimizer, self).__init__(
+ learning_rate=learning_rate, rho=rho, epsilon=epsilon, name=name)
diff --git a/tensorflow/contrib/optimizer_v2/adagrad.py b/tensorflow/contrib/optimizer_v2/adagrad.py
index dab1e02716..716361e29c 100644
--- a/tensorflow/contrib/optimizer_v2/adagrad.py
+++ b/tensorflow/contrib/optimizer_v2/adagrad.py
@@ -18,15 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.optimizer_v2 import optimizer_v2
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import gen_array_ops
-from tensorflow.python.ops import init_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.training import training_ops
+from tensorflow.python.keras.optimizer_v2 import adagrad
+from tensorflow.python.util import deprecation
-class AdagradOptimizer(optimizer_v2.OptimizerV2):
+class AdagradOptimizer(adagrad.Adagrad):
"""Optimizer that implements the Adagrad algorithm.
See this [paper](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
@@ -34,6 +30,10 @@ class AdagradOptimizer(optimizer_v2.OptimizerV2):
[intro](https://ppasupat.github.io/a9online/uploads/proximal_notes.pdf).
"""
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self, learning_rate, initial_accumulator_value=0.1,
use_locking=False, name="Adagrad"):
"""Construct a new Adagrad optimizer.
@@ -54,64 +54,7 @@ class AdagradOptimizer(optimizer_v2.OptimizerV2):
Raises:
ValueError: If the `initial_accumulator_value` is invalid.
"""
- if initial_accumulator_value <= 0.0:
- raise ValueError("initial_accumulator_value must be positive: %s" %
- initial_accumulator_value)
- super(AdagradOptimizer, self).__init__(use_locking, name)
- self._set_hyper("learning_rate", learning_rate)
-
- self._initial_accumulator_value = initial_accumulator_value
-
- def _create_vars(self, var_list, state):
- for v in var_list:
- dtype = v.dtype.base_dtype
- if v.get_shape().is_fully_defined():
- init = init_ops.constant_initializer(self._initial_accumulator_value,
- dtype=dtype)
- else:
- def init(v=v, dtype=dtype):
- # Use a Tensor instead of initializer if variable does not have
- # static shape.
- init_constant = gen_array_ops.fill(array_ops.shape(v),
- self._initial_accumulator_value)
- return math_ops.cast(init_constant, dtype)
- state.create_slot_with_initializer(v, init, v.get_shape(), dtype,
- "accumulator")
-
- def _apply_dense(self, grad, var, state):
- acc = state.get_slot(var, "accumulator")
- return training_ops.apply_adagrad(
- var,
- acc,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- use_locking=self._use_locking)
-
- def _resource_apply_dense(self, grad, var, state):
- acc = state.get_slot(var, "accumulator")
- return training_ops.resource_apply_adagrad(
- var.handle,
- acc.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- use_locking=self._use_locking)
-
- def _apply_sparse(self, grad, var, state):
- acc = state.get_slot(var, "accumulator")
- return training_ops.sparse_apply_adagrad(
- var,
- acc,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad.values,
- grad.indices,
- use_locking=self._use_locking)
-
- def _resource_apply_sparse(self, grad, var, indices, state):
- acc = state.get_slot(var, "accumulator")
- return training_ops.resource_sparse_apply_adagrad(
- var.handle,
- acc.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- indices,
- use_locking=self._use_locking)
+ super(AdagradOptimizer, self).__init__(
+ learning_rate=learning_rate,
+ initial_accumulator_value=initial_accumulator_value,
+ name=name)
diff --git a/tensorflow/contrib/optimizer_v2/adagrad_test.py b/tensorflow/contrib/optimizer_v2/adagrad_test.py
index debaaaeeba..320e41567f 100644
--- a/tensorflow/contrib/optimizer_v2/adagrad_test.py
+++ b/tensorflow/contrib/optimizer_v2/adagrad_test.py
@@ -68,9 +68,6 @@ class AdagradOptimizerTest(test.TestCase):
def testBasicResource(self):
self.doTestBasic(use_locking=False, use_resource=True)
- def testBasicLocked(self):
- self.doTestBasic(use_locking=True)
-
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
with self.cached_session():
diff --git a/tensorflow/contrib/optimizer_v2/adam.py b/tensorflow/contrib/optimizer_v2/adam.py
index 04b1552b61..363e020757 100644
--- a/tensorflow/contrib/optimizer_v2/adam.py
+++ b/tensorflow/contrib/optimizer_v2/adam.py
@@ -18,22 +18,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.optimizer_v2 import optimizer_v2
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import state_ops
-from tensorflow.python.training import training_ops
+from tensorflow.python.keras.optimizer_v2 import adam
+from tensorflow.python.util import deprecation
-class AdamOptimizer(optimizer_v2.OptimizerV2):
+class AdamOptimizer(adam.Adam):
"""Optimizer that implements the Adam algorithm.
See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
"""
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
use_locking=False, name="Adam"):
"""Construct a new Adam optimizer.
@@ -87,111 +86,9 @@ class AdamOptimizer(optimizer_v2.OptimizerV2):
name: Optional name for the operations created when applying gradients.
Defaults to "Adam".
"""
- super(AdamOptimizer, self).__init__(use_locking, name)
-
- self._set_hyper("learning_rate", learning_rate)
- self._set_hyper("beta1", beta1)
- self._set_hyper("beta2", beta2)
- self._set_hyper("epsilon", epsilon)
-
- def _get_beta_accumulators(self, state=None):
- if state is None:
- state = self._get_per_graph_state()
- return (state.get_non_slot("beta1_power"),
- state.get_non_slot("beta2_power"))
-
- def _create_vars(self, var_list, state):
- # Non-slot variables end up on the same device(s).
- state.create_non_slot(initial_value=lambda: state.get_hyper("beta1"),
- name="beta1_power")
- state.create_non_slot(initial_value=lambda: state.get_hyper("beta2"),
- name="beta2_power")
-
- # Create slots for the first and second moments.
- for v in var_list:
- state.zeros_slot(v, "m")
- state.zeros_slot(v, "v")
-
- def _apply_dense(self, grad, var, state):
- m = state.get_slot(var, "m")
- v = state.get_slot(var, "v")
- beta1_power, beta2_power = self._get_beta_accumulators(state)
- return training_ops.apply_adam(
- var, m, v,
- math_ops.cast(beta1_power, var.dtype.base_dtype),
- math_ops.cast(beta2_power, var.dtype.base_dtype),
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("beta1", var.dtype.base_dtype),
- state.get_hyper("beta2", var.dtype.base_dtype),
- state.get_hyper("epsilon", var.dtype.base_dtype),
- grad, use_locking=self._use_locking).op
-
- def _resource_apply_dense(self, grad, var, state):
- m = state.get_slot(var, "m")
- v = state.get_slot(var, "v")
- beta1_power, beta2_power = self._get_beta_accumulators(state)
- return training_ops.resource_apply_adam(
- var.handle, m.handle, v.handle,
- math_ops.cast(beta1_power, grad.dtype.base_dtype),
- math_ops.cast(beta2_power, grad.dtype.base_dtype),
- state.get_hyper("learning_rate", grad.dtype.base_dtype),
- state.get_hyper("beta1", grad.dtype.base_dtype),
- state.get_hyper("beta2", grad.dtype.base_dtype),
- state.get_hyper("epsilon", grad.dtype.base_dtype),
- grad, use_locking=self._use_locking)
-
- def _apply_sparse_shared(self, grad, var, indices, scatter_add, state):
- beta1_power, beta2_power = self._get_beta_accumulators(state)
- beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
- beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
- lr_t = state.get_hyper("learning_rate", var.dtype.base_dtype)
- beta1_t = state.get_hyper("beta1", var.dtype.base_dtype)
- beta2_t = state.get_hyper("beta2", var.dtype.base_dtype)
- epsilon_t = state.get_hyper("epsilon", var.dtype.base_dtype)
- lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
- # m_t = beta1 * m + (1 - beta1) * g_t
- m = state.get_slot(var, "m")
- m_scaled_g_values = grad * (1 - beta1_t)
- m_t = state_ops.assign(m, m * beta1_t,
- use_locking=self._use_locking)
- with ops.control_dependencies([m_t]):
- m_t = scatter_add(m, indices, m_scaled_g_values)
- # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
- v = state.get_slot(var, "v")
- v_scaled_g_values = (grad * grad) * (1 - beta2_t)
- v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
- with ops.control_dependencies([v_t]):
- v_t = scatter_add(v, indices, v_scaled_g_values)
- v_sqrt = math_ops.sqrt(v_t)
- var_update = state_ops.assign_sub(var,
- lr * m_t / (v_sqrt + epsilon_t),
- use_locking=self._use_locking)
- return control_flow_ops.group(*[var_update, m_t, v_t])
-
- def _apply_sparse(self, grad, var, state):
- return self._apply_sparse_shared(
- grad.values, var, grad.indices,
- lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
- x, i, v, use_locking=self._use_locking),
- state)
-
- def _resource_scatter_add(self, x, i, v):
- with ops.control_dependencies(
- [resource_variable_ops.resource_scatter_add(
- x.handle, i, v)]):
- return x.value()
-
- def _resource_apply_sparse(self, grad, var, indices, state):
- return self._apply_sparse_shared(
- grad, var, indices, self._resource_scatter_add, state)
-
- def _finish(self, state):
- # Update the power accumulators.
- beta1_power, beta2_power = self._get_beta_accumulators(state)
- update_beta1 = beta1_power.assign(
- beta1_power * state.get_hyper("beta1"),
- use_locking=self._use_locking)
- update_beta2 = beta2_power.assign(
- beta2_power * state.get_hyper("beta2"),
- use_locking=self._use_locking)
- return control_flow_ops.group(update_beta1, update_beta2)
+ super(AdamOptimizer, self).__init__(
+ learning_rate=learning_rate,
+ beta_1=beta1,
+ beta_2=beta2,
+ epsilon=epsilon,
+ name=name)
diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
index e13b82d1d2..3c68ef995a 100644
--- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
+++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
@@ -130,8 +130,8 @@ class CheckpointingTests(test.TestCase):
# non-Layer dependency of the model
"model/_non_layer/a_variable",
# The optimizer creates two non-slot variables
- "optimizer/beta1_power",
- "optimizer/beta2_power",
+ "optimizer/beta_1_power",
+ "optimizer/beta_2_power",
# Slot variables
"model/_second/kernel/.OPTIMIZER_SLOT/optimizer/m",
"model/_second/kernel/.OPTIMIZER_SLOT/optimizer/v",
@@ -161,21 +161,20 @@ class CheckpointingTests(test.TestCase):
"my_model/dense/kernel",
named_variables["model/_named_dense/kernel" + suffix].full_name)
self.assertEqual(
- "beta1_power",
- named_variables["optimizer/beta1_power" + suffix].full_name)
+ "beta_1_power",
+ named_variables["optimizer/beta_1_power" + suffix].full_name)
self.assertEqual(
- "beta2_power",
- named_variables["optimizer/beta2_power" + suffix].full_name)
+ "beta_2_power",
+ named_variables["optimizer/beta_2_power" + suffix].full_name)
# Spot check the generated protocol buffers.
self.assertEqual("optimizer",
serialized_graph.nodes[0].children[1].local_name)
optimizer_node = serialized_graph.nodes[serialized_graph.nodes[0].children[
1].node_id]
- self.assertEqual("beta1_power",
- optimizer_node.children[0].local_name)
- self.assertEqual("beta1_power",
- serialized_graph.nodes[optimizer_node.children[0].node_id]
- .attributes[0].full_name)
+ self.assertEqual("beta_1_power", optimizer_node.children[0].local_name)
+ self.assertEqual(
+ "beta_1_power", serialized_graph.nodes[
+ optimizer_node.children[0].node_id].attributes[0].full_name)
self.assertEqual(
"my_model/dense/kernel",
serialized_graph.nodes[optimizer_node.slot_variables[0]
@@ -241,9 +240,10 @@ class CheckpointingTests(test.TestCase):
on_create_model = MyModel()
on_create_optimizer = adam.AdamOptimizer(
0.001,
- # Preserve beta1_power and beta2_power when appying gradients so we can
- # test that they've been restored correctly.
- beta1=1.0, beta2=1.0)
+ # Preserve beta_1_power and beta_2_power when appying gradients
+ # so we can test that they've been restored correctly.
+ beta1=1.0,
+ beta2=1.0)
on_create_root = util.Checkpoint(
optimizer=on_create_optimizer, model=on_create_model)
# Deferred restoration
@@ -263,9 +263,9 @@ class CheckpointingTests(test.TestCase):
dummy_var = resource_variable_ops.ResourceVariable([1.])
on_create_optimizer.minimize(loss=dummy_var.read_value)
status.assert_consumed()
- beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators()
- self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power))
- self.assertAllEqual(optimizer_variables[1], self.evaluate(beta2_power))
+ beta_1_power, beta_2_power = on_create_optimizer._get_beta_accumulators()
+ self.assertAllEqual(optimizer_variables[0], self.evaluate(beta_1_power))
+ self.assertAllEqual(optimizer_variables[1], self.evaluate(beta_2_power))
# TODO(allenl): Debug garbage created by this test in python3.
def testDeferredRestorationUsageEager(self):
@@ -477,7 +477,7 @@ class CheckpointingTests(test.TestCase):
no_slot_status.run_restore_ops()
self.assertEqual(12., self.evaluate(new_root.var))
new_root.optimizer = adam.AdamOptimizer(0.1)
- with self.assertRaisesRegexp(AssertionError, "beta1_power"):
+ with self.assertRaisesRegexp(AssertionError, "beta_1_power"):
slot_status.assert_consumed()
self.assertEqual(12., self.evaluate(new_root.var))
if context.executing_eagerly():
@@ -556,8 +556,8 @@ class CheckpointingTests(test.TestCase):
self.evaluate(first_variable.assign([1.]))
self.evaluate(optimizer.get_slot(
var=first_variable, name="m").assign([2.]))
- beta1_power, _ = optimizer._get_beta_accumulators()
- self.evaluate(beta1_power.assign(3.))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(3.))
# Save and load in a second graph
second_graph = ops.Graph()
@@ -571,29 +571,29 @@ class CheckpointingTests(test.TestCase):
self.evaluate(second_variable.assign([4.]))
self.evaluate(optimizer.get_slot(
var=second_variable, name="m").assign([5.]))
- beta1_power, _ = optimizer._get_beta_accumulators()
- self.evaluate(beta1_power.assign(6.))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(6.))
save_path = second_root_checkpointable.save(checkpoint_prefix)
self.evaluate(second_variable.assign([7.]))
self.evaluate(optimizer.get_slot(
var=second_variable, name="m").assign([8.]))
- beta1_power, _ = optimizer._get_beta_accumulators()
- self.assertAllEqual(6., self.evaluate(beta1_power))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.assertAllEqual(6., self.evaluate(beta_1_power))
status = second_root_checkpointable.restore(save_path)
status.assert_consumed().run_restore_ops()
self.assertAllEqual([4.], self.evaluate(second_variable))
self.assertAllEqual([5.], self.evaluate(optimizer.get_slot(
var=second_variable, name="m")))
- beta1_power, _ = optimizer._get_beta_accumulators()
- self.assertAllEqual(6., self.evaluate(beta1_power))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.assertAllEqual(6., self.evaluate(beta_1_power))
# Check that the first graph is unmolested
with first_graph.as_default(), first_session.as_default():
self.assertAllEqual([1.], self.evaluate(first_variable))
self.assertAllEqual([2.], self.evaluate(optimizer.get_slot(
var=first_variable, name="m")))
- beta1_power, _ = optimizer._get_beta_accumulators()
- self.assertAllEqual(3., self.evaluate(beta1_power))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.assertAllEqual(3., self.evaluate(beta_1_power))
class TemplateTests(test.TestCase):
@@ -659,8 +659,8 @@ class CheckpointCompatibilityTests(test.TestCase):
self.evaluate(model._named_dense.bias.assign([1.]))
self.evaluate(optimizer.get_slot(
var=model._named_dense.bias, name="m").assign([2.]))
- beta1_power, _ = optimizer._get_beta_accumulators()
- self.evaluate(beta1_power.assign(3.))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(3.))
return root_checkpointable
def _set_sentinels(self, root_checkpointable):
@@ -669,8 +669,8 @@ class CheckpointCompatibilityTests(test.TestCase):
root_checkpointable.optimizer.get_slot(
var=root_checkpointable.model._named_dense.bias, name="m")
.assign([102.]))
- beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
- self.evaluate(beta1_power.assign(103.))
+ beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(103.))
def _check_sentinels(self, root_checkpointable):
self.assertAllEqual(
@@ -678,8 +678,8 @@ class CheckpointCompatibilityTests(test.TestCase):
self.assertAllEqual([2.], self.evaluate(
root_checkpointable.optimizer.get_slot(
var=root_checkpointable.model._named_dense.bias, name="m")))
- beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
- self.assertAllEqual(3., self.evaluate(beta1_power))
+ beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
+ self.assertAllEqual(3., self.evaluate(beta_1_power))
def _write_name_based_checkpoint(self):
checkpoint_directory = self.get_temp_dir()
diff --git a/tensorflow/contrib/optimizer_v2/gradient_descent.py b/tensorflow/contrib/optimizer_v2/gradient_descent.py
index 945c8de559..8bdf408217 100644
--- a/tensorflow/contrib/optimizer_v2/gradient_descent.py
+++ b/tensorflow/contrib/optimizer_v2/gradient_descent.py
@@ -18,15 +18,17 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.optimizer_v2 import optimizer_v2
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.training import training_ops
+from tensorflow.python.keras.optimizer_v2 import sgd
+from tensorflow.python.util import deprecation
-class GradientDescentOptimizer(optimizer_v2.OptimizerV2):
+class GradientDescentOptimizer(sgd.SGD):
"""Optimizer that implements the gradient descent algorithm."""
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self, learning_rate, use_locking=False, name="GradientDescent"):
"""Construct a new gradient descent optimizer.
@@ -41,29 +43,5 @@ class GradientDescentOptimizer(optimizer_v2.OptimizerV2):
name: Optional name prefix for the operations created when applying
gradients. Defaults to "GradientDescent".
"""
- super(GradientDescentOptimizer, self).__init__(use_locking, name)
- self._set_hyper("learning_rate", learning_rate)
-
- def _apply_dense(self, grad, var, state):
- return training_ops.apply_gradient_descent(
- var,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- use_locking=self._use_locking).op
-
- def _resource_apply_dense(self, grad, handle, state):
- lr = state.get_hyper("learning_rate", grad.dtype.base_dtype)
- return training_ops.resource_apply_gradient_descent(
- handle.handle, lr, grad, use_locking=self._use_locking)
-
- def _resource_apply_sparse_duplicate_indices(
- self, grad, handle, indices, state):
- lr = state.get_hyper("learning_rate", grad.dtype.base_dtype)
- return resource_variable_ops.resource_scatter_add(
- handle.handle, indices, -grad * lr)
-
- def _apply_sparse_duplicate_indices(self, grad, var, state):
- delta = ops.IndexedSlices(
- grad.values * state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad.indices, grad.dense_shape)
- return var.scatter_sub(delta, use_locking=self._use_locking)
+ super(GradientDescentOptimizer, self).__init__(
+ learning_rate=learning_rate, name=name)
diff --git a/tensorflow/contrib/optimizer_v2/momentum.py b/tensorflow/contrib/optimizer_v2/momentum.py
index 0a5aadc2d1..0636f7e356 100644
--- a/tensorflow/contrib/optimizer_v2/momentum.py
+++ b/tensorflow/contrib/optimizer_v2/momentum.py
@@ -18,11 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.optimizer_v2 import optimizer_v2
-from tensorflow.python.training import training_ops
+from tensorflow.python.keras.optimizer_v2 import sgd
+from tensorflow.python.util import deprecation
-class MomentumOptimizer(optimizer_v2.OptimizerV2):
+class MomentumOptimizer(sgd.SGD):
"""Optimizer that implements the Momentum algorithm.
Computes (if `use_nesterov = False`):
@@ -39,6 +39,10 @@ class MomentumOptimizer(optimizer_v2.OptimizerV2):
when that part of the variable was used in the forward pass.
"""
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self, learning_rate, momentum,
use_locking=False, name="Momentum", use_nesterov=False):
"""Construct a new Momentum optimizer.
@@ -68,57 +72,8 @@ class MomentumOptimizer(optimizer_v2.OptimizerV2):
optimizer functions.
@end_compatibility
"""
- super(MomentumOptimizer, self).__init__(use_locking, name)
- self._set_hyper("learning_rate", learning_rate)
- self._set_hyper("momentum", momentum)
- self._use_nesterov = use_nesterov
-
- def _create_vars(self, var_list, state):
- for v in var_list:
- state.zeros_slot(v, "momentum")
-
- def _apply_dense(self, grad, var, state):
- mom = state.get_slot(var, "momentum")
- return training_ops.apply_momentum(
- var,
- mom,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- state.get_hyper("momentum", var.dtype.base_dtype),
- use_locking=self._use_locking,
- use_nesterov=self._use_nesterov).op
-
- def _resource_apply_dense(self, grad, var, state):
- mom = state.get_slot(var, "momentum")
- return training_ops.resource_apply_momentum(
- var.handle,
- mom.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- state.get_hyper("momentum", var.dtype.base_dtype),
- use_locking=self._use_locking,
- use_nesterov=self._use_nesterov)
-
- def _apply_sparse(self, grad, var, state):
- mom = state.get_slot(var, "momentum")
- return training_ops.sparse_apply_momentum(
- var,
- mom,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad.values,
- grad.indices,
- state.get_hyper("momentum", var.dtype.base_dtype),
- use_locking=self._use_locking,
- use_nesterov=self._use_nesterov).op
-
- def _resource_apply_sparse(self, grad, var, indices, state):
- mom = state.get_slot(var, "momentum")
- return training_ops.resource_sparse_apply_momentum(
- var.handle,
- mom.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- grad,
- indices,
- state.get_hyper("momentum", var.dtype.base_dtype),
- use_locking=self._use_locking,
- use_nesterov=self._use_nesterov)
+ super(MomentumOptimizer, self).__init__(
+ learning_rate=learning_rate,
+ momentum=momentum,
+ name=name,
+ nesterov=use_nesterov)
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index 53e27c08c4..9c98dd93b4 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -20,462 +20,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import abc
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.util import deprecation
-from tensorflow.python.eager import backprop
-from tensorflow.python.eager import context
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import gradients
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops import variables
-from tensorflow.python.training import distribute as distribute_lib
-from tensorflow.python.training import distribution_strategy_context
-from tensorflow.python.training import optimizer as optimizer_v1
-from tensorflow.python.training import slot_creator
-from tensorflow.python.training.checkpointable import base as checkpointable
-from tensorflow.python.util import nest
-
-class _OptimizableVariable(object):
- """Interface for abstracting over variables in the optimizers."""
-
- @abc.abstractmethod
- def target(self):
- """Returns the optimization target for this variable."""
- raise NotImplementedError("Calling an abstract method.")
-
- @abc.abstractmethod
- def update_op(self, optimizer, g, *args):
- """Returns the update ops for updating the variable."""
- raise NotImplementedError("Calling an abstract method.")
-
-
-class _RefVariableProcessor(_OptimizableVariable):
- """Processor for Variable."""
-
- def __init__(self, v):
- self._v = v
-
- def target(self):
- return self._v._ref() # pylint: disable=protected-access
-
- def update_op(self, optimizer, g, *args):
- if isinstance(g, ops.Tensor):
- update_op = optimizer._apply_dense(g, self._v, *args) # pylint: disable=protected-access
- if self._v.constraint is not None:
- with ops.control_dependencies([update_op]):
- return self._v.assign(self._v.constraint(self._v))
- else:
- return update_op
- else:
- assert isinstance(g, ops.IndexedSlices), ("Gradient ", g, " is neither a "
- "tensor nor IndexedSlices.")
- if self._v.constraint is not None:
- raise RuntimeError(
- "Cannot use a constraint function on a sparse variable.")
- # pylint: disable=protected-access
- return optimizer._apply_sparse_duplicate_indices(g, self._v, *args)
-
-
-class _DenseReadResourceVariableProcessor(_OptimizableVariable):
- """Processor for dense ResourceVariables."""
-
- def __init__(self, v):
- self._v = v
-
- def target(self):
- return self._v
-
- def update_op(self, optimizer, g, *args):
- # pylint: disable=protected-access
- update_op = optimizer._resource_apply_dense(g, self._v.op.inputs[0], *args)
- if self._v.constraint is not None:
- with ops.control_dependencies([update_op]):
- return self._v.assign(self._v.constraint(self._v))
- else:
- return update_op
-
-
-class _DenseResourceVariableProcessor(_OptimizableVariable):
- """Processor for dense ResourceVariables."""
-
- def __init__(self, v):
- self._v = v
-
- def target(self):
- return self._v
-
- def update_op(self, optimizer, g, *args):
- # pylint: disable=protected-access
- if isinstance(g, ops.IndexedSlices):
- if self._v.constraint is not None:
- raise RuntimeError(
- "Cannot use a constraint function on a sparse variable.")
- return optimizer._resource_apply_sparse_duplicate_indices(
- g.values, self._v, g.indices, *args)
- update_op = optimizer._resource_apply_dense(g, self._v, *args)
- if self._v.constraint is not None:
- with ops.control_dependencies([update_op]):
- return self._v.assign(self._v.constraint(self._v))
- else:
- return update_op
-
-
-class _TensorProcessor(_OptimizableVariable):
- """Processor for ordinary Tensors.
-
- Even though a Tensor can't really be updated, sometimes it is useful to
- compute the gradients with respect to a Tensor using the optimizer. Updating
- the Tensor is, of course, unsupported.
- """
-
- def __init__(self, v):
- self._v = v
-
- def target(self):
- return self._v
-
- def update_op(self, optimizer, g, *args):
- raise NotImplementedError("Trying to update a Tensor ", self._v)
-
-
-def _get_processor(v):
- """The processor of v."""
- if context.executing_eagerly():
- if isinstance(v, ops.Tensor):
- return _TensorProcessor(v)
- else:
- return _DenseResourceVariableProcessor(v)
- if v.op.type == "VarHandleOp":
- return _DenseResourceVariableProcessor(v)
- if isinstance(v, variables.Variable):
- return _RefVariableProcessor(v)
- if isinstance(v, ops.Tensor):
- return _TensorProcessor(v)
- raise NotImplementedError("Trying to optimize unsupported type ", v)
-
-
-def _var_key_v2(var):
- """Key for representing a primary variable, for looking up slots."""
- # pylint: disable=protected-access
- if hasattr(var, "_distributed_container"):
- distributed_container = var._distributed_container()
- assert distributed_container is not None
- if context.executing_eagerly():
- return distributed_container._unique_id
- return distributed_container._shared_name
- if context.executing_eagerly():
- return var._unique_id
- return var.op.name
-
-
-def _resolve(value, name):
- if callable(value):
- value = value()
- return ops.convert_to_tensor(value, name=name)
-
-
-def _is_dynamic(value):
- """Returns true if __init__ arg `value` should be re-evaluated each step."""
- if callable(value): return True
- # Don't need to do anything special in graph mode, since dynamic values
- # will propagate correctly automatically.
- # TODO(josh11b): Add per-device caching across steps using variables for
- # truly static values once we add distributed support.
- if context.executing_eagerly() and isinstance(
- value, resource_variable_ops.ResourceVariable):
- return True
- return False
-
-
-class _OptimizerV2State(object):
- """Holds per-graph and per-step optimizer state.
-
- Use _init_with_static_hyper() to create the state for a graph, and then
- _copy_with_dynamic_hyper() to convert that to state for a particular step.
- The difference between the two is that the former only has hyper
- parameter values that are static and the latter also has values that
- can change every step (according to _is_dynamic()).
- """
-
- def __init__(self, op_name):
- self._op_name = op_name
-
- def _init_with_static_hyper(self, hyper):
- """Initialize a fresh state object from hyper dict."""
- # self._hyper contains a dict from name to a dict with the Tensor values.
- # This dict starts with a single item with key "None" with the hyper
- # parameter value converted to a Tensor. Other items have dtype keys
- # with that Tensor cast to that dtype.
- with ops.init_scope():
- self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)}
- for name, (dynamic, value) in sorted(hyper.items())
- if not dynamic}
- self._slots = {}
- self._non_slot_dict = {}
- # Extra state to help Optimizers implement Checkpointable. Holds information
- # about variables which will be restored as soon as they're created.
- self._deferred_dependencies = {} # Non-slot variables
- self._deferred_slot_restorations = {} # Slot variables
-
- def _copy_with_dynamic_hyper(self, hyper, distribution, non_slot_devices):
- """Create a new state object for a particular step."""
- ret = _OptimizerV2State(self._op_name)
- # pylint: disable=protected-access
- ret._slots = self._slots
- ret._non_slot_dict = self._non_slot_dict
- ret._deferred_dependencies = self._deferred_dependencies
- ret._deferred_slot_restorations = self._deferred_slot_restorations
- ret._hyper = {name: {None: _resolve(value, name)}
- for name, (dynamic, value) in sorted(hyper.items())
- if dynamic}
- ret._hyper.update(self._hyper)
- ret._non_slot_devices = non_slot_devices
- ret._distribution = distribution
- return ret
-
- def _variables(self):
- """Returns a list of all variables held by self."""
- optimizer_variables = list(self._non_slot_dict.values())
- for variable_dict in self._slots.values():
- for slot_for_variable in variable_dict.values():
- optimizer_variables.append(slot_for_variable)
- # Sort variables by name so that the return is deterministic.
- return sorted(optimizer_variables, key=lambda v: v.name)
-
- def _slot_dict(self, slot_name):
- """Returns a dict for caching slots created under the given name.
-
- Args:
- slot_name: Name for the slot.
-
- Returns:
- A dict that maps primary `Variable` objects to the slot created
- for that variable, under the given slot name.
- """
- named_slots = self._slots.get(slot_name, None)
- if named_slots is None:
- named_slots = {}
- self._slots[slot_name] = named_slots
- return named_slots
-
- def create_slot(self, var, val, slot_name, optional_op_name=None):
- """Find or create a slot for a variable.
-
- Args:
- var: A `Variable` object.
- val: A `Tensor`. The initial value of the slot.
- slot_name: Name for the slot.
- optional_op_name: Name to use when scoping the Variable that
- needs to be created for the slot.
-
- Returns:
- A `Variable` object.
- """
- named_slots = self._slot_dict(slot_name)
- var_key = _var_key_v2(var)
- if var_key not in named_slots:
- new_slot_variable = slot_creator.create_slot(
- var, val, optional_op_name or self._op_name)
- self._restore_slot_variable(
- slot_name=slot_name, variable=var,
- slot_variable=new_slot_variable)
- named_slots[var_key] = new_slot_variable
- return named_slots[var_key]
-
- def create_slot_with_initializer(self, var, initializer, shape, dtype,
- slot_name, optional_op_name=None):
- """Find or create a slot for a variable, using an Initializer.
-
- Args:
- var: A `Variable` object.
- initializer: An `Initializer`. The initial value of the slot.
- shape: Shape of the initial value of the slot.
- dtype: Type of the value of the slot.
- slot_name: Name for the slot.
- optional_op_name: Name to use when scoping the Variable that
- needs to be created for the slot.
-
- Returns:
- A `Variable` object.
- """
- named_slots = self._slot_dict(slot_name)
- var_key = _var_key_v2(var)
- if var_key not in named_slots:
- new_slot_variable = slot_creator.create_slot_with_initializer(
- var, initializer, shape, dtype, optional_op_name or self._op_name)
- self._restore_slot_variable(
- slot_name=slot_name, variable=var,
- slot_variable=new_slot_variable)
- named_slots[var_key] = new_slot_variable
- return named_slots[var_key]
-
- def zeros_slot(self, var, slot_name, optional_op_name=None):
- """Find or create a slot initialized with 0.0.
-
- Args:
- var: A `Variable` object.
- slot_name: Name for the slot.
- optional_op_name: Name to use when scoping the Variable that
- needs to be created for the slot.
-
- Returns:
- A `Variable` object.
- """
- named_slots = self._slot_dict(slot_name)
- var_key = _var_key_v2(var)
- if var_key not in named_slots:
- new_slot_variable = slot_creator.create_zeros_slot(
- var, optional_op_name or self._op_name)
- self._restore_slot_variable(
- slot_name=slot_name, variable=var,
- slot_variable=new_slot_variable)
- named_slots[var_key] = new_slot_variable
- return named_slots[var_key]
-
- def _create_or_restore_slot_variable(
- self, slot_variable_position, slot_name, variable,
- optional_op_name=None):
- """Restore a slot variable's value, possibly creating it.
-
- Called when a variable which has an associated slot variable is created or
- restored. When executing eagerly, we create the slot variable with a
- restoring initializer.
-
- No new variables are created when graph building. Instead,
- _restore_slot_variable catches these after normal creation and adds restore
- ops to the graph. This method is nonetheless important when graph building
- for the case when a slot variable has already been created but `variable`
- has just been added to a dependency graph (causing us to realize that the
- slot variable needs to be restored).
-
- Args:
- slot_variable_position: A `checkpointable._CheckpointPosition` object
- indicating the slot variable `Checkpointable` object to be restored.
- slot_name: The name of this `Optimizer`'s slot to restore into.
- variable: The variable object this slot is being created for.
- optional_op_name: Name to use when scoping the Variable that
- needs to be created for the slot.
- """
- slot_variable = self.get_slot(var=variable, name=slot_name)
- if (slot_variable is None and context.executing_eagerly() and
- slot_variable_position.is_simple_variable()
- # Defer slot variable creation if there is an active variable creator
- # scope. Generally we'd like to eagerly create/restore slot variables
- # when possible, but this may mean that scopes intended to catch
- # `variable` also catch its eagerly created slot variable
- # unintentionally (specifically make_template would add a dependency on
- # a slot variable if not for this case). Deferring is mostly harmless
- # (aside from double initialization), and makes variable creator scopes
- # behave the same way they do when graph building.
- and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access
- initializer = checkpointable.CheckpointInitialValue(
- checkpoint_position=slot_variable_position)
- slot_variable = self.create_slot(
- var=variable,
- val=initializer,
- slot_name=slot_name,
- optional_op_name=optional_op_name)
- # Optimizers do not have unconditional dependencies on their slot
- # variables (nor do any other objects). They are only saved if the
- # variables they were created for are also saved.
- if slot_variable is not None:
- # If we've either made this slot variable, or if we've pulled out an
- # existing slot variable, we should restore it.
- slot_variable_position.restore(slot_variable)
- else:
- # We didn't make the slot variable. Defer restoring until it gets created
- # normally. We keep a list rather than the one with the highest restore
- # UID in case slot variables have their own dependencies, in which case
- # those could differ between restores.
- variable_key = _var_key_v2(variable)
- self._deferred_slot_restorations.setdefault(
- slot_name, {}).setdefault(variable_key, []).append(
- slot_variable_position)
-
- def get_slot(self, var, name):
- """Return a slot named `name` created for `var` by the Optimizer.
-
- Some `Optimizer` subclasses use additional variables. For example
- `Momentum` and `Adagrad` use variables to accumulate updates. This method
- gives access to these `Variable` objects if for some reason you need them.
-
- Use `get_slot_names()` to get the list of slot names created by the
- `Optimizer`.
-
- Args:
- var: A variable passed to `minimize()` or `apply_gradients()`.
- name: A string.
-
- Returns:
- The `Variable` for the slot if it was created, `None` otherwise.
- """
- named_slots = self._slots.get(name, None)
- if not named_slots:
- return None
- return named_slots.get(_var_key_v2(var), None)
-
- def get_slot_names(self):
- """Return a list of the names of slots created by the `Optimizer`.
-
- See `get_slot()`.
-
- Returns:
- A list of strings.
- """
- return sorted(self._slots.keys())
-
- def create_non_slot(self, initial_value, name, colocate_with=None):
- """Add an extra variable, not associated with a slot."""
- v = self._non_slot_dict.get(name, None)
- if v is None:
- if colocate_with is None: colocate_with = self._non_slot_devices
- with self._distribution.colocate_vars_with(colocate_with):
- # TODO(josh11b): Use get_variable() except for the legacy Adam use case.
- v = variable_scope.variable(initial_value, name=name, trainable=False)
- self._non_slot_dict[name] = v
- deferred_dependencies_list = self._deferred_dependencies.pop(name, ())
- for checkpoint_position in sorted(
- deferred_dependencies_list,
- key=lambda restore: restore.checkpoint.restore_uid,
- reverse=True):
- checkpoint_position.restore(v)
- return v
-
- def _restore_slot_variable(self, slot_name, variable, slot_variable):
- """Restore a newly created slot variable's value."""
- variable_key = _var_key_v2(variable)
- deferred_restorations = self._deferred_slot_restorations.get(
- slot_name, {}).pop(variable_key, [])
- # Iterate over restores, highest restore UID first to minimize the number
- # of assignments.
- deferred_restorations.sort(key=lambda position: position.restore_uid,
- reverse=True)
- for checkpoint_position in deferred_restorations:
- checkpoint_position.restore(slot_variable)
-
- def get_non_slot(self, name):
- """Returns the non-slot variable identified by `name`."""
- return self._non_slot_dict.get(name, None)
-
- def get_hyper(self, name, dtype=None):
- """Returns the `name` hyper parameter, optionally cast to `dtype`."""
- dtype_dict = self._hyper[name]
- # Do we have the value cast to dtype already cached? This should always
- # succeed when dtype is None.
- if dtype in dtype_dict:
- return dtype_dict[dtype]
- # Not cached, cast to dtype and save the result in the cache.
- result = math_ops.cast(dtype_dict[None], dtype)
- dtype_dict[dtype] = result
- return result
-
-
-class OptimizerV2(optimizer_v1.Optimizer):
+class OptimizerV2(optimizer_v2.OptimizerV2):
"""Updated base class for optimizers.
This class defines the API to add Ops to train a model. You never use this
@@ -586,6 +135,10 @@ class OptimizerV2(optimizer_v1.Optimizer):
GATE_OP = 1
GATE_GRAPH = 2
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self, use_locking, name):
"""Create a new Optimizer.
@@ -606,746 +159,4 @@ class OptimizerV2(optimizer_v1.Optimizer):
RuntimeError: If _create_slots has been overridden instead of
_create_vars.
"""
- # Note: We intentionally don't call parent __init__.
-
- # Optimizer._create_slots was replaced by _create_vars in OptimizerV2.
- if (self.__class__._create_slots.__code__ is not # pylint: disable=protected-access
- OptimizerV2._create_slots.__code__):
- raise RuntimeError("Override _create_vars instead of _create_slots when "
- "descending from OptimizerV2 (class %s)" %
- self.__class__.__name__)
- if not name:
- raise ValueError("Must specify the optimizer name")
-
- self._use_locking = use_locking
- self._name = name
- # Map from graph_key to state for that graph. We use the graph_key
- # since it works in both eager and graph mode, and gives the outer
- # graph inside functions.
- tower_context = distribution_strategy_context.get_tower_context()
- if tower_context is None:
- # In a cross-tower context for a DistributionStrategy, which means
- # only one Optimizer will be created, not one per tower.
- self._per_graph_state = {}
- else:
- # We use get_tower_context().merge_call() to get a single dict
- # shared across all model replicas when running with a
- # DistributionStrategy.
- self._per_graph_state = tower_context.merge_call(lambda _: {})
-
- # Hyper parameters, and whether they should be re-evaluated every step.
- self._hyper = {}
-
- def _set_hyper(self, name, value):
- self._hyper[name] = (_is_dynamic(value), value)
-
- def minimize(self, loss, global_step=None, var_list=None,
- gate_gradients=GATE_OP, aggregation_method=None,
- colocate_gradients_with_ops=False, name=None,
- grad_loss=None, stop_gradients=None,
- scale_loss_by_num_towers=None):
- """Add operations to minimize `loss` by updating `var_list`.
-
- This method simply combines calls `compute_gradients()` and
- `apply_gradients()`. If you want to process the gradient before applying
- them call `compute_gradients()` and `apply_gradients()` explicitly instead
- of using this function.
-
- Args:
- loss: A `Tensor` containing the value to minimize.
- global_step: Optional `Variable` to increment by one after the
- variables have been updated.
- var_list: Optional list or tuple of `Variable` objects to update to
- minimize `loss`. Defaults to the list of variables collected in
- the graph under the key `GraphKeys.TRAINABLE_VARIABLES`.
- gate_gradients: How to gate the computation of gradients. Can be
- `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
- aggregation_method: Specifies the method used to combine gradient terms.
- Valid values are defined in the class `AggregationMethod`.
- colocate_gradients_with_ops: If True, try colocating gradients with
- the corresponding op.
- name: Optional name for the returned operation.
- grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
- stop_gradients: Optional. A Tensor or list of tensors not to differentiate
- through.
- scale_loss_by_num_towers: Optional boolean. If true, scale the loss
- down by the number of towers. By default, auto-detects whether this
- is needed.
-
- Returns:
- An Operation that updates the variables in `var_list`. If `global_step`
- was not `None`, that operation also increments `global_step`.
-
- Raises:
- ValueError: If some of the variables are not `Variable` objects.
-
- @compatibility(eager)
- When eager execution is enabled, `loss` should be a Python function that
- takes elements of `var_list` as arguments and computes the value to be
- minimized. If `var_list` is None, `loss` should take no arguments.
- Minimization (and gradient computation) is done with respect to the
- elements of `var_list` if not None, else with respect to any trainable
- variables created during the execution of the `loss` function.
- `gate_gradients`, `aggregation_method`, `colocate_gradients_with_ops` and
- `grad_loss` are ignored when eager execution is enabled.
- @end_compatibility
- """
- grads_and_vars = self.compute_gradients(
- loss, var_list=var_list, gate_gradients=gate_gradients,
- aggregation_method=aggregation_method,
- colocate_gradients_with_ops=colocate_gradients_with_ops,
- grad_loss=grad_loss, stop_gradients=stop_gradients,
- scale_loss_by_num_towers=scale_loss_by_num_towers)
-
- vars_with_grad = [v for g, v in grads_and_vars if g is not None]
- if not vars_with_grad:
- raise ValueError(
- "No gradients provided for any variable, check your graph for ops"
- " that do not support gradients, between variables %s and loss %s." %
- ([str(v) for _, v in grads_and_vars], loss))
-
- return self.apply_gradients(grads_and_vars, global_step=global_step,
- name=name)
-
- def compute_gradients(self, loss, var_list=None,
- gate_gradients=GATE_OP,
- aggregation_method=None,
- colocate_gradients_with_ops=False,
- grad_loss=None, stop_gradients=None,
- scale_loss_by_num_towers=None):
- """Compute gradients of `loss` for the variables in `var_list`.
-
- This is the first part of `minimize()`. It returns a list
- of (gradient, variable) pairs where "gradient" is the gradient
- for "variable". Note that "gradient" can be a `Tensor`, an
- `IndexedSlices`, or `None` if there is no gradient for the
- given variable.
-
- Args:
- loss: A Tensor containing the value to minimize or a callable taking
- no arguments which returns the value to minimize. When eager execution
- is enabled it must be a callable.
- var_list: Optional list or tuple of `tf.Variable` to update to minimize
- `loss`. Defaults to the list of variables collected in the graph
- under the key `GraphKeys.TRAINABLE_VARIABLES`.
- gate_gradients: How to gate the computation of gradients. Can be
- `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
- aggregation_method: Specifies the method used to combine gradient terms.
- Valid values are defined in the class `AggregationMethod`.
- colocate_gradients_with_ops: If True, try colocating gradients with
- the corresponding op.
- grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
- stop_gradients: Optional. A Tensor or list of tensors not to differentiate
- through.
- scale_loss_by_num_towers: Optional boolean. If true, scale the loss
- down by the number of towers. By default, auto-detects whether this
- is needed.
-
- Returns:
- A list of (gradient, variable) pairs. Variable is always present, but
- gradient can be `None`.
-
- Raises:
- TypeError: If `var_list` contains anything else than `Variable` objects.
- ValueError: If some arguments are invalid.
- RuntimeError: If called with eager execution enabled and `loss` is
- not callable.
-
- @compatibility(eager)
- When eager execution is enabled, `gate_gradients`, `aggregation_method`,
- and `colocate_gradients_with_ops` are ignored.
- @end_compatibility
- """
- # TODO(josh11b): Test that we handle weight decay in a reasonable way.
- if callable(loss):
- with backprop.GradientTape() as tape:
- if var_list is not None:
- tape.watch(var_list)
- loss_value = loss()
-
- # Scale loss for number of towers (callable-loss case). In this case,
- # we have to be careful to call distribute_lib.get_loss_reduction()
- # *after* loss() is evaluated, so we know what loss reduction it uses.
- if scale_loss_by_num_towers is None:
- scale_loss_by_num_towers = (
- distribute_lib.get_loss_reduction() ==
- variable_scope.VariableAggregation.MEAN)
- if scale_loss_by_num_towers:
- num_towers = distribution_strategy_context.get_distribution_strategy(
- ).num_towers
- if num_towers > 1:
- loss_value *= 1. / num_towers
-
- if var_list is None:
- var_list = tape.watched_variables()
- grads = tape.gradient(loss_value, var_list, grad_loss)
- return list(zip(grads, var_list))
- if context.executing_eagerly():
- raise RuntimeError(
- "`loss` passed to Optimizer.compute_gradients should "
- "be a function when eager execution is enabled.")
-
- # Scale loss for number of towers (non-callable-loss case).
- if scale_loss_by_num_towers is None:
- scale_loss_by_num_towers = (
- distribute_lib.get_loss_reduction() ==
- variable_scope.VariableAggregation.MEAN)
- if scale_loss_by_num_towers:
- num_towers = distribution_strategy_context.get_distribution_strategy(
- ).num_towers
- if num_towers > 1:
- loss *= 1. / num_towers
-
- if gate_gradients not in [optimizer_v1.Optimizer.GATE_NONE,
- optimizer_v1.Optimizer.GATE_OP,
- optimizer_v1.Optimizer.GATE_GRAPH]:
- raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
- "Optimizer.GATE_OP, Optimizer.GATE_GRAPH. Not %s" %
- gate_gradients)
- self._assert_valid_dtypes([loss])
- if grad_loss is not None:
- self._assert_valid_dtypes([grad_loss])
- if var_list is None:
- var_list = (
- variables.trainable_variables() +
- ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
- else:
- var_list = nest.flatten(var_list)
- # pylint: disable=protected-access
- var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS)
- # pylint: enable=protected-access
- processors = [_get_processor(v) for v in var_list]
- if not var_list:
- raise ValueError("No variables to optimize.")
- var_refs = [p.target() for p in processors]
- grads = gradients.gradients(
- loss, var_refs, grad_ys=grad_loss,
- gate_gradients=(gate_gradients == optimizer_v1.Optimizer.GATE_OP),
- aggregation_method=aggregation_method,
- colocate_gradients_with_ops=colocate_gradients_with_ops,
- stop_gradients=stop_gradients)
- if gate_gradients == optimizer_v1.Optimizer.GATE_GRAPH:
- grads = control_flow_ops.tuple(grads)
- grads_and_vars = list(zip(grads, var_list))
- self._assert_valid_dtypes(
- [v for g, v in grads_and_vars
- if g is not None and v.dtype != dtypes.resource])
- return grads_and_vars
-
- def apply_gradients(self, grads_and_vars, global_step=None, name=None):
- """Apply gradients to variables.
-
- This is the second part of `minimize()`. It returns an `Operation` that
- applies gradients.
-
- Args:
- grads_and_vars: List of (gradient, variable) pairs as returned by
- `compute_gradients()`.
- global_step: Optional `Variable` to increment by one after the
- variables have been updated.
- name: Optional name for the returned operation. Default to the
- name passed to the `Optimizer` constructor.
-
- Returns:
- An `Operation` that applies the specified gradients. If `global_step`
- was not None, that operation also increments `global_step`.
-
- Raises:
- TypeError: If `grads_and_vars` is malformed.
- ValueError: If none of the variables have gradients.
- """
- # This is a default implementation of apply_gradients() that can be shared
- # by most optimizers. It relies on the subclass implementing the following
- # methods: _create_vars(), _prepare(), _apply_dense(), and _apply_sparse().
-
- # Filter out variables with gradients of `None`.
- grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works.
- if not grads_and_vars:
- raise ValueError("No variables provided.")
- filtered = tuple((g, v) for (g, v) in grads_and_vars if g is not None)
- if not filtered:
- raise ValueError("No gradients provided for any variable: %s." %
- ([str(v) for _, v in grads_and_vars],))
- return distribution_strategy_context.get_tower_context().merge_call(
- self._distributed_apply, filtered, global_step=global_step, name=name)
-
- def _get_or_create_state(self, var_list=None):
- """Either looks up or creates `_OptimizerV2State`.
-
- If any variables are available, they should be passed via the `var_list`
- argument, and these will be used to determine the graph to create/retrieve
- state for. Otherwise the returned state is for the current default graph.
-
- Args:
- var_list: A list of variables to extract a graph from.
-
- Returns:
- An `_OptimizerV2State` object.
- """
- # Determine the graph_key from the current graph.
- eager_execution = context.executing_eagerly()
- if eager_execution or var_list is None:
- graph = ops.get_default_graph()
- else:
- graph = ops._get_graph_from_inputs(var_list) # pylint: disable=protected-access
- assert graph is not None
- graph_key = graph._graph_key # pylint: disable=protected-access
-
- # Get the per graph state by looking up the graph_key.
- if graph_key in self._per_graph_state:
- per_graph_state = self._per_graph_state[graph_key]
- else:
- per_graph_state = _OptimizerV2State(self._name)
- per_graph_state._init_with_static_hyper(self._hyper) # pylint: disable=protected-access
- self._per_graph_state[graph_key] = per_graph_state
- return per_graph_state
-
- def _distributed_apply(self, distribution, grads_and_vars, global_step, name):
- """`apply_gradients` for use with a `DistributionStrategy`."""
- reduced_grads = distribution.batch_reduce(
- variable_scope.VariableAggregation.SUM, grads_and_vars)
- var_list = [v for _, v in grads_and_vars]
- grads_and_vars = zip(reduced_grads, var_list)
-
- unwrapped_var_list = [x for v in var_list for x in distribution.unwrap(v)]
- eager_execution = context.executing_eagerly()
- if eager_execution:
- # Give a clear error in this case instead of "name not supported
- # for Eager Tensors" when we compute non_slot_devices.
- for v in unwrapped_var_list:
- if isinstance(v, ops.Tensor):
- raise NotImplementedError("Trying to update a Tensor ", v)
-
- with ops.name_scope(name, self._name) as name:
- per_graph_state = self._get_or_create_state(var_list=unwrapped_var_list)
- # Include the current value of any dynamic hyper parameters in `state`.
- non_slot_devices = distribution.non_slot_devices(var_list)
- state = per_graph_state._copy_with_dynamic_hyper( # pylint: disable=protected-access
- self._hyper, distribution, non_slot_devices)
-
- # Create any slot and non-slot variables we need in `state`.
- with ops.init_scope():
- self._create_vars(var_list, state)
-
- with ops.name_scope(name): # Re-enter name_scope created above
- # Give the child class a chance to do something before we start
- # applying gradients.
- self._prepare(state)
-
- def update(v, g):
- """Update variable `v` using gradient `g`."""
- assert v is not None
-
- # Convert the grad to Tensor or IndexedSlices if necessary, and
- # look up a processor for each variable's type.
- try:
- g = ops.convert_to_tensor_or_indexed_slices(g)
- except TypeError:
- raise TypeError(
- "Gradient must be convertible to a Tensor"
- " or IndexedSlices, or None: %s" % g)
- if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
- raise TypeError(
- "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
- processor = _get_processor(v)
-
- # We colocate all ops created in _apply_dense or _apply_sparse
- # on the same device as the variable.
- # TODO(apassos): figure out how to get the variable name here.
- scope_name = "" if eager_execution else v.op.name
- # device_policy is set because non-mirrored tensors will be read in
- # `update_op`.
- # TODO(josh11b): Make different state objects for each device to
- # avoid needing to set the device_policy.
- with ops.name_scope("update_" + scope_name), \
- context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
- return processor.update_op(self, g, state)
-
- # Use the processors to update the variables.
- update_ops = []
- for grad, var in grads_and_vars:
- update_ops.extend(distribution.update(var, update, grad, grouped=False))
-
- # Give the child class a chance to do something after applying
- # gradients
- def finish():
- # TODO(josh11b): Make different state objects for each device to
- # avoid needing to set the device_policy.
- with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
- return self._finish(state)
-
- update_ops = control_flow_ops.group(update_ops)
- with ops.control_dependencies([update_ops]):
- finish_updates = distribution.update_non_slot(
- non_slot_devices, finish, grouped=False)
- # We said grouped=False, which means finish_updates is always a list.
- # It will be [None] when finish() returns None.
- if finish_updates == [None]:
- finish_updates = [update_ops]
-
- # Update `global_step` (if any).
- if global_step is None:
- apply_updates = distribution.group(finish_updates, name=name)
- else:
- with ops.control_dependencies(finish_updates):
-
- def update_global_step(global_step, name):
- return global_step.assign_add(1, read_value=False, name=name)
-
- apply_updates = distribution.update(
- global_step, update_global_step, name)
-
- # Add the training op to the TRAIN_OP graph collection in graph mode.
- if not eager_execution:
- if isinstance(apply_updates, ops.Tensor):
- apply_updates = apply_updates.op
- train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
- if apply_updates not in train_op:
- train_op.append(apply_updates)
-
- return apply_updates
-
- def get_slot(self, var, name):
- """Return a slot named `name` created for `var` by the Optimizer.
-
- Some `Optimizer` subclasses use additional variables. For example
- `Momentum` and `Adagrad` use variables to accumulate updates. This method
- gives access to these `Variable` objects if for some reason you need them.
-
- Use `get_slot_names()` to get the list of slot names created by the
- `Optimizer`.
-
- Args:
- var: A variable passed to `minimize()` or `apply_gradients()`.
- name: A string.
-
- Returns:
- The `Variable` for the slot if it was created, `None` otherwise.
- """
- state = self._get_state_for_var(var)
- return state.get_slot(var, name) if state is not None else None
-
- def get_slot_names(self):
- """Return a list of the names of slots created by the `Optimizer`.
-
- See `get_slot()`.
-
- Returns:
- A list of strings.
- """
- state = self._get_per_graph_state()
- return state.get_slot_names() if state is not None else []
-
- def variables(self):
- """A list of variables which encode the current state of `Optimizer`.
-
- Includes slot variables and additional global variables created by the
- optimizer in the current default graph.
-
- Returns:
- A list of variables.
- """
- state = self._get_per_graph_state()
- return state._variables() if state is not None else [] # pylint: disable=protected-access
-
- # --------------
- # Methods to be implemented by subclasses if they want to use the
- # inherited implementation of apply_gradients() or compute_gradients().
- # --------------
- def _create_vars(self, var_list, state):
- """Create all slots needed by the variables and any non-slot variables.
-
- Args:
- var_list: A list of `Variable` objects.
- state: An object with these methods:
- `create_slot(var, val, slot_name, optional_op_name)`,
- `create_slot_with_initializer(`
- `var, initializer, shape, dtype, slot_name, optional_op_name)`,
- `zeros_slot(var, slot_name, optional_op_name)`,
- `create_non_slot_variable(initial_value, name, colocate_with)`,
- `get_hyper(name)`
- """
- # No slots needed by default
- pass
-
- def _prepare(self, state):
- """Code to execute before applying gradients.
-
- Note that most uses of _prepare() in Optimizer have been subsumed
- by explicit support for hyper parameters in OptimizerV2
-
- Args:
- state: An object with a `get_hyper(name)` method.
-
- Returns:
- Return value will be ignored.
- """
- pass
-
- def _apply_dense(self, grad, var, state):
- """Add ops to apply dense gradients to `var`.
-
- Args:
- grad: A `Tensor`.
- var: A `Variable` object.
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- An `Operation`.
- """
- raise NotImplementedError()
-
- def _resource_apply_dense(self, grad, handle, state):
- """Add ops to apply dense gradients to the variable `handle`.
-
- Args:
- grad: a `Tensor` representing the gradient.
- handle: a `Tensor` of dtype `resource` which points to the variable
- to be updated.
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- An `Operation` which updates the value of the variable.
- """
- raise NotImplementedError()
-
- def _resource_apply_sparse_duplicate_indices(
- self, grad, handle, indices, state):
- """Add ops to apply sparse gradients to `handle`, with repeated indices.
-
- Optimizers which override this method must deal with repeated indices. See
- the docstring of `_apply_sparse_duplicate_indices` for details. By default
- the correct behavior, to sum non-unique indices and their associated
- gradients, is enforced by first pre-processing `grad` and `indices` and
- passing them on to `_resource_apply_sparse`. Optimizers which deal correctly
- with duplicate indices may instead override this method to avoid the
- overhead of summing.
-
- Args:
- grad: a `Tensor` representing the gradient for the affected indices.
- handle: a `Tensor` of dtype `resource` which points to the variable
- to be updated.
- indices: a `Tensor` of integral type representing the indices for
- which the gradient is nonzero. Indices may be repeated.
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- An `Operation` which updates the value of the variable.
- """
- # pylint: disable=protected-access
- summed_grad, unique_indices = optimizer_v1._deduplicate_indexed_slices(
- values=grad, indices=indices)
- # pylint: enable=protected-access
- return self._resource_apply_sparse(
- summed_grad, handle, unique_indices, state)
-
- def _resource_apply_sparse(self, grad, handle, indices, state):
- """Add ops to apply sparse gradients to the variable `handle`.
-
- Similar to `_apply_sparse`, the `indices` argument to this method has been
- de-duplicated. Optimizers which deal correctly with non-unique indices may
- instead override `_resource_apply_sparse_duplicate_indices` to avoid this
- overhead.
-
- Args:
- grad: a `Tensor` representing the gradient for the affected indices.
- handle: a `Tensor` of dtype `resource` which points to the variable
- to be updated.
- indices: a `Tensor` of integral type representing the indices for
- which the gradient is nonzero. Indices are unique.
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- An `Operation` which updates the value of the variable.
- """
- raise NotImplementedError()
-
- def _apply_sparse_duplicate_indices(self, grad, var, state):
- """Add ops to apply sparse gradients to `var`, with repeated sparse indices.
-
- Optimizers which override this method must deal with IndexedSlices objects
- such as the following:
-
- IndexedSlicesValue(values=[1, 1], indices=[0, 0], dense_shape=[1])
-
- The correct interpretation is:
-
- IndexedSlicesValue(values=[2], indices=[0], dense_shape=[1])
-
- Many optimizers deal incorrectly with repeated indices when updating based
- on sparse gradients (e.g. summing squares rather than squaring the sum, or
- applying momentum terms multiple times). Adding first is always the correct
- behavior, so this is enforced here by reconstructing the IndexedSlices to
- have only unique indices, then calling _apply_sparse.
-
- Optimizers which deal correctly with repeated indices may instead override
- this method to avoid the overhead of summing indices.
-
- Args:
- grad: `IndexedSlices`.
- var: A `Variable` object.
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- An `Operation`.
- """
- # pylint: disable=protected-access
- summed_values, unique_indices = optimizer_v1._deduplicate_indexed_slices(
- values=grad.values, indices=grad.indices)
- # pylint: enable=protected-access
- gradient_no_duplicate_indices = ops.IndexedSlices(
- indices=unique_indices,
- values=summed_values,
- dense_shape=grad.dense_shape)
- return self._apply_sparse(gradient_no_duplicate_indices, var, state)
-
- def _apply_sparse(self, grad, var, state):
- """Add ops to apply sparse gradients to `var`.
-
- The IndexedSlices object passed to `grad` in this function is by default
- pre-processed in `_apply_sparse_duplicate_indices` to remove duplicate
- indices (see its docstring for details). Optimizers which can tolerate or
- have correct special cases for duplicate sparse indices may override
- `_apply_sparse_duplicate_indices` instead of this function, avoiding that
- overhead.
-
- Args:
- grad: `IndexedSlices`, with no repeated indices.
- var: A `Variable` object.
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- An `Operation`.
- """
- raise NotImplementedError()
-
- def _finish(self, state):
- """Do what is needed to finish the update.
-
- This is called inside a scope colocated with any non-slot variables.
-
- Args:
- state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
- and `get_hyper(name)` methods.
-
- Returns:
- The operation to apply updates, or None if no updates.
- """
- return None
-
- # --------------
- # Utility methods for subclasses.
- # --------------
- def _get_per_graph_state(self):
- # pylint: disable=protected-access
- return self._per_graph_state.get(ops.get_default_graph()._graph_key, None)
-
- def _get_state_for_var(self, var):
- # pylint: disable=protected-access
- return self._per_graph_state.get(var._graph_key, None)
-
- # --------------
- # Overridden methods from Checkpointable.
- # --------------
-
- def _track_checkpointable(self, *args, **kwargs):
- """Optimizers may not track dependencies. Raises an error."""
- raise NotImplementedError(
- "Optimizers may not have dependencies. File a feature request if this "
- "limitation bothers you.")
-
- @property
- def _checkpoint_dependencies(self):
- """From Checkpointable. Gather graph-specific non-slot variables to save."""
- current_graph_non_slot_variables = []
- state = self._get_per_graph_state()
- if state is not None:
- for name, variable_object in sorted(
- state._non_slot_dict.items(), # pylint: disable=protected-access
- # Avoid comparing variables
- key=lambda item: item[0]):
- current_graph_non_slot_variables.append(
- checkpointable.CheckpointableReference(
- name=name, ref=variable_object))
- # Note: ignores super(); Optimizers may not have any dependencies outside of
- # state objects.
- return current_graph_non_slot_variables
-
- def _lookup_dependency(self, name):
- """From Checkpointable. Find a non-slot variable in the current graph."""
- state = self._get_per_graph_state()
- if state is None:
- return None
- else:
- return state.get_non_slot(name)
-
- @property
- def _deferred_dependencies(self):
- """Lets Checkpointable know where non-slot variables are created.
-
- If necessary, creates a new state object for the current default graph.
- Checkpointable will then add entries to that state's deferred dependency
- dictionary. The state object will check that dictionary when creating
- non-slot variables, restoring their value if an entry is found.
-
- Returns:
- A dictionary which holds deferred dependencies for the current default
- graph.
- """
- state = self._get_or_create_state()
- return state._deferred_dependencies # pylint: disable=protected-access
-
- def _create_or_restore_slot_variable(
- self, slot_variable_position, slot_name, variable):
- """Checkpointable: Restore a slot variable's value, possibly creating it.
-
- Called when a variable which has an associated slot variable is created or
- restored.
-
- Args:
- slot_variable_position: A `checkpointable._CheckpointPosition` object
- indicating the slot variable `Checkpointable` object to be restored.
- slot_name: The name of this `Optimizer`'s slot to restore into.
- variable: The variable object this slot is being created for.
- """
- state = self._get_or_create_state(var_list=[variable])
- state._create_or_restore_slot_variable( # pylint: disable=protected-access
- slot_variable_position=slot_variable_position,
- slot_name=slot_name,
- variable=variable,
- optional_op_name=self._name)
-
- # --------------
- # Unsupported parent methods
- # --------------
- def _slot_dict(self, slot_name):
- raise NotImplementedError(
- "_slot_dict() method unsupported in OptimizerV2")
-
- def _get_or_make_slot(self, var, val, slot_name, op_name):
- raise NotImplementedError(
- "_get_or_make_slot() method unsupported in OptimizerV2")
-
- def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype,
- slot_name, op_name):
- raise NotImplementedError(
- "_get_or_make_slot_with_initializer() method unsupported in "
- "OptimizerV2")
-
- def _create_non_slot_variable(self, initial_value, name, colocate_with):
- raise NotImplementedError(
- "_create_non_slot_variable() method unsupported in OptimizerV2")
-
- def _get_non_slot_variable(self, name, graph=None):
- raise NotImplementedError(
- "_get_non_slot_variable() method unsupported in OptimizerV2")
-
- def _non_slot_variables(self):
- raise NotImplementedError(
- "_non_slot_variables() method unsupported in OptimizerV2")
+ super(OptimizerV2, self).__init__(name)
diff --git a/tensorflow/contrib/optimizer_v2/rmsprop.py b/tensorflow/contrib/optimizer_v2/rmsprop.py
index 3de53405ec..090e257ddc 100644
--- a/tensorflow/contrib/optimizer_v2/rmsprop.py
+++ b/tensorflow/contrib/optimizer_v2/rmsprop.py
@@ -41,19 +41,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.optimizer_v2 import optimizer_v2
-from tensorflow.python.ops import array_ops
+from tensorflow.python.keras.optimizer_v2 import rmsprop
+from tensorflow.python.util import deprecation
-from tensorflow.python.training import training_ops
-
-class RMSPropOptimizer(optimizer_v2.OptimizerV2):
+class RMSPropOptimizer(rmsprop.RMSProp):
"""Optimizer that implements the RMSProp algorithm.
See the
[paper](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf).
"""
+ @deprecation.deprecated_args(
+ "2018-10-01",
+ "`use_locking = True` is no longer supported and will be ignored.",
+ ("use_locking", [False]))
def __init__(self,
learning_rate,
decay=0.9,
@@ -96,138 +98,10 @@ class RMSPropOptimizer(optimizer_v2.OptimizerV2):
name: Optional name prefix for the operations created when applying
gradients. Defaults to "RMSProp".
"""
- super(RMSPropOptimizer, self).__init__(use_locking, name)
- self._set_hyper("learning_rate", learning_rate)
- self._set_hyper("decay", decay)
- self._set_hyper("momentum", momentum)
- self._set_hyper("epsilon", epsilon)
-
- self._centered = centered
-
- def _create_vars(self, var_list, state):
- for v in var_list:
- init_rms = state.get_hyper(
- "epsilon", v.dtype.base_dtype) * array_ops.ones_like(v)
- state.create_slot_with_initializer(v, init_rms, v.get_shape(),
- v.dtype.base_dtype, "rms")
- if self._centered:
- state.zeros_slot(v, "mg")
- state.zeros_slot(v, "momentum")
-
- def _apply_dense(self, grad, var, state):
- rms = state.get_slot(var, "rms")
- mom = state.get_slot(var, "momentum")
- if self._centered:
- mg = state.get_slot(var, "mg")
- return training_ops.apply_centered_rms_prop(
- var,
- mg,
- rms,
- mom,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- # epsilon is now the rms initial value and is not added to the
- # denominator anymore, hence calling the kernel op with epsilon=0.
- 0,
- grad,
- use_locking=self._use_locking).op
- else:
- return training_ops.apply_rms_prop(
- var,
- rms,
- mom,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad,
- use_locking=self._use_locking).op
-
- def _resource_apply_dense(self, grad, var, state):
- rms = state.get_slot(var, "rms")
- mom = state.get_slot(var, "momentum")
- if self._centered:
- mg = state.get_slot(var, "mg")
- return training_ops.resource_apply_centered_rms_prop(
- var.handle,
- mg.handle,
- rms.handle,
- mom.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad,
- use_locking=self._use_locking)
- else:
- return training_ops.resource_apply_rms_prop(
- var.handle,
- rms.handle,
- mom.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad,
- use_locking=self._use_locking)
-
- def _apply_sparse(self, grad, var, state):
- rms = state.get_slot(var, "rms")
- mom = state.get_slot(var, "momentum")
- if self._centered:
- mg = state.get_slot(var, "mg")
- return training_ops.sparse_apply_centered_rms_prop(
- var,
- mg,
- rms,
- mom,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad.values,
- grad.indices,
- use_locking=self._use_locking)
- else:
- return training_ops.sparse_apply_rms_prop(
- var,
- rms,
- mom,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad.values,
- grad.indices,
- use_locking=self._use_locking)
-
- def _resource_apply_sparse(self, grad, var, indices, state):
- rms = state.get_slot(var, "rms")
- mom = state.get_slot(var, "momentum")
- if self._centered:
- mg = self.get_slot(var, "mg")
- return training_ops.resource_sparse_apply_centered_rms_prop(
- var.handle,
- mg.handle,
- rms.handle,
- mom.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad,
- indices,
- use_locking=self._use_locking)
- else:
- return training_ops.resource_sparse_apply_rms_prop(
- var.handle,
- rms.handle,
- mom.handle,
- state.get_hyper("learning_rate", var.dtype.base_dtype),
- state.get_hyper("decay", var.dtype.base_dtype),
- state.get_hyper("momentum", var.dtype.base_dtype),
- 0,
- grad,
- indices,
- use_locking=self._use_locking)
+ super(RMSPropOptimizer, self).__init__(
+ learning_rate=learning_rate,
+ rho=decay,
+ momentum=momentum,
+ epsilon=epsilon,
+ centered=centered,
+ name=name)
diff --git a/tensorflow/contrib/optimizer_v2/rmsprop_test.py b/tensorflow/contrib/optimizer_v2/rmsprop_test.py
index 44301ffe9e..83f5971039 100644
--- a/tensorflow/contrib/optimizer_v2/rmsprop_test.py
+++ b/tensorflow/contrib/optimizer_v2/rmsprop_test.py
@@ -157,8 +157,11 @@ class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase):
self.assertAllCloseAccordingToType(rms1_np, rms1.eval())
self.assertAllCloseAccordingToType(mom0_np, mom0.eval())
self.assertAllCloseAccordingToType(mom1_np, mom1.eval())
- self.assertAllCloseAccordingToType(var0_np, var0.eval())
- self.assertAllCloseAccordingToType(var1_np, var1.eval())
+ # TODO(b/117393988): Reduce tolerances for float16.
+ self.assertAllCloseAccordingToType(
+ var0_np, var0.eval(), half_rtol=3e-3, half_atol=3e-3)
+ self.assertAllCloseAccordingToType(
+ var1_np, var1.eval(), half_rtol=3e-3, half_atol=3e-3)
@parameterized.parameters([dtypes.float32, dtypes.float64])
def testMinimizeSparseResourceVariable(self, dtype):
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index 0a27200015..aa1d7d2b01 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -1120,6 +1120,71 @@ class RNNCellTest(test.TestCase):
r"input size \(3\) must be divisible by number_of_groups \(2\)"):
gcell(glstm_input, gcell_zero_state)
+ def testCFNCell(self):
+ with self.cached_session() as sess:
+ with variable_scope.variable_scope("root"):
+ x = array_ops.zeros([1, 2])
+ m = array_ops.zeros([1, 2])
+ cell = contrib_rnn_cell.CFNCell(
+ units=2,
+ kernel_initializer=initializers.Constant(0.5))
+ g, _ = cell(x, m)
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([g], {
+ x.name: np.array([[1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
+ # Smoke test
+ self.assertAllClose(res[0], [[0.17188203, 0.17188203]])
+ with variable_scope.variable_scope("other"):
+ # Test CFN with input_size != num_units.
+ x = array_ops.zeros([1, 3])
+ m = array_ops.zeros([1, 2])
+ cell = contrib_rnn_cell.CFNCell(
+ units=2,
+ kernel_initializer=initializers.Constant(0.5))
+ g, _ = cell(x, m)
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([g], {
+ x.name: np.array([[1., 1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
+ # Smoke test
+ self.assertAllClose(res[0], [[0.15535763, 0.15535763]])
+
+ def testCFNCellEndToEnd(self):
+ with self.cached_session() as sess:
+ input_shape = 10
+ output_shape = 5
+ timestep = 4
+ batch = 100
+ (x_train, y_train), _ = testing_utils.get_test_data(
+ train_samples=batch,
+ test_samples=0,
+ input_shape=(timestep, input_shape),
+ num_classes=output_shape)
+ y_train = utils.to_categorical(y_train)
+ cell = contrib_rnn_cell.CFNCell(output_shape)
+
+ inputs = array_ops.placeholder(
+ dtypes.float32, shape=(None, timestep, input_shape))
+ predict = array_ops.placeholder(
+ dtypes.float32, shape=(None, output_shape))
+
+ outputs, state = rnn.dynamic_rnn(
+ cell, inputs, dtype=dtypes.float32)
+ self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape])
+ self.assertEqual(state.shape.as_list(), [None, output_shape])
+ loss = losses.softmax_cross_entropy(predict, state)
+ train_op = training.GradientDescentOptimizer(0.001).minimize(loss)
+
+ sess.run([variables.global_variables_initializer()])
+ _, outputs, state = sess.run(
+ [train_op, outputs, state], {inputs: x_train, predict: y_train})
+
+ self.assertEqual(len(outputs), batch)
+ self.assertEqual(len(state), batch)
+
def testMinimalRNNCell(self):
with self.cached_session() as sess:
with variable_scope.variable_scope(
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index 59a61af7b3..78cea8feb4 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -3510,3 +3510,132 @@ class MinimalRNNCell(rnn_cell_impl.LayerRNNCell):
new_h = u * state + (1 - u) * feedforward
return new_h, new_h
+
+
+class CFNCell(rnn_cell_impl.LayerRNNCell):
+ """Chaos Free Network cell.
+
+ The implementation is based on:
+
+ https://openreview.net/pdf?id=S1dIzvclg
+
+ Thomas Laurent, James von Brecht.
+ "A recurrent neural network without chaos." ICLR, 2017.
+
+ A CFN cell first projects the input to the hidden space. The hidden state
+ goes through a contractive mapping. The new hidden state is then calcuated
+ as a linear combination of the projected input and the contracted previous
+ hidden state, using decoupled input and forget gates.
+ """
+
+ def __init__(self,
+ units,
+ activation="tanh",
+ kernel_initializer="glorot_uniform",
+ bias_initializer="ones",
+ name=None,
+ dtype=None,
+ **kwargs):
+ """Initialize the parameters for a CFN cell.
+
+ Args:
+ units: int, The number of units in the CFN cell.
+ activation: Nonlinearity to use. Default: `tanh`.
+ kernel_initializer: Initializer for the `kernel` weights
+ matrix. Default: `glorot_uniform`.
+ bias_initializer: The initializer to use for the bias in the
+ gates. Default: `ones`.
+ name: String, the name of the cell.
+ dtype: Default dtype of the cell.
+ **kwargs: Dict, keyword named properties for common cell attributes.
+ """
+ super(CFNCell, self).__init__(name=name, dtype=dtype, **kwargs)
+
+ # Inputs must be 2-dimensional.
+ self.input_spec = base_layer.InputSpec(ndim=2)
+
+ self.units = units
+ self.activation = activations.get(activation)
+ self.kernel_initializer = initializers.get(kernel_initializer)
+ self.bias_initializer = initializers.get(bias_initializer)
+
+ @property
+ def state_size(self):
+ return self.units
+
+ @property
+ def output_size(self):
+ return self.units
+
+ def build(self, inputs_shape):
+ if inputs_shape[-1] is None:
+ raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
+ % str(inputs_shape))
+
+ input_size = inputs_shape[-1]
+ # pylint: disable=protected-access
+ # `self.kernel` contains V_{\theta}, V_{\eta}, W.
+ # `self.recurrent_kernel` contains U_{\theta}, U_{\eta}.
+ # `self.bias` contains b_{\theta}, b_{\eta}.
+ self.kernel = self.add_weight(
+ shape=[input_size, 3 * self.units],
+ name=rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ initializer=self.kernel_initializer)
+ self.recurrent_kernel = self.add_weight(
+ shape=[self.units, 2 * self.units],
+ name="recurrent_%s" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ initializer=self.kernel_initializer)
+ self.bias = self.add_weight(
+ shape=[2 * self.units],
+ name=rnn_cell_impl._BIAS_VARIABLE_NAME,
+ initializer=self.bias_initializer)
+ # pylint: enable=protected-access
+
+ self.built = True
+
+ def call(self, inputs, state):
+ """Run one step of CFN.
+
+ Args:
+ inputs: input Tensor, must be 2-D, `[batch, input_size]`.
+ state: state Tensor, must be 2-D, `[batch, state_size]`.
+
+ Returns:
+ A tuple containing:
+
+ - Output: A `2-D` tensor with shape `[batch_size, state_size]`.
+ - New state: A `2-D` tensor with shape `[batch_size, state_size]`.
+
+ Raises:
+ ValueError: If input size cannot be inferred from inputs via
+ static shape inference.
+ """
+ input_size = inputs.get_shape()[-1]
+ if input_size.value is None:
+ raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
+
+ # The variable names u, v, w, b are consistent with the notations in the
+ # original paper.
+ v, w = array_ops.split(
+ value=self.kernel,
+ num_or_size_splits=[2 * self.units, self.units],
+ axis=1)
+ u = self.recurrent_kernel
+ b = self.bias
+
+ gates = math_ops.matmul(state, u) + math_ops.matmul(inputs, v)
+ gates = nn_ops.bias_add(gates, b)
+ gates = math_ops.sigmoid(gates)
+ theta, eta = array_ops.split(value=gates,
+ num_or_size_splits=2,
+ axis=1)
+
+ proj_input = math_ops.matmul(inputs, w)
+
+ # The input gate is (1 - eta), which is different from the original paper.
+ # This is for the propose of initialization. With the default
+ # bias_initializer `ones`, the input gate is initialized to a small number.
+ new_h = theta * self.activation(state) + (1 - eta) * self.activation(
+ proj_input)
+
+ return new_h, new_h
diff --git a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py
index 360e7dbe75..7743f5b4a7 100644
--- a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py
+++ b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_loss_test.py
@@ -109,6 +109,42 @@ class SparsemaxLossTest(test.TestCase):
np_loss, tf_loss_out, half_atol=1e-2, half_rtol=5e-3)
self.assertShapeEqual(np_loss, tf_loss_op)
+ def _test_sparsemax_loss_of_nan(self, dtype, random, use_gpu):
+ """check sparsemax-loss transfers nan"""
+ q = np.asarray([[0, 0, 1], [0, 0, 1], [0, 0, 1]])
+ z_nan = np.asarray([[0, np.nan, 0], [0, np.nan, np.nan],
+ [np.nan, np.nan, np.nan]]).astype(dtype)
+
+ _, tf_loss_nan = self._tf_sparsemax_loss(z_nan, q, dtype, use_gpu)
+ self.assertAllCloseAccordingToType([np.nan, np.nan, np.nan], tf_loss_nan)
+
+ def _test_sparsemax_loss_of_inf(self, dtype, random, use_gpu):
+ """check sparsemax-loss is infinity safe"""
+ q = np.asarray([[0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1]])
+ z_neg = np.asarray([
+ [0, -np.inf, 0],
+ [0, -np.inf, -np.inf],
+ [-np.inf, -np.inf, 0],
+ [-np.inf, -np.inf, -np.inf],
+ ]).astype(dtype)
+ z_pos = np.asarray([[0, np.inf, 0], [0, np.inf,
+ np.inf], [np.inf, np.inf, 0],
+ [np.inf, np.inf, np.inf]]).astype(dtype)
+ z_mix = np.asarray([[0, np.inf, 0], [0, np.inf, -np.inf],
+ [-np.inf, np.inf, 0], [-np.inf, np.inf,
+ -np.inf]]).astype(dtype)
+
+ _, tf_loss_neg = self._tf_sparsemax_loss(z_neg, q, dtype, use_gpu)
+ self.assertAllCloseAccordingToType([0.25, np.inf, 0, np.nan], tf_loss_neg)
+
+ _, tf_loss_pos = self._tf_sparsemax_loss(z_pos, q, dtype, use_gpu)
+ self.assertAllCloseAccordingToType([np.nan, np.nan, np.nan, np.nan],
+ tf_loss_pos)
+
+ _, tf_loss_mix = self._tf_sparsemax_loss(z_mix, q, dtype, use_gpu)
+ self.assertAllCloseAccordingToType([np.nan, np.nan, np.nan, np.nan],
+ tf_loss_mix)
+
def _test_constant_add(self, dtype, random, use_gpu):
"""check sparsemax-loss proposition 3"""
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
@@ -198,6 +234,10 @@ class SparsemaxLossTest(test.TestCase):
self._test_sparsemax_loss_against_numpy(dtype, random, use_gpu=False)
+ self._test_sparsemax_loss_of_nan(dtype, random, use_gpu=False)
+
+ self._test_sparsemax_loss_of_inf(dtype, random, use_gpu=False)
+
self._test_constant_add(dtype, random, use_gpu=False)
self._test_sparsemax_loss_positive(dtype, random, use_gpu=False)
diff --git a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py
index 259e62bd86..c95b9da1e4 100644
--- a/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py
+++ b/tensorflow/contrib/sparsemax/python/kernel_tests/sparsemax_test.py
@@ -87,6 +87,46 @@ class SparsemaxTest(test.TestCase):
p_sparemax, tf_sparsemax_out, half_atol=5e-3)
self.assertShapeEqual(p_sparemax, tf_sparsemax_op)
+ def _test_sparsemax_of_nan(self, dtype, random, use_gpu):
+ """check sparsemax transfers nan"""
+ z_nan = np.asarray([
+ [0, np.nan, 0],
+ [0, np.nan, np.nan],
+ [np.nan, np.nan, np.nan],
+ ]).astype(dtype)
+
+ _, tf_sparsemax_nan = self._tf_sparsemax(z_nan, dtype, use_gpu)
+ self.assertAllCloseAccordingToType(
+ [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan],
+ [np.nan, np.nan, np.nan]], tf_sparsemax_nan)
+
+ def _test_sparsemax_of_inf(self, dtype, random, use_gpu):
+ """check sparsemax is infinity safe"""
+ z_neg = np.asarray([
+ [0, -np.inf, 0],
+ [0, -np.inf, -np.inf],
+ [-np.inf, -np.inf, -np.inf],
+ ]).astype(dtype)
+ z_pos = np.asarray([[0, np.inf, 0], [0, np.inf, np.inf],
+ [np.inf, np.inf, np.inf]]).astype(dtype)
+ z_mix = np.asarray([[0, np.inf, 0], [0, np.inf, -np.inf],
+ [-np.inf, np.inf, -np.inf]]).astype(dtype)
+
+ _, tf_sparsemax_neg = self._tf_sparsemax(z_neg, dtype, use_gpu)
+ self.assertAllCloseAccordingToType(
+ [[0.5, 0, 0.5], [1, 0, 0], [np.nan, np.nan, np.nan]], tf_sparsemax_neg)
+
+ _, tf_sparsemax_pos = self._tf_sparsemax(z_pos, dtype, use_gpu)
+ self.assertAllCloseAccordingToType(
+ [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan],
+ [np.nan, np.nan, np.nan]], tf_sparsemax_pos)
+
+ _, tf_sparsemax_mix = self._tf_sparsemax(z_mix, dtype, use_gpu)
+ self.assertAllCloseAccordingToType(
+ [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan],
+ [np.nan, np.nan, np.nan]], tf_sparsemax_mix)
+
+
def _test_sparsemax_of_zero(self, dtype, random, use_gpu):
"""check sparsemax proposition 1, part 1"""
z = np.zeros((1, 10))
@@ -97,7 +137,7 @@ class SparsemaxTest(test.TestCase):
self.assertAllCloseAccordingToType(p_sparemax, tf_sparsemax_out)
self.assertShapeEqual(p_sparemax, tf_sparsemax_op)
- def _test_sparsemax_of_inf(self, dtype, random, use_gpu):
+ def _test_sparsemax_of_to_inf(self, dtype, random, use_gpu):
"""check sparsemax proposition 1, part 2"""
z = random.uniform(low=-3, high=3, size=(test_obs, 10))
@@ -210,10 +250,14 @@ class SparsemaxTest(test.TestCase):
self._test_sparsemax_against_numpy(dtype, random, use_gpu=False)
- self._test_sparsemax_of_zero(dtype, random, use_gpu=False)
+ self._test_sparsemax_of_nan(dtype, random, use_gpu=False)
self._test_sparsemax_of_inf(dtype, random, use_gpu=False)
+ self._test_sparsemax_of_zero(dtype, random, use_gpu=False)
+
+ self._test_sparsemax_of_to_inf(dtype, random, use_gpu=False)
+
self._test_constant_add(dtype, random, use_gpu=False)
self._test_permutation(dtype, random, use_gpu=False)
diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py
index e617af2ff1..f79c93f347 100644
--- a/tensorflow/contrib/sparsemax/python/ops/sparsemax.py
+++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax.py
@@ -49,7 +49,14 @@ def sparsemax(logits, name=None):
obs = array_ops.shape(logits)[0]
dims = array_ops.shape(logits)[1]
- z = logits - math_ops.reduce_mean(logits, axis=1)[:, array_ops.newaxis]
+ # In the paper, they call the logits z.
+ # The mean(logits) can be substracted from logits to make the algorithm
+ # more numerically stable. the instability in this algorithm comes mostly
+ # from the z_cumsum. Substacting the mean will cause z_cumsum to be close
+ # to zero. However, in practise the numerical instability issues are very
+ # minor and substacting the mean causes extra issues with inf and nan
+ # input.
+ z = logits
# sort z
z_sorted, _ = nn.top_k(z, k=dims)
@@ -64,10 +71,24 @@ def sparsemax(logits, name=None):
k_z = math_ops.reduce_sum(math_ops.cast(z_check, dtypes.int32), axis=1)
# calculate tau(z)
- indices = array_ops.stack([math_ops.range(0, obs), k_z - 1], axis=1)
+ # If there are inf values or all values are -inf, the k_z will be zero,
+ # this is mathematically invalid and will also cause the gather_nd to fail.
+ # Prevent this issue for now by setting k_z = 1 if k_z = 0, this is then
+ # fixed later (see p_safe) by returning p = nan. This results in the same
+ # behavior as softmax.
+ k_z_safe = math_ops.maximum(k_z, 1)
+ indices = array_ops.stack([math_ops.range(0, obs), k_z_safe - 1], axis=1)
tau_sum = array_ops.gather_nd(z_cumsum, indices)
tau_z = (tau_sum - 1) / math_ops.cast(k_z, logits.dtype)
# calculate p
- return math_ops.maximum(
+ p = math_ops.maximum(
math_ops.cast(0, logits.dtype), z - tau_z[:, array_ops.newaxis])
+ # If k_z = 0 or if z = nan, then the input is invalid
+ p_safe = array_ops.where(
+ math_ops.logical_or(
+ math_ops.equal(k_z, 0), math_ops.is_nan(z_cumsum[:, -1])),
+ array_ops.fill([obs, dims], math_ops.cast(float("nan"), logits.dtype)),
+ p)
+
+ return p_safe
diff --git a/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py
index 582d1e6136..c0438f16bc 100644
--- a/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py
+++ b/tensorflow/contrib/sparsemax/python/ops/sparsemax_loss.py
@@ -47,14 +47,30 @@ def sparsemax_loss(logits, sparsemax, labels, name=None):
sparsemax = ops.convert_to_tensor(sparsemax, name="sparsemax")
labels = ops.convert_to_tensor(labels, name="labels")
- shifted_logits = logits - \
- math_ops.reduce_mean(logits, axis=1)[:, array_ops.newaxis]
+ # In the paper, they call the logits z.
+ # A constant can be substracted from logits to make the algorithm
+ # more numerically stable in theory. However, there are really no major
+ # source numerical instability in this algorithm.
+ z = logits
# sum over support
- support = math_ops.cast(sparsemax > 0, sparsemax.dtype)
- sum_s = support * sparsemax * (shifted_logits - 0.5 * sparsemax)
+ # Use a conditional where instead of a multiplication to support z = -inf.
+ # If z = -inf, and there is no support (sparsemax = 0), a multiplication
+ # would cause 0 * -inf = nan, which is not correct in this case.
+ sum_s = array_ops.where(
+ math_ops.logical_or(sparsemax > 0, math_ops.is_nan(sparsemax)),
+ sparsemax * (z - 0.5 * sparsemax), array_ops.zeros_like(sparsemax))
# - z_k + ||q||^2
- q_part = labels * (0.5 * labels - shifted_logits)
+ q_part = labels * (0.5 * labels - z)
+ # Fix the case where labels = 0 and z = -inf, where q_part would
+ # otherwise be 0 * -inf = nan. But since the lables = 0, no cost for
+ # z = -inf should be consideredself.
+ # The code below also coveres the case where z = inf. Howeverm in this
+ # caose the sparsemax will be nan, which means the sum_s will also be nan,
+ # therefor this case doesn't need addtional special treatment.
+ q_part_safe = array_ops.where(
+ math_ops.logical_and(math_ops.equal(labels, 0), math_ops.is_inf(z)),
+ array_ops.zeros_like(z), q_part)
- return math_ops.reduce_sum(sum_s + q_part, axis=1)
+ return math_ops.reduce_sum(sum_s + q_part_safe, axis=1)
diff --git a/tensorflow/contrib/stateless/BUILD b/tensorflow/contrib/stateless/BUILD
index a217397c1a..e9ddec8889 100644
--- a/tensorflow/contrib/stateless/BUILD
+++ b/tensorflow/contrib/stateless/BUILD
@@ -11,7 +11,10 @@ load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
py_library(
name = "stateless",
- srcs = ["__init__.py"],
+ srcs = [
+ "__init__.py",
+ "python/stateless_ops.py",
+ ],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:framework_ops",
diff --git a/tensorflow/contrib/stateless/__init__.py b/tensorflow/contrib/stateless/__init__.py
index fe23fe0dd8..30d0a7ab6a 100644
--- a/tensorflow/contrib/stateless/__init__.py
+++ b/tensorflow/contrib/stateless/__init__.py
@@ -32,16 +32,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.framework import ops
-
# pylint: disable=wildcard-import
-from tensorflow.python.ops.gen_stateless_random_ops import *
+from tensorflow.contrib.stateless.python.stateless_ops import *
from tensorflow.python.util.all_util import remove_undocumented
-ops.NotDifferentiable("StatelessMultinomial")
-ops.NotDifferentiable("StatelessRandomNormal")
-ops.NotDifferentiable("StatelessRandomUniform")
-ops.NotDifferentiable("StatelessTruncatedNormal")
-
remove_undocumented(__name__)
diff --git a/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py b/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py
index d724a5c014..ec5a13b7c6 100644
--- a/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py
+++ b/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
+
import numpy as np
from tensorflow.contrib import stateless
from tensorflow.python.framework import constant_op
@@ -27,10 +29,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
-CASES = [(stateless.stateless_random_uniform, random_ops.random_uniform),
- (stateless.stateless_random_normal, random_ops.random_normal),
- (stateless.stateless_truncated_normal, random_ops.truncated_normal)]
-
def invert_philox(key, value):
"""Invert the Philox bijection."""
@@ -51,90 +49,30 @@ def invert_philox(key, value):
class StatelessOpsTest(test.TestCase):
- def testMatchStateful(self):
+ def _test_match(self, cases):
# Stateless ops should be the same as stateful ops on the first call
# after seed scrambling.
+ cases = tuple(cases)
key = 0x3ec8f720, 0x02461e29
for seed in (7, 17), (11, 5), (2, 3):
preseed = invert_philox(key, (seed[0], 0, seed[1], 0)).astype(np.uint64)
preseed = preseed[::2] | preseed[1::2] << 32
random_seed.set_random_seed(seed[0])
with self.test_session(use_gpu=True):
- for stateless_op, stateful_op in CASES:
- for shape in (), (3,), (2, 5):
- stateful = stateful_op(shape, seed=seed[1])
- pure = stateless_op(shape, seed=preseed)
- self.assertAllEqual(stateful.eval(), pure.eval())
+ for stateless_op, stateful_op in cases:
+ stateful = stateful_op(seed=seed[1])
+ pure = stateless_op(seed=preseed)
+ self.assertAllEqual(stateful.eval(), pure.eval())
- def testDeterminism(self):
+ def _test_determinism(self, cases):
# Stateless values should be equal iff the seeds are equal (roughly)
+ cases = tuple(cases)
with self.test_session(use_gpu=True):
for seed_type in [dtypes.int32, dtypes.int64]:
seed_t = array_ops.placeholder(seed_type, shape=[2])
seeds = [(x, y) for x in range(5) for y in range(5)] * 3
- for stateless_op, _ in CASES:
- for shape in (), (3,), (2, 5):
- pure = stateless_op(shape, seed=seed_t)
- values = [(seed, pure.eval(feed_dict={seed_t: seed}))
- for seed in seeds]
- for s0, v0 in values:
- for s1, v1 in values:
- self.assertEqual(s0 == s1, np.all(v0 == v1))
-
- def testShapeType(self):
- with self.test_session(use_gpu=True):
- for shape_dtype in [dtypes.int32, dtypes.int64]:
- seed_t = array_ops.placeholder(dtypes.int64, shape=[2])
- seeds = [(x, y) for x in range(5) for y in range(5)] * 3
- for stateless_op, _ in CASES:
- for shape in (), (3,), (2, 5):
- pure = stateless_op(constant_op.constant(shape, dtype=shape_dtype),
- seed=seed_t)
- values = [(seed, pure.eval(feed_dict={seed_t: seed}))
- for seed in seeds]
- for s0, v0 in values:
- for s1, v1 in values:
- self.assertEqual(s0 == s1, np.all(v0 == v1))
-
- def testMatchStatefulMultinomial(self):
- # Stateless ops should be the same as stateful ops on the first call
- # after seed scrambling.
- key = 0x3ec8f720, 0x02461e29
- num_samples = 4
- for logits_dtype in np.float16, np.float32, np.float64:
- for output_dtype in dtypes.int32, dtypes.int64:
- for seed in (7, 17), (11, 5), (2, 3):
- preseed = invert_philox(key,
- (seed[0], 0, seed[1], 0)).astype(np.uint64)
- preseed = preseed[::2] | preseed[1::2] << 32
- random_seed.set_random_seed(seed[0])
- with self.test_session(use_gpu=True):
- for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
- [0.25, 0.75]]):
- logits_t = constant_op.constant(logits, dtype=logits_dtype)
- stateful = random_ops.multinomial(
- logits_t,
- num_samples,
- seed=seed[1],
- output_dtype=output_dtype)
- pure = stateless.stateless_multinomial(
- logits_t,
- num_samples,
- seed=preseed,
- output_dtype=output_dtype)
- self.assertAllEqual(stateful.eval(), pure.eval())
-
- def testDeterminismMultinomial(self):
- # Stateless values should be equal iff the seeds are equal (roughly)
- num_samples = 10
- with self.test_session(use_gpu=True):
- for seed_type in [dtypes.int32, dtypes.int64]:
- seed_t = array_ops.placeholder(seed_type, shape=[2])
- seeds = [(x, y) for x in range(5) for y in range(5)] * 3
- for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
- [0.25, 0.75]]):
- pure = stateless.stateless_multinomial(
- logits, num_samples, seed=seed_t)
+ for stateless_op, _ in cases:
+ pure = stateless_op(seed=seed_t)
values = [
(seed, pure.eval(feed_dict={seed_t: seed})) for seed in seeds
]
@@ -142,6 +80,74 @@ class StatelessOpsTest(test.TestCase):
for s1, v1 in values:
self.assertEqual(s0 == s1, np.all(v0 == v1))
+ def _float_cases(self, shape_dtypes=(None,)):
+ float_cases = (
+ # Uniform distribution, with and without range
+ (stateless.stateless_random_uniform, random_ops.random_uniform, {}),
+ (stateless.stateless_random_uniform, random_ops.random_uniform,
+ dict(minval=2.2, maxval=7.1)),
+ # Normal distribution, with and without mean+stddev
+ (stateless.stateless_random_normal, random_ops.random_normal, {}),
+ (stateless.stateless_random_normal, random_ops.random_normal,
+ dict(mean=2, stddev=3)),
+ # Truncated normal distribution, with and without mean+stddev
+ (stateless.stateless_truncated_normal, random_ops.truncated_normal, {}),
+ (stateless.stateless_truncated_normal, random_ops.truncated_normal,
+ dict(mean=3, stddev=4)),
+ )
+ for dtype in dtypes.float16, dtypes.float32, dtypes.float64:
+ for shape_dtype in shape_dtypes:
+ for shape in (), (3,), (2, 5):
+ if shape_dtype is not None:
+ shape = constant_op.constant(shape, dtype=shape_dtype)
+ for stateless_op, stateful_op, kwds in float_cases:
+ kwds = dict(shape=shape, dtype=dtype, **kwds)
+ yield (functools.partial(stateless_op, **kwds),
+ functools.partial(stateful_op, **kwds))
+
+ def _int_cases(self, shape_dtypes=(None,)):
+ for shape_dtype in shape_dtypes:
+ for shape in (), (3,), (2, 5):
+ if shape_dtype is not None:
+ shape = constant_op.constant(shape, dtype=shape_dtype)
+ for dtype in dtypes.int32, dtypes.int64:
+ kwds = dict(minval=2, maxval=11111, dtype=dtype, shape=shape)
+ yield (functools.partial(stateless.stateless_random_uniform, **kwds),
+ functools.partial(random_ops.random_uniform, **kwds))
+
+ def _multinomial_cases(self):
+ num_samples = 10
+ for logits_dtype in np.float16, np.float32, np.float64:
+ for output_dtype in dtypes.int32, dtypes.int64:
+ for logits in ([[0.1, 0.25, 0.5, 0.15]], [[0.5, 0.5], [0.8, 0.2],
+ [0.25, 0.75]]):
+ kwds = dict(
+ logits=constant_op.constant(logits, dtype=logits_dtype),
+ num_samples=num_samples,
+ output_dtype=output_dtype)
+ yield (functools.partial(stateless.stateless_multinomial, **kwds),
+ functools.partial(random_ops.multinomial, **kwds))
+
+ def testMatchFloat(self):
+ self._test_match(self._float_cases())
+
+ def testMatchInt(self):
+ self._test_match(self._int_cases())
+
+ def testMatchMultinomial(self):
+ self._test_match(self._multinomial_cases())
+
+ def testDeterminismFloat(self):
+ self._test_determinism(
+ self._float_cases(shape_dtypes=(dtypes.int32, dtypes.int64)))
+
+ def testDeterminismInt(self):
+ self._test_determinism(
+ self._int_cases(shape_dtypes=(dtypes.int32, dtypes.int64)))
+
+ def testDeterminismMultinomial(self):
+ self._test_determinism(self._multinomial_cases())
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/stateless/python/stateless_ops.py b/tensorflow/contrib/stateless/python/stateless_ops.py
new file mode 100644
index 0000000000..1449825c83
--- /dev/null
+++ b/tensorflow/contrib/stateless/python/stateless_ops.py
@@ -0,0 +1,214 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Stateless random ops which take seed as a tensor input."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import gen_stateless_random_ops
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import math_ops
+
+ops.NotDifferentiable("StatelessMultinomial")
+ops.NotDifferentiable("StatelessRandomNormal")
+ops.NotDifferentiable("StatelessRandomUniform")
+ops.NotDifferentiable("StatelessRandomUniformInt")
+ops.NotDifferentiable("StatelessTruncatedNormal")
+
+
+def stateless_random_uniform(shape,
+ seed,
+ minval=0,
+ maxval=None,
+ dtype=dtypes.float32,
+ name=None):
+ """Outputs deterministic pseudorandom values from a uniform distribution.
+
+ This is a stateless version of `tf.random_uniform`: if run twice with the
+ same seeds, it will produce the same pseudorandom numbers. The output is
+ consistent across multiple runs on the same hardware (and between CPU
+ and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
+ hardware.
+
+ The generated values follow a uniform distribution in the range
+ `[minval, maxval)`. The lower bound `minval` is included in the range, while
+ the upper bound `maxval` is excluded.
+
+ For floats, the default range is `[0, 1)`. For ints, at least `maxval` must
+ be specified explicitly.
+
+ In the integer case, the random integers are slightly biased unless
+ `maxval - minval` is an exact power of two. The bias is small for values of
+ `maxval - minval` significantly smaller than the range of the output (either
+ `2**32` or `2**64`).
+
+ Args:
+ shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
+ seed: A shape [2] integer Tensor of seeds to the random number generator.
+ minval: A 0-D Tensor or Python value of type `dtype`. The lower bound on the
+ range of random values to generate. Defaults to 0.
+ maxval: A 0-D Tensor or Python value of type `dtype`. The upper bound on the
+ range of random values to generate. Defaults to 1 if `dtype` is floating
+ point.
+ dtype: The type of the output: `float16`, `float32`, `float64`, `int32`, or
+ `int64`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tensor of the specified shape filled with random uniform values.
+
+ Raises:
+ ValueError: If `dtype` is integral and `maxval` is not specified.
+ """
+ dtype = dtypes.as_dtype(dtype)
+ if dtype not in (dtypes.float16, dtypes.bfloat16, dtypes.float32,
+ dtypes.float64, dtypes.int32, dtypes.int64):
+ raise ValueError("Invalid dtype %r" % dtype)
+ if maxval is None:
+ if dtype.is_integer:
+ raise ValueError("Must specify maxval for integer dtype %r" % dtype)
+ maxval = 1
+ with ops.name_scope(name, "stateless_random_uniform",
+ [shape, seed, minval, maxval]) as name:
+ shape = random_ops._ShapeTensor(shape) # pylint: disable=protected-access
+ minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
+ maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
+ if dtype.is_integer:
+ return gen_stateless_random_ops.stateless_random_uniform_int(
+ shape, seed=seed, minval=minval, maxval=maxval, name=name)
+ else:
+ rnd = gen_stateless_random_ops.stateless_random_uniform(
+ shape, seed=seed, dtype=dtype)
+ return math_ops.add(rnd * (maxval - minval), minval, name=name)
+
+
+def stateless_random_normal(shape,
+ seed,
+ mean=0.0,
+ stddev=1.0,
+ dtype=dtypes.float32,
+ name=None):
+ """Outputs deterministic pseudorandom values from a normal distribution.
+
+ This is a stateless version of `tf.random_normal`: if run twice with the
+ same seeds, it will produce the same pseudorandom numbers. The output is
+ consistent across multiple runs on the same hardware (and between CPU
+ and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
+ hardware.
+
+ Args:
+ shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
+ seed: A shape [2] integer Tensor of seeds to the random number generator.
+ mean: A 0-D Tensor or Python value of type `dtype`. The mean of the normal
+ distribution.
+ stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
+ of the normal distribution.
+ dtype: The type of the output.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tensor of the specified shape filled with random normal values.
+ """
+ with ops.name_scope(name, "stateless_random_normal",
+ [shape, seed, mean, stddev]) as name:
+ shape = random_ops._ShapeTensor(shape) # pylint: disable=protected-access
+ mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
+ stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
+ rnd = gen_stateless_random_ops.stateless_random_normal(shape, seed, dtype)
+ return math_ops.add(rnd * stddev, mean, name=name)
+
+
+def stateless_truncated_normal(shape,
+ seed,
+ mean=0.0,
+ stddev=1.0,
+ dtype=dtypes.float32,
+ name=None):
+ """Outputs deterministic pseudorandom values, truncated normally distributed.
+
+ This is a stateless version of `tf.truncated_normal`: if run twice with the
+ same seeds, it will produce the same pseudorandom numbers. The output is
+ consistent across multiple runs on the same hardware (and between CPU
+ and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
+ hardware.
+
+ The generated values follow a normal distribution with specified mean and
+ standard deviation, except that values whose magnitude is more than 2 standard
+ deviations from the mean are dropped and re-picked.
+
+ Args:
+ shape: A 1-D integer Tensor or Python array. The shape of the output tensor.
+ seed: A shape [2] integer Tensor of seeds to the random number generator.
+ mean: A 0-D Tensor or Python value of type `dtype`. The mean of the
+ truncated normal distribution.
+ stddev: A 0-D Tensor or Python value of type `dtype`. The standard deviation
+ of the normal distribution, before truncation.
+ dtype: The type of the output.
+ name: A name for the operation (optional).
+
+ Returns:
+ A tensor of the specified shape filled with random truncated normal values.
+ """
+ with ops.name_scope(name, "stateless_truncated_normal",
+ [shape, seed, mean, stddev]) as name:
+ shape = random_ops._ShapeTensor(shape) # pylint: disable=protected-access
+ mean = ops.convert_to_tensor(mean, dtype=dtype, name="mean")
+ stddev = ops.convert_to_tensor(stddev, dtype=dtype, name="stddev")
+ rnd = gen_stateless_random_ops.stateless_truncated_normal(
+ shape, seed, dtype)
+ return math_ops.add(rnd * stddev, mean, name=name)
+
+
+def stateless_multinomial(logits,
+ num_samples,
+ seed,
+ output_dtype=dtypes.int64,
+ name=None):
+ """Draws deterministic pseudorandom samples from a multinomial distribution.
+
+ This is a stateless version of `tf.multinomial`: if run twice with the
+ same seeds, it will produce the same pseudorandom numbers. The output is
+ consistent across multiple runs on the same hardware (and between CPU
+ and GPU), but may change between versions of TensorFlow or on non-CPU/GPU
+ hardware.
+
+ Example:
+
+ ```python
+ # samples has shape [1, 5], where each value is either 0 or 1 with equal
+ # probability.
+ samples = tf.contrib.stateless.stateless_multinomial(
+ tf.log([[10., 10.]]), 5, seed=[7, 17])
+ ```
+
+ Args:
+ logits: 2-D Tensor with shape `[batch_size, num_classes]`. Each slice
+ `[i, :]` represents the unnormalized log-probabilities for all classes.
+ num_samples: 0-D. Number of independent samples to draw for each row slice.
+ seed: A shape [2] integer Tensor of seeds to the random number generator.
+ name: Optional name for the operation.
+ output_dtype: integer type to use for the output. Defaults to int64.
+
+ Returns:
+ The drawn samples of shape `[batch_size, num_samples]`.
+ """
+ with ops.name_scope(name, "stateless_multinomial", [logits, seed]):
+ logits = ops.convert_to_tensor(logits, name="logits")
+ return gen_stateless_random_ops.stateless_multinomial(
+ logits, num_samples, seed, output_dtype=output_dtype)
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index 9e8979bce4..5c16fcb760 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -455,7 +455,6 @@ cuda_py_tests(
"test/multi_connection_neighbor_engine_test.py",
"test/neighboring_engine_test.py",
"test/rank_two_test.py",
- "test/unary_test.py",
"test/vgg_block_nchw_test.py",
"test/vgg_block_test.py",
],
@@ -471,6 +470,25 @@ cuda_py_tests(
],
)
+cuda_py_tests(
+ name = "tf_trt_integration_test_no_oss",
+ srcs = [
+ "test/unary_test.py",
+ ],
+ additional_deps = [
+ ":tf_trt_integration_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ ],
+ tags = [
+ "no_cuda_on_cpu_tap",
+ "no_oss", # TODO(b/117274186): re-enable in OSS after crash fixed
+ "no_pip", # TODO(b/117274186): re-enable in OSS after crash fixed
+ "no_windows",
+ "nomac",
+ ],
+)
+
cc_library(
name = "utils",
srcs = ["convert/utils.cc"],
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index 10ed1c2891..8c36d5a297 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -302,6 +302,7 @@ tf_py_test(
"//tensorflow/python:client_testlib",
":datasets",
],
+ flaky = 1, # TODO(b/117363808): fails 1/1000 OSS runs
grpc_enabled = True,
)
diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
index f88dc51636..1e66801efd 100644
--- a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
+++ b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto
@@ -168,6 +168,12 @@ message RunEnvironmentResult {
optional HostIndependentJobInfoResult host_independent_job_info = 5;
// Host-dependent job information.
repeated HostDependentJobInfoResult host_dependent_job_info = 6;
+ // The number of replicas, corresponds to input parallelism.
+ // If there is no model parallelism, replica_count = tpu_core_count
+ optional int32 replica_count = 7;
+ // The number of cores used for a single replica, e.g. model parallelism.
+ // If there is no model parallelism, then num_cores_per_replica = 1
+ optional int32 num_cores_per_replica = 8;
}
// The types of host operations that are tracked.
diff --git a/tensorflow/contrib/tpu/proto/optimization_parameters.proto b/tensorflow/contrib/tpu/proto/optimization_parameters.proto
index 8529b48c15..c2e3be03db 100644
--- a/tensorflow/contrib/tpu/proto/optimization_parameters.proto
+++ b/tensorflow/contrib/tpu/proto/optimization_parameters.proto
@@ -62,9 +62,9 @@ message FtrlParameters {
// (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/AdamOptimizer). If
// use_non_lazy_adam is enabled, use_gradient_accumulation is also required in
// order to get correct results; a warning will be printed otherwise (which may
-// change to an error in the future). If use_max_with_epsilon is set, the Adam
+// change to an error in the future). If use_sum_inside_sqrt is set, the Adam
// variable update formula will be changed from m / (sqrt(v) + epsilon) to
-// m / max(sqrt(v), abs(epsilon)); this option improves the performance of TPU
+// m / sqrt(v + epsilon**2); this option improves the performance of TPU
// training and is not expected to harm model quality.
message AdamParameters {
float beta1 = 3;
@@ -73,7 +73,7 @@ message AdamParameters {
float initial_m = 6;
float initial_v = 7;
bool use_non_lazy_adam = 8;
- bool use_max_with_epsilon = 9;
+ bool use_sum_inside_sqrt = 10;
}
// https://www.tensorflow.org/api_docs/python/tf/train/MomentumOptimizer
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index a3a7fd8bb0..af183b3232 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -1998,6 +1998,9 @@ class KerasTPUModel(models.Model):
logging.info('Setting weights on TPU model.')
cloned_model.set_weights(weights)
+ if self._tpu_model.optimizer is None:
+ # tpu_model may not be compiled, e.g., loading weights and then predict.
+ return
for k, v in six.iteritems(cpu_optimizer_config):
opt_var = getattr(self._tpu_model.optimizer, k)
if isinstance(opt_var, variables.Variable):
@@ -2052,6 +2055,10 @@ class KerasTPUModel(models.Model):
self._cpu_model.set_weights(weights)
self._tpu_weights_initialized = False
+ def load_weights(self, filepath, by_name=False):
+ self._cpu_model.load_weights(filepath, by_name)
+ self._tpu_weights_initialized = False
+
# pylint: disable=bad-continuation
def _validate_shapes(model):
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 3aa5b6efa1..8d15c857f8 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -177,14 +177,29 @@ def _create_or_get_iterations_per_loop():
use_resource=True)
-def _sync_variables_ops():
- # Gets the variables back from TPU nodes. This means the variables updated
- # by TPU will now be *synced* to host memory.
- return [
- array_ops.check_numerics(v.read_value(),
- 'Gradient for %s is NaN' % v.name).op
- for v in variables.trainable_variables()
- ]
+def _sync_variables_ops(ctx):
+ """Create varriables synchronization ops.
+
+ Gets the variables back from TPU nodes. This means the variables updated
+ by TPU will now be *synced* to host memory.
+ In BROADCAST mode, we skip this sync since the variables are ususally too
+ big to transmit via RPC.
+
+ Args:
+ ctx: A `_InternalTPUContext` instance with mode.
+
+ Returns:
+ A list of sync ops.
+ """
+
+ if not ctx.is_input_broadcast_with_iterators():
+ return [
+ array_ops.check_numerics(v.read_value(),
+ 'Gradient for %s is NaN' % v.name).op
+ for v in variables.trainable_variables()
+ ]
+ else:
+ return [control_flow_ops.no_op()]
def _increase_eval_step_op(iterations_per_loop):
@@ -2567,7 +2582,7 @@ class TPUEstimator(estimator_lib.Estimator):
summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss)
with ops.control_dependencies([loss]):
- update_ops = _sync_variables_ops()
+ update_ops = _sync_variables_ops(ctx)
# Validate the TPU training graph to catch basic errors
_validate_tpu_training_graph()
@@ -2600,7 +2615,7 @@ class TPUEstimator(estimator_lib.Estimator):
# After TPU evaluation computation is done (the mean_loss tensor),
# reads all variables back from TPU and updates the eval step
# counter properly
- internal_ops_to_run = _sync_variables_ops()
+ internal_ops_to_run = _sync_variables_ops(ctx)
internal_ops_to_run.append(
_increase_eval_step_op(iterations_per_loop_var))
with ops.control_dependencies(internal_ops_to_run):
@@ -2645,7 +2660,7 @@ class TPUEstimator(estimator_lib.Estimator):
scaffold, prediction_hooks) = _predict_on_tpu_system(
ctx, model_fn_wrapper, dequeue_fn)
with ops.control_dependencies([dummy_predict_op]):
- internal_ops_to_run = _sync_variables_ops()
+ internal_ops_to_run = _sync_variables_ops(ctx)
with ops.control_dependencies(internal_ops_to_run):
dummy_predict_op = control_flow_ops.no_op()
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 6a3ee3c1cb..900a0e11c4 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1242,6 +1242,7 @@ cc_library(
srcs = [
"ops/math_grad.cc",
"ops/random_grad.cc",
+ "ops/stateless_random_grad.cc",
],
linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669
visibility = ["//visibility:public"],
diff --git a/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt b/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt
new file mode 100644
index 0000000000..280148e032
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_LeakyRelu.pbtxt
@@ -0,0 +1,5 @@
+op {
+ graph_op_name: "LeakyRelu"
+ visibility: HIDDEN
+ summary: "Computes rectified linear: `max(features, features * alpha)`."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt b/tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt
new file mode 100644
index 0000000000..e427526602
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_LeakyReluGrad.pbtxt
@@ -0,0 +1,24 @@
+op {
+ graph_op_name: "LeakyReluGrad"
+ visibility: HIDDEN
+ in_arg {
+ name: "gradients"
+ description: <<END
+The backpropagated gradients to the corresponding LeakyRelu operation.
+END
+ }
+ in_arg {
+ name: "features"
+ description: <<END
+The features passed as input to the corresponding LeakyRelu operation,
+OR the outputs of that operation (both work equivalently).
+END
+ }
+ out_arg {
+ name: "backprops"
+ description: <<END
+`gradients * (features > 0) + alpha * gradients * (featurs <= 0)`.
+END
+ }
+ summary: "Computes rectified linear gradients for a LeakyRelu operation."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt b/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt
index 4433693759..d158f4b502 100644
--- a/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt
@@ -4,16 +4,23 @@ op {
in_arg {
name: "arguments"
description: <<END
- A list of tensors whose types are Targuments, corresponding to the inputs the
- function should be mapped over.
+ A list of tensors whose types are `Targuments`, corresponding to the inputs
+ the function should be mapped over.
+END
+ }
+ in_arg {
+ name: "captured_inputs"
+ description: <<END
+ A list of tensors whose types are `Tcaptured`, corresponding to the captured
+ inputs of the defun.
END
}
out_arg {
name: "output"
description: <<END
- A list of output tensors whose types are output_types and whose dimensions 0
- are the same as the dimensions 0 of the tensors in arguments, and whose
- remaining dimensions correspond to those in output_shapes.
+ A list of output tensors whose types are `output_types` and whose dimensions
+ 0 are the same as the dimensions 0 of the tensors in `arguments`, and whose
+ remaining dimensions correspond to those in `output_shapes`.
END
}
attr {
@@ -21,6 +28,10 @@ END
description: "A list of types."
}
attr {
+ name: "Tcaptured"
+ description: "A list of types."
+ }
+ attr {
name: "output_types"
description: "A list of types."
}
@@ -29,6 +40,6 @@ END
description: "A list of shapes."
}
summary: <<END
- Maps a function on the list of tensors unpacked from inputs on dimension 0.
+ Maps a function on the list of tensors unpacked from arguments on dimension 0.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt b/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt
new file mode 100644
index 0000000000..b6a6dbdf54
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt
@@ -0,0 +1,46 @@
+op {
+ graph_op_name: "StatelessRandomUniformInt"
+ visibility: HIDDEN
+ in_arg {
+ name: "shape"
+ description: <<END
+The shape of the output tensor.
+END
+ }
+ in_arg {
+ name: "seed"
+ description: <<END
+2 seeds (shape [2]).
+END
+ }
+ in_arg {
+ name: "minval"
+ description: <<END
+Minimum value (inclusive, scalar).
+END
+ }
+ in_arg {
+ name: "maxval"
+ description: <<END
+Maximum value (exclusive, scalar).
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+Random values with specified shape.
+END
+ }
+ attr {
+ name: "dtype"
+ description: <<END
+The type of the output.
+END
+ }
+ summary: "Outputs deterministic pseudorandom random integers from a uniform distribution."
+ description: <<END
+The generated values follow a uniform distribution in the range `[minval, maxval)`.
+
+The outputs are a deterministic function of `shape`, `seed`, `minval`, and `maxval`.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt
index 5246090ab3..fe0fcc9508 100644
--- a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt
@@ -18,6 +18,16 @@ END
Scalar defining the number of characters to include in each substring
END
}
+ attr {
+ name: "unit"
+ description: <<END
+The unit that is used to create the substring. One of: `"BYTE"` (for
+defining position and length by bytes) or `"UTF8_CHAR"` (for the UTF-8
+encoded Unicode code points). The default is `"BYTE"`. Results are undefined if
+`unit=UTF8_CHAR` and the `input` strings do not contain structurally valid
+UTF-8.
+END
+ }
out_arg {
name: "output"
description: <<END
diff --git a/tensorflow/core/api_def/python_api/api_defTensorListPushBackBatch.pbtxt b/tensorflow/core/api_def/python_api/api_defTensorListPushBackBatch.pbtxt
new file mode 100644
index 0000000000..3d937c745c
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_defTensorListPushBackBatch.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListPushBackBatch"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_EmptyTensorList.pbtxt b/tensorflow/core/api_def/python_api/api_def_EmptyTensorList.pbtxt
new file mode 100644
index 0000000000..44f25b5d93
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_EmptyTensorList.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "EmptyTensorList"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt b/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt
index 4778d7927c..4fb9ee56e9 100644
--- a/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt
@@ -1,10 +1,4 @@
op {
graph_op_name: "Substr"
- endpoint {
- name: "strings.substr"
- }
- endpoint {
- name: "substr"
- deprecated: true
- }
+ visibility: HIDDEN
}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListConcatLists.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListConcatLists.pbtxt
new file mode 100644
index 0000000000..45fc55e71e
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListConcatLists.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListConcatLists"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListElementShape.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListElementShape.pbtxt
new file mode 100644
index 0000000000..e1ad713e7f
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListElementShape.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListElementShape"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListFromTensor.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListFromTensor.pbtxt
new file mode 100644
index 0000000000..4aaefba3c5
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListFromTensor.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListFromTensor"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListGather.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListGather.pbtxt
new file mode 100644
index 0000000000..aaf607d70e
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListGather.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListGather"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListGetItem.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListGetItem.pbtxt
new file mode 100644
index 0000000000..3bb5f39cbc
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListGetItem.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListGetItem"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListLength.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListLength.pbtxt
new file mode 100644
index 0000000000..a04c20bb8a
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListLength.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListLength"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListPopBack.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListPopBack.pbtxt
new file mode 100644
index 0000000000..9287162f22
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListPopBack.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListPopBack"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListPushBack.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListPushBack.pbtxt
new file mode 100644
index 0000000000..da2bc11721
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListPushBack.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListPushBack"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListReserve.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListReserve.pbtxt
new file mode 100644
index 0000000000..77e63747d5
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListReserve.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListReserve"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListScatter.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListScatter.pbtxt
new file mode 100644
index 0000000000..0015189d7f
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListScatter.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListScatter"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListSetItem.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListSetItem.pbtxt
new file mode 100644
index 0000000000..4999ee7ad9
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListSetItem.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListSetItem"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListStack.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListStack.pbtxt
new file mode 100644
index 0000000000..2dc7b2784b
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListStack.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListStack"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
index 3b2dc6a050..7cb90de3c7 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
@@ -522,7 +522,6 @@ void CollectiveParamResolverLocal::CallInitInstanceSharedParams(
InitInstanceSharedParams(
gr, cp, ir,
[this, ir, done](const Status& s) UNLOCK_FUNCTION(ir->out_mu) {
- DCHECK(!ir->out_mu.try_lock());
DCHECK(ir->out_mu_available);
ir->status.Update(s);
ir->out_mu.unlock();
diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc
index 419867ff58..e81e61b633 100644
--- a/tensorflow/core/common_runtime/constant_folding.cc
+++ b/tensorflow/core/common_runtime/constant_folding.cc
@@ -473,16 +473,16 @@ bool ReplaceTensorWithConstant(
// 1) Do not replace another constant.
// 2) If the destination tensor is not an int32 tensor, and has HOST_MEMORY
// constraint, do not replace it.
- // 3) If the size of the constant in bytes is too large (>
+ // 3) If the destination tensor is an int32 tensor, and has DEVICE_MEMORY
+ // constraint, do not replace it.
+ // 4) If the size of the constant in bytes is too large (>
// max_constant_in_bytes), do not replace it. This prevents the size of the
// Graph from growing too large.
- // 4) If the constant op created does not have a kernel implementation
+ // 5) If the constant op created does not have a kernel implementation
// for the device, do not use it.
// TODO(keveman): Consider adding a new constant op that has a kernel
// implementation for all types, but with HostMemory constraint on it's
// output.
- // 5) If the constant op for the device has different output memory type
- // from the original op output memory type, do not replace it.
if (tensor.first->IsConstant()) {
return false;
}
@@ -497,7 +497,8 @@ bool ReplaceTensorWithConstant(
return false;
}
bool is_int32 = tensor.first->output_type(tensor.second) == DT_INT32;
- if (memory_type == HOST_MEMORY && !is_int32) {
+ if ((memory_type == HOST_MEMORY && !is_int32) ||
+ (memory_type == DEVICE_MEMORY && is_int32)) {
return false;
}
}
@@ -535,23 +536,6 @@ bool ReplaceTensorWithConstant(
if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) {
return false;
}
- if (partition_device && device_type != DEVICE_CPU) {
- MemoryType original_output_memory_type;
- if (!MemoryTypeForOutput(device_type, graph, tensor.first, tensor.second,
- &original_output_memory_type)
- .ok()) {
- return false;
- }
- MemoryType const_output_memory_type;
- if (!MemoryTypeForOutput(device_type, graph, constant_node, 0,
- &const_output_memory_type)
- .ok()) {
- return false;
- }
- if (original_output_memory_type != const_output_memory_type) {
- return false;
- }
- }
for (auto edge : edges_to_remove) {
graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input());
graph->RemoveEdge(edge);
diff --git a/tensorflow/core/common_runtime/eager/attr_builder.cc b/tensorflow/core/common_runtime/eager/attr_builder.cc
index cf1cd4134e..5c8369de87 100644
--- a/tensorflow/core/common_runtime/eager/attr_builder.cc
+++ b/tensorflow/core/common_runtime/eager/attr_builder.cc
@@ -136,6 +136,22 @@ void AttrBuilder::FillAttrValueMap(AttrValueMap* m,
m->insert(*it);
}
}
+ // For any attr-value pairs that exist in the op def (from op registry) but
+ // not `m`, fill them into `m`, so that we can run a TFE_Op without having to
+ // specify all the default attr values (e.g. for matmul, the `transpose_a`
+ // attr defaults to false).
+ const OpDef* op_def = nullptr;
+ Status s = OpDefForOp(op_name_.c_str(), &op_def);
+ // This is expected, if this op is a custom function, and is therefore not
+ // present in the op registry.
+ if (!s.ok()) return;
+
+ DCHECK(op_def);
+ for (const auto& attr_def : op_def->attr()) {
+ if (attr_def.has_default_value() && !m->count(attr_def.name())) {
+ SetInAttrValueMap(m, attr_def.name(), attr_def.default_value());
+ }
+ }
}
const NodeDef& AttrBuilder::BuildNodeDef() {
diff --git a/tensorflow/core/common_runtime/eager/attr_builder.h b/tensorflow/core/common_runtime/eager/attr_builder.h
index cbe6a1cb50..c114ea4ba0 100644
--- a/tensorflow/core/common_runtime/eager/attr_builder.h
+++ b/tensorflow/core/common_runtime/eager/attr_builder.h
@@ -110,6 +110,12 @@ class AttrBuilder {
using AttrVec = tensorflow::gtl::InlinedVector<std::pair<StringPiece, T>, 2>;
void MayBeInitializeNodeDef();
+ // Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far, as
+ // well as any default attr-value pairs from the associated op_def, if there
+ // is one.
+ //
+ // If `include_those_in_node_def` is true, also include any attr-value pairs
+ // from `node_def_`.
void FillAttrValueMap(AttrValueMap* m, bool include_those_in_node_def) const;
template <class T>
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 18420b60fd..f23cefb33d 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -70,7 +70,9 @@ EagerContext::EagerContext(const SessionOptions& opts,
async_default_(async),
log_memory_(LogMemory::IsEnabled()),
env_(opts.env),
- use_send_tensor_rpc_(false) {
+ use_send_tensor_rpc_(false),
+ pin_small_ops_to_cpu_(ReadBoolFromEnvVar(
+ "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING", true)) {
if (device_mgr_owned) {
local_device_manager_.reset(device_mgr);
local_unowned_device_manager_ = nullptr;
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 5ed6057ec6..15eeaa8066 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -202,6 +202,7 @@ class EagerContext {
// EagerService.SendTensor RPC. If false, _Send/_Recv ops should be used
// instead (which in-turn use WorkerService.RecvTensor RPCs).
bool UseSendTensorRPC() { return use_send_tensor_rpc_; }
+ bool PinSmallOpsToCPU() { return pin_small_ops_to_cpu_; }
private:
void InitDeviceMapAndAsync();
@@ -293,6 +294,7 @@ class EagerContext {
#endif
bool use_send_tensor_rpc_;
+ const bool pin_small_ops_to_cpu_;
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 1bc63616d0..a52f933d75 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -579,19 +579,23 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
return Status::OK();
#endif
}
-} // namespace
-Status EagerExecute(EagerOperation* op,
- gtl::InlinedVector<TensorHandle*, 2>* retvals,
- int* num_retvals) {
- // Ensure all resource-touching ops run in the device the resource is,
- // regardless of anything else that has been specified. This is identical to
- // the graph mode behavior.
+// The Op device may be updated if:
+// - A resource touching input is specified: all resource-touching ops run in
+// the device the resource is, regardless of anything else that has been
+// specified. This is identical to the graph mode behavior.
+//
+// - All op inputs are on the CPU, small (<64 elements) and integers
+// (int32/int64). This can be disabled by setting the environment variable
+// "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING" to "0" or "false".
+Status MaybeUpdateOpDevice(EagerOperation* op) {
EagerContext* ctx = op->EagerContext();
+ bool device_set_for_resource_variable = false;
+ bool all_inputs_eligible_for_cpu_pinning = ctx->PinSmallOpsToCPU();
+
for (int i = 0; i < op->Inputs().size(); ++i) {
Device* input_op_device = nullptr;
- auto status = op->Inputs()[i]->OpDevice(&input_op_device);
- if (!status.ok()) return status;
+ TF_RETURN_IF_ERROR(op->Inputs()[i]->OpDevice(&input_op_device));
VLOG(2) << "for op " << op->Name() << " input " << i << " "
<< DataTypeString(op->Inputs()[i]->dtype) << " "
<< (input_op_device == nullptr ? "cpu" : input_op_device->name())
@@ -603,8 +607,53 @@ Status EagerExecute(EagerOperation* op,
<< d->name() << " because input #" << i
<< " is a resource in this device.";
op->SetDevice(d);
+
+ device_set_for_resource_variable = true;
+ all_inputs_eligible_for_cpu_pinning = false;
+ } else if (all_inputs_eligible_for_cpu_pinning) {
+ TensorHandle* handle = op->Inputs()[i];
+
+ // Input is on CPU.
+ if (input_op_device != nullptr && input_op_device != ctx->HostCPU()) {
+ all_inputs_eligible_for_cpu_pinning = false;
+ continue;
+ }
+
+ if (handle->dtype != DataType::DT_INT32 &&
+ handle->dtype != DataType::DT_INT64) {
+ all_inputs_eligible_for_cpu_pinning = false;
+ continue;
+ }
+
+ int64 num_elements;
+ TF_RETURN_IF_ERROR(handle->NumElements(&num_elements));
+ if (num_elements > 64) {
+ all_inputs_eligible_for_cpu_pinning = false;
+ }
}
}
+
+ // Ops without inputs are usually ops that generate a tensor in some way and
+ // usually require being present on whatever device they are scheduled on
+ // - for e.g. VarHandleOp or _Recv).
+ // TODO(nareshmodi): Is it possible there is no int32/int64 CPU kernel for
+ // an op, but there is a GPU kernel?
+ if (!op->Inputs().empty() && all_inputs_eligible_for_cpu_pinning) {
+ VLOG(1) << "Forcing op " << op->Name()
+ << " to be on the CPU since all input tensors have an "
+ "int32/int64 dtype, and are small (less than 64 elements).";
+ op->SetDevice(ctx->HostCPU());
+ }
+
+ return Status::OK();
+}
+} // namespace
+
+Status EagerExecute(EagerOperation* op,
+ gtl::InlinedVector<TensorHandle*, 2>* retvals,
+ int* num_retvals) {
+ TF_RETURN_IF_ERROR(MaybeUpdateOpDevice(op));
+
bool op_is_local = IsLocal(op->EagerContext(), op->Device());
if (op_is_local) {
diff --git a/tensorflow/core/common_runtime/eval_const_tensor.cc b/tensorflow/core/common_runtime/eval_const_tensor.cc
index c1542f1f57..87749da7af 100644
--- a/tensorflow/core/common_runtime/eval_const_tensor.cc
+++ b/tensorflow/core/common_runtime/eval_const_tensor.cc
@@ -113,6 +113,13 @@ Status TryToInferTensorOutputFromInputShapes(const Edge& edge,
return Status::OK();
}
+// Returns true if 'node' has a registered CPU kernel.
+bool HasCpuKernel(const Node& node) {
+ return FindKernelDef(DeviceType(DEVICE_CPU), node.def(), /*def=*/nullptr,
+ /*kernel_class_name=*/nullptr)
+ .ok();
+}
+
// Extracts the subgraph ending at 'target_node' that is statically computable
// and inserts into 'out_graph'. If statically computable, 'is_constant_graph'
// will be set to true.
@@ -136,6 +143,12 @@ Status ExtractConstantSubgraph(
return Status::OK();
}
+ // Since constant-folding runs on the CPU, do not attempt to constant-fold
+ // operators that have no CPU kernel.
+ if (!HasCpuKernel(target_node)) {
+ return Status::OK();
+ }
+
// TODO(skyewm): should more of the filtering applied in input nodes below be
// applied to target_node here?
@@ -201,6 +214,11 @@ Status ExtractConstantSubgraph(
return Status::OK();
}
+ if (!HasCpuKernel(*current_node)) {
+ *is_constant_graph = false;
+ return Status::OK();
+ }
+
// If there is nothing more to recurse down, see if
// the generator node is a constant.
if (current_node->num_inputs() == 0) {
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 2c48084cab..40ec1502da 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -54,6 +54,7 @@ limitations under the License.
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/context.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
@@ -1240,6 +1241,7 @@ class ExecutorState {
StepStatsCollectorInterface* const stats_collector_;
const tracing::TraceCollector* const trace_collector_;
const tracing::EventCollector* const event_collector_;
+ Context context_;
// QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper
// instead of a pointer? (avoids having to delete).
@@ -1367,6 +1369,7 @@ ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl)
trace_collector_(tracing::GetTraceCollector()),
event_collector_(
tracing::GetEventCollector(tracing::EventCategory::kCompute)),
+ context_(ContextKind::kThread),
slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
call_frame_(args.call_frame),
impl_(impl),
@@ -1586,6 +1589,7 @@ bool MightTrace(const NodeItem& item,
}
void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
+ WithContext wc(context_);
const GraphView& gview = impl_->gview_;
TaggedNodeSeq ready;
TaggedNodeReadyQueue inline_ready;
diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc
index a02084f223..9306386117 100644
--- a/tensorflow/core/common_runtime/lower_if_op.cc
+++ b/tensorflow/core/common_runtime/lower_if_op.cc
@@ -107,6 +107,8 @@ CondBuilder::CondBuilder(Node* if_op, const string& then_fn_name,
then_call_builder_(NewName("then"), then_fn_name, graph->op_registry()),
else_call_builder_(NewName("else"), else_fn_name, graph->op_registry()) {
TF_CHECK_OK(if_op_->input_node(0, &pred_));
+ then_call_builder_.Device(if_op_->requested_device());
+ else_call_builder_.Device(if_op_->requested_device());
}
Status CondBuilder::CreatePivotNodes() {
@@ -117,15 +119,18 @@ Status CondBuilder::CreatePivotNodes() {
NodeBuilder(NewName("switch_pred"), "Switch", graph_->op_registry())
.Input(NodeOut(pred_, 0))
.Input(NodeOut(pred_, 0))
+ .Device(if_op_->requested_device())
.Finalize(graph_, &switch_pred));
control_predecessor_ = switch_pred;
TF_RETURN_IF_ERROR(
NodeBuilder(NewName("pivot_f"), "Identity", graph_->op_registry())
.Input(switch_pred, kElseBranch)
+ .Device(if_op_->requested_device())
.Finalize(graph_, &pivot_f_));
TF_RETURN_IF_ERROR(
NodeBuilder(NewName("pivot_t"), "Identity", graph_->op_registry())
.Input(switch_pred, kThenBranch)
+ .Device(if_op_->requested_device())
.Finalize(graph_, &pivot_t_));
return Status::OK();
}
@@ -140,6 +145,7 @@ Status CondBuilder::AddInput(Node* src, int src_output) {
NodeBuilder(NewName(src->name()), "Switch", graph_->op_registry())
.Input(src, src_output)
.Input(pred_, 0)
+ .Device(if_op_->requested_device())
.Finalize(graph_, &input));
then_call_builder_.Input(input, kThenBranch);
else_call_builder_.Input(input, kElseBranch);
@@ -178,6 +184,7 @@ Status CondBuilder::AddOutputs() {
TF_RETURN_IF_ERROR(
NodeBuilder(graph_->NewName("merge"), "Merge", graph_->op_registry())
.Input({NodeOut(then_call_node_, i), NodeOut(else_call_node_, i)})
+ .Device(if_op_->requested_device())
.Finalize(graph_, &merges[i]));
outputs_[i] = NodeOut(merges[i], 0);
}
@@ -218,7 +225,7 @@ Status InlineCallInGraph(Node* n, const FunctionLibraryDefinition& flib,
Status CondBuilder::BuildLoweredIfOutput() {
// Build the identity node output.
NodeBuilder ib(name_, "IdentityN");
- ib.Input(outputs_);
+ ib.Input(outputs_).Device(if_op_->requested_device());
return ib.Finalize(graph_, &lowered_if_output_);
}
diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc
index fa4d1eda62..9488a44778 100644
--- a/tensorflow/core/common_runtime/shape_refiner.cc
+++ b/tensorflow/core/common_runtime/shape_refiner.cc
@@ -288,6 +288,11 @@ Status ShapeRefiner::SetShape(const Node* node, int output_port,
"output_port '", output_port, "' is out of range, ", "node '",
node->name(), "' has ", node->num_outputs(), " outputs");
}
+ // Note: it's possible, if the node's been updated, that the shape inference
+ // context doesn't have the right number of outputs.
+ if (node->num_outputs() > c->num_outputs()) {
+ TF_RETURN_IF_ERROR(c->ExpandOutputs(node->num_outputs()));
+ }
// Check compatibility, and merge the shapes.
ShapeHandle existing_shape = c->output(output_port);
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 3e77028a5f..4dcc80680f 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -239,6 +239,15 @@ void InferenceContext::PreInputInit(
output_handle_shapes_and_types_.resize(num_outputs);
}
+Status InferenceContext::ExpandOutputs(int new_output_size) {
+ if (new_output_size < outputs_.size()) {
+ return errors::InvalidArgument("Trying to reduce number of outputs of op.");
+ }
+ outputs_.resize(new_output_size, nullptr);
+ output_handle_shapes_and_types_.resize(new_output_size);
+ return Status::OK();
+}
+
void InferenceContext::PostInputInit(
std::vector<std::unique_ptr<std::vector<ShapeAndType>>> input_handle_data) {
int num_inputs_from_node_def = 0;
diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h
index 81258b55b3..e3885b7d9e 100644
--- a/tensorflow/core/framework/shape_inference.h
+++ b/tensorflow/core/framework/shape_inference.h
@@ -323,13 +323,13 @@ class InferenceContext {
return input_tensors_as_shapes_;
}
- ShapeHandle output(int64 idx) const { return outputs_[idx]; }
- void set_output(int idx, ShapeHandle shape) { outputs_[idx] = shape; }
+ ShapeHandle output(int64 idx) const { return outputs_.at(idx); }
+ void set_output(int idx, ShapeHandle shape) { outputs_.at(idx) = shape; }
Status set_output(StringPiece output_name,
const std::vector<ShapeHandle>& shapes);
int num_outputs() const { return outputs_.size(); }
- ShapeHandle output(int idx) const { return outputs_[idx]; }
+ ShapeHandle output(int idx) const { return outputs_.at(idx); }
Status output(StringPiece output_name,
std::vector<ShapeHandle>* output) const;
@@ -645,6 +645,9 @@ class InferenceContext {
return merged_dims_;
}
+ // Adds new outputs; useful when mutating the graph.
+ Status ExpandOutputs(int new_output_size);
+
private:
// Creates and stores shapes for use in InferenceContext.
class ShapeManager {
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 7a4a0096fa..6f068546d2 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -142,6 +142,19 @@ void Node::Clear() {
assigned_device_name_index_ = 0;
}
+void Node::UpdateProperties() {
+ DataTypeVector inputs;
+ DataTypeVector outputs;
+ Status status =
+ InOutTypesForNode(props_->node_def, *(props_->op_def), &inputs, &outputs);
+ if (!status.ok()) {
+ LOG(ERROR) << "Failed at updating node: " << status;
+ return;
+ }
+ props_ = std::make_shared<NodeProperties>(props_->op_def, props_->node_def,
+ inputs, outputs);
+}
+
const string& Node::name() const { return props_->node_def.name(); }
const string& Node::type_string() const { return props_->node_def.op(); }
const NodeDef& Node::def() const { return props_->node_def; }
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index 2944951f82..228b1331d9 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -171,6 +171,7 @@ class Node {
template <typename T>
void AddAttr(const string& name, const T& val) {
SetAttrValue(val, AddAttrHelper(name));
+ UpdateProperties();
}
void ClearAttr(const string& name);
@@ -211,6 +212,10 @@ class Node {
// e.g. in AddAttr.
void MaybeCopyOnWrite();
+ // Called after an attr has changed. Decides whether we need to update some
+ // property of the node (stored in props_).
+ void UpdateProperties();
+
AttrValue* AddAttrHelper(const string& name);
// A set of mutually exclusive classes for different kinds of nodes,
diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc
index d92874909f..68a20fcc5f 100644
--- a/tensorflow/core/graph/node_builder.cc
+++ b/tensorflow/core/graph/node_builder.cc
@@ -140,10 +140,10 @@ void NodeBuilder::AddIndexError(const Node* node, int i) {
strings::StrCat("Attempt to add nullptr Node to node with type ",
def_builder_.op_def().name()));
} else {
- errors_.emplace_back(
- strings::StrCat("Attempt to add output ", i, " of ", node->name(),
- " not in range [0, ", node->num_outputs(),
- ") to node with type ", def_builder_.op_def().name()));
+ errors_.emplace_back(strings::StrCat(
+ "Attempt to add output ", i, " of ", node->name(), " not in range [0, ",
+ node->num_outputs(), ") to node with type ",
+ def_builder_.op_def().name(), ". Node: ", node->DebugString()));
}
}
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index 362092a6cf..db10f586bc 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -1340,6 +1340,8 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) {
Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
Output g = ops::Shape(s.WithOpName("g"), c);
Output h = ops::Fill(s.WithOpName("h"), g, zero);
+ Output zero_idx = ops::Const(s.WithOpName("zero_idx"), {0}, {1});
+ Output j = ops::Sum(s.WithOpName("j"), a, zero_idx);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
@@ -1382,6 +1384,10 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) {
ASSERT_EQ(2, shape_f.dim_size());
EXPECT_EQ(shape_h.dim(0).size(), shape_c.dim(0).size());
EXPECT_EQ(shape_h.dim(1).size(), shape_c.dim(1).size());
+
+ const auto shape_j = properties.GetOutputProperties("j").at(0).shape();
+ ASSERT_EQ(1, shape_j.dim_size());
+ EXPECT_EQ(shape_j.dim(0).size(), shape_a.dim(1).size());
}
TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) {
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 1b5a215987..cbf5c8e038 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -102,15 +102,19 @@ bool IsConjugateTranspose(const NodeDef& node) {
}
bool IsControlFlow(const NodeDef& node) {
- // clang-format off
- return node.op() == "ControlTrigger" ||
- node.op() == "Enter" ||
- node.op() == "Exit" ||
- node.op() == "LoopCond" ||
- node.op() == "Merge" ||
- node.op() == "NextIteration" ||
- node.op() == "Switch";
- // clang-format on
+ // TODO(williamchan): Add a microbenchmark to compare FlatSet vs. iterative
+ // string comparison.
+ static const gtl::FlatSet<string>* const kControFlowOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{
+ "ControlTrigger",
+ "Enter",
+ "Exit",
+ "LoopCond",
+ "Merge",
+ "NextIteration",
+ "Switch",
+ }));
+ return kControFlowOps->count(node.op()) > 0;
}
bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; }
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index ca5d3a6dfd..3d0d95bba7 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -616,28 +616,37 @@ Status ConstantFolding::MaterializeReductionIndices(
// We can't do anything if we don't know the rank of the input.
return Status::OK();
}
- const int rank = input_prop.shape().dim_size();
- if (rank == 0) {
+ const int input_rank = input_prop.shape().dim_size();
+ if (input_rank < 1) {
// Unexpected graph, don't try to change it.
return Status::OK();
}
+ const OpInfo::TensorProperties& reduction_indices_prop = input_props[1];
+ DataType dtype = reduction_indices_prop.dtype();
+ if (dtype != DT_INT32 && dtype != DT_INT64) {
+ return Status::OK();
+ }
+ PartialTensorShape reduction_indices_shape(reduction_indices_prop.shape());
+ const int num_reduction_indices = reduction_indices_shape.num_elements();
+
const std::vector<OpInfo::TensorProperties>& output_props =
properties.GetOutputProperties(node->name());
if (output_props.size() != 1) {
return Status::OK();
}
- const bool keep_dims =
- node->attr().count("keep_dims") && node->attr().at("keep_dims").b();
const OpInfo::TensorProperties& output_prop = output_props[0];
- PartialTensorShape output_shape(output_prop.shape());
- if (output_shape.num_elements() != 1) {
- bool full_reduction = false;
+ const int output_rank =
+ output_prop.shape().unknown_rank() ? -1 : output_prop.shape().dim_size();
+
+ bool full_reduction = output_rank == 0 || num_reduction_indices == input_rank;
+ if (!full_reduction) {
+ // A full reduction will generate a tensor of one of the shapes
+ // [], [1], [1, 1], [1, 1, ...]. Even if we do not know the number of
+ // elements in the output of the reduction, we may deduce it from reshape
+ // nodes following it.
for (const NodeDef* fanout : node_map_->GetOutputs(node->name())) {
- if (!IsReshape(*fanout) && !keep_dims) {
- // Depending on how it's setup, a full reduction will generate a tensor
- // of shape [], [1], [1, 1], [1, 1, ...]. If keep_dims isn't true, we
- // rely on the existence of a reshape node following the reduction to
- // ensure that the fanout is fed a scalar of the right shape.
+ full_reduction = false;
+ if (!IsReshape(*fanout)) {
return Status::OK();
}
const std::vector<OpInfo::TensorProperties>& reshape_props =
@@ -658,20 +667,15 @@ Status ConstantFolding::MaterializeReductionIndices(
}
}
- const OpInfo::TensorProperties& reduction_prop = input_props[1];
- DataType dtype = reduction_prop.dtype();
- if (dtype != DT_INT32 && dtype != DT_INT64) {
- return Status::OK();
- }
- // We know it's a full reduction. We can generate the set of indices to
- // reduce.
+ // We know it's a full reduction. We can generate the full set of indices to
+ // reduce as a constant node.
string const_name = OptimizedNodeName(*node, "-reduction_indices");
if (node_map_->GetNode(const_name)) {
return Status::OK();
}
NodeDef* reduction_indices = graph_->add_node();
- Tensor value(dtype, TensorShape({rank}));
- for (int i = 0; i < rank; ++i) {
+ Tensor value(dtype, TensorShape({input_rank}));
+ for (int i = 0; i < input_rank; ++i) {
if (dtype == DT_INT32) {
value.vec<int32>()(i) = i;
} else {
@@ -680,6 +684,7 @@ Status ConstantFolding::MaterializeReductionIndices(
}
TF_RETURN_IF_ERROR(
CreateNodeDef(const_name, TensorValue(&value), reduction_indices));
+
reduction_indices->set_device(node->device());
string ctrl_dep =
AddControlDependency(node->input(1), graph_, node_map_.get());
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index b09360a2c2..fab01edfed 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -2591,58 +2591,100 @@ TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs_InfiniteLoop) {
}
TEST_F(ConstantFoldingTest, MaterializeReductionIndices) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output input =
- ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
- ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
- Output indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32);
- Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
- Output size = ops::Const(s.WithOpName("size"), 1, {1});
- Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size);
+ for (bool use_reshape : {true, false}) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output input =
+ ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
+ ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
+ // If use_reshape is false, we need to now the number of indices to apply
+ // the rewrite.
+ Output indices = ops::Placeholder(
+ s.WithOpName("indices"), DT_INT32,
+ ops::Placeholder::Shape(PartialTensorShape({use_reshape ? -1 : 2})));
+ Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
+ if (use_reshape) {
+ Output size = ops::Const(s.WithOpName("size"), 1, {1});
+ Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size);
+ }
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- item.fetch.push_back("reshape");
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch.push_back(use_reshape ? "reshape" : "sum");
- auto input_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
- Tensor indices_t(DT_INT32, TensorShape({2}));
- indices_t.flat<int>()(0) = 0;
- indices_t.flat<int>()(1) = 1;
- auto tensors_expected = EvaluateNodes(
- item.graph, item.fetch, {{"input", input_t}, {"indices", indices_t}});
- EXPECT_EQ(1, tensors_expected.size());
+ auto input_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
+ Tensor indices_t(DT_INT32, TensorShape({2}));
+ indices_t.flat<int>()(0) = 0;
+ indices_t.flat<int>()(1) = 1;
+ auto tensors_expected = EvaluateNodes(
+ item.graph, item.fetch, {{"input", input_t}, {"indices", indices_t}});
+ EXPECT_EQ(1, tensors_expected.size());
- ConstantFolding optimizer(nullptr /* cpu_device */);
- GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ // Use aggressive mode to force the shape inference to propagate placeholder
+ // shapes.
+ ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
+ nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
- // Run a second time to make sure the optimization is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ // Run a second time to make sure the optimization is idempotent.
+ item.graph.Swap(&output);
+ status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
- int found = 0;
- for (const auto& node : output.node()) {
- if (node.name() == "ConstantFolding/sum-reduction_indices") {
- ++found;
- EXPECT_EQ("Const", node.op());
- EXPECT_EQ("^indices", node.input(0));
- EXPECT_EQ(2, TensorShape(node.attr().at("value").tensor().tensor_shape())
- .num_elements());
- } else if (node.name() == "sum") {
- ++found;
- EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1));
- } else if (node.name() == "indices") {
- ++found;
+ int found = 0;
+ for (const auto& node : output.node()) {
+ if (node.name() == "ConstantFolding/sum-reduction_indices") {
+ ++found;
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ("^indices", node.input(0));
+ EXPECT_EQ(2,
+ TensorShape(node.attr().at("value").tensor().tensor_shape())
+ .num_elements());
+ } else if (node.name() == "sum") {
+ ++found;
+ EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1));
+ } else if (node.name() == "indices") {
+ ++found;
+ }
}
+ EXPECT_EQ(3, found);
+
+ auto tensors = EvaluateNodes(output, item.fetch,
+ {{"input", input_t}, {"indices", indices_t}});
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
}
- EXPECT_EQ(3, found);
+}
- auto tensors = EvaluateNodes(output, item.fetch,
- {{"input", input_t}, {"indices", indices_t}});
- EXPECT_EQ(1, tensors.size());
- test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
+TEST_F(ConstantFoldingTest, MaterializeReductionIndices_NotFullReduction) {
+ for (bool input_rank_known : {true, false}) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output input =
+ (input_rank_known ? ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
+ ops::Placeholder::Shape(
+ PartialTensorShape({-1, -1})))
+ : ops::Placeholder(s.WithOpName("input"), DT_FLOAT));
+ Output indices =
+ ops::Placeholder(s.WithOpName("indices"), DT_INT32,
+ ops::Placeholder::Shape(
+ PartialTensorShape({input_rank_known ? 1 : 2})));
+ Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch.push_back("sum");
+
+ // Use aggressive mode to force the shape inference to propagate placeholder
+ // shapes.
+ ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
+ nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ CompareGraphs(item.graph, output);
+ }
}
TEST_F(ConstantFoldingTest, LargeConstant) {
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index 755af3361e..ee7c14e3ab 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -524,6 +524,7 @@ cc_library(
deps = [
":function_utils",
":graph_utils",
+ "//tensorflow/cc:ops",
"@com_google_absl//absl/strings",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
index 9328a7ca99..a9254ed58b 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
@@ -44,7 +44,7 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
// Function inputs and outputs are the same as original, just
// with different shapes.
*vectorized_func->mutable_signature() = orig_func.signature();
- graph_utils::SetUniqueGraphFunctionName("vectorized_function", library,
+ graph_utils::SetUniqueGraphFunctionName("naively_vectorized_fn", library,
vectorized_func);
// Add MapDefun node
@@ -67,6 +67,7 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
map_defun_node->add_input(input.name());
}
(*map_defun_node->mutable_attr())["Targuments"] = t_args;
+ AddNodeAttr("Tcaptured", DataTypeVector(), map_defun_node);
// Set return values to match output names
string output_prefix = strings::StrCat(map_defun_node->name(), ":output:");
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
index 37aa24b947..985d6c6c3a 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
@@ -13,9 +13,19 @@ VECTORIZER_DEPS = [
] + tf_protos_all()
cc_library(
+ name = "wrapped_tensor",
+ hdrs = ["wrapped_tensor.h"],
+ deps = [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
name = "vectorizer",
hdrs = ["vectorizer.h"],
deps = [
+ ":wrapped_tensor",
"//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
] + tf_protos_all(),
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
index 3af6bab409..f445157531 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
@@ -19,13 +19,13 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
+namespace {
class CastVectorizer : public Vectorizer {
public:
Status Vectorize(const Node& node, Graph* outer_scope,
- std::vector<Port>* input_ports,
- std::vector<Port>* output_ports) override {
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) override {
Status s;
if (node.num_inputs() != 1) {
return errors::Internal("Cast op should only have one input.");
@@ -35,15 +35,17 @@ class CastVectorizer : public Vectorizer {
auto new_cast_node = outer_scope->AddNode(node.def(), &s);
TF_RETURN_IF_ERROR(s);
- // Add input and output mappings
- input_ports->push_back({new_cast_node, 0});
- output_ports->push_back({new_cast_node, 0});
+ outer_scope->AddEdge(inputs[0].node, inputs[0].output_index, new_cast_node,
+ 0);
+
+ // Add output mappings
+ outputs->push_back({new_cast_node, 0, true});
return Status::OK();
}
};
REGISTER_VECTORIZER("Cast", CastVectorizer);
-} // namespace vectorization_utils
+} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
index 74ce520ce1..f1ba741821 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
@@ -19,15 +19,15 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
+namespace {
class UnpackVectorizer : public Vectorizer {
public:
Status Vectorize(const Node& node, Graph* outer_scope,
- std::vector<Port>* input_ports,
- std::vector<Port>* output_ports) override {
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) override {
Status s;
- if (node.num_inputs() != 1) {
+ if (node.num_inputs() != 1 || inputs.size() != 1) {
return errors::Internal("Unpack op should only have one input.");
}
@@ -39,13 +39,13 @@ class UnpackVectorizer : public Vectorizer {
int new_axis = node.def().attr().at("axis").i() + 1;
new_unpack_node->AddAttr("axis", new_axis);
- // Add the input mappings
- input_ports->push_back({new_unpack_node, 0});
+ outer_scope->AddEdge(inputs[0].node, inputs[0].output_index,
+ new_unpack_node, 0);
// Add the output mappings
int num = node.def().attr().at("num").i();
for (int i = 0; i < num; ++i) {
- output_ports->push_back({new_unpack_node, i});
+ outputs->push_back({new_unpack_node, i, true});
}
return Status::OK();
@@ -54,6 +54,6 @@ class UnpackVectorizer : public Vectorizer {
REGISTER_VECTORIZER("Unpack", UnpackVectorizer);
-} // namespace vectorization_utils
+} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
index 56eb88c95e..8d4676aae0 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
@@ -18,15 +18,12 @@ limitations under the License.
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
-
-// Describes a tensor with its operation Node and output position
-typedef std::pair<Node*, int> Port;
// Interface for vectorization of TensorFlow operations. See `CastVectorizer`
// for an example.
@@ -36,17 +33,17 @@ class Vectorizer {
// Vectorizes an operation, `node`, by adding Node(s) to `outer_scope`
// that produce the same vector output(s) as executing `node`'s op
- // on elements of the vector inputs. The new Node(s) collectively have the
+ // on elements of `inputs`. The new Node(s) collectively have the
// same number of input and output ports as the node being converted.
- // Adds mappings for the new nodes' input and output ports to `inputs` and
- // `outputs` respectively, where the i'th Port in inputs/outputs
- // corresponds to the i'th input/output port of the node to be converted.
+ // Adds edges between the newly created nodes and nodes in `inputs`, and adds
+ // mappings to the new nodes' output ports to `outputs`, where the i'th
+ // value in `outputs` corresponds to the i'th output port of the node
+ // to be converted.
virtual Status Vectorize(const Node& node, Graph* outer_scope,
- std::vector<Port>* input_ports,
- std::vector<Port>* output_ports) = 0;
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) = 0;
};
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
index a6551e36ac..e1cf77a7d5 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
@@ -19,7 +19,6 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
VectorizerRegistry* VectorizerRegistry::Global() {
static VectorizerRegistry* registry = new VectorizerRegistry;
@@ -42,6 +41,5 @@ void VectorizerRegistry::Register(const string& op_type,
vectorizers_.insert(std::pair<const string&, std::unique_ptr<Vectorizer>>(
op_type, std::move(vectorizer)));
}
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
index 16159d47ca..ad54c74933 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
@@ -23,7 +23,6 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
// A global VectorizerRegistry is used to hold all the vectorizers.
class VectorizerRegistry {
@@ -59,16 +58,12 @@ class VectorizerRegistration {
#define REGISTER_VECTORIZER_UNIQ_HELPER(ctr, op_type, vectorizer) \
REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer)
-#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \
- static ::tensorflow::grappler::vectorization_utils:: \
- vectorizer_registration::VectorizerRegistration \
- vectorizer_registration_##ctr( \
- op_type, \
- ::std::unique_ptr< \
- ::tensorflow::grappler::vectorization_utils::Vectorizer>( \
- new vectorizer()))
+#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \
+ static ::tensorflow::grappler::vectorizer_registration:: \
+ VectorizerRegistration vectorizer_registration_##ctr( \
+ op_type, ::std::unique_ptr<::tensorflow::grappler::Vectorizer>( \
+ new vectorizer()))
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
index 663ceba027..054aeb9a8f 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
@@ -20,13 +20,12 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
class TestVectorizer : public Vectorizer {
public:
Status Vectorize(const Node& node, Graph* outer_scope,
- std::vector<Port>* inputs,
- std::vector<Port>* outputs) override {
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) override {
return Status::OK();
}
};
@@ -43,10 +42,10 @@ TEST(TestVectorizer, TestTestVectorizer) {
NodeDef node_def;
Status s;
Node* node = g.AddNode(node_def, &s);
- std::vector<Port> inputs, outputs;
- EXPECT_TRUE(vectorizer->Vectorize(*node, &g, &inputs, &outputs).ok());
+ std::vector<WrappedTensor> inputs, outputs;
+ EXPECT_TRUE(
+ vectorizer->Vectorize(*node, &g, std::move(inputs), &outputs).ok());
}
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h b/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h
new file mode 100644
index 0000000000..4439b4ab4e
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h
@@ -0,0 +1,44 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_
+
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// Represents a tensor that has been vectorized.
+struct WrappedTensor {
+ Node* const node;
+ const int output_index;
+
+ // Whether the tensor is stacked, i.e. represents the results of applying
+ // the operation on all slices of the input, where each row i of the
+ // tensor corresponds to the op's output on slice i of the input. False
+ // if the tensor is not stacked, i.e. represents the result of the op on
+ // a single slice of the input, where the result does not vary between
+ // slices.
+ bool stacked;
+
+ WrappedTensor(Node* node, int output_index, bool stacked)
+ : node(node), output_index(output_index), stacked(stacked) {}
+};
+
+} // namespace grappler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
index 2d6cf562b1..ba857ab5d9 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -14,10 +14,10 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
-#include <memory>
#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
#include "absl/strings/str_join.h"
+#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/device_base.h"
@@ -28,13 +28,13 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/lib/strings/scanner.h"
namespace tensorflow {
namespace grappler {
@@ -132,7 +132,8 @@ class Vectorization {
const NodeDef& map_defun_node, FunctionDef** result);
private:
- // Converts FunctionDefs to Graphs.
+ // Converts FunctionDefs to Graphs and adds mappings from
+ // arg nodes and unstacked nodes to the corresponding nodes in outer_scope_.
Status Initialize(const FunctionDef& outer_scope,
const NodeDef& map_defun_node);
@@ -162,9 +163,30 @@ class Vectorization {
// the conversion map.
Status AddConversionMapping(Node* op_node);
- // Maps a tensor to the corresponding vectorized tensor. For example,
- // {"Cast" Node*, 0} -> {"Vectorize/Cast" Node*, 0}
- std::map<TensorDesc, TensorDesc> conversion_map_;
+ // Given a tensor t in `unstacked`, stacks it by doing the equivalent of
+ // tf.tile(tf.expand_dims(t, 0), [n, 1, 1, ...]) where n is dimension 0 of
+ // inputs to `map_defun_node_`. This stacked tensor will be compatible with
+ // the expected output shape of `map_defun_node_`.
+ // This is equivalent to the _stack function in python Pfor.
+ Status StackTensor(WrappedTensor* unstacked, TensorDesc* result);
+
+ // Recursively looks for unstacked nodes in the `map_defun_fn_` graph by
+ // doing a depth-first search from the ret nodes. Lifts nodes that are
+ // unstacked (i.e. don't derive from arg nodes) into `outer_scope_` directly
+ // and add mappings to `conversion_map_`.
+ Status AddUnstackedNodeMappings();
+
+ // Recursive helper for `AddUnstackedNodeMappings`, returns true if tensor
+ // is unstacked.
+ bool AddUnstackedNodeMappingsHelper(TensorDesc&& tensor, Status* status);
+
+ // Add mappings from `map_defun_fn_` arg nodes to `map_defun_node_` input
+ // nodes to `conversion_map_`.
+ Status AddArgNodeMappings();
+
+ // Maps a tensor to the corresponding WrappedTensor. For example,
+ // {"Cast" Node*, 0} -> WrappedTensor({"Vectorize/Cast" Node*, 0}, true)
+ std::map<TensorDesc, WrappedTensor> conversion_map_;
// Unconvertible ret nodes
std::set<Node*> unconvertible_;
@@ -180,6 +202,10 @@ class Vectorization {
std::unique_ptr<Graph> outer_scope_;
std::unique_ptr<FunctionBody> map_defun_fn_;
Node* map_defun_node_ = nullptr; // Owned by `outer_scope`
+
+ // Caches the loop_len_node_ needed for tiling unstacked output. This
+ // corresponds to a vector with one element.
+ Node* loop_len_node_ = nullptr; // Owned by `outer_scope`
Status status_;
};
@@ -197,34 +223,48 @@ Status Vectorization::AddConversionMapping(Node* op_node) {
return errors::Unimplemented("No vectorizer registered for op: ",
op_node->type_string());
}
- std::vector<Port> input_ports, output_ports;
- input_ports.reserve(op_node->num_inputs());
- output_ports.reserve(op_node->num_outputs());
- TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(),
- &input_ports, &output_ports));
+ std::vector<WrappedTensor> inputs, outputs;
+ inputs.reserve(op_node->num_inputs());
+ outputs.reserve(op_node->num_outputs());
std::vector<const Edge*> input_edges;
TF_RETURN_IF_ERROR(op_node->input_edges(&input_edges));
- if (op_node->num_outputs() != output_ports.size() ||
- op_node->num_inputs() != input_ports.size() ||
- input_edges.size() != input_ports.size()) {
- return errors::Internal("Vectorizer inputs/outputs don't match.");
- }
-
- // Promote the inputs of the op to MapDefun outputs and connect the edges
- // accordingly.
+ // The inputs for the node to be converted may already have been converted
+ // themselves. For those that are not, we promote them to MapDefun outputs.
for (size_t i = 0; i < op_node->num_inputs(); ++i) {
auto edge = input_edges[i];
- TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_,
- {edge->src(), edge->src_output()}));
- outer_scope_->AddEdge(map_defun_node_, map_defun_fn_->ret_nodes.size() - 1,
- input_ports[i].first, input_ports[i].second);
+ if (auto found = gtl::FindOrNull(conversion_map_,
+ {edge->src(), edge->src_output()})) {
+ inputs.push_back(*found);
+ } else {
+ // TODO(rachelim): Handle the case where unconverted inputs are unstacked.
+ // We assume that all unconverted inputs will be stacked, since we
+ // converted all unstacked nodes in `Initialize`. However, it's actually
+ // possible that yet-unconverted nodes may produce unstacked outputs after
+ // they are vectorized. (For example, see the "Shape" converter in
+ // tensorflow/python/ops/parallel_for/pfor.py). If a vectorizer expects
+ // an unstacked input but receives a stacked one, vectorizer->Vectorize
+ // will return an error.
+ TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_,
+ {edge->src(), edge->src_output()}));
+ int output_index = map_defun_fn_->ret_nodes.size() - 1;
+ inputs.push_back({map_defun_node_, output_index, true});
+ }
+ }
+
+ TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(),
+ std::move(inputs), &outputs));
+
+ if (op_node->num_outputs() != outputs.size()) {
+ return errors::Internal(
+ "Number of vectorizer outputs does not match. Expected: ",
+ op_node->num_outputs(), " Actual: ", outputs.size());
}
// Add output mappings.
for (size_t i = 0; i < op_node->num_outputs(); ++i) {
- conversion_map_.insert({{op_node, i}, std::move(output_ports[i])});
+ conversion_map_.insert({{op_node, i}, outputs[i]});
}
return Status::OK();
@@ -239,13 +279,22 @@ Status Vectorization::ConvertOutput(int output_position) {
TensorDesc output({ret_edge->src(), ret_edge->src_output()});
TensorDesc converted_output;
- if (auto found = gtl::FindOrNull(conversion_map_, output)) {
- // It's possible the output already has a mapping, if it comes from a node
- // that has already been converted.
- converted_output = *found;
- } else {
+
+ // It's possible the output already has a mapping, if it comes from a node
+ // that has already been converted.
+ auto found = gtl::FindOrNull(conversion_map_, output);
+ if (!found) {
TF_RETURN_IF_ERROR(AddConversionMapping(output.first));
- converted_output = conversion_map_.at(output);
+ found = &conversion_map_.at(output);
+ }
+
+ if (found->stacked) {
+ converted_output = {found->node, found->output_index};
+ } else {
+ // Some outputs may be unstacked if they don't derive from arg nodes
+ // (for example, if a function returns a constant). For these, we
+ // have to add extra nodes to tile it in the 0th dimension.
+ TF_RETURN_IF_ERROR(StackTensor(found, &converted_output));
}
ReplaceEdgeSources({map_defun_node_, output_position}, converted_output,
@@ -297,6 +346,7 @@ void Vectorization::VectorizeHelper() {
map_defun_node_->AddAttr("output_types", map_defun_fn_->ret_types);
}
}
+
Status Vectorization::Initialize(const FunctionDef& outer_scope,
const NodeDef& map_defun_node) {
// Convert outer_scope and map_defun_fn to FunctionBodys so we can
@@ -337,16 +387,183 @@ Status Vectorization::Initialize(const FunctionDef& outer_scope,
}
map_defun_node_ = outer_scope_->FindNodeId(node_id);
- // Add mappings from map_defun_fn_ arg nodes to map_defun_node_ input nodes to
- // the conversion map
+ TF_RETURN_IF_ERROR(AddArgNodeMappings());
+
+ TF_RETURN_IF_ERROR(AddUnstackedNodeMappings());
+ loop_len_node_ = nullptr;
+
+ return Status::OK();
+}
+
+// TODO(rachelim): It might be profitable to use the C++ API for this instead of
+// NodeBuilder
+Status Vectorization::StackTensor(WrappedTensor* unstacked,
+ TensorDesc* result) {
+ // Note that all these nodes are necessary as the size of the batch may not be
+ // constant.
+ if (unstacked->stacked) {
+ return errors::Internal("Can only stack unstacked tensor.");
+ }
+
+ Graph* g = outer_scope_.get();
+ auto node_builder = [](StringPiece op) {
+ return NodeBuilder(strings::StrCat("vectorized/stack/", op), op);
+ };
+
+ auto make_const = [&node_builder](const Input::Initializer& val, Graph* graph,
+ Node** result) {
+ TF_RETURN_IF_ERROR(val.status);
+ return node_builder("Const")
+ .Attr("value", val.tensor)
+ .Attr("dtype", val.tensor.dtype())
+ .Finalize(graph, result);
+ };
+
+ // If loop_len_node_ hasn't been created yet, add the node and cache it.
+ if (loop_len_node_ == nullptr) {
+ Node* input_node;
+ TF_RETURN_IF_ERROR(map_defun_node_->input_node(0, &input_node));
+
+ Node* shape_node;
+ TF_RETURN_IF_ERROR(
+ node_builder("Shape").Input(input_node).Finalize(g, &shape_node));
+
+ Node* const_vec_0;
+ TF_RETURN_IF_ERROR(make_const({0}, g, &const_vec_0));
+ Node* const_vec_1;
+ TF_RETURN_IF_ERROR(make_const({1}, g, &const_vec_1));
+
+ Node* strided_slice_node;
+ TF_RETURN_IF_ERROR(node_builder("StridedSlice")
+ .Input(shape_node) // input
+ .Input(const_vec_0) // begin
+ .Input(const_vec_1) // end
+ .Input(const_vec_1) // strides
+ .Finalize(g, &strided_slice_node));
+
+ // Produces a vector of length 1
+ TF_RETURN_IF_ERROR(node_builder("Reshape")
+ .Input(strided_slice_node) // tensor
+ .Input(const_vec_1) // shape
+ .Finalize(g, &loop_len_node_));
+ }
+
+ Node* ones_shape;
+ TF_RETURN_IF_ERROR(node_builder("Shape")
+ .Input(unstacked->node) // input
+ .Finalize(g, &ones_shape));
+
+ Node* ones;
+ TF_RETURN_IF_ERROR(
+ node_builder("OnesLike").Input(ones_shape).Finalize(g, &ones));
+
+ Node* const_0;
+ TF_RETURN_IF_ERROR(make_const(0, g, &const_0));
+
+ Node* multiples;
+ TF_RETURN_IF_ERROR(node_builder("Concat")
+ .Input(const_0) // concat_dim
+ .Input({{loop_len_node_, 0}, {ones, 0}}) // values
+ .Finalize(g, &multiples));
+
+ Node* expand_dims;
+ TF_RETURN_IF_ERROR(node_builder("ExpandDims")
+ .Input(unstacked->node) // input
+ .Input(const_0) // dim
+ .Finalize(g, &expand_dims));
+
+ TF_RETURN_IF_ERROR(node_builder("Tile")
+ .Input(expand_dims) // input
+ .Input(multiples) // multiples
+ .Finalize(g, &result->first));
+ result->second = 0;
+ return Status::OK();
+}
+
+Status Vectorization::AddArgNodeMappings() {
for (auto arg_node : map_defun_fn_->arg_nodes) {
Node* input_node;
TF_RETURN_IF_ERROR(map_defun_node_->input_node(
arg_node->attrs().Find("index")->i(), &input_node));
- conversion_map_.insert({{arg_node, 0}, {input_node, 0}});
+ conversion_map_.insert({{arg_node, 0}, {input_node, 0, true}});
+
+ // Control inputs
+ conversion_map_.insert({{arg_node, Graph::kControlSlot},
+ {input_node, Graph::kControlSlot, true}});
+ }
+ return Status::OK();
+}
+
+bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor,
+ Status* status) {
+ if (auto found = gtl::FindOrNull(conversion_map_, tensor)) {
+ return !found->stacked;
+ }
+
+ if (tensor.first->op_def().is_stateful()) {
+ // We don't lift stateful nodes directly out of the MapDefun, since they may
+ // have to be executed N times.
+ return false;
}
+ bool is_unstacked = true;
+ for (auto edge : tensor.first->in_edges()) {
+ // Ignore Source nodes. Note that these are also ignored in the
+ // GraphToFunctionDef conversion.
+ if (edge->src()->IsSource()) continue;
+
+ // A node is unstacked if all of its inputs are unstacked
+ is_unstacked &= AddUnstackedNodeMappingsHelper(
+ {edge->src(), edge->src_output()}, status);
+ }
+
+ if (!is_unstacked) {
+ return false;
+ }
+
+ // If the node is unstacked, we copy it into outer_scope_ and
+ // add it to the map. Note that we don't clean up the nodes that are copied
+ // in map_defun_fn_, and rely on them being pruned out later.
+ Node* node = outer_scope_->AddNode(tensor.first->def(), status);
+ if (!status->ok()) return true;
+
+ // Add input edges to nodes that should already have been lifted.
+ for (auto edge : tensor.first->in_edges()) {
+ // Ignore Source nodes. Note that these are also ignored in the
+ // GraphToFunctionDef conversion.
+ if (edge->src()->IsSource()) continue;
+
+ if (auto found = gtl::FindOrNull(conversion_map_,
+ {edge->src(), edge->src_output()})) {
+ outer_scope_->AddEdge(found->node, found->output_index, node,
+ edge->dst_input());
+ } else {
+ status->Update(errors::Internal(
+ "Could not find input conversion even though we did depth first "
+ "conversion."));
+ }
+ }
+
+ // Add output mappings
+ for (int i = 0; i < tensor.first->num_outputs(); ++i) {
+ conversion_map_.insert({{tensor.first, i}, WrappedTensor(node, i, false)});
+ }
+ conversion_map_.insert({{tensor.first, Graph::kControlSlot},
+ WrappedTensor(node, Graph::kControlSlot, false)});
+
+ return true;
+}
+
+Status Vectorization::AddUnstackedNodeMappings() {
+ SetVector<Node*> unstacked_nodes;
+ Status s;
+ for (const auto& ret_node : map_defun_fn_->ret_nodes) {
+ const Edge* in_edge = nullptr;
+ TF_RETURN_IF_ERROR(ret_node->input_edge(0, &in_edge));
+ AddUnstackedNodeMappingsHelper({in_edge->src(), in_edge->src_output()}, &s);
+ TF_RETURN_IF_ERROR(s);
+ }
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
index 1ff62217dd..a6020e36bb 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
@@ -55,6 +55,7 @@ NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs,
func.set_name(function_name);
NodeDef* node = function_utils::AddNode(name, "MapDefun", inputs, {}, fn);
graph_transforms::SetNodeAttr("Targuments", t_arguments, node);
+ graph_transforms::SetNodeAttr("Tcaptured", DataTypeVector(), node);
graph_transforms::SetNodeAttr("output_types", output_types, node);
graph_transforms::SetNodeAttr("output_shapes", output_shapes, node);
graph_transforms::SetNodeAttr("f", func, node);
@@ -142,6 +143,8 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
*lib.add_function() = outer;
*lib.add_function() = inner;
FunctionDef* vectorized;
+ Status s = VectorizeMapDefun(outer, *map_defun, &lib, &vectorized);
+ LOG(ERROR) << s;
EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
EXPECT_TRUE(
!function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
@@ -670,6 +673,257 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
cast_node.input(1) == control_input);
}
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +------+ | |
+// | | | |
+// | | | |
+// | | +------+ | |
+// | | |Const | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +------+ |
+// | |
+// | +------+ |
+// | |Const | |
+// | +---+--+ |
+// | | |
+// | | |
+// | | |
+// | +---v--+ |
+// | |Stack*| |
+// | +---+--+ |
+// | | |
+// | | |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+// *Not actually a Stack node, but does the equivalent.
+//
+TEST(VectorizeMapDefunTest, VectorizeConst) {
+ FunctionDef inner = FunctionDefHelper::Create(
+ "inner_function", {"arg0: int32"}, {"ret0: int32"}, {/* attrs */},
+ {/* nodes */ FunctionDefHelper::Const("Const", 2)},
+ {{"ret0", "Const:output:0"}});
+ FunctionDef outer = FunctionDefHelper::Create(
+ "outer_function", {"outer_arg0: int32"}, {"mapdefun: int32"},
+ {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT32}, {{}},
+ inner.signature().name(), &outer);
+
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ EXPECT_TRUE(function_utils::ContainsFunctionNodeWithOp("Const", *vectorized));
+}
+
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +------+ | |
+// | | | |
+// | | | |
+// | | +------+ | |
+// | | |Const | | |
+// | | +---+--+ | |
+// | | | | |
+// | | +---v--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +------+ |
+// | |
+// | +------+ |
+// | |Const | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | |Stack*| |
+// | +---+--+ |
+// | | |
+// | | |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+// *Not actually a Stack node, but does the equivalent.
+//
+TEST(VectorizeMapDefunTest, VectorizeUnstackedOutput) {
+ FunctionDef inner = FunctionDefHelper::Create(
+ "inner_function", {"arg0: int32"}, {"ret0: int64"}, {/* attrs */},
+ {/* nodes */ FunctionDefHelper::Const("Const", 2)},
+ {{"ret0", "Cast:y:0"}});
+ AddCastNode("Cast", {"Const:output:0"}, DT_INT32, DT_INT64, false, &inner);
+
+ FunctionDef outer = FunctionDefHelper::Create(
+ "outer_function", {"outer_arg0: int32"}, {"mapdefun: int64"},
+ {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT64}, {{}},
+ inner.signature().name(), &outer);
+
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ auto const_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Const", *vectorized));
+ auto cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
+ EXPECT_EQ(cast_node.input(0).substr(0, cast_node.input(0).find(':')),
+ const_node.name());
+}
+
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +------+ | |
+// | | | |
+// | | +------+ +------+ | |
+// | | |Const | |Const | | |
+// | | +---+--+ +---+--+ | |
+// | | : +---v--+ | |
+// | | ::::::> Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +------+ |
+// | |
+// | |
+// | +------+ |
+// | +------+ |Const | |
+// | |Const | +---+--+ |
+// | +---+--+ | |
+// | : +---v--+ |
+// | ::::::> Cast | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +Stack*+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+// *Not actually a Stack node, but does the equivalent.
+//
+TEST(VectorizeMapDefunTest, VectorizeUnstackedControl) {
+ FunctionDef inner = FunctionDefHelper::Create(
+ "inner_function", {"arg0: int32"}, {"ret0: int64"}, {/* attrs */},
+ {/* nodes */ FunctionDefHelper::Const("Const", 2),
+ FunctionDefHelper::Const("ConstDep", 3)},
+ {{"ret0", "Cast:y:0"}});
+ AddCastNode("Cast", {"Const:output:0", "^ConstDep"}, DT_INT32, DT_INT64,
+ false, &inner);
+
+ FunctionDef outer = FunctionDefHelper::Create(
+ "outer_function", {"outer_arg0: int32"}, {"mapdefun: int64"},
+ {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT64}, {{}},
+ inner.signature().name(), &outer);
+
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+
+ auto find_const = [vectorized](int val) -> const NodeDef* {
+ for (const auto& n : vectorized->node_def()) {
+ if (n.attr().at("value").tensor().int_val(0) == val) {
+ return &n;
+ }
+ }
+ return nullptr;
+ };
+
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ auto const_node = find_const(2);
+ auto const_dep_node = find_const(3);
+ auto cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
+ EXPECT_EQ(cast_node.input(0).substr(0, cast_node.input(0).find(':')),
+ const_node->name());
+ EXPECT_EQ(cast_node.input(1), strings::StrCat("^", const_dep_node->name()));
+}
+
// TODO(rachelim): More test cases when we get around to implementing them:
// [] A badly defined converter, e.g. doesn't produce nodes that have the
// same number of outputs/inputs as the nodes to be converted
diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer.h b/tensorflow/core/grappler/optimizers/graph_optimizer.h
index 765dd13263..bd6bf9f860 100644
--- a/tensorflow/core/grappler/optimizers/graph_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/graph_optimizer.h
@@ -16,8 +16,11 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_H_
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_H_
+#include <atomic>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
namespace tensorflow {
namespace grappler {
@@ -29,6 +32,7 @@ struct GrapplerItem;
// optimization of a GrapplerItem for running on a cluster.
class GraphOptimizer {
public:
+ GraphOptimizer() : is_cancelled_(false) {}
virtual ~GraphOptimizer() {}
virtual string name() const = 0;
@@ -45,8 +49,25 @@ class GraphOptimizer {
// call to Optimize) performed. Lower "result" scores are better.
virtual void Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimized_graph, double result) = 0;
+
+ // Best effort cancellation. Sets is_cancelled to true and requests that the
+ // optimizer returns as soon as possible from active calls to Optimize() or
+ // FeedBack().
+ void Cancel() { is_cancelled_ = true; }
+
+ bool is_cancelled() const { return is_cancelled_; }
+
+ private:
+ std::atomic<bool> is_cancelled_;
};
+#define GRAPPLER_RETURN_IF_CANCELLED() \
+ do { \
+ if (is_cancelled()) { \
+ return errors::DeadlineExceeded(this->name(), " was cancelled."); \
+ } \
+ } while (0)
+
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index c3d70a1fdf..7488cedec5 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -14,6 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
+
+#include <memory>
+
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
@@ -37,7 +40,11 @@ limitations under the License.
#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/notification.h"
+#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -107,13 +114,29 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
MK_OPT("scoped_allocator",
new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(),
cfg_.scoped_allocator_opts()));
- MK_OPT("small_op", new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
+ MK_OPT("pin_to_host",
+ new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
return std::unique_ptr<GraphOptimizer>();
}
#undef MK_OPT
+MetaOptimizer::MetaOptimizer(DeviceBase* cpu_device, const RewriterConfig& cfg)
+ : cpu_device_(cpu_device), cfg_(cfg) {
+ // TODO(rmlarsen): Increase kNumThreads to, say, port::NumSchedulableCPUs()
+ // if we want to the threadpool for parallelizing Grappler
+ const int kNumThreads = 1;
+ thread_pool_ = absl::make_unique<thread::ThreadPool>(
+ Env::Default(), "MetaOptimizerThreadPool", kNumThreads);
+}
+
+MetaOptimizer::~MetaOptimizer() {
+ // The ThreadPool destructor waits for threads to finish, so we don't
+ // pull the rug out from under them.
+ thread_pool_.reset();
+}
+
Status MetaOptimizer::InitializeOptimizers(
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
if (cfg_.disable_meta_optimizer()) {
@@ -139,7 +162,7 @@ Status MetaOptimizer::InitializeOptimizers(
if (cfg_.remapping() != RewriterConfig::OFF) {
optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping()));
}
- if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) {
+ if (cfg_.pin_to_host_optimization() != RewriterConfig::OFF) {
optimizers->push_back(MakeUnique<PinToHostOptimizer>());
}
if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
@@ -309,6 +332,7 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
VLOG(4) << "Starting optimization iteration " << iteration;
for (const auto& optimizer : optimizers) {
+ GRAPPLER_RETURN_IF_CANCELLED();
// Some optimizers can run only once.
if (iteration > 0 && IsRunOnceOptimizer(optimizer->name())) continue;
// Some must run only on the last iteration.
@@ -367,6 +391,7 @@ Status MetaOptimizer::RunOptimizer(
// resets optimized_graph to an empty graph.
optimized_graph->Swap(&optimized_item->graph);
*optimized_graph = GraphDef();
+ // TODO(rmlarsen): Add timeout for individual optimizers.
Status status =
optimizer->Optimize(cluster, *optimized_item, optimized_graph);
uint64 end_us = Env::Default()->NowMicros();
@@ -388,14 +413,15 @@ Status MetaOptimizer::RunOptimizer(
return status;
}
-Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
- GraphDef* optimized_graph) {
+Status MetaOptimizer::OptimizeMainGraphAndFunctionLibrary(
+ Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) {
VLOG(1) << "Starting optimization for grappler item: " << item.id;
optimization_results_.clear();
// 1. Optimize main graph
TF_RETURN_IF_ERROR(OptimizeGraph(cluster, item, optimized_graph));
VLOG(1) << "Optimized main graph.";
+ GRAPPLER_RETURN_IF_CANCELLED();
// Skip optimizing functions if this is a TPU graph. Currently, Grappler
// passes do not handle TPU functions correctly in a variety of ways (Note
@@ -431,6 +457,8 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
optimize_function_library = false;
for (const FunctionDef& func : optimized_graph->library().function()) {
+ GRAPPLER_RETURN_IF_CANCELLED();
+
const string& func_name = func.signature().name();
// Skip already optimized functions.
@@ -505,6 +533,43 @@ void MetaOptimizer::PrintResult() {
}
}
+Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) {
+ const int64 kFiveMinutesInUsec = 5 * 60 * 1000 * 1000;
+ const int64 timeout_usec = (cfg_.meta_optimizer_timeout_ms() == 0
+ ? kFiveMinutesInUsec
+ : cfg_.meta_optimizer_timeout_ms() * 1000);
+ if (timeout_usec < 0) {
+ return OptimizeMainGraphAndFunctionLibrary(cluster, item, optimized_graph);
+ }
+
+ GraphDef optimized_with_timeout;
+ Status status;
+ Notification done;
+ thread_pool_->Schedule(
+ [this, cluster, &done, &optimized_with_timeout, &item, &status]() {
+ status = this->OptimizeMainGraphAndFunctionLibrary(
+ cluster, item, &optimized_with_timeout);
+ done.Notify();
+ });
+
+ const bool notified = WaitForNotificationWithTimeout(&done, timeout_usec);
+ if (notified && status.ok()) {
+ optimized_graph->Swap(&optimized_with_timeout);
+ } else {
+ *optimized_graph = item.graph;
+ if (!notified) {
+ this->Cancel();
+ done.WaitForNotification();
+ status = errors::DeadlineExceeded(
+ "Grappler MetaOptimizer timed out after ",
+ static_cast<float>(timeout_usec) / (1000 * 1000), " seconds");
+ LOG(WARNING) << status.error_message();
+ }
+ }
+ return status;
+}
+
void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& pruned_graph, double result) {
// Nothing to do for MetaOptimizer.
@@ -527,7 +592,7 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
cfg.debug_stripper() == RewriterConfig::ON ||
cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
- cfg.pin_to_host_optimization() == RewriterConfig::ON ||
+ cfg.pin_to_host_optimization() != RewriterConfig::OFF ||
!cfg.optimizers().empty() || !cfg.custom_optimizers().empty();
}
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h
index 99a0a33ffa..35d6a4559b 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
@@ -28,9 +29,8 @@ namespace grappler {
// Run the other grappler optimizers based on the specified rewriter config.
class MetaOptimizer : public GraphOptimizer {
public:
- MetaOptimizer(DeviceBase* cpu_device, const RewriterConfig& cfg)
- : cpu_device_(cpu_device), cfg_(cfg) {}
- ~MetaOptimizer() override = default;
+ MetaOptimizer(DeviceBase* cpu_device, const RewriterConfig& cfg);
+ ~MetaOptimizer();
string name() const override { return "meta_optimizer"; };
@@ -65,9 +65,18 @@ class MetaOptimizer : public GraphOptimizer {
Status OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph);
+ // Run optimization passes over the main graph and for functions in the
+ // function library.
+ Status OptimizeMainGraphAndFunctionLibrary(Cluster* cluster,
+ const GrapplerItem& item,
+ GraphDef* optimized_graph);
+
DeviceBase* const cpu_device_; // may be NULL
RewriterConfig cfg_;
+ // Thread pool used for launching optimizers asynchronously.
+ std::unique_ptr<thread::ThreadPool> thread_pool_;
+
struct OptimizerResult {
string optimizer_name;
string result;
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
index 3f3f43382f..7f1dd91f09 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
@@ -461,6 +461,68 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibraryWithRestrictions) {
EXPECT_FALSE(allowed_optimizations_my_mul_2->non_differentiable_rewrites);
}
+class SleepingOptimizer : public CustomGraphOptimizer {
+ public:
+ SleepingOptimizer() {}
+ string name() const override { return "test_optimizer"; }
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) override {
+ *optimized_graph = item.graph;
+ optimized_graph->add_node();
+ sleep(1);
+ return Status::OK();
+ }
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimized_graph, double result) override {}
+};
+
+REGISTER_GRAPH_OPTIMIZER(SleepingOptimizer);
+
+TEST_F(MetaOptimizerTest, OptimizerTimesOut) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ RewriterConfig rewriter_config;
+ rewriter_config.add_optimizers("SleepingOptimizer");
+ rewriter_config.set_min_graph_nodes(-1);
+ rewriter_config.set_meta_optimizer_timeout_ms(1500);
+ rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
+
+ MetaOptimizer optimizer(nullptr, rewriter_config);
+ GraphDef output;
+ const Status status = optimizer.Optimize(nullptr, item, &output);
+ EXPECT_EQ(status.error_message(),
+ "Grappler MetaOptimizer timed out after 1.5 seconds");
+ // Make sure the graph was reverted to the original regardless of when the
+ // optimizer timed out.
+ CompareGraphs(item.graph, output);
+}
+
+TEST_F(MetaOptimizerTest, OptimizerDoesNotTimeOut) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ RewriterConfig rewriter_config;
+ rewriter_config.add_optimizers("SleepingOptimizer");
+ rewriter_config.set_min_graph_nodes(-1);
+ rewriter_config.set_meta_optimizer_timeout_ms(1500);
+ rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE);
+ MetaOptimizer optimizer(nullptr, rewriter_config);
+ GraphDef output;
+ const Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ EXPECT_EQ(item.graph.node_size() + 1, output.node_size());
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
index 8ed4271fa4..29a3b2b74c 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
@@ -25,16 +25,29 @@ limitations under the License.
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
namespace grappler {
+
namespace internal {
+namespace {
// TODO(williamchan): Change this constant to be something smarter, maybe
// dynamically determined.
constexpr int64 kTensorMaxSize = 64;
+struct OpDevicePortHasher {
+ std::size_t operator()(const std::tuple<string, string, int>& x) const {
+ uint64 code = Hash64Combine(Hash64(std::get<0>(x)), Hash64(std::get<1>(x)));
+
+ return Hash64Combine(code, hash<int>()(std::get<2>(x)));
+ }
+};
+using OpDevicePortOnHostMap =
+ gtl::FlatMap<std::tuple<string, string, int>, bool, OpDevicePortHasher>;
+
// All the nodes that should be blacklisted and not swapped.
bool IsBlacklisted(const NodeDef& node) {
return
@@ -82,10 +95,10 @@ Status TryFindKernelDef(const std::vector<DeviceType>& devices,
// Checks if a node's output port is host friendly.
// Roughly this means checking if the output port is on Host memory.
-Status IsNodeOutputPortHostFriendly(const GraphView& graph,
- GraphProperties* properties,
- const NodeDef& node, int port_id,
- bool* is_candidate) {
+Status IsNodeOutputPortHostFriendly(
+ const GraphView& graph, GraphProperties* properties, const NodeDef& node,
+ int port_id, OpDevicePortOnHostMap* op_device_outport_pinned_to_host_cache,
+ bool* is_candidate) {
*is_candidate = false;
// Make sure we are not a blacklisted op.
@@ -117,7 +130,8 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph,
for (const auto& fanin : graph.GetFanins(node, false)) {
bool fanin_candidate = false;
TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
- graph, properties, *fanin.node, fanin.port_id, &fanin_candidate));
+ graph, properties, *fanin.node, fanin.port_id,
+ op_device_outport_pinned_to_host_cache, &fanin_candidate));
if (!fanin_candidate) {
return Status::OK();
}
@@ -132,11 +146,22 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph,
return Status::OK();
}
+ // Check `op_device_outport_pinned_to_host_cache` for our
+ // {op, device, port_id} combo to see if the arg is pinned on Host.
+ const std::tuple<string, string, int> cache_key(node.op(), node.device(),
+ port_id);
+ auto it = op_device_outport_pinned_to_host_cache->find(cache_key);
+ if (it != op_device_outport_pinned_to_host_cache->end()) {
+ *is_candidate = it->second;
+ return Status::OK();
+ }
+
// Check if op's output port is pinned to HostMemory.
const OpDef* op = nullptr;
Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
if (!s.ok()) {
LOG(WARNING) << "Could not find OpDef for : " << node.op();
+ op_device_outport_pinned_to_host_cache->emplace(cache_key, false);
return Status::OK();
}
@@ -146,6 +171,7 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph,
LOG(WARNING) << "Invalid port: " << port_id << "!\n"
<< node.DebugString() << "\n"
<< op->DebugString();
+ op_device_outport_pinned_to_host_cache->emplace(cache_key, false);
return Status::OK();
}
@@ -155,6 +181,7 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph,
&kernel);
if (!s.ok()) {
LOG(INFO) << "Could not find KernelDef for: " << node.op();
+ op_device_outport_pinned_to_host_cache->emplace(cache_key, false);
return Status::OK();
}
@@ -166,22 +193,35 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph,
}
}
+ op_device_outport_pinned_to_host_cache->emplace(cache_key, *is_candidate);
+
return Status::OK();
}
// Checks if a node's input port is Host friendly.
// Roughly this means checking if the input port is on Host memory.
-bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) {
+bool IsNodeInputPortHostFriendly(
+ const NodeDef& node, int port_id,
+ OpDevicePortOnHostMap* op_device_inport_pinned_to_host_cache) {
// If node is on Host, assume its inputs are Host friendly.
if (str_util::StrContains(node.device(), DEVICE_CPU)) {
return true;
}
+ // Check `op_device_inport_pinned_to_host_cache` for our
+ // {op, device, port_id} combo to see if the arg is pinned on Host.
+ std::tuple<string, string, int> cache_key(node.op(), node.device(), port_id);
+ auto it = op_device_inport_pinned_to_host_cache->find(cache_key);
+ if (it != op_device_inport_pinned_to_host_cache->end()) {
+ return it->second;
+ }
+
// Check if op's input port is pinned to HostMemory.
const OpDef* op = nullptr;
Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
if (!s.ok()) {
LOG(WARNING) << "Could not find OpDef for : " << node.op();
+ op_device_inport_pinned_to_host_cache->emplace(cache_key, false);
return false;
}
const int input_arg_id = OpInputPortIdToArgId(node, *op, port_id);
@@ -192,16 +232,20 @@ bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) {
{node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node, &kernel);
if (!s.ok()) {
LOG(INFO) << "Could not find KernelDef for: " << node.op();
+ op_device_inport_pinned_to_host_cache->emplace(cache_key, false);
return false;
}
// Check if the input_arg is pinned to Host.
for (const string& host_memory_arg : kernel->host_memory_arg()) {
if (op->input_arg(input_arg_id).name() == host_memory_arg) {
+ op_device_inport_pinned_to_host_cache->emplace(cache_key, true);
return true;
}
}
+ op_device_inport_pinned_to_host_cache->emplace(cache_key, false);
+
return false;
}
@@ -211,18 +255,20 @@ bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) {
// 2] Check if node can run on Host.
// 3] Check all input/outputs are Host "friendly" (atm, friendly means small,
// ints, and pinned to Host).
-Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties,
- const NodeDef& node, bool* is_candidate) {
+Status IsNodeHostCandidate(
+ const GraphView& graph, GraphProperties* properties, const NodeDef& node,
+ OpDevicePortOnHostMap* op_device_outport_pinned_to_host_cache,
+ bool* is_candidate) {
*is_candidate = false;
- // Check if node already on CPU.
- if (str_util::StrContains(node.device(), DEVICE_CPU)) {
- *is_candidate = true;
+ // Skip these node types.
+ if (IsBlacklisted(node)) {
return Status::OK();
}
- // Skip these node types.
- if (IsBlacklisted(node)) {
+ // Check if node already on CPU.
+ if (str_util::StrContains(node.device(), DEVICE_CPU)) {
+ *is_candidate = true;
return Status::OK();
}
@@ -232,17 +278,6 @@ Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties,
return Status::OK();
}
- // Check all inputs are Host friendly.
- for (const GraphView::OutputPort& fanin :
- graph.GetFanins(node, /*include_controlling_nodes=*/false)) {
- bool fanin_candidate = false;
- TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
- graph, properties, *fanin.node, fanin.port_id, &fanin_candidate));
- if (!fanin_candidate) {
- return Status::OK();
- }
- }
-
// Check all outputs are Host friendly.
if (!properties->has_properties()) {
// This is an expensive call, call it lazily.
@@ -255,16 +290,42 @@ Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties,
}
}
+ // Check all inputs are Host friendly.
+ for (const GraphView::OutputPort& fanin :
+ graph.GetFanins(node, /*include_controlling_nodes=*/false)) {
+ bool fanin_candidate = false;
+ TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
+ graph, properties, *fanin.node, fanin.port_id,
+ op_device_outport_pinned_to_host_cache, &fanin_candidate));
+ if (!fanin_candidate) {
+ return Status::OK();
+ }
+ }
+
*is_candidate = true;
return Status::OK();
}
-string TryFindHostDevice(const gtl::FlatSet<string>& devices,
- bool has_device_cpu, const string& device) {
+bool IsTPUGraphDef(const GraphDef& def) {
+ for (const auto& node : def.node()) {
+ if (node.op() == "TPUCompile" || node.op() == "TPUExecute" ||
+ node.op() == "TPUPartitionedCall") {
+ return true;
+ }
+ }
+ return false;
+}
+} // end namespace
+
+// Tries to swap `device` to a Host device from `devices`. Returns true iff
+// there was a swap.
+bool TrySwapToHostDevice(const gtl::FlatSet<string>& devices,
+ bool has_device_cpu, string* device) {
// Force this node onto the CPU.
- if (device.empty() && has_device_cpu) {
- return "/device:CPU:0";
- } else if (str_util::StrContains(device, DEVICE_GPU)) {
+ if (device->empty() && has_device_cpu) {
+ *device = "/device:CPU:0";
+ return true;
+ } else if (str_util::StrContains(*device, DEVICE_GPU)) {
// Sometimes the cluster can have:
// devices = {"/device:CPU:0", "/device:XLA_GPU:0"}
// and we need to handle them properly.
@@ -272,27 +333,19 @@ string TryFindHostDevice(const gtl::FlatSet<string>& devices,
{std::pair<string, string>("GPU", "CPU:0"),
std::pair<string, string>("/device", "/device:CPU:0")}) {
const string device_host =
- strings::StrCat(device.substr(0, device.rfind(device_match.first)),
+ strings::StrCat(device->substr(0, device->rfind(device_match.first)),
device_match.second);
if (devices.find(device_host) != devices.end()) {
- return device_host;
+ *device = device_host;
+ return true;
}
}
}
- // We couldn't find an appropriate Host device, return original device.
- return device;
-}
-
-bool IsTPUGraphDef(const GraphDef& def) {
- for (const auto& node : def.node()) {
- if (node.op() == "TPUCompile" || node.op() == "TPUExecute" ||
- node.op() == "TPUPartitionedCall") {
- return true;
- }
- }
+ // We couldn't find an appropriate Host device, return false.
return false;
}
+
} // end namespace internal
Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
@@ -324,20 +377,26 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// All the Const nodes, and their original devices in topological order.
std::vector<std::pair<NodeDef*, string>> const_nodes;
+ // Cache to map {op, device, port} -> bool on whether it is pinned to host.
+ internal::OpDevicePortOnHostMap op_device_outport_pinned_to_host_cache;
+ internal::OpDevicePortOnHostMap op_device_inport_pinned_to_host_cache;
+
for (auto& node : *optimized_graph->mutable_node()) {
bool is_candidate = false;
- TF_RETURN_IF_ERROR(
- internal::IsNodeHostCandidate(graph, &properties, node, &is_candidate));
+ TF_RETURN_IF_ERROR(internal::IsNodeHostCandidate(
+ graph, &properties, node, &op_device_outport_pinned_to_host_cache,
+ &is_candidate));
if (!is_candidate) {
continue;
}
- if (IsConstant(node)) {
- const_nodes.emplace_back(&node, node.device());
+ const string original_device = node.device();
+ const bool swapped = internal::TrySwapToHostDevice(devices, has_device_cpu,
+ node.mutable_device());
+ // Keep track of all Const nodes that we swapped.
+ if (swapped && IsConstant(node)) {
+ const_nodes.emplace_back(&node, original_device);
}
- // Try and swap the device to Host.
- node.set_device(
- internal::TryFindHostDevice(devices, has_device_cpu, node.device()));
}
// Traverse all `const_nodes`, and map them back to GPU greedily.
@@ -349,8 +408,9 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// this node back onto the original device.
for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) {
// The consumer is not Host friendly, swap it back to the original device.
- if (!internal::IsNodeInputPortHostFriendly(*fanout.node,
- fanout.port_id)) {
+ if (!internal::IsNodeInputPortHostFriendly(
+ *fanout.node, fanout.port_id,
+ &op_device_inport_pinned_to_host_cache)) {
node->set_device(device);
break;
}
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
index d557a03463..bed4a9ef95 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
@@ -26,8 +26,8 @@ namespace tensorflow {
namespace grappler {
namespace internal {
// Try and find an appropriate Host device in `devices` given `device`.
-string TryFindHostDevice(const gtl::FlatSet<string>& devices,
- bool has_device_cpu, const string& device);
+bool TrySwapToHostDevice(const gtl::FlatSet<string>& devices,
+ bool has_device_cpu, string* device);
} // end namespace internal
// Optimize TensorFlow ops that should be swapped into the CPU to avoid
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
index 7c64529441..9bb030b220 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
@@ -28,30 +28,60 @@ namespace {
class PinToHostOptimizerTest : public GrapplerTest {};
-TEST_F(PinToHostOptimizerTest, TryFindHostDevice) {
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceNoDevices) {
gtl::FlatSet<string> devices = {};
- EXPECT_EQ("ABC", internal::TryFindHostDevice(devices, false, "ABC"));
-
- devices = {"/device:CPU:0", "/device:XLA_GPU:0"};
- EXPECT_EQ(internal::TryFindHostDevice(devices, true, ""), "/device:CPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:0"),
- "/device:CPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:*"),
- "/device:CPU:0");
-
- devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"};
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
- "/device:XLA_CPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
- "/device:XLA_CPU:0");
-
- devices = {"/device:XLA_GPU:0"};
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
- "/device:XLA_GPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
- "/device:XLA_GPU:*");
+
+ string device = "ABC";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "ABC");
+}
+
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceCpuXlaGpu) {
+ gtl::FlatSet<string> devices = {"/device:CPU:0", "/device:XLA_GPU:0"};
+
+ string device = "";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device));
+ EXPECT_EQ(device, "/device:CPU:0");
+
+ device = "/device:XLA_GPU:0";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device));
+ EXPECT_EQ(device, "/device:CPU:0");
+
+ device = "/device:XLA_GPU:*";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device));
+ EXPECT_EQ(device, "/device:CPU:0");
+}
+
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaCpuXlaGpu) {
+ gtl::FlatSet<string> devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"};
+
+ string device = "";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_TRUE(device.empty());
+
+ device = "/device:XLA_GPU:0";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_CPU:0");
+
+ device = "/device:XLA_GPU:*";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_CPU:0");
+}
+
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaGpu) {
+ gtl::FlatSet<string> devices = {"/device:XLA_GPU:0"};
+
+ string device = "";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_TRUE(device.empty());
+
+ device = "/device:XLA_GPU:0";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_GPU:0");
+
+ device = "/device:XLA_GPU:*";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_GPU:*");
}
TEST_F(PinToHostOptimizerTest, OptimizeSmallOpsToHost) {
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 9439ab332c..3a920f26f3 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -4458,7 +4458,12 @@ cc_library(
name = "string_util",
srcs = ["string_util.cc"],
hdrs = ["string_util.h"],
- deps = ["//tensorflow/core:lib"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "@icu//:common",
+ ],
)
STRING_DEPS = [
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index 451f8c1a6c..37c1c54786 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -45,6 +45,16 @@ cc_library(
],
)
+tf_cc_test(
+ name = "dataset_utils_test",
+ srcs = ["dataset_utils_test.cc"],
+ deps = [
+ ":dataset_utils",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
cc_library(
name = "captured_function",
srcs = ["captured_function.cc"],
@@ -205,6 +215,7 @@ tf_kernel_library(
deps = [
":captured_function",
":dataset",
+ ":dataset_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
@@ -232,6 +243,7 @@ tf_kernel_library(
deps = [
":captured_function",
":dataset",
+ ":dataset_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
@@ -245,6 +257,7 @@ tf_kernel_library(
deps = [
":captured_function",
":dataset",
+ ":dataset_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
@@ -285,6 +298,7 @@ tf_kernel_library(
deps = [
":captured_function",
":dataset",
+ ":dataset_utils",
":parallel_map_iterator",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
index a04f150e71..9607e9444c 100644
--- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc
+++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
@@ -171,16 +171,16 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
static PartialTensorShape MostSpecificCompatibleShape(
const PartialTensorShape& ts1, const PartialTensorShape& ts2) {
- PartialTensorShape output_tensorshape;
if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank())
- return output_tensorshape;
+ return PartialTensorShape();
+ PartialTensorShape output_tensorshape({});
auto dims1 = ts1.dim_sizes();
auto dims2 = ts2.dim_sizes();
for (int d = 0; d < ts1.dims(); d++) {
if (dims1[d] == dims2[d])
- output_tensorshape.Concatenate(dims1[d]);
+ output_tensorshape.AddDim(dims1[d]);
else
- output_tensorshape.Concatenate(-1);
+ output_tensorshape.AddDim(-1);
}
return output_tensorshape;
}
diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc
index e10833f525..a40f7f2146 100644
--- a/tensorflow/core/kernels/data/dataset_utils.cc
+++ b/tensorflow/core/kernels/data/dataset_utils.cc
@@ -15,10 +15,57 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
namespace tensorflow {
namespace data {
+Status ComputeShortCircuitIndices(OpKernelContext* ctx,
+ const NameAttrList& func,
+ std::vector<int>* indices) {
+ FunctionLibraryRuntime::Handle fn_handle;
+ TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
+ func.name(), AttrSlice(&func.attr()), &fn_handle));
+ auto cleanup = gtl::MakeCleanup([ctx, fn_handle]() {
+ Status s = ctx->function_library()->ReleaseHandle(fn_handle);
+ if (!s.ok()) {
+ LOG(WARNING) << "Failed to release handle: " << s.error_message();
+ }
+ });
+
+ const FunctionBody* fn_body =
+ ctx->function_library()->GetFunctionBody(fn_handle);
+ indices->resize(fn_body->ret_nodes.size());
+ for (size_t i = 0; i < fn_body->ret_nodes.size(); ++i) {
+ Node* ret_node = fn_body->ret_nodes[i];
+ Node* ret_input_node;
+ TF_RETURN_IF_ERROR(ret_node->input_node(0, &ret_input_node));
+ if (ret_input_node->def().op() == FunctionLibraryDefinition::kArgOp) {
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(ret_input_node->def(), "index", &((*indices)[i])));
+ } else {
+ indices->clear();
+ break;
+ }
+ }
+ return Status::OK();
+}
+
+std::vector<bool> ComputeMoveVector(const std::vector<int>& indices) {
+ std::map<int, int> last_use;
+ for (size_t i = 0; i < indices.size(); ++i) {
+ last_use[indices[i]] = i;
+ }
+ std::vector<bool> can_move;
+ can_move.resize(indices.size());
+ for (size_t i = 0; i < indices.size(); ++i) {
+ can_move[i] = last_use[indices[i]] == i;
+ }
+ return can_move;
+}
+
Status MakeIteratorFromInputElement(
IteratorContext* ctx, const std::vector<Tensor>& input_element,
int64 thread_index, CapturedFunction* captured_func, StringPiece prefix,
diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h
index 6ec1350cd4..d777062293 100644
--- a/tensorflow/core/kernels/data/dataset_utils.h
+++ b/tensorflow/core/kernels/data/dataset_utils.h
@@ -22,6 +22,26 @@ limitations under the License.
namespace tensorflow {
namespace data {
+// This method is used to determine whether we can short-circuit the evaluation
+// of the user-defined function `func`. Short-circuting is possible if every
+// function output corresponds to one of its inputs (e.g. `f(x) = x`, `f(x,y) =
+// (y,x)`, or `f(x) = (x,x)`).
+//
+// If short-circuiting is possible, the method stores the mapping from output
+// indices to input indices in `indices`. Otherwise, `indices` will be empty.
+//
+// Returns non-ok status if analysis of the function fails.
+//
+// TODO(jsimsa): Extend this to support constants as well.
+Status ComputeShortCircuitIndices(OpKernelContext* ctx,
+ const NameAttrList& func,
+ std::vector<int>* indices);
+
+// Given a vector that maps output indices to input indices, return a vector
+// that identifies for which output indices can we move the input (assuming
+// output indices are processed left to right).
+std::vector<bool> ComputeMoveVector(const std::vector<int>& indices);
+
Status MakeIteratorFromInputElement(
IteratorContext* ctx, const std::vector<Tensor>& input_element,
int64 thread_index, CapturedFunction* captured_func, StringPiece prefix,
diff --git a/tensorflow/core/kernels/data/dataset_utils_test.cc b/tensorflow/core/kernels/data/dataset_utils_test.cc
new file mode 100644
index 0000000000..43295b8ebb
--- /dev/null
+++ b/tensorflow/core/kernels/data/dataset_utils_test.cc
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/data/dataset_utils.h"
+
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+TEST(DatasetUtils, ComputeMoveVector) {
+ struct TestCase {
+ std::vector<int> indices;
+ std::vector<bool> expected;
+ };
+
+ TestCase test_cases[] = {
+ TestCase{{}, {}},
+ TestCase{{1}, {true}},
+ TestCase{{1, 1}, {false, true}},
+ TestCase{{1, 2}, {true, true}},
+ TestCase{{1, 1, 2}, {false, true, true}},
+ TestCase{{1, 2, 2}, {true, false, true}},
+ };
+
+ for (auto& test_case : test_cases) {
+ EXPECT_EQ(test_case.expected, ComputeMoveVector(test_case.indices));
+ }
+}
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc
index 00884314a9..be7d182a1f 100644
--- a/tensorflow/core/kernels/data/filter_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_dataset_op.cc
@@ -18,9 +18,11 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -31,67 +33,84 @@ namespace {
class FilterDatasetOp : public UnaryDatasetOpKernel {
public:
+ using FilterIteratorPredicate =
+ std::function<Status(IteratorContext*, std::vector<Tensor>, bool*)>;
+
explicit FilterDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("predicate", &func_));
}
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- FunctionLibraryRuntime::Handle pred_handle;
- OP_REQUIRES_OK(ctx,
- ctx->function_library()->Instantiate(
- func_.name(), AttrSlice(&func_.attr()), &pred_handle));
- auto cleanup = gtl::MakeCleanup([ctx, pred_handle]() {
- OP_REQUIRES_OK(ctx, ctx->function_library()->ReleaseHandle(pred_handle));
- });
-
- const FunctionBody* pred_body =
- ctx->function_library()->GetFunctionBody(pred_handle);
- OP_REQUIRES(ctx, pred_body->ret_nodes.size() == 1,
- errors::InvalidArgument(
- "predicate function must have a single return value."));
- Node* ret_node = pred_body->ret_nodes[0];
- Node* ret_input_node;
- OP_REQUIRES_OK(ctx, ret_node->input_node(0, &ret_input_node));
-
std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
&captured_func));
- if (ret_input_node->def().op() == "_Arg") {
- int32 index = -1;
- OP_REQUIRES_OK(ctx, GetNodeAttr(ret_input_node->def(), "index", &index));
- *output = new FilterTensorDataset(ctx, input, func_,
- std::move(captured_func), index);
+ std::vector<int> indices;
+ OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices));
+ OP_REQUIRES(ctx, indices.size() <= 1,
+ errors::InvalidArgument(
+ "predicate function has more than one return value."));
+
+ FilterIteratorPredicate filter_pred;
+ if (indices.empty()) {
+ CapturedFunction* raw_captured_func = captured_func.get();
+ filter_pred = [raw_captured_func](IteratorContext* ctx,
+ const std::vector<Tensor>& args,
+ bool* out_matched) {
+ std::vector<Tensor> result;
+ TF_RETURN_IF_ERROR(
+ raw_captured_func->RunWithBorrowedArgs(ctx, args, &result));
+
+ if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
+ result[0].NumElements() != 1) {
+ return errors::InvalidArgument(
+ "Filter predicate `f` must return a scalar bool.");
+ }
+ *out_matched = result[0].scalar<bool>()();
+ return Status::OK();
+ };
} else {
- *output = new FilterFunctionDataset(ctx, input, func_,
- std::move(captured_func));
+ filter_pred = [indices](IteratorContext* ctx,
+ const std::vector<Tensor>& args,
+ bool* out_matched) {
+ const Tensor& predicate = args[indices[0]];
+ if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) {
+ return errors::InvalidArgument(
+ "Filter predicate `f` must return a scalar bool.");
+ }
+ *out_matched = predicate.scalar<bool>()();
+ return Status::OK();
+ };
}
+
+ *output = new Dataset(ctx, input, func_, std::move(captured_func),
+ std::move(filter_pred));
}
private:
- const int graph_def_version_;
-
- class FilterDatasetBase : public DatasetBase {
+ class Dataset : public DatasetBase {
public:
- FilterDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
- const NameAttrList& func,
- std::unique_ptr<CapturedFunction> captured_func)
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const NameAttrList& func,
+ std::unique_ptr<CapturedFunction> captured_func,
+ FilterIteratorPredicate filter_pred)
: DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
- captured_func_(std::move(captured_func)) {
+ captured_func_(std::move(captured_func)),
+ filter_pred_(std::move(filter_pred)) {
input_->Ref();
}
- ~FilterDatasetBase() override { input_->Unref(); }
+ ~Dataset() override { input_->Unref(); }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(
- new Iterator({this, strings::StrCat(prefix, "::Filter")}));
+ return MakeUnique<Iterator>(
+ Iterator::Params{this, strings::StrCat(prefix, "::Filter")},
+ filter_pred_);
}
const DataTypeVector& output_dtypes() const override {
@@ -133,17 +152,15 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
- virtual Status EvaluatePredicate(IteratorContext* ctx,
- const std::vector<Tensor>& element,
- bool* out_matched) const = 0;
-
private:
- class Iterator : public DatasetIterator<FilterDatasetBase> {
+ class Iterator : public DatasetIterator<Dataset> {
public:
- explicit Iterator(const Params& params)
- : DatasetIterator<FilterDatasetBase>(params),
+ explicit Iterator(const Params& params,
+ FilterIteratorPredicate filter_pred)
+ : DatasetIterator<Dataset>(params),
filtered_elements_(0),
- dropped_elements_(0) {
+ dropped_elements_(0),
+ filter_pred_(std::move(filter_pred)) {
std::vector<string> components =
str_util::Split(params.prefix, "::", str_util::SkipEmpty());
prefix_end_ = components.back();
@@ -180,8 +197,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
- TF_RETURN_IF_ERROR(
- dataset()->EvaluatePredicate(ctx, *out_tensors, &matched));
+ TF_RETURN_IF_ERROR(filter_pred_(ctx, *out_tensors, &matched));
if (!matched) {
// Clear the output tensor list since it didn't match.
out_tensors->clear();
@@ -251,64 +267,14 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
int64 filtered_elements_ GUARDED_BY(mu_);
int64 dropped_elements_ GUARDED_BY(mu_);
+ const FilterIteratorPredicate filter_pred_;
string prefix_end_;
};
const DatasetBase* const input_;
const NameAttrList func_;
-
- protected:
const std::unique_ptr<CapturedFunction> captured_func_;
- };
-
- class FilterFunctionDataset : public FilterDatasetBase {
- public:
- using FilterDatasetBase::FilterDatasetBase;
-
- protected:
- Status EvaluatePredicate(IteratorContext* ctx,
- const std::vector<Tensor>& element,
- bool* out_matched) const override {
- // TODO(mrry): Avoid blocking a threadpool thread. We will need to
- // stack-rip the iterators and use async kernels.
- std::vector<Tensor> result;
- TF_RETURN_IF_ERROR(
- captured_func_->RunWithBorrowedArgs(ctx, element, &result));
-
- if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
- result[0].NumElements() != 1) {
- return errors::InvalidArgument(
- "Filter predicate `f` must return a scalar bool.");
- }
- *out_matched = result[0].scalar<bool>()();
- return Status::OK();
- }
- };
-
- class FilterTensorDataset : public FilterDatasetBase {
- public:
- FilterTensorDataset(OpKernelContext* ctx, const DatasetBase* input,
- const NameAttrList& func,
- std::unique_ptr<CapturedFunction> captured_func,
- int32 index)
- : FilterDatasetBase(ctx, input, func, std::move(captured_func)),
- index_(index) {}
-
- protected:
- Status EvaluatePredicate(IteratorContext* ctx,
- const std::vector<Tensor>& element,
- bool* out_matched) const override {
- const Tensor& predicate = element[index_];
- if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) {
- return errors::InvalidArgument(
- "Filter predicate `f` must return a scalar bool.");
- }
- *out_matched = predicate.scalar<bool>()();
- return Status::OK();
- }
-
- private:
- const int32 index_;
+ const FilterIteratorPredicate filter_pred_;
};
private:
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index bf08970560..f45a239793 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/inplace_ops_functor.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
@@ -29,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -41,6 +43,10 @@ namespace {
// transformation more robust.
class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
public:
+ using MapAndBatchIteratorFunction =
+ std::function<void(IteratorContext*, const string&, std::vector<Tensor>,
+ std::shared_ptr<std::vector<Tensor>>, StatusCallback)>;
+
explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx),
op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) {
@@ -91,31 +97,73 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
&captured_func));
- *output = new Dataset(ctx, input, batch_size, num_parallel_calls,
- drop_remainder, output_types_, output_shapes_, func_,
- std::move(captured_func), &ctx->eigen_cpu_device());
+ std::vector<int> indices;
+ OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices));
+
+ MapAndBatchIteratorFunction map_func;
+ CapturedFunction* raw_captured_func = captured_func.get();
+ if (indices.empty()) {
+ map_func = [raw_captured_func](
+ IteratorContext* ctx, const string& prefix,
+ std::vector<Tensor> args,
+ std::shared_ptr<std::vector<Tensor>> out_tensors,
+ StatusCallback done) {
+ raw_captured_func->RunAsync(ctx, std::move(args), out_tensors.get(),
+ std::move(done), prefix);
+ };
+ } else {
+ std::vector<bool> can_move = ComputeMoveVector(indices);
+ map_func = [raw_captured_func, indices, can_move](
+ IteratorContext* ctx, const string& prefix,
+ std::vector<Tensor> args,
+ std::shared_ptr<std::vector<Tensor>> out_tensors,
+ StatusCallback done) {
+ const std::vector<Tensor>& captured_inputs =
+ raw_captured_func->captured_inputs();
+ size_t num_args = args.size();
+ for (size_t i = 0; i < indices.size(); ++i) {
+ if (indices[i] < num_args) {
+ if (can_move[i]) {
+ out_tensors->push_back(std::move(args[indices[i]]));
+ } else {
+ out_tensors->push_back(args[indices[i]]);
+ }
+ } else {
+ out_tensors->push_back(captured_inputs[indices[i] - num_args]);
+ }
+ }
+ done(Status::OK());
+ };
+ }
+
+ *output = new Dataset(ctx, input, func_, batch_size, num_parallel_calls,
+ drop_remainder, output_types_, output_shapes_,
+ std::move(captured_func), &ctx->eigen_cpu_device(),
+ std::move(map_func));
}
private:
class Dataset : public DatasetBase {
public:
- Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 batch_size,
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const NameAttrList& func, int64 batch_size,
int64 num_parallel_calls, bool drop_remainder,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
- const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func,
- const Eigen::ThreadPoolDevice* device)
+ const Eigen::ThreadPoolDevice* device,
+ MapAndBatchIteratorFunction map_func)
: DatasetBase(DatasetContext(ctx)),
input_(input),
+ func_(func),
batch_size_(batch_size),
num_parallel_calls_(num_parallel_calls),
drop_remainder_(drop_remainder),
output_types_(output_types),
output_shapes_(output_shapes),
- map_fn_(func),
captured_func_(std::move(captured_func)),
- device_(device) {
+ device_(device),
+ map_func_(std::move(map_func)) {
input_->Ref();
}
@@ -123,8 +171,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(
- new Iterator({this, strings::StrCat(prefix, "::MapAndBatch")}));
+ return MakeUnique<Iterator>(
+ Iterator::Params{this, strings::StrCat(prefix, "::MapAndBatch")},
+ map_func_);
}
const DataTypeVector& output_dtypes() const override {
@@ -143,7 +192,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx, map_fn_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
Node* batch_size_node;
@@ -165,7 +214,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
other_arguments_types.emplace_back(t.dtype());
}
AttrValue f;
- b->BuildAttrValue(map_fn_, &f);
+ b->BuildAttrValue(func_, &f);
AttrValue other_arguments_types_attr;
b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
@@ -185,12 +234,14 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
private:
class Iterator : public DatasetIterator<Dataset> {
public:
- explicit Iterator(const Params& params)
+ explicit Iterator(const Params& params,
+ MapAndBatchIteratorFunction map_func)
: DatasetIterator<Dataset>(params),
mu_(std::make_shared<mutex>()),
cond_var_(std::make_shared<condition_variable>()),
num_parallel_calls_(std::make_shared<model::SharedState>(
- params.dataset->num_parallel_calls_, mu_, cond_var_)) {}
+ params.dataset->num_parallel_calls_, mu_, cond_var_)),
+ map_func_(std::move(map_func)) {}
~Iterator() override {
mutex_lock l(*mu_);
@@ -297,44 +348,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
int64 num_calls; // access guarded by owner's mutex
};
- void Callback(const std::shared_ptr<IteratorContext>& ctx,
- const std::shared_ptr<BatchResult>& result,
- const std::shared_ptr<std::vector<Tensor>>& return_values,
- int64 offset, const Status& status) LOCKS_EXCLUDED(*mu_) {
- result->UpdateStatus(status);
- if (status.ok()) {
- EnsureOutputAllocated(ctx, result, return_values);
- for (size_t i = 0; i < return_values->size(); ++i) {
- const Tensor& tensor = return_values->at(i);
- Tensor* batch = &(result->output)[i];
- if (tensor.NumElements() !=
- (batch->NumElements() / batch->dim_size(0))) {
- TensorShape batch_shape = batch->shape();
- batch_shape.RemoveDim(0);
- result->UpdateStatus(errors::InvalidArgument(
- "Cannot add tensor to the batch: number of elements does not "
- "match. Shapes are: [tensor]: ",
- tensor.shape().DebugString(),
- ", [batch]: ", batch_shape.DebugString()));
- break;
- }
- // TODO(mrry): Add a version of DoParallelConcat that allows us to
- // move `tensor` where possible, to speed up string tensor batching.
- Status copy_status = ::tensorflow::functor::DoParallelConcat(
- *dataset()->device_, tensor, offset, batch);
- if (!copy_status.ok()) {
- result->UpdateStatus(copy_status);
- break;
- }
- }
- {
- mutex_lock l(result->mu);
- result->num_elements++;
- }
- }
- CallCompleted(result);
- }
-
void CallCompleted(const std::shared_ptr<BatchResult>& result)
LOCKS_EXCLUDED(*mu_) {
mutex_lock l(*mu_);
@@ -363,21 +376,48 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
return;
}
- // Call `captured_func_(input_element)`, using `Callback` to store the
- // result in `result`.
- (*ctx->runner())(std::bind(
- [this, result, offset](std::shared_ptr<IteratorContext> ctx,
- std::vector<Tensor> input_element) {
- std::shared_ptr<std::vector<Tensor>> return_values(
- new std::vector<Tensor>());
- dataset()->captured_func_->RunAsync(
- ctx.get(), std::move(input_element), return_values.get(),
- [this, ctx, result, return_values, offset](Status status) {
- Callback(ctx, result, return_values, offset, status);
- },
- prefix());
- },
- ctx, std::move(input_element)));
+ std::shared_ptr<std::vector<Tensor>> return_values =
+ std::make_shared<std::vector<Tensor>>();
+ auto done = [this, ctx, result, return_values, offset](Status status) {
+ result->UpdateStatus(status);
+ if (status.ok()) {
+ EnsureOutputAllocated(ctx, result, return_values);
+ for (size_t i = 0; i < return_values->size(); ++i) {
+ const Tensor& tensor = return_values->at(i);
+ Tensor* batch = &(result->output)[i];
+ if (tensor.NumElements() !=
+ (batch->NumElements() / batch->dim_size(0))) {
+ TensorShape batch_shape = batch->shape();
+ batch_shape.RemoveDim(0);
+ result->UpdateStatus(errors::InvalidArgument(
+ "Cannot add tensor to the batch: number of elements does "
+ "not match. Shapes are: [tensor]: ",
+ tensor.shape().DebugString(),
+ ", [batch]: ", batch_shape.DebugString()));
+ break;
+ }
+ // TODO(mrry): Add a version of DoParallelConcat that allows us to
+ // move `tensor` where possible, to speed up string tensor
+ // batching.
+ Status copy_status = ::tensorflow::functor::DoParallelConcat(
+ *dataset()->device_, tensor, offset, batch);
+ if (!copy_status.ok()) {
+ result->UpdateStatus(copy_status);
+ break;
+ }
+ }
+ {
+ mutex_lock l(result->mu);
+ result->num_elements++;
+ }
+ }
+ CallCompleted(result);
+ };
+
+ // Apply the map function on `input_element`, storing the result in
+ // `return_values`, and invoking `done` when finished.
+ map_func_(ctx.get(), prefix(), std::move(input_element),
+ std::move(return_values), std::move(done));
}
Status CopyPartialBatch(Tensor* output, const Tensor& value,
@@ -404,7 +444,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
void EnsureRunnerThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
- std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
+ auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
runner_thread_.reset(ctx->env()->StartThread(
{}, "runner_thread",
std::bind(&Iterator::RunnerThread, this, ctx_copy)));
@@ -509,8 +549,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
while (!busy()) {
if (call_counter_ % dataset()->batch_size_ == 0) {
- batch_results_.emplace_back(
- new BatchResult(dataset()->batch_size_));
+ batch_results_.push_back(
+ std::make_shared<BatchResult>(dataset()->batch_size_));
}
int64 offset = call_counter_++ % dataset()->batch_size_;
new_calls.emplace_back(batch_results_.back(), offset);
@@ -527,7 +567,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader,
size_t index) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
- batch_results_.emplace_back(new BatchResult(dataset()->batch_size_));
+ batch_results_.push_back(
+ std::make_shared<BatchResult>(dataset()->batch_size_));
std::shared_ptr<BatchResult> result = batch_results_.back();
string prefix = strings::StrCat("batch_results_", index);
mutex_lock l(result->mu);
@@ -653,6 +694,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
const std::shared_ptr<condition_variable> cond_var_;
// Identifies the maximum number of parallel calls.
const std::shared_ptr<model::SharedState> num_parallel_calls_;
+ const MapAndBatchIteratorFunction map_func_;
+
// Counts the number of outstanding calls for this batch.
int64 num_calls_ GUARDED_BY(*mu_) = 0;
// Counts the total number of calls.
@@ -671,9 +714,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
const bool drop_remainder_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
- const NameAttrList map_fn_;
const std::unique_ptr<CapturedFunction> captured_func_;
const Eigen::ThreadPoolDevice* device_; // not owned
+ const MapAndBatchIteratorFunction map_func_;
};
const int op_version_;
diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc
index f112e1dc43..6b6ffabf4f 100644
--- a/tensorflow/core/kernels/data/map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_dataset_op.cc
@@ -17,7 +17,9 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -28,6 +30,9 @@ namespace {
class MapDatasetOp : public UnaryDatasetOpKernel {
public:
+ using MapIteratorFunction = std::function<Status(
+ IteratorContext*, std::vector<Tensor>, std::vector<Tensor>*)>;
+
explicit MapDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
@@ -43,8 +48,42 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
use_inter_op_parallelism_,
&captured_func));
+ std::vector<int> indices;
+ OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices));
+
+ MapIteratorFunction map_func;
+ CapturedFunction* raw_captured_func = captured_func.get();
+ if (indices.empty()) {
+ map_func = [raw_captured_func](IteratorContext* ctx,
+ std::vector<Tensor> args,
+ std::vector<Tensor>* out_tensors) {
+ return raw_captured_func->Run(ctx, std::move(args), out_tensors);
+ };
+ } else {
+ std::vector<bool> can_move = ComputeMoveVector(indices);
+ map_func = [raw_captured_func, indices, can_move](
+ IteratorContext* ctx, std::vector<Tensor> args,
+ std::vector<Tensor>* out_tensors) {
+ const std::vector<Tensor>& captured_inputs =
+ raw_captured_func->captured_inputs();
+ size_t num_args = args.size();
+ for (size_t i = 0; i < indices.size(); ++i) {
+ if (indices[i] < num_args) {
+ if (can_move[i]) {
+ out_tensors->push_back(std::move(args[indices[i]]));
+ } else {
+ out_tensors->push_back(args[indices[i]]);
+ }
+ } else {
+ out_tensors->push_back(captured_inputs[indices[i] - num_args]);
+ }
+ }
+ return Status::OK();
+ };
+ }
+
*output = new Dataset(ctx, input, func_, std::move(captured_func),
- output_types_, output_shapes_);
+ output_types_, output_shapes_, std::move(map_func));
}
private:
@@ -54,13 +93,15 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func,
const DataTypeVector& output_types,
- const std::vector<PartialTensorShape>& output_shapes)
+ const std::vector<PartialTensorShape>& output_shapes,
+ MapIteratorFunction map_func)
: DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
captured_func_(std::move(captured_func)),
output_types_(output_types),
- output_shapes_(output_shapes) {
+ output_shapes_(output_shapes),
+ map_func_(std::move(map_func)) {
input_->Ref();
}
@@ -68,8 +109,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(
- new Iterator({this, strings::StrCat(prefix, "::Map")}));
+ return MakeUnique<Iterator>(
+ Iterator::Params{this, strings::StrCat(prefix, "::Map")}, map_func_);
}
const DataTypeVector& output_dtypes() const override {
@@ -116,8 +157,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
private:
class Iterator : public DatasetIterator<Dataset> {
public:
- explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {}
+ explicit Iterator(const Params& params, MapIteratorFunction map_func)
+ : DatasetIterator<Dataset>(params), map_func_(std::move(map_func)) {}
Status Initialize(IteratorContext* ctx) override {
TF_RETURN_IF_ERROR(
@@ -139,10 +180,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
- // TODO(mrry): Avoid blocking a threadpool thread. We will need to
- // stack-rip the iterators and use async kernels.
- Status s =
- dataset()->captured_func_->Run(ctx, std::move(args), out_tensors);
+ Status s = map_func_(ctx, args, out_tensors);
if (errors::IsOutOfRange(s)) {
// `f` may deliberately raise `errors::OutOfRange` to indicate
// that we should terminate the iteration early.
@@ -167,6 +205,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
private:
std::unique_ptr<IteratorBase> input_impl_;
+ const MapIteratorFunction map_func_;
};
const DatasetBase* const input_;
@@ -174,6 +213,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
const std::unique_ptr<CapturedFunction> captured_func_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
+ const MapIteratorFunction map_func_;
};
DataTypeVector output_types_;
diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc
index 6657f2b2b3..705b0393de 100644
--- a/tensorflow/core/kernels/data/map_defun_op.cc
+++ b/tensorflow/core/kernels/data/map_defun_op.cc
@@ -62,24 +62,6 @@ class MapDefunOp : public AsyncOpKernel {
~MapDefunOp() override {}
- Status GetInputBatchSize(OpKernelContext* ctx, int64* batch_size) {
- // Validates inputs and gets the size of their leading dimension.
- *batch_size = ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1;
- for (size_t i = 0; i < ctx->num_inputs(); ++i) {
- if (ctx->input(i).dims() == 0) {
- return errors::InvalidArgument(
- "All inputs must have rank at least 1. Input ", i,
- " has a rank of 0.");
- } else if (ctx->input(i).dim_size(0) != *batch_size) {
- return errors::InvalidArgument(
- "All inputs must have the same dimension 0. Input ", i,
- " has leading dimension ", ctx->input(i).dim_size(0),
- ", while all previous inputs have leading dimension ", batch_size);
- }
- }
- return Status::OK();
- }
-
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
ComputeOptions* compute_opts = nullptr;
@@ -150,8 +132,9 @@ class MapDefunOp : public AsyncOpKernel {
// all calls to the function are complete. This struct also encapsulates
// all the components that need to be passed to each MapFunctionCallFrame.
- const std::vector<Tensor> args;
+ OpInputList args;
const std::vector<TensorShape> arg_shapes;
+ OpInputList captured_inputs;
const int64 batch_size;
// Output of a compute call
@@ -161,26 +144,31 @@ class MapDefunOp : public AsyncOpKernel {
// Create a copy of output_shapes because every `Compute` may expect a
// different output shape.
- ComputeOptions(std::vector<Tensor> args,
+ ComputeOptions(OpInputList args, OpInputList captured_inputs,
std::vector<TensorShape> arg_shapes, int64 batch_size,
const std::vector<PartialTensorShape>& output_shapes_attr)
- : args(std::move(args)),
+ : args(args),
arg_shapes(std::move(arg_shapes)),
+ captured_inputs(captured_inputs),
batch_size(batch_size),
output_shapes(output_shapes_attr) {}
};
// Get inputs to Compute and check that they are valid.
Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts) {
- int64 batch_size =
- ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1;
+ OpInputList arguments;
+ TF_RETURN_IF_ERROR(ctx->input_list("arguments", &arguments));
+ OpInputList captured_inputs;
+ TF_RETURN_IF_ERROR(ctx->input_list("captured_inputs", &captured_inputs));
+
+ int64 batch_size = arguments[0].dims() > 0 ? arguments[0].dim_size(0) : -1;
- for (size_t i = 0; i < ctx->num_inputs(); ++i) {
- if (ctx->input(i).dims() == 0) {
+ for (size_t i = 0; i < arguments.size(); ++i) {
+ if (arguments[i].dims() == 0) {
return errors::InvalidArgument(
"All inputs must have rank at least 1. Input ", i,
" has a rank of 0.");
- } else if (ctx->input(i).dim_size(0) != batch_size) {
+ } else if (arguments[i].dim_size(0) != batch_size) {
return errors::InvalidArgument(
"All inputs must have the same dimension 0. Input ", i,
" has leading dimension ", ctx->input(i).dim_size(0),
@@ -188,19 +176,17 @@ class MapDefunOp : public AsyncOpKernel {
}
}
- std::vector<Tensor> args;
std::vector<TensorShape> arg_shapes;
- args.reserve(ctx->num_inputs());
- arg_shapes.reserve(ctx->num_inputs());
+ arg_shapes.reserve(arguments.size());
- for (size_t i = 0; i < ctx->num_inputs(); ++i) {
- args.push_back(ctx->input(i));
- arg_shapes.push_back(ctx->input(i).shape());
+ for (size_t i = 0; i < arguments.size(); ++i) {
+ arg_shapes.push_back(arguments[i].shape());
arg_shapes.at(i).RemoveDim(0);
}
- *compute_opts = new ComputeOptions(std::move(args), std::move(arg_shapes),
- batch_size, output_shapes_);
+ *compute_opts =
+ new ComputeOptions(arguments, captured_inputs, std::move(arg_shapes),
+ batch_size, output_shapes_);
return Status::OK();
}
@@ -235,12 +221,21 @@ class MapDefunOp : public AsyncOpKernel {
}
Status GetArg(int index, Tensor* val) const override {
- if (index < 0 || index >= compute_opts_->args.size()) {
+ if (index < 0 || index >= compute_opts_->args.size() +
+ compute_opts_->captured_inputs.size()) {
return errors::InvalidArgument(
"Mismatch in number of function inputs.");
}
+
+ if (index >= compute_opts_->args.size()) {
+ // The function is calling for a captured input
+ *val =
+ compute_opts_->captured_inputs[index - compute_opts_->args.size()];
+ return Status::OK();
+ }
+
bool result =
- val->CopyFrom(compute_opts_->args.at(index).Slice(iter_, iter_ + 1),
+ val->CopyFrom(compute_opts_->args[index].Slice(iter_, iter_ + 1),
compute_opts_->arg_shapes.at(index));
if (!result) {
return errors::Internal("GetArg failed.");
@@ -248,7 +243,6 @@ class MapDefunOp : public AsyncOpKernel {
// Ensure alignment
*val = tensor::DeepCopy(*val);
}
-
return Status::OK();
}
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index 6abe6c8338..3a14924fba 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/parallel_map_iterator.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/random/random.h"
@@ -56,9 +57,55 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
use_inter_op_parallelism_,
&captured_func));
+ std::vector<int> indices;
+ OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices));
+
+ ParallelMapIteratorFunction map_func;
+ CapturedFunction* raw_captured_func = captured_func.get();
+ if (indices.empty()) {
+ map_func = [raw_captured_func](IteratorContext* ctx, const string& prefix,
+ std::vector<Tensor> args,
+ std::vector<Tensor>* out_tensors,
+ StatusCallback done) {
+ raw_captured_func->RunAsync(ctx, std::move(args), out_tensors,
+ std::move(done), prefix);
+ };
+ if (!use_inter_op_parallelism_) {
+ map_func = [map_func](IteratorContext* ctx, const string& prefix,
+ std::vector<Tensor> args,
+ std::vector<Tensor>* out_tensors,
+ StatusCallback done) {
+ (*ctx->runner())(std::bind(map_func, ctx, prefix, std::move(args),
+ out_tensors, std::move(done)));
+ };
+ }
+ } else {
+ std::vector<bool> can_move = ComputeMoveVector(indices);
+ map_func = [raw_captured_func, indices, can_move](
+ IteratorContext* ctx, const string& prefix,
+ std::vector<Tensor> args, std::vector<Tensor>* out_tensors,
+ StatusCallback done) {
+ const std::vector<Tensor>& captured_inputs =
+ raw_captured_func->captured_inputs();
+ size_t num_args = args.size();
+ for (size_t i = 0; i < indices.size(); ++i) {
+ if (indices[i] < num_args) {
+ if (can_move[i]) {
+ out_tensors->push_back(std::move(args[indices[i]]));
+ } else {
+ out_tensors->push_back(args[indices[i]]);
+ }
+ } else {
+ out_tensors->push_back(captured_inputs[indices[i] - num_args]);
+ }
+ }
+ done(Status::OK());
+ };
+ }
+
*output = new Dataset(ctx, input, func_, num_parallel_calls, output_types_,
output_shapes_, use_inter_op_parallelism_,
- std::move(captured_func));
+ std::move(captured_func), std::move(map_func));
}
private:
@@ -69,7 +116,8 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
bool use_inter_op_parallelism,
- std::unique_ptr<CapturedFunction> captured_func)
+ std::unique_ptr<CapturedFunction> captured_func,
+ ParallelMapIteratorFunction map_func)
: DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
@@ -77,7 +125,8 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
output_types_(output_types),
output_shapes_(output_shapes),
use_inter_op_parallelism_(use_inter_op_parallelism),
- captured_func_(std::move(captured_func)) {
+ captured_func_(std::move(captured_func)),
+ map_func_(std::move(map_func)) {
input_->Ref();
}
@@ -89,26 +138,9 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
return captured_func_->Instantiate(ctx);
};
- const string& new_prefix = strings::StrCat(prefix, "::ParallelMap");
- ParallelMapIteratorFunction map_func =
- [this, new_prefix](IteratorContext* ctx,
- std::vector<Tensor> input_element,
- std::vector<Tensor>* result, StatusCallback done) {
- captured_func_->RunAsync(ctx, std::move(input_element), result,
- std::move(done), new_prefix);
- };
- if (!use_inter_op_parallelism_) {
- map_func = [map_func](
- IteratorContext* ctx, std::vector<Tensor> input_element,
- std::vector<Tensor>* result, StatusCallback done) {
- (*ctx->runner())(std::bind(map_func, ctx, std::move(input_element),
- result, std::move(done)));
- };
- }
-
- return NewParallelMapIterator({this, new_prefix}, input_,
- std::move(init_func), std::move(map_func),
- num_parallel_calls_);
+ return NewParallelMapIterator(
+ {this, strings::StrCat(prefix, "::ParallelMap")}, input_,
+ std::move(init_func), map_func_, num_parallel_calls_);
}
const DataTypeVector& output_dtypes() const override {
@@ -176,6 +208,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
const bool use_inter_op_parallelism_;
const std::unique_ptr<CapturedFunction> captured_func_;
+ const ParallelMapIteratorFunction map_func_;
};
DataTypeVector output_types_;
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
index 13bd4b6036..ebf41925c9 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.cc
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -179,7 +180,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
void EnsureRunnerThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
- std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
+ auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
runner_thread_.reset(ctx->env()->StartThread(
{}, "runner_thread",
std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy)));
@@ -208,15 +209,15 @@ class ParallelMapIterator : public DatasetBaseIterator {
return;
}
- // Call `func_(input_element)`, store the result in `result->return_values`,
- // and notify `result->notification` to unblock a consumer.
auto done = [this, result](Status status) {
result->status.Update(status);
CallCompleted(result);
};
- map_func_(ctx.get(), std::move(input_element), &result->return_values,
- std::move(done));
+ // Apply the map function on `input_element`, storing the result in
+ // `result->return_values`, and invoking `done` when finished.
+ map_func_(ctx.get(), prefix(), std::move(input_element),
+ &result->return_values, std::move(done));
}
Status ProcessResult(const std::shared_ptr<InvocationResult>& result,
@@ -349,9 +350,9 @@ std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBase* input_dataset,
std::function<Status(IteratorContext*)> init_func,
ParallelMapIteratorFunction map_func, int32 num_parallel_calls) {
- return std::unique_ptr<IteratorBase>(
- new ParallelMapIterator(params, input_dataset, std::move(init_func),
- std::move(map_func), num_parallel_calls));
+ return MakeUnique<ParallelMapIterator>(
+ params, input_dataset, std::move(init_func), std::move(map_func),
+ num_parallel_calls);
}
} // namespace data
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.h b/tensorflow/core/kernels/data/parallel_map_iterator.h
index dc26c5cf25..813f13c9e4 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.h
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.h
@@ -30,7 +30,7 @@ namespace data {
// 3. A `std::vector<Tensor>*` to which the function will write the result.
// 4. A `StatusCallback` that should be invoked when the function is complete.
using ParallelMapIteratorFunction =
- std::function<void(IteratorContext*, std::vector<Tensor>,
+ std::function<void(IteratorContext*, const string&, std::vector<Tensor>,
std::vector<Tensor>*, StatusCallback)>;
// Returns a new iterator that applies `map_func` to the elements of
diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
index 1d1a717062..7de5ea8860 100644
--- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
@@ -182,7 +182,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- auto map_fn = [this](IteratorContext* ctx,
+ auto map_fn = [this](IteratorContext* ctx, const string& prefix,
std::vector<Tensor> input_element,
std::vector<Tensor>* result, StatusCallback done) {
(*ctx->runner())([this, ctx, input_element, result, done]() {
diff --git a/tensorflow/core/kernels/dequantize_op.cc b/tensorflow/core/kernels/dequantize_op.cc
index 42fbf95cd3..28940e0849 100644
--- a/tensorflow/core/kernels/dequantize_op.cc
+++ b/tensorflow/core/kernels/dequantize_op.cc
@@ -96,8 +96,6 @@ class DequantizeOp : public OpKernel {
output);
}
} else if (mode_ == QUANTIZE_MODE_SCALED) {
- // TODO(pauldonnelly): Update QuantizeAndDequantizeV2 and
- // QuantizeAndDequantizeV3 to match this SCALED mode again.
const float scale_factor =
std::numeric_limits<T>::min() == 0
? (max_range / std::numeric_limits<T>::max())
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc
index fdb4c84c46..3979e4b53a 100644
--- a/tensorflow/core/kernels/partitioned_function_ops.cc
+++ b/tensorflow/core/kernels/partitioned_function_ops.cc
@@ -97,6 +97,13 @@ class PartitionedCallOp : public AsyncOpKernel {
OP_REQUIRES_ASYNC(ctx, fbody != nullptr,
errors::Internal("Could not find handle ", handle),
done);
+ OP_REQUIRES_ASYNC(
+ ctx, args.size() == fbody->arg_nodes.size(),
+ errors::InvalidArgument(
+ "Wrong number of arguments to the op; function expects ",
+ fbody->arg_nodes.size(), " but PartitionedCall received ",
+ args.size()),
+ done);
// We need to pass global op_registry as default_registry when creating
// graph. So that graph optimization passes can lookup all possible ops
// by name.
diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc
index 04a53697c0..3810d817ca 100644
--- a/tensorflow/core/kernels/random_op.cc
+++ b/tensorflow/core/kernels/random_op.cc
@@ -489,13 +489,15 @@ class RandomGammaOp : public OpKernel {
Name("RandomGamma").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
RandomGammaOp<TYPE>)
-#define REGISTER_INT(IntType) \
- REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
- .Device(DEVICE_CPU) \
- .HostMemory("shape") \
- .HostMemory("minval") \
- .HostMemory("maxval") \
- .TypeConstraint<IntType>("Tout"), \
+#define REGISTER_INT(IntType) \
+ template struct functor::FillPhiloxRandom< \
+ CPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \
+ REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
+ .Device(DEVICE_CPU) \
+ .HostMemory("shape") \
+ .HostMemory("minval") \
+ .HostMemory("maxval") \
+ .TypeConstraint<IntType>("Tout"), \
RandomUniformIntOp<CPUDevice, IntType>);
TF_CALL_half(REGISTER);
@@ -538,14 +540,16 @@ TF_CALL_int64(REGISTER_INT);
random::TruncatedNormalDistribution< \
random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>);
-#define REGISTER_INT(IntType) \
- REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
- .Device(DEVICE_GPU) \
- .HostMemory("shape") \
- .HostMemory("minval") \
- .HostMemory("maxval") \
- .TypeConstraint<int32>("T") \
- .TypeConstraint<IntType>("Tout"), \
+#define REGISTER_INT(IntType) \
+ template struct functor::FillPhiloxRandom< \
+ GPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \
+ REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("shape") \
+ .HostMemory("minval") \
+ .HostMemory("maxval") \
+ .TypeConstraint<int32>("T") \
+ .TypeConstraint<IntType>("Tout"), \
RandomUniformIntOp<GPUDevice, IntType>);
TF_CALL_half(REGISTER);
diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc
index 173fea37ed..e67695d54a 100644
--- a/tensorflow/core/kernels/relu_op.cc
+++ b/tensorflow/core/kernels/relu_op.cc
@@ -33,19 +33,25 @@ typedef Eigen::GpuDevice GPUDevice;
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
-#define REGISTER_RELU_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
- ReluOp<CPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("ReluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
- ReluGradOp<CPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
- Relu6Op<CPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6Grad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
- Relu6GradOp<CPUDevice, type>)
+#define REGISTER_RELU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ ReluOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ReluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ ReluGradOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ Relu6Op<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6Grad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ Relu6GradOp<CPUDevice, type>) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyRelu").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ LeakyReluOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyReluGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ LeakyReluGradOp<CPUDevice, type>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_RELU_KERNELS);
#undef REGISTER_RELU_KERNELS
@@ -99,6 +105,19 @@ namespace functor {
extern template struct Relu6Grad<GPUDevice, T>; \
\
template <> \
+ void LeakyRelu<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::ConstTensor features, T alpha, \
+ typename TTypes<T>::Tensor activations); \
+ extern template struct LeakyRelu<GPUDevice, T>; \
+ \
+ template <> \
+ void LeakyReluGrad<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T>::ConstTensor gradients, \
+ typename TTypes<T>::ConstTensor features, T alpha, \
+ typename TTypes<T>::Tensor backprops); \
+ extern template struct LeakyReluGrad<GPUDevice, T>; \
+ \
+ template <> \
void Elu<GPUDevice, T>::operator()(const GPUDevice& d, \
typename TTypes<T>::ConstTensor features, \
typename TTypes<T>::Tensor activations); \
@@ -134,30 +153,36 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
} // namespace functor
// Registration of the GPU implementations.
-#define REGISTER_GPU_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- ReluOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("ReluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- ReluGradOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- Relu6Op<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6Grad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- Relu6GradOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Elu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- EluOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("EluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- EluGradOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Selu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
- SeluOp<GPUDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("SeluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+#define REGISTER_GPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ ReluOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ReluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ ReluGradOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ Relu6Op<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6Grad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ Relu6GradOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyRelu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ LeakyReluOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyReluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ LeakyReluGradOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Elu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ EluOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("EluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ EluGradOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Selu").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ SeluOp<GPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SeluGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
SeluGradOp<GPUDevice, type>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
@@ -188,30 +213,36 @@ REGISTER_KERNEL_BUILDER(
#ifdef TENSORFLOW_USE_SYCL
// Registration of the GPU implementations.
-#define REGISTER_SYCL_KERNELS(type) \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- ReluOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("ReluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- ReluGradOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- Relu6Op<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Relu6Grad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- Relu6GradOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Elu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- EluOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("EluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- EluGradOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("Selu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
- SeluOp<SYCLDevice, type>); \
- REGISTER_KERNEL_BUILDER( \
- Name("SeluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+#define REGISTER_SYCL_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ ReluOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ReluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ ReluGradOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ Relu6Op<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Relu6Grad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ Relu6GradOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyRelu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ LeakyReluOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("LeakyReluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ LeakyReluGradOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Elu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ EluOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("EluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ EluGradOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Selu").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ SeluOp<SYCLDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SeluGrad").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
SeluGradOp<SYCLDevice, type>)
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNELS);
diff --git a/tensorflow/core/kernels/relu_op.h b/tensorflow/core/kernels/relu_op.h
index 4775deeb61..a4638c70c2 100644
--- a/tensorflow/core/kernels/relu_op.h
+++ b/tensorflow/core/kernels/relu_op.h
@@ -132,6 +132,67 @@ void Relu6GradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
}
template <typename Device, typename T>
+class LeakyReluOp : public UnaryElementWiseOp<T, LeakyReluOp<Device, T>> {
+ public:
+ explicit LeakyReluOp(OpKernelConstruction* context)
+ : UnaryElementWiseOp<T, LeakyReluOp<Device, T>>(context) {
+ float alpha_tmp;
+ OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp));
+ alpha_ = T(alpha_tmp);
+ }
+
+ void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
+ functor::LeakyRelu<Device, T> functor;
+ functor(context->eigen_device<Device>(), input.flat<T>(), alpha_,
+ output->flat<T>());
+ }
+
+ private:
+ T alpha_;
+};
+
+template <typename Device, typename T>
+class LeakyReluGradOp
+ : public BinaryElementWiseOp<T, LeakyReluGradOp<Device, T>> {
+ public:
+ explicit LeakyReluGradOp(OpKernelConstruction* context)
+ : BinaryElementWiseOp<T, LeakyReluGradOp<Device, T>>(context) {
+ float alpha_tmp;
+ OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_tmp));
+ alpha_ = T(alpha_tmp);
+ }
+
+ void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
+ const Tensor& a, T alpha, Tensor* output);
+
+ // INPUTS:
+ // g (gradients): backpropagated gradients
+ // a (inputs): either the inputs that were passed to LeakyReluOp(), or its
+ // outputs (using either one yields the same result here).
+ // OUTPUT:
+ // gradients to backprop
+ template <int NDIMS>
+ void Operate(OpKernelContext* context, const Tensor& g, const Tensor& a,
+ Tensor* output) {
+ OperateNoTemplate(context, g, a, alpha_, output);
+ }
+
+ private:
+ T alpha_;
+};
+
+template <typename Device, typename T>
+void LeakyReluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
+ const Tensor& g,
+ const Tensor& a, T alpha,
+ Tensor* output) {
+ if (!ReluHelpers::ValidateSameSize(context, g, a)) return;
+ functor::LeakyReluGrad<Device, T> functor;
+ functor(context->eigen_device<Device>(), g.flat<T>(), a.flat<T>(), alpha,
+ output->flat<T>());
+};
+
+template <typename Device, typename T>
class EluOp : public UnaryElementWiseOp<T, EluOp<Device, T>> {
public:
using UnaryElementWiseOp<T, EluOp<Device, T>>::UnaryElementWiseOp;
diff --git a/tensorflow/core/kernels/relu_op_functor.h b/tensorflow/core/kernels/relu_op_functor.h
index e564da335a..f917142a12 100644
--- a/tensorflow/core/kernels/relu_op_functor.h
+++ b/tensorflow/core/kernels/relu_op_functor.h
@@ -91,6 +91,36 @@ struct Relu6Grad {
}
};
+// Functor used by LeakyReluOp to do the computations.
+template <typename Device, typename T>
+struct LeakyRelu {
+ // Computes LeakyRelu activation.
+ //
+ // features: any shape.
+ // activations: same shape as "features".
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor features,
+ T alpha, typename TTypes<T>::Tensor activations) {
+ activations.device(d) = features.cwiseMax(features * alpha);
+ }
+};
+
+// Functor used by LeakyReluGradOp to do the computations.
+template <typename Device, typename T>
+struct LeakyReluGrad {
+ // Computes LeakyReluGrad backprops.
+ //
+ // gradients: gradients backpropagated to the LeakyRelu op.
+ // features: either the inputs that were passed to the LeakyRelu or, or its
+ // outputs (using either one yields the same result here).
+ // backprops: gradients to backpropagate to the LeakyRelu inputs.
+ void operator()(const Device& d, typename TTypes<T>::ConstTensor gradients,
+ typename TTypes<T>::ConstTensor features, T alpha,
+ typename TTypes<T>::Tensor backprops) {
+ backprops.device(d) =
+ (features > static_cast<T>(0)).select(gradients, gradients * alpha);
+ }
+};
+
// Functor used by EluOp to do the computations.
template <typename Device, typename T>
struct Elu {
diff --git a/tensorflow/core/kernels/relu_op_gpu.cu.cc b/tensorflow/core/kernels/relu_op_gpu.cu.cc
index b9391517c1..dd5f9495e2 100644
--- a/tensorflow/core/kernels/relu_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/relu_op_gpu.cu.cc
@@ -145,14 +145,16 @@ struct Relu<Device, qint8> {
} // namespace functor
// Definition of the GPU implementations declared in relu_op.cc.
-#define DEFINE_GPU_KERNELS(T) \
- template struct functor::Relu<GPUDevice, T>; \
- template struct functor::ReluGrad<GPUDevice, T>; \
- template struct functor::Relu6<GPUDevice, T>; \
- template struct functor::Relu6Grad<GPUDevice, T>; \
- template struct functor::Elu<GPUDevice, T>; \
- template struct functor::EluGrad<GPUDevice, T>; \
- template struct functor::Selu<GPUDevice, T>; \
+#define DEFINE_GPU_KERNELS(T) \
+ template struct functor::Relu<GPUDevice, T>; \
+ template struct functor::ReluGrad<GPUDevice, T>; \
+ template struct functor::Relu6<GPUDevice, T>; \
+ template struct functor::Relu6Grad<GPUDevice, T>; \
+ template struct functor::LeakyRelu<GPUDevice, T>; \
+ template struct functor::LeakyReluGrad<GPUDevice, T>; \
+ template struct functor::Elu<GPUDevice, T>; \
+ template struct functor::EluGrad<GPUDevice, T>; \
+ template struct functor::Selu<GPUDevice, T>; \
template struct functor::SeluGrad<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index 23d76986bf..678d675c4a 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -426,6 +426,12 @@ class AssignUpdateVariableOp : public OpKernel {
// ADD if value's refcount was 1.
mutex_lock ml(*variable->mu());
Tensor* var_tensor = variable->tensor();
+ OP_REQUIRES(context, var_tensor->shape().IsSameSize(value.shape()),
+ errors::InvalidArgument("Cannot update variable with shape ",
+ var_tensor->shape().DebugString(),
+ " using a Tensor with shape ",
+ value.shape().DebugString(),
+ ", shapes must be equal."));
OP_REQUIRES_OK(context,
PrepareToUpdateVariable<Device, T>(context, var_tensor));
functor::DenseUpdate<Device, T, Op> update_functor;
diff --git a/tensorflow/core/kernels/stateless_random_ops.cc b/tensorflow/core/kernels/stateless_random_ops.cc
index eab176c7fb..925f5291a6 100644
--- a/tensorflow/core/kernels/stateless_random_ops.cc
+++ b/tensorflow/core/kernels/stateless_random_ops.cc
@@ -113,74 +113,109 @@ class StatelessRandomOp : public StatelessRandomOpBase {
}
};
-#define REGISTER(TYPE) \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessRandomUniform") \
- .Device(DEVICE_CPU) \
- .HostMemory("shape") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp<CPUDevice, random::UniformDistribution< \
- random::PhiloxRandom, TYPE> >); \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessRandomNormal") \
- .Device(DEVICE_CPU) \
- .HostMemory("shape") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp<CPUDevice, random::NormalDistribution< \
- random::PhiloxRandom, TYPE> >); \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessTruncatedNormal") \
- .Device(DEVICE_CPU) \
- .HostMemory("shape") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp< \
- CPUDevice, \
- random::TruncatedNormalDistribution< \
- random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >);
+template <typename Device, typename IntType>
+class StatelessRandomUniformIntOp : public StatelessRandomOpBase {
+ public:
+ using StatelessRandomOpBase::StatelessRandomOpBase;
-TF_CALL_half(REGISTER);
-TF_CALL_float(REGISTER);
-TF_CALL_double(REGISTER);
+ void Fill(OpKernelContext* context, random::PhiloxRandom random,
+ Tensor* output) override {
+ const Tensor& minval = context->input(2);
+ const Tensor& maxval = context->input(3);
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(minval.shape()),
+ errors::InvalidArgument("minval must be 0-D, got shape ",
+ minval.shape().DebugString()));
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(maxval.shape()),
+ errors::InvalidArgument("maxval must be 0-D, got shape ",
+ maxval.shape().DebugString()));
+
+ // Verify that minval < maxval. Note that we'll never reach this point for
+ // empty output. Zero impossible things are fine.
+ const auto lo = minval.scalar<IntType>()();
+ const auto hi = maxval.scalar<IntType>()();
+ OP_REQUIRES(
+ context, lo < hi,
+ errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi));
+
+ // Build distribution
+ typedef random::UniformDistribution<random::PhiloxRandom, IntType>
+ Distribution;
+ Distribution dist(lo, hi);
+
+ auto flat = output->flat<IntType>();
+ // Reuse the compute kernels from the stateful random ops
+ functor::FillPhiloxRandom<Device, Distribution>()(
+ context, context->eigen_device<Device>(), random, flat.data(),
+ flat.size(), dist);
+ }
+};
-#undef REGISTER
+#define REGISTER(DEVICE, TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("StatelessRandomUniform") \
+ .Device(DEVICE_##DEVICE) \
+ .HostMemory("shape") \
+ .HostMemory("seed") \
+ .TypeConstraint<TYPE>("dtype"), \
+ StatelessRandomOp<DEVICE##Device, random::UniformDistribution< \
+ random::PhiloxRandom, TYPE> >); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("StatelessRandomNormal") \
+ .Device(DEVICE_##DEVICE) \
+ .HostMemory("shape") \
+ .HostMemory("seed") \
+ .TypeConstraint<TYPE>("dtype"), \
+ StatelessRandomOp<DEVICE##Device, random::NormalDistribution< \
+ random::PhiloxRandom, TYPE> >); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("StatelessTruncatedNormal") \
+ .Device(DEVICE_##DEVICE) \
+ .HostMemory("shape") \
+ .HostMemory("seed") \
+ .TypeConstraint<TYPE>("dtype"), \
+ StatelessRandomOp< \
+ DEVICE##Device, \
+ random::TruncatedNormalDistribution< \
+ random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >);
+
+#define REGISTER_INT(DEVICE, TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("StatelessRandomUniformInt") \
+ .Device(DEVICE_##DEVICE) \
+ .HostMemory("shape") \
+ .HostMemory("seed") \
+ .HostMemory("minval") \
+ .HostMemory("maxval") \
+ .TypeConstraint<TYPE>("dtype"), \
+ StatelessRandomUniformIntOp<DEVICE##Device, TYPE>);
+
+#define REGISTER_CPU(TYPE) REGISTER(CPU, TYPE)
+#define REGISTER_GPU(TYPE) REGISTER(GPU, TYPE)
+#define REGISTER_INT_CPU(TYPE) REGISTER_INT(CPU, TYPE)
+#define REGISTER_INT_GPU(TYPE) REGISTER_INT(GPU, TYPE)
+
+TF_CALL_half(REGISTER_CPU);
+TF_CALL_bfloat16(REGISTER_CPU);
+TF_CALL_float(REGISTER_CPU);
+TF_CALL_double(REGISTER_CPU);
+TF_CALL_int32(REGISTER_INT_CPU);
+TF_CALL_int64(REGISTER_INT_CPU);
#if GOOGLE_CUDA
-#define REGISTER(TYPE) \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessRandomUniform") \
- .Device(DEVICE_GPU) \
- .HostMemory("shape") \
- .HostMemory("seed") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp<GPUDevice, random::UniformDistribution< \
- random::PhiloxRandom, TYPE> >); \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessRandomNormal") \
- .Device(DEVICE_GPU) \
- .HostMemory("shape") \
- .HostMemory("seed") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp<GPUDevice, random::NormalDistribution< \
- random::PhiloxRandom, TYPE> >); \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessTruncatedNormal") \
- .Device(DEVICE_GPU) \
- .HostMemory("shape") \
- .HostMemory("seed") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp< \
- GPUDevice, \
- random::TruncatedNormalDistribution< \
- random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >);
+TF_CALL_half(REGISTER_GPU);
+TF_CALL_float(REGISTER_GPU);
+TF_CALL_double(REGISTER_GPU);
+TF_CALL_int32(REGISTER_INT_GPU);
+TF_CALL_int64(REGISTER_INT_GPU);
-TF_CALL_half(REGISTER);
-TF_CALL_float(REGISTER);
-TF_CALL_double(REGISTER);
+#endif // GOOGLE_CUDA
#undef REGISTER
-
-#endif // GOOGLE_CUDA
+#undef REGISTER_INT
+#undef REGISTER_CPU
+#undef REGISTER_GPU
+#undef REGISTER_INT_CPU
+#undef REGISTER_INT_GPU
} // namespace
diff --git a/tensorflow/core/kernels/string_util.cc b/tensorflow/core/kernels/string_util.cc
index 3a9803a052..92c73220d8 100644
--- a/tensorflow/core/kernels/string_util.cc
+++ b/tensorflow/core/kernels/string_util.cc
@@ -16,10 +16,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
-namespace {
-inline bool IsTrailByte(char x) { return static_cast<signed char>(x) < -0x40; }
-} // namespace
-
namespace tensorflow {
// Sets unit value based on str.
diff --git a/tensorflow/core/kernels/string_util.h b/tensorflow/core/kernels/string_util.h
index 390cf57702..d40e93ea33 100644
--- a/tensorflow/core/kernels/string_util.h
+++ b/tensorflow/core/kernels/string_util.h
@@ -30,6 +30,9 @@ enum class UnicodeEncoding { UTF8 };
// TODO(edloper): Add support for: UTF32_CHAR, etc.
enum class CharUnit { BYTE, UTF8_CHAR };
+// Whether or not the given byte is the trailing byte of a UTF-8/16/32 char.
+inline bool IsTrailByte(char x) { return static_cast<signed char>(x) < -0x40; }
+
// Sets `encoding` based on `str`.
Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding);
@@ -40,6 +43,47 @@ Status ParseCharUnit(const string& str, CharUnit* unit);
// Result may be incorrect if the input string is not valid UTF-8.
int32 UTF8StrLen(const string& string);
+// Get the next UTF8 character position starting at the given position and
+// skipping the given number of characters. Position is a byte offset, and
+// should never be `null`. The function return true if successful. However, if
+// the end of the string is reached before the requested characters, then the
+// position will point to the end of string and this function will return false.
+template <typename T>
+bool ForwardNUTF8CharPositions(const StringPiece in,
+ const T num_utf8_chars_to_shift, T* pos) {
+ const size_t size = in.size();
+ T utf8_chars_counted = 0;
+ while (utf8_chars_counted < num_utf8_chars_to_shift && *pos < size) {
+ // move forward one utf-8 character
+ do {
+ ++*pos;
+ } while (IsTrailByte(in[*pos]) && *pos < size);
+ ++utf8_chars_counted;
+ }
+ return utf8_chars_counted == num_utf8_chars_to_shift;
+}
+
+// Get the previous UTF8 character position starting at the given position and
+// skipping the given number of characters. Position is a byte offset with a
+// positive value, relative to the beginning of the string, and should never be
+// `null`. The function return true if successful. However, if the beginning of
+// the string is reached before the requested character, then the position will
+// point to the beginning of the string and this function will return false.
+template <typename T>
+bool BackNUTF8CharPositions(const StringPiece in,
+ const T num_utf8_chars_to_shift, T* pos) {
+ const size_t start = 0;
+ T utf8_chars_counted = 0;
+ while (utf8_chars_counted < num_utf8_chars_to_shift && (*pos > start)) {
+ // move back one utf-8 character
+ do {
+ --*pos;
+ } while (IsTrailByte(in[*pos]) && *pos > start);
+ ++utf8_chars_counted;
+ }
+ return utf8_chars_counted == num_utf8_chars_to_shift;
+}
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_
diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc
index 07f1d6e767..93c427039d 100644
--- a/tensorflow/core/kernels/substr_op.cc
+++ b/tensorflow/core/kernels/substr_op.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/string_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/types.h"
@@ -37,7 +38,11 @@ namespace tensorflow {
template <typename T>
class SubstrOp : public OpKernel {
public:
- using OpKernel::OpKernel;
+ explicit SubstrOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string unit;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("unit", &unit));
+ OP_REQUIRES_OK(ctx, ParseCharUnit(unit, &unit_));
+ }
void Compute(OpKernelContext* context) override {
// Get inputs
@@ -69,11 +74,23 @@ class SubstrOp : public OpKernel {
tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()());
for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
StringPiece in(input(i));
- OP_REQUIRES(
- context, FastBoundsCheck(std::abs(pos), in.size() + 1),
- errors::InvalidArgument("pos ", pos, " out of range for string",
- "b'", in, "' at index ", i));
- StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ T byte_pos = pos;
+ T byte_len = len;
+ switch (unit_) {
+ case CharUnit::UTF8_CHAR:
+ OP_REQUIRES(
+ context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string at index ", i));
+ break;
+ case CharUnit::BYTE:
+ byte_pos = AdjustedPosIndex(byte_pos, in);
+ OP_REQUIRES(
+ context, FastBoundsCheck(byte_pos, in.size() + 1),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string b'", in, "' at index ", i));
+ }
+ StringPiece sub_in = in.substr(byte_pos, byte_len);
output(i).assign(sub_in.data(), sub_in.size());
}
} else {
@@ -84,11 +101,23 @@ class SubstrOp : public OpKernel {
StringPiece in(input(i));
const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i));
const T len = tensorflow::internal::SubtleMustCopy(len_flat(i));
- OP_REQUIRES(
- context, FastBoundsCheck(std::abs(pos), in.size() + 1),
- errors::InvalidArgument("pos ", pos, " out of range for string",
- "b'", in, "' at index ", i));
- StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ T byte_pos = pos;
+ T byte_len = len;
+ switch (unit_) {
+ case CharUnit::UTF8_CHAR:
+ OP_REQUIRES(
+ context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string at index ", i));
+ break;
+ case CharUnit::BYTE:
+ byte_pos = AdjustedPosIndex(byte_pos, in);
+ OP_REQUIRES(
+ context, FastBoundsCheck(byte_pos, in.size() + 1),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string b'", in, "' at index ", i));
+ }
+ StringPiece sub_in = in.substr(byte_pos, byte_len);
output(i).assign(sub_in.data(), sub_in.size());
}
}
@@ -151,12 +180,24 @@ class SubstrOp : public OpKernel {
StringPiece in(input_bcast(i));
const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i));
const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i));
- OP_REQUIRES(
- context,
- FastBoundsCheck(std::abs(pos), input_bcast(i).size() + 1),
- errors::InvalidArgument("pos ", pos, " out of range for string",
- "b'", in, "' at index ", i));
- StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ T byte_pos = pos;
+ T byte_len = len;
+ switch (unit_) {
+ case CharUnit::UTF8_CHAR:
+ OP_REQUIRES(
+ context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string at index ", i));
+ break;
+ case CharUnit::BYTE:
+ byte_pos = AdjustedPosIndex(byte_pos, in);
+ OP_REQUIRES(
+ context,
+ FastBoundsCheck(byte_pos, input_bcast(i).size() + 1),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string b'", in, "' at index ", i));
+ }
+ StringPiece sub_in = in.substr(byte_pos, byte_len);
output(i).assign(sub_in.data(), sub_in.size());
}
break;
@@ -205,12 +246,24 @@ class SubstrOp : public OpKernel {
tensorflow::internal::SubtleMustCopy(pos_bcast(i, j));
const T len =
tensorflow::internal::SubtleMustCopy(len_bcast(i, j));
- OP_REQUIRES(
- context, FastBoundsCheck(std::abs(pos), in.size() + 1),
- errors::InvalidArgument("pos ", pos, " out of range for ",
- "string b'", in, "' at index (", i,
- ", ", j, ")"));
- StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ T byte_pos = pos;
+ T byte_len = len;
+ switch (unit_) {
+ case CharUnit::UTF8_CHAR:
+ OP_REQUIRES(
+ context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string at index ", i));
+ break;
+ case CharUnit::BYTE:
+ byte_pos = AdjustedPosIndex(byte_pos, in);
+ OP_REQUIRES(
+ context, FastBoundsCheck(byte_pos, in.size() + 1),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string b'", in, "' at index (",
+ i, ", ", j, ")"));
+ }
+ StringPiece sub_in = in.substr(byte_pos, byte_len);
output(i, j).assign(sub_in.data(), sub_in.size());
}
}
@@ -227,12 +280,73 @@ class SubstrOp : public OpKernel {
private:
// This adjusts the requested position. Note it does not perform any bound
// checks.
- T AdjustedPosIndex(const T pos_requested, const StringPiece s) {
+ static inline T AdjustedPosIndex(const T pos_requested, const StringPiece s) {
if (pos_requested < 0) {
return s.size() + pos_requested;
}
return pos_requested;
}
+
+ // Return true if successful; otherwise, return false if the `pos` argument
+ // is out of range in the string.
+ static inline bool UpdatePosAndLenForUtf8(const StringPiece in, T* pos,
+ T* len) {
+ if (*pos >= 0) {
+ return UpdatePositivePosAndLenForUtf8(in, *pos, *len, pos, len);
+ } else {
+ return UpdateNegativePosAndLenForUtf8(in, *pos, *len, pos, len);
+ }
+ }
+
+ static bool UpdatePositivePosAndLenForUtf8(const StringPiece in, const T pos,
+ const T len, T* char_pos,
+ T* char_len) {
+ *char_pos = 0;
+ // Determine byte position of the substring start.
+ if (!ForwardNUTF8CharPositions(in, pos, char_pos)) {
+ return false;
+ }
+ // Determine position of the end of the substring.
+ // The length will be capped at the end of the string, and we ignore whether
+ // the string had enough characters to handle it or not.
+ *char_len = *char_pos;
+ ForwardNUTF8CharPositions(in, len, char_len);
+ // The length in bytes is the position end of the substring less the start.
+ *char_len = *char_len - *char_pos;
+ return true;
+ }
+
+ // This function expects a negative position relative to the end of the
+ // string, but will update the character position to a positive number
+ // relative to the beginning of the string.
+ static bool UpdateNegativePosAndLenForUtf8(const StringPiece in, const T pos,
+ const T len, T* char_pos,
+ T* char_len) {
+ // Initially treat the length as position of the end of the substring.
+ *char_len = in.size();
+ // This is the number of character to skip from the end of the string to
+ // arrive at the position where the substring should end.
+ T utf8_chars_to_skip = -pos - len;
+ if (utf8_chars_to_skip < 0) {
+ utf8_chars_to_skip = 0;
+ }
+ // Find the byte position where the substring should end using the computed
+ // number of characters to skip.
+ if (!BackNUTF8CharPositions(in, utf8_chars_to_skip, char_len)) {
+ return false;
+ }
+ // Next, determine where the substring should begin. The number of chars to
+ // skip is the requested position minus the chars we've previously skipped.
+ *char_pos = *char_len;
+ if (!BackNUTF8CharPositions(in, -pos - utf8_chars_to_skip, char_pos)) {
+ return false;
+ }
+ // The length in bytes is the position end of the substring less the start.
+ *char_len = *char_len - *char_pos;
+ return true;
+ }
+
+ CharUnit unit_ = CharUnit::BYTE;
};
#define REGISTER_SUBSTR(type) \
diff --git a/tensorflow/core/kernels/substr_op_test.cc b/tensorflow/core/kernels/substr_op_test.cc
index 2e07050260..ea6b1ed500 100644
--- a/tensorflow/core/kernels/substr_op_test.cc
+++ b/tensorflow/core/kernels/substr_op_test.cc
@@ -42,7 +42,7 @@ limitations under the License.
namespace tensorflow {
// Test data from the TensorFlow README.md.
-const char* lines[] = {
+const char* ascii_lines[] = {
"**TensorFlow** is an open source software library for numerical "
"computation using data flow graphs.",
"The graph nodes represent mathematical operations, while the graph edges "
@@ -64,17 +64,76 @@ const char* lines[] = {
"backwards compatibility guarantee like C++, Go, Java, JavaScript and "
"Swift."};
+const char* unicode_lines[] = {
+ "TensorFlow\xe6\x98\xaf\xe4\xb8\x80\xe4\xb8\xaa\xe4\xbd\xbf\xe7\x94\xa8\xe6"
+ "\x95\xb0\xe6\x8d\xae\xe6\xb5\x81\xe5\x9b\xbe\xe8\xbf\x9b\xe8\xa1\x8c\xe6"
+ "\x95\xb0\xe5\x80\xbc\xe8\xae\xa1\xe7\xae\x97\xe7\x9a\x84\xe5\xbc\x80\xe6"
+ "\xba\x90\xe8\xbd\xaf\xe4\xbb\xb6\xe5\xba\x93\xe3\x80\x82",
+ "\xe5\x9b\xbe\xe5\xbd\xa2\xe8\x8a\x82\xe7\x82\xb9\xe8\xa1\xa8\xe7\xa4\xba"
+ "\xe6\x95\xb0\xe5\xad\xa6\xe8\xbf\x90\xe7\xae\x97\xef\xbc\x8c\xe8\x80\x8c"
+ "\xe5\x9b\xbe\xe5\xbd\xa2\xe8\xbe\xb9\xe7\xbc\x98\xe8\xa1\xa8\xe7\xa4\xba"
+ "\xe5\x9c\xa8\xe5\xae\x83\xe4\xbb\xac\xe4\xb9\x8b\xe9\x97\xb4\xe6\xb5\x81"
+ "\xe5\x8a\xa8\xe7\x9a\x84\xe5\xa4\x9a\xe7\xbb\xb4\xe6\x95\xb0\xe6\x8d\xae"
+ "\xe9\x98\xb5\xe5\x88\x97\xef\xbc\x88\xe5\xbc\xa0\xe9\x87\x8f\xef\xbc\x89"
+ "\xe3\x80\x82",
+ "\xe8\xbf\x99\xe7\xa7\x8d\xe7\x81\xb5\xe6\xb4\xbb\xe7\x9a\x84\xe4\xbd\x93"
+ "\xe7\xb3\xbb\xe7\xbb\x93\xe6\x9e\x84\xe4\xbd\xbf\xe6\x82\xa8\xe5\x8f\xaf"
+ "\xe4\xbb\xa5\xe5\xb0\x86\xe8\xae\xa1\xe7\xae\x97\xe9\x83\xa8\xe7\xbd\xb2"
+ "\xe5\x88\xb0\xe6\xa1\x8c\xe9\x9d\xa2\xef\xbc\x8c\xe6\x9c\x8d\xe5\x8a\xa1"
+ "\xe5\x99\xa8\xe6\x88\x96\xe7\xa7\xbb\xe5\x8a\xa8\xe8\xae\xbe\xe5\xa4\x87"
+ "\xe4\xb8\xad\xe7\x9a\x84\xe4\xb8\x80\xe4\xb8\xaa\xe6\x88\x96\xe5\xa4\x9a"
+ "\xe4\xb8\xaa CPU\xe6\x88\x96GPU\xef\xbc\x8c\xe8\x80\x8c\xe6\x97\xa0\xe9"
+ "\x9c\x80\xe9\x87\x8d\xe5\x86\x99\xe4\xbb\xa3\xe7\xa0\x81\xe3\x80\x82",
+ "TensorFlow\xe8\xbf\x98\xe5\x8c\x85\xe6\x8b\xac[TensorBoard]\xef\xbc\x88"
+ "https://www.tensorflow.org/guide/summaries_and_tensorboard\xef\xbc\x89\xef"
+ "\xbc\x8c\xe8\xbf\x99\xe6\x98\xaf\xe4\xb8\x80\xe4\xb8\xaa\xe6\x95\xb0\xe6"
+ "\x8d\xae\xe5\x8f\xaf\xe8\xa7\x86\xe5\x8c\x96\xe5\xb7\xa5\xe5\x85\xb7\xe5"
+ "\x8c\x85\xe3\x80\x82",
+ "TensorFlow\xe6\x9c\x80\xe5\x88\x9d\xe6\x98\xaf\xe7\x94\xb1\xe7\xa0\x94\xe7"
+ "\xa9\xb6\xe4\xba\xba\xe5\x91\x98\xe5\x92\x8c\xe5\xb7\xa5\xe7\xa8\x8b\xe5"
+ "\xb8\x88\xe5\x9c\xa8Google\xe6\x9c\xba\xe5\x99\xa8\xe6\x99\xba\xe8\x83\xbd"
+ "\xe7\xa0\x94\xe7\xa9\xb6\xe7\xbb\x84\xe7\xbb\x87\xe7\x9a\x84Google Brain"
+ "\xe5\x9b\xa2\xe9\x98\x9f\xe5\xbc\x80\xe5\x8f\x91\xe7\x9a\x84\xef\xbc\x8c"
+ "\xe7\x9b\xae\xe7\x9a\x84\xe6\x98\xaf\xe8\xbf\x9b\xe8\xa1\x8c\xe6\x9c\xba"
+ "\xe5\x99\xa8\xe5\xad\xa6\xe4\xb9\xa0\xe5\x92\x8c\xe6\xb7\xb1\xe5\xba\xa6"
+ "\xe7\xa5\x9e\xe7\xbb\x8f\xe7\xbd\x91\xe7\xbb\x9c\xe7\xa0\x94\xe7\xa9\xb6"
+ "\xe3\x80\x82",
+ "\xe8\xaf\xa5\xe7\xb3\xbb\xe7\xbb\x9f\xe8\xb6\xb3\xe4\xbb\xa5\xe9\x80\x82"
+ "\xe7\x94\xa8\xe4\xba\x8e\xe5\x90\x84\xe7\xa7\x8d\xe5\x85\xb6\xe4\xbb\x96"
+ "\xe9\xa2\x86\xe5\x9f\x9f\xe4\xb9\x9f\xe6\x98\xaf\xe5\xa6\x82\xe6\xad\xa4"
+ "\xe3\x80\x82",
+ "TensorFlow\xe6\x8f\x90\xe4\xbe\x9b\xe7\xa8\xb3\xe5\xae\x9a\xe7\x9a\x84"
+ "Python API\xe5\x92\x8c C API\xef\xbc\x8c\xe4\xbb\xa5\xe5\x8f\x8a\xe6\xb2"
+ "\xa1\xe6\x9c\x89 API\xe5\x90\x91\xe5\x90\x8e\xe5\x85\xbc\xe5\xae\xb9\xe6"
+ "\x80\xa7\xe4\xbf\x9d\xe8\xaf\x81\xef\xbc\x8c\xe5\xa6\x82 C ++\xef\xbc\x8c"
+ "Go\xef\xbc\x8cJava\xef\xbc\x8cJavaScript\xe5\x92\x8cSwift\xe3\x80\x82",
+};
+
+const char* const kByteUnit = "BYTE";
+const char* const kUTF8Unit = "UTF8_CHAR";
+
Tensor GetTestTensor(int batch) {
- const int sz = TF_ARRAYSIZE(lines);
+ const int sz = TF_ARRAYSIZE(ascii_lines);
+ Tensor t(DT_STRING, {batch});
+ auto s = t.flat<string>();
+ for (int i = 0; i < batch; ++i) {
+ s(i) = ascii_lines[i % sz];
+ }
+ return t;
+}
+
+Tensor GetTestUTF8Tensor(int batch) {
+ const int sz = TF_ARRAYSIZE(unicode_lines);
Tensor t(DT_STRING, {batch});
auto s = t.flat<string>();
for (int i = 0; i < batch; ++i) {
- s(i) = lines[i % sz];
+ s(i) = unicode_lines[i % sz];
}
return t;
}
-Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len) {
+Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len,
+ const char* const unit) {
Graph* g = new Graph(OpRegistry::Global());
Tensor position(DT_INT32, TensorShape({}));
position.flat<int32>().setConstant(pos);
@@ -85,21 +144,46 @@ Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len) {
.Input(test::graph::Constant(g, input))
.Input(test::graph::Constant(g, position))
.Input(test::graph::Constant(g, length))
+ .Attr("unit", unit)
.Finalize(g, nullptr /* node */));
return g;
}
-void BM_Substr(int iters, int batch_size) {
+void BM_SubstrByte(int iters, int batch_size) {
testing::StopTiming();
testing::ItemsProcessed(static_cast<int64>(iters));
testing::UseRealTime();
Tensor input = GetTestTensor(batch_size);
- Graph* g = SetupSubstrGraph(input, 3, 30);
+ Graph* g = SetupSubstrGraph(input, 3, 30, kByteUnit);
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+}
+
+void BM_SubstrUTF8(int iters, int batch_size) {
+ testing::StopTiming();
+ testing::ItemsProcessed(static_cast<int64>(iters));
+ testing::UseRealTime();
+ Tensor input = GetTestUTF8Tensor(batch_size);
+ Graph* g = SetupSubstrGraph(input, 3, 30, kUTF8Unit);
testing::StartTiming();
test::Benchmark("cpu", g).Run(iters);
}
-BENCHMARK(BM_Substr)->Arg(1)->Arg(8)->Arg(16)->Arg(32)->Arg(64)->Arg(128)->Arg(
- 256);
+BENCHMARK(BM_SubstrByte)
+ ->Arg(1)
+ ->Arg(8)
+ ->Arg(16)
+ ->Arg(32)
+ ->Arg(64)
+ ->Arg(128)
+ ->Arg(256);
+BENCHMARK(BM_SubstrUTF8)
+ ->Arg(1)
+ ->Arg(8)
+ ->Arg(16)
+ ->Arg(32)
+ ->Arg(64)
+ ->Arg(128)
+ ->Arg(256);
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/unique_op.cc b/tensorflow/core/kernels/unique_op.cc
index 3559baa18e..3bdcfc90b8 100644
--- a/tensorflow/core/kernels/unique_op.cc
+++ b/tensorflow/core/kernels/unique_op.cc
@@ -108,7 +108,7 @@ class UniqueOp : public OpKernel {
std::unordered_map<T, TIndex> uniq;
uniq.reserve(2 * N);
- for (int64 i = 0, j = 0; i < N; ++i) {
+ for (Eigen::Index i = 0, j = 0; i < N; ++i) {
auto it = uniq.insert(std::make_pair(Tin(i), j));
idx_vec(i) = it.first->second;
if (it.second) {
@@ -131,19 +131,20 @@ class UniqueOp : public OpKernel {
// General implementation when unique is run over multiple elements.
auto Tin = input.shaped<T, 3>(new_sizes);
- auto hash_fn = [&Tin](const int64& key) {
+ auto hash_fn = [&Tin](const Eigen::Index& key) {
size_t h = 0;
- for (int64 i = 0; i < Tin.dimension(0); i++) {
- for (int64 j = 0; j < Tin.dimension(2); j++) {
+ for (Eigen::Index i = 0; i < Tin.dimension(0); i++) {
+ for (Eigen::Index j = 0; j < Tin.dimension(2); j++) {
h = Hash64Combine(h, hash<T>{}(Tin(i, key, j)));
}
}
return h;
};
- auto equal_to_fn = [&Tin](const int64& lhs, const int64& rhs) {
- for (int64 i = 0; i < Tin.dimension(0); i++) {
- for (int64 j = 0; j < Tin.dimension(2); j++) {
+ auto equal_to_fn = [&Tin](const Eigen::Index& lhs,
+ const Eigen::Index& rhs) {
+ for (Eigen::Index i = 0; i < Tin.dimension(0); i++) {
+ for (Eigen::Index j = 0; j < Tin.dimension(2); j++) {
if (Tin(i, lhs, j) != Tin(i, rhs, j)) {
return false;
}
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 33f18ae13f..9df0ece69b 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -28981,6 +28981,74 @@ op {
}
}
op {
+ name: "LeakyRelu"
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "activations"
+ type_attr: "T"
+ }
+ attr {
+ name: "alpha"
+ type: "float"
+ default_value {
+ f: 0.2
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
+ name: "LeakyReluGrad"
+ input_arg {
+ name: "gradients"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "backprops"
+ type_attr: "T"
+ }
+ attr {
+ name: "alpha"
+ type: "float"
+ default_value {
+ f: 0.2
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "LearnedUnigramCandidateSampler"
input_arg {
name: "true_classes"
@@ -30567,6 +30635,52 @@ op {
}
}
op {
+ name: "MapDefun"
+ input_arg {
+ name: "arguments"
+ type_list_attr: "Targuments"
+ }
+ input_arg {
+ name: "captured_inputs"
+ type_list_attr: "Tcaptured"
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "Tcaptured"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+}
+op {
name: "MapIncompleteSize"
output_arg {
name: "size"
@@ -70851,6 +70965,62 @@ op {
}
}
op {
+ name: "StatelessRandomNormal"
+ input_arg {
+ name: "shape"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "seed"
+ type_attr: "Tseed"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tseed"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "StatelessRandomUniform"
input_arg {
name: "shape"
@@ -70948,6 +71118,118 @@ op {
}
}
op {
+ name: "StatelessRandomUniform"
+ input_arg {
+ name: "shape"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "seed"
+ type_attr: "Tseed"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tseed"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
+ name: "StatelessRandomUniformInt"
+ input_arg {
+ name: "shape"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "seed"
+ type_attr: "Tseed"
+ }
+ input_arg {
+ name: "minval"
+ type_attr: "dtype"
+ }
+ input_arg {
+ name: "maxval"
+ type_attr: "dtype"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tseed"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "StatelessTruncatedNormal"
input_arg {
name: "shape"
@@ -71045,6 +71327,62 @@ op {
}
}
op {
+ name: "StatelessTruncatedNormal"
+ input_arg {
+ name: "shape"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "seed"
+ type_attr: "Tseed"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tseed"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "StatelessWhile"
input_arg {
name: "input"
@@ -71844,6 +72182,48 @@ op {
}
}
op {
+ name: "Substr"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "pos"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "len"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "unit"
+ type: "string"
+ default_value {
+ s: "BYTE"
+ }
+ allowed_values {
+ list {
+ s: "BYTE"
+ s: "UTF8_CHAR"
+ }
+ }
+ }
+}
+op {
name: "Sum"
input_arg {
name: "input"
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 889a6a4640..ec22eee874 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -903,14 +903,18 @@ REGISTER_OP("ModelDataset")
REGISTER_OP("MapDefun")
.Input("arguments: Targuments")
+ .Input("captured_inputs: Tcaptured")
.Output("output: output_types")
.Attr("Targuments: list(type) >= 1")
+ .Attr("Tcaptured: list(type) >= 0 = []")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("f: func")
.SetShapeFn([](shape_inference::InferenceContext* c) {
std::vector<PartialTensorShape> output_shapes;
TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ DataTypeVector t_args;
+ TF_RETURN_IF_ERROR(c->GetAttr("Targuments", &t_args));
if (output_shapes.size() != c->num_outputs()) {
return errors::InvalidArgument(
"`output_shapes` must be the same length as `output_types` (",
@@ -918,10 +922,11 @@ REGISTER_OP("MapDefun")
}
int64 dim_zero = -1;
- for (size_t i = 0; i < static_cast<size_t>(c->num_inputs()); ++i) {
+ for (size_t i = 0; i < t_args.size(); ++i) {
if (c->Rank(c->input(i)) == 0) {
return errors::InvalidArgument(
- "Inputs must have rank at least 1. Input ", i, " has rank of 0");
+ "Arguments must have rank at least 1. Input ", i,
+ " has rank of 0.");
}
auto dim_handle = c->Dim(c->input(i), 0);
if (c->ValueKnown(dim_handle)) {
@@ -929,7 +934,7 @@ REGISTER_OP("MapDefun")
dim_zero = c->Value(dim_handle);
} else if (c->Value(dim_handle) != dim_zero) {
return errors::InvalidArgument(
- "Inputs must have the same dimension 0.");
+ "Arguments must have the same dimension 0.");
}
}
}
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 3eff728f03..a9e5e7824d 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -1437,7 +1437,24 @@ REGISTER_OP("Bincount")
.Attr("T: {int32, int64, float32, float64}")
.Output("bins: T")
.SetShapeFn([](InferenceContext* c) {
- c->set_output(0, c->UnknownShapeOfRank(1));
+ ShapeHandle unused;
+ // The input `size` must be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+
+ const Tensor* size_tensor = c->input_tensor(1);
+ if (size_tensor == nullptr) {
+ // Return unknown shape if size is not known.
+ c->set_output(0, c->UnknownShapeOfRank(1));
+ return Status::OK();
+ }
+
+ // Return `[size]` shape if size is known.
+ int32 size_val = size_tensor->scalar<int32>()();
+ if (size_val < 0) {
+ return errors::InvalidArgument("size (", size_val,
+ ") must be non-negative");
+ }
+ c->set_output(0, c->MakeShape({size_val}));
return Status::OK();
});
diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc
index be4c3ed2b6..05379a7d69 100644
--- a/tensorflow/core/ops/math_ops_test.cc
+++ b/tensorflow/core/ops/math_ops_test.cc
@@ -559,4 +559,16 @@ TEST(MathOpsTest, QuantizedAdd_ShapeFn) {
INFER_ERROR("must be rank 0", op, "?;?;?;?;[3];?");
INFER_ERROR("must be rank 0", op, "?;?;?;?;?;[4]");
}
+
+TEST(MathOpsTest, Bincount_ShapeFn) {
+ ShapeInferenceTestOp op("Bincount");
+
+ // size should be scalar.
+ INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[1];?");
+
+ INFER_OK(op, "?;?;?", "[?]");
+ INFER_OK(op, "?;[];?", "[?]");
+ INFER_OK(op, "[?];[];?", "[?]");
+ INFER_OK(op, "[?];[];[?]", "[?]");
+}
} // end namespace tensorflow
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index d1d81b27cc..a9ca69ad86 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -983,6 +983,21 @@ REGISTER_OP("Relu6Grad")
.Attr("T: realnumbertype")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn);
+REGISTER_OP("LeakyRelu")
+ .Input("features: T")
+ .Output("activations: T")
+ .Attr("alpha: float = 0.2")
+ .Attr("T: {half, float, double} = DT_FLOAT")
+ .SetShapeFn(shape_inference::UnchangedShape);
+
+REGISTER_OP("LeakyReluGrad")
+ .Input("gradients: T")
+ .Input("features: T")
+ .Output("backprops: T")
+ .Attr("alpha: float = 0.2")
+ .Attr("T: {half, float, double} = DT_FLOAT")
+ .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
+
REGISTER_OP("Elu")
.Input("features: T")
.Output("activations: T")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 0e58a9475d..2048ad26ac 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -14296,6 +14296,74 @@ op {
}
}
op {
+ name: "LeakyRelu"
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "activations"
+ type_attr: "T"
+ }
+ attr {
+ name: "alpha"
+ type: "float"
+ default_value {
+ f: 0.2
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
+ name: "LeakyReluGrad"
+ input_arg {
+ name: "gradients"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "backprops"
+ type_attr: "T"
+ }
+ attr {
+ name: "alpha"
+ type: "float"
+ default_value {
+ f: 0.2
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "LearnedUnigramCandidateSampler"
input_arg {
name: "true_classes"
@@ -15262,6 +15330,10 @@ op {
name: "arguments"
type_list_attr: "Targuments"
}
+ input_arg {
+ name: "captured_inputs"
+ type_list_attr: "Tcaptured"
+ }
output_arg {
name: "output"
type_list_attr: "output_types"
@@ -15273,6 +15345,15 @@ op {
minimum: 1
}
attr {
+ name: "Tcaptured"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ attr {
name: "output_types"
type: "list(type)"
has_minimum: true
@@ -32965,6 +33046,7 @@ op {
allowed_values {
list {
type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -33020,6 +33102,7 @@ op {
allowed_values {
list {
type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -33053,6 +33136,62 @@ op {
}
}
op {
+ name: "StatelessRandomUniformInt"
+ input_arg {
+ name: "shape"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "seed"
+ type_attr: "Tseed"
+ }
+ input_arg {
+ name: "minval"
+ type_attr: "dtype"
+ }
+ input_arg {
+ name: "maxval"
+ type_attr: "dtype"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tseed"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "StatelessTruncatedNormal"
input_arg {
name: "shape"
@@ -33075,6 +33214,7 @@ op {
allowed_values {
list {
type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -33748,6 +33888,19 @@ op {
}
}
}
+ attr {
+ name: "unit"
+ type: "string"
+ default_value {
+ s: "BYTE"
+ }
+ allowed_values {
+ list {
+ s: "BYTE"
+ s: "UTF8_CHAR"
+ }
+ }
+ }
}
op {
name: "Sum"
diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc
index adc9cd1486..65bdde375b 100644
--- a/tensorflow/core/ops/resource_variable_ops.cc
+++ b/tensorflow/core/ops/resource_variable_ops.cc
@@ -216,7 +216,8 @@ REGISTER_OP("VarIsInitializedOp")
Status VariableShapeShapeFn(InferenceContext* c) {
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data == nullptr || handle_data->empty()) {
- return errors::InvalidArgument("Handle doesn't have shape information.");
+ c->set_output(0, c->Vector(c->UnknownDim()));
+ return Status::OK();
}
ShapeHandle var_shape = (*handle_data)[0].shape;
int64 rank = c->RankKnown(var_shape) ? c->Rank(var_shape)
diff --git a/tensorflow/core/ops/stateless_random_grad.cc b/tensorflow/core/ops/stateless_random_grad.cc
new file mode 100644
index 0000000000..331e1d0152
--- /dev/null
+++ b/tensorflow/core/ops/stateless_random_grad.cc
@@ -0,0 +1,23 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/function.h"
+
+namespace tensorflow {
+REGISTER_OP_NO_GRADIENT("StatelessRandomUniform");
+REGISTER_OP_NO_GRADIENT("StatelessRandomNormal");
+REGISTER_OP_NO_GRADIENT("StatelessTruncatedNormal");
+REGISTER_OP_NO_GRADIENT("StatelessMultinomial");
+} // end namespace tensorflow
diff --git a/tensorflow/core/ops/stateless_random_ops.cc b/tensorflow/core/ops/stateless_random_ops.cc
index 742709fb18..f919a21d60 100644
--- a/tensorflow/core/ops/stateless_random_ops.cc
+++ b/tensorflow/core/ops/stateless_random_ops.cc
@@ -19,42 +19,55 @@ limitations under the License.
namespace tensorflow {
using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
-static Status StatelessShape(shape_inference::InferenceContext* context) {
+static Status StatelessShape(InferenceContext* c) {
// Check seed shape
ShapeHandle seed;
- TF_RETURN_IF_ERROR(context->WithRank(context->input(1), 1, &seed));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seed));
DimensionHandle unused;
- TF_RETURN_IF_ERROR(context->WithValue(context->Dim(seed, 0), 2, &unused));
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused));
// Set output shape
ShapeHandle out;
- TF_RETURN_IF_ERROR(context->MakeShapeFromShapeTensor(0, &out));
- context->set_output(0, out);
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
+ c->set_output(0, out);
return Status::OK();
}
-#define REGISTER_STATELESS_OP(name) \
- REGISTER_OP(name) \
- .Input("shape: T") \
- .Input("seed: Tseed") \
- .Output("output: dtype") \
- .Attr("dtype: {half,float,double} = DT_FLOAT") \
- .Attr("T: {int32, int64} = DT_INT32") \
- .Attr("Tseed: {int32, int64} = DT_INT64") \
+#define REGISTER_STATELESS_OP(name) \
+ REGISTER_OP(name) \
+ .Input("shape: T") \
+ .Input("seed: Tseed") \
+ .Output("output: dtype") \
+ .Attr("dtype: {half,bfloat16,float,double} = DT_FLOAT") \
+ .Attr("T: {int32, int64} = DT_INT32") \
+ .Attr("Tseed: {int32, int64} = DT_INT64") \
.SetShapeFn(StatelessShape)
-// This op is exposed through contrib/stateless only. The interface may change.
REGISTER_STATELESS_OP("StatelessRandomUniform");
-
-// This op is exposed through contrib/stateless only. The interface may change.
REGISTER_STATELESS_OP("StatelessRandomNormal");
-
-// This op is exposed through contrib/stateless only. The interface may change.
REGISTER_STATELESS_OP("StatelessTruncatedNormal");
-// This op is exposed through contrib/stateless only. The interface may change.
+#undef REGISTER_STATELESS_OP
+
+REGISTER_OP("StatelessRandomUniformInt")
+ .Input("shape: T")
+ .Input("seed: Tseed")
+ .Input("minval: dtype")
+ .Input("maxval: dtype")
+ .Output("output: dtype")
+ .Attr("dtype: {int32, int64}")
+ .Attr("T: {int32, int64}")
+ .Attr("Tseed: {int32, int64} = DT_INT64")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ return StatelessShape(c);
+ });
+
REGISTER_OP("StatelessMultinomial")
.Input("logits: T")
.Input("num_samples: int32")
@@ -80,6 +93,4 @@ REGISTER_OP("StatelessMultinomial")
return Status::OK();
});
-#undef REGISTER_STATELESS_OP
-
} // namespace tensorflow
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index b4fbde54d9..94d71a4113 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -223,6 +223,7 @@ REGISTER_OP("Substr")
.Input("len: T")
.Output("output: string")
.Attr("T: {int32, int64}")
+ .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle pos_shape = c->input(1);
ShapeHandle len_shape = c->input(2);
diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto
index 8c31468ff5..7ccd54b818 100644
--- a/tensorflow/core/protobuf/rewriter_config.proto
+++ b/tensorflow/core/protobuf/rewriter_config.proto
@@ -83,6 +83,10 @@ message RewriterConfig {
// Controls how many times we run the optimizers in meta optimizer (default
// is once).
NumIterationsType meta_optimizer_iterations = 12;
+ // Maximum number of milliseconds to spend optimizing a single graph before
+ // timing out. If equal to 0 the system picks a default (currently 5 minutes).
+ // If less than 0 the optimizer will never time out.
+ int64 meta_optimizer_timeout_ms = 20;
// The minimum number of nodes in a graph to optimizer. For smaller graphs,
// optimization is skipped.
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
deleted file mode 100644
index 96d269bec4..0000000000
--- a/tensorflow/docs_src/performance/xla/operation_semantics.md
+++ /dev/null
@@ -1,2426 +0,0 @@
-# Operation Semantics
-
-The following describes the semantics of operations defined in the
-[`XlaBuilder`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
-interface. Typically, these operations map one-to-one to operations defined in
-the RPC interface in
-[`xla_data.proto`](https://www.tensorflow.org/code/tensorflow/compiler/xla/xla_data.proto).
-
-A note on nomenclature: the generalized data type XLA deals with is an
-N-dimensional array holding elements of some uniform type (such as 32-bit
-float). Throughout the documentation, *array* is used to denote an
-arbitrary-dimensional array. For convenience, special cases have more specific
-and familiar names; for example a *vector* is a 1-dimensional array and a
-*matrix* is a 2-dimensional array.
-
-## AllToAll
-
-See also
-[`XlaBuilder::AllToAll`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Alltoall is a collective operation that sends data from all cores to all cores.
-It has two phases:
-
-1. the scatter phase. On each core, the operand is split into `split_count`
- number of blocks along the `split_dimensions`, and the blocks are scattered
- to all cores, e.g., the ith block is send to the ith core.
-2. the gather phase. Each core concatenates the received blocks along the
- `concat_dimension`.
-
-The participating cores can be configured by:
-
-- `replica_groups`: each ReplicaGroup contains a list of replica id. If empty,
- all replicas belong to one group in the order of 0 - (n-1). Alltoall will be
- applied within subgroups in the specified order. For example, replica
- groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied within replica
- 1, 2, 3, and in the gather phase, the received blocks will be concatenated
- in the order of 1, 2, 3; another Alltoall will be applied within replica 4,
- 5, 0, and the concatenation order is 4, 5, 0.
-
-Prerequisites:
-
-- The dimension size of the operand on the split_dimension is divisible by
- split_count.
-- The operand's shape is not tuple.
-
-<b> `AllToAll(operand, split_dimension, concat_dimension, split_count,
-replica_groups)` </b>
-
-
-| Arguments | Type | Semantics |
-| ------------------ | --------------------- | ------------------------------- |
-| `operand` | `XlaOp` | n dimensional input array |
-| `split_dimension` | `int64` | A value in the interval `[0, |
-: : : n)` that names the dimension :
-: : : along which the operand is :
-: : : split :
-| `concat_dimension` | `int64` | a value in the interval `[0, |
-: : : n)` that names the dimension :
-: : : along which the split blocks :
-: : : are concatenated :
-| `split_count` | `int64` | the number of cores that |
-: : : participate this operation. If :
-: : : `replica_groups` is empty, this :
-: : : should be the number of :
-: : : replicas; otherwise, this :
-: : : should be equal to the number :
-: : : of replicas in each group. :
-| `replica_groups` | `ReplicaGroup` vector | each group contains a list of |
-: : : replica id. :
-
-Below shows an example of Alltoall.
-
-```
-XlaBuilder b("alltoall");
-auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
-AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4);
-```
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:100%" src="../../images/xla/ops_alltoall.png">
-</div>
-
-In this example, there are 4 cores participating the Alltoall. On each core, the
-operand is split into 4 parts along dimension 0, so each part has shape
-f32[4,4]. The 4 parts are scattered to all cores. Then each core concatenates
-the received parts along dimension 1, in the order or core 0-4. So the output on
-each core has shape f32[16,4].
-
-## BatchNormGrad
-
-See also
-[`XlaBuilder::BatchNormGrad`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
-and [the original batch normalization paper](https://arxiv.org/abs/1502.03167)
-for a detailed description of the algorithm.
-
-Calculates gradients of batch norm.
-
-<b> `BatchNormGrad(operand, scale, mean, variance, grad_output, epsilon, feature_index)` </b>
-
-| Arguments | Type | Semantics |
-| --------------- | ----------------------- | -------------------------------- |
-| `operand` | `XlaOp` | n dimensional array to be |
-: : : normalized (x) :
-| `scale` | `XlaOp` | 1 dimensional array |
-: : : (\\(\gamma\\)) :
-| `mean` | `XlaOp` | 1 dimensional array (\\(\mu\\)) |
-| `variance` | `XlaOp` | 1 dimensional array |
-: : : (\\(\sigma^2\\)) :
-| `grad_output` | `XlaOp` | Gradients passed to |
-: : : `BatchNormTraining` :
-: : : (\\( \nabla y\\)) :
-| `epsilon` | `float` | Epsilon value (\\(\epsilon\\)) |
-| `feature_index` | `int64` | Index to feature dimension in |
-: : : `operand` :
-
-For each feature in the feature dimension (`feature_index` is the index for the
-feature dimension in `operand`), the operation calculates the gradients with
-respect to `operand`, `offset` and `scale` across all the other dimensions. The
-`feature_index` must be a valid index for the feature dimension in `operand`.
-
-The three gradients are defined by the following formulas (assuming a
-4-dimensional tensor as `operand` and with feature dimension index \\(l\\),
-batch size `m` and spatial sizes `w` and `h`):
-
-\\[ \begin{split} c_l&=
-\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h
-\left( \nabla y_{ijkl} \frac{x_{ijkl} - \mu_l}{\sigma^2_l+\epsilon} \right)
-\\\\
-\nabla x_{ijkl} &= \frac{\gamma_{l}}{\sqrt{\sigma^2_{l}+\epsilon}}
-\left( \nabla y_{ijkl} - \mathrm{mean}(\nabla y) - c_l (x_{ijkl} - \mu_{l})
-\right)
-\\\\
-\nabla \gamma_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \left( \nabla y_{ijkl}
-\frac{x_{ijkl} - \mu_l}{\sqrt{\sigma^2_{l}+\epsilon}} \right)
-\\\\\
-\nabla \beta_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl}
-\end{split} \\]
-
-The inputs `mean` and `variance` represent moments value
-across batch and spatial dimensions.
-
-The output type is a tuple of three handles:
-
-| Outputs | Type | Semantics |
-| ------------- | ----------------------- | --------------------------------- |
-| `grad_operand` | `XlaOp` | gradient with respect to input |
-: : : `operand` (\\( \nabla x\\)) :
-| `grad_scale` | `XlaOp` | gradient with respect to input |
-: : : `scale` (\\( \nabla \gamma\\)) :
-| `grad_offset` | `XlaOp` | gradient with respect to input |
-: : : `offset`(\\( \nabla \beta\\)) :
-
-## BatchNormInference
-
-See also
-[`XlaBuilder::BatchNormInference`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
-and [the original batch normalization paper](https://arxiv.org/abs/1502.03167)
-for a detailed description of the algorithm.
-
-Normalizes an array across batch and spatial dimensions.
-
-<b> `BatchNormInference(operand, scale, offset, mean, variance, epsilon, feature_index)` </b>
-
-Arguments | Type | Semantics
---------------- | ------- | ---------------------------------------
-`operand` | `XlaOp` | n dimensional array to be normalized
-`scale` | `XlaOp` | 1 dimensional array
-`offset` | `XlaOp` | 1 dimensional array
-`mean` | `XlaOp` | 1 dimensional array
-`variance` | `XlaOp` | 1 dimensional array
-`epsilon` | `float` | Epsilon value
-`feature_index` | `int64` | Index to feature dimension in `operand`
-
-For each feature in the feature dimension (`feature_index` is the index for the
-feature dimension in `operand`), the operation calculates the mean and variance
-across all the other dimensions and uses the mean and variance to normalize each
-element in `operand`. The `feature_index` must be a valid index for the feature
-dimension in `operand`.
-
-`BatchNormInference` is equivalent to calling `BatchNormTraining` without
-computing `mean` and `variance` for each batch. It uses the input `mean` and
-`variance` instead as estimated values. The purpose of this op is to reduce
-latency in inference, hence the name `BatchNormInference`.
-
-The output is an n-dimensional, normalized array with the same shape as input
-`operand`.
-
-## BatchNormTraining
-
-See also
-[`XlaBuilder::BatchNormTraining`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
-and [`the original batch normalization paper`](https://arxiv.org/abs/1502.03167)
-for a detailed description of the algorithm.
-
-Normalizes an array across batch and spatial dimensions.
-
-<b> `BatchNormTraining(operand, scale, offset, epsilon, feature_index)` </b>
-
-Arguments | Type | Semantics
---------------- | ------- | ----------------------------------------
-`operand` | `XlaOp` | n dimensional array to be normalized (x)
-`scale` | `XlaOp` | 1 dimensional array (\\(\gamma\\))
-`offset` | `XlaOp` | 1 dimensional array (\\(\beta\\))
-`epsilon` | `float` | Epsilon value (\\(\epsilon\\))
-`feature_index` | `int64` | Index to feature dimension in `operand`
-
-For each feature in the feature dimension (`feature_index` is the index for the
-feature dimension in `operand`), the operation calculates the mean and variance
-across all the other dimensions and uses the mean and variance to normalize each
-element in `operand`. The `feature_index` must be a valid index for the feature
-dimension in `operand`.
-
-The algorithm goes as follows for each batch in `operand` \\(x\\) that
-contains `m` elements with `w` and `h` as the size of spatial dimensions
-(assuming `operand` is an 4 dimensional array):
-
-- Calculates batch mean \\(\mu_l\\) for each feature `l` in feature dimension:
-\\(\mu_l=\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h x_{ijkl}\\)
-
-- Calculates batch variance \\(\sigma^2_l\\):
-\\(\sigma^2_l=\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h (x_{ijkl} - \mu_l)^2\\)
-
-- Normalizes, scales and shifts:
-\\(y_{ijkl}=\frac{\gamma_l(x_{ijkl}-\mu_l)}{\sqrt[2]{\sigma^2_l+\epsilon}}+\beta_l\\)
-
-The epsilon value, usually a small number, is added to avoid divide-by-zero errors.
-
-The output type is a tuple of three `XlaOp`s:
-
-| Outputs | Type | Semantics |
-| ------------ | ----------------------- | -------------------------------------|
-| `output` | `XlaOp` | n dimensional array with the same |
-: : : shape as input `operand` (y) :
-| `batch_mean` | `XlaOp` | 1 dimensional array (\\(\mu\\)) |
-| `batch_var` | `XlaOp` | 1 dimensional array (\\(\sigma^2\\)) |
-
-The `batch_mean` and `batch_var` are moments calculated across the batch and
-spatial dimensions using the formulas above.
-
-## BitcastConvertType
-
-See also
-[`XlaBuilder::BitcastConvertType`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Similar to a `tf.bitcast` in TensorFlow, performs an element-wise bitcast
-operation from a data shape to a target shape. The dimensions must match, and
-the conversion is an element-wise one; e.g. `s32` elements become `f32` elements
-via bitcast routine. Bitcast is implemented as a low-level cast, so machines
-with different floating-point representations will give different results.
-
-<b> `BitcastConvertType(operand, new_element_type)` </b>
-
-Arguments | Type | Semantics
------------------- | --------------- | ---------------------------
-`operand` | `XlaOp` | array of type T with dims D
-`new_element_type` | `PrimitiveType` | type U
-
-The dimensions of the operand and the target shape must match. The bit-width of
-the source and destination element types must be equal. The source
-and destination element types must not be tuples.
-
-## Broadcast
-
-See also
-[`XlaBuilder::Broadcast`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Adds dimensions to an array by duplicating the data in the array.
-
-<b> `Broadcast(operand, broadcast_sizes)` </b>
-
-Arguments | Type | Semantics
------------------ | ------------------- | -------------------------------
-`operand` | `XlaOp` | The array to duplicate
-`broadcast_sizes` | `ArraySlice<int64>` | The sizes of the new dimensions
-
-The new dimensions are inserted on the left, i.e. if `broadcast_sizes` has
-values `{a0, ..., aN}` and the operand shape has dimensions `{b0, ..., bM}` then
-the shape of the output has dimensions `{a0, ..., aN, b0, ..., bM}`.
-
-The new dimensions index into copies of the operand, i.e.
-
-```
-output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
-```
-
-For example, if `operand` is a scalar `f32` with value `2.0f`, and
-`broadcast_sizes` is `{2, 3}`, then the result will be an array with shape
-`f32[2, 3]` and all the values in the result will be `2.0f`.
-
-## Call
-
-See also
-[`XlaBuilder::Call`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Invokes a computation with the given arguments.
-
-<b> `Call(computation, args...)` </b>
-
-| Arguments | Type | Semantics |
-| ------------- | ---------------------- | ----------------------------------- |
-| `computation` | `XlaComputation` | computation of type `T_0, T_1, ..., |
-: : : T_N -> S` with N parameters of :
-: : : arbitrary type :
-| `args` | sequence of N `XlaOp`s | N arguments of arbitrary type |
-
-The arity and types of the `args` must match the parameters of the
-`computation`. It is allowed to have no `args`.
-
-## Clamp
-
-See also
-[`XlaBuilder::Clamp`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Clamps an operand to within the range between a minimum and maximum value.
-
-<b> `Clamp(min, operand, max)` </b>
-
-Arguments | Type | Semantics
---------- | ------- | ---------------
-`min` | `XlaOp` | array of type T
-`operand` | `XlaOp` | array of type T
-`max` | `XlaOp` | array of type T
-
-Given an operand and minimum and maximum values, returns the operand if it is in
-the range between the minimum and maximum, else returns the minimum value if the
-operand is below this range or the maximum value if the operand is above this
-range. That is, `clamp(a, x, b) = min(max(a, x), b)`.
-
-All three arrays must be the same shape. Alternatively, as a restricted form of
-[broadcasting](broadcasting.md), `min` and/or `max` can be a scalar of type `T`.
-
-Example with scalar `min` and `max`:
-
-```
-let operand: s32[3] = {-1, 5, 9};
-let min: s32 = 0;
-let max: s32 = 6;
-==>
-Clamp(min, operand, max) = s32[3]{0, 5, 6};
-```
-
-## Collapse
-
-See also
-[`XlaBuilder::Collapse`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
-and the `tf.reshape` operation.
-
-Collapses dimensions of an array into one dimension.
-
-<b> `Collapse(operand, dimensions)` </b>
-
-Arguments | Type | Semantics
------------- | -------------- | -----------------------------------------------
-`operand` | `XlaOp` | array of type T
-`dimensions` | `int64` vector | in-order, consecutive subset of T's dimensions.
-
-Collapse replaces the given subset of the operand's dimensions by a single
-dimension. The input arguments are an arbitrary array of type T and a
-compile-time-constant vector of dimension indices. The dimension indices must be
-an in-order (low to high dimension numbers), consecutive subset of T's
-dimensions. Thus, {0, 1, 2}, {0, 1}, or {1, 2} are all valid dimension sets, but
-{1, 0} or {0, 2} are not. They are replaced by a single new dimension, in the
-same position in the dimension sequence as those they replace, with the new
-dimension size equal to the product of original dimension sizes. The lowest
-dimension number in `dimensions` is the slowest varying dimension (most major)
-in the loop nest which collapses these dimension, and the highest dimension
-number is fastest varying (most minor). See the `tf.reshape` operator
-if more general collapse ordering is needed.
-
-For example, let v be an array of 24 elements:
-
-```
-let v = f32[4x2x3] {{{10, 11, 12}, {15, 16, 17}},
- {{20, 21, 22}, {25, 26, 27}},
- {{30, 31, 32}, {35, 36, 37}},
- {{40, 41, 42}, {45, 46, 47}}};
-
-// Collapse to a single dimension, leaving one dimension.
-let v012 = Collapse(v, {0,1,2});
-then v012 == f32[24] {10, 11, 12, 15, 16, 17,
- 20, 21, 22, 25, 26, 27,
- 30, 31, 32, 35, 36, 37,
- 40, 41, 42, 45, 46, 47};
-
-// Collapse the two lower dimensions, leaving two dimensions.
-let v01 = Collapse(v, {0,1});
-then v01 == f32[4x6] {{10, 11, 12, 15, 16, 17},
- {20, 21, 22, 25, 26, 27},
- {30, 31, 32, 35, 36, 37},
- {40, 41, 42, 45, 46, 47}};
-
-// Collapse the two higher dimensions, leaving two dimensions.
-let v12 = Collapse(v, {1,2});
-then v12 == f32[8x3] {{10, 11, 12},
- {15, 16, 17},
- {20, 21, 22},
- {25, 26, 27},
- {30, 31, 32},
- {35, 36, 37},
- {40, 41, 42},
- {45, 46, 47}};
-
-```
-
-## Concatenate
-
-See also
-[`XlaBuilder::ConcatInDim`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Concatenate composes an array from multiple array operands. The array is of the
-same rank as each of the input array operands (which must be of the same rank as
-each other) and contains the arguments in the order that they were specified.
-
-<b> `Concatenate(operands..., dimension)` </b>
-
-| Arguments | Type | Semantics |
-| ----------- | --------------------- | -------------------------------------- |
-| `operands` | sequence of N `XlaOp` | N arrays of type T with dimensions |
-: : : [L0, L1, ...]. Requires N >= 1. :
-| `dimension` | `int64` | A value in the interval `[0, N)` that |
-: : : names the dimension to be concatenated :
-: : : between the `operands`. :
-
-With the exception of `dimension` all dimensions must be the same. This is
-because XLA does not support "ragged" arrays. Also note that rank-0 values
-cannot be concatenated (as it's impossible to name the dimension along which the
-concatenation occurs).
-
-1-dimensional example:
-
-```
-Concat({{2, 3}, {4, 5}, {6, 7}}, 0)
->>> {2, 3, 4, 5, 6, 7}
-```
-
-2-dimensional example:
-
-```
-let a = {
- {1, 2},
- {3, 4},
- {5, 6},
-};
-let b = {
- {7, 8},
-};
-Concat({a, b}, 0)
->>> {
- {1, 2},
- {3, 4},
- {5, 6},
- {7, 8},
-}
-```
-
-Diagram:
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:100%" src="https://www.tensorflow.org/images/ops_concatenate.png">
-</div>
-
-## Conditional
-
-See also
-[`XlaBuilder::Conditional`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-<b> `Conditional(pred, true_operand, true_computation, false_operand,
-false_computation)` </b>
-
-Arguments | Type | Semantics
-------------------- | ---------------- | ---------------------------------
-`pred` | `XlaOp` | Scalar of type `PRED`
-`true_operand` | `XlaOp` | Argument of type `T_0`
-`true_computation` | `XlaComputation` | XlaComputation of type `T_0 -> S`
-`false_operand` | `XlaOp` | Argument of type `T_1`
-`false_computation` | `XlaComputation` | XlaComputation of type `T_1 -> S`
-
-Executes `true_computation` if `pred` is `true`, `false_computation` if `pred`
-is `false`, and returns the result.
-
-The `true_computation` must take in a single argument of type `T_0` and will be
-invoked with `true_operand` which must be of the same type. The
-`false_computation` must take in a single argument of type `T_1` and will be
-invoked with `false_operand` which must be of the same type. The type of the
-returned value of `true_computation` and `false_computation` must be the same.
-
-Note that only one of `true_computation` and `false_computation` will be
-executed depending on the value of `pred`.
-
-## Conv (convolution)
-
-See also
-[`XlaBuilder::Conv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-As ConvWithGeneralPadding, but the padding is specified in a short-hand way as
-either SAME or VALID. SAME padding pads the input (`lhs`) with zeroes so that
-the output has the same shape as the input when not taking striding into
-account. VALID padding simply means no padding.
-
-## ConvWithGeneralPadding (convolution)
-
-See also
-[`XlaBuilder::ConvWithGeneralPadding`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Computes a convolution of the kind used in neural networks. Here, a convolution
-can be thought of as a n-dimensional window moving across a n-dimensional base
-area and a computation is performed for each possible position of the window.
-
-| Arguments | Type | Semantics |
-| --------------------- | -------------------- | ----------------------------- |
-| `lhs` | `XlaOp` | rank n+2 array of inputs |
-| `rhs` | `XlaOp` | rank n+2 array of kernel |
-: : : weights :
-| `window_strides` | `ArraySlice<int64>` | n-d array of kernel strides |
-| `padding` | `ArraySlice< | n-d array of (low, high) |
-: : pair<int64, int64>>` : padding :
-| `lhs_dilation` | `ArraySlice<int64>` | n-d lhs dilation factor array |
-| `rhs_dilation` | `ArraySlice<int64>` | n-d rhs dilation factor array |
-| `feature_group_count` | int64 | the number of feature groups |
-
-Let n be the number of spatial dimensions. The `lhs` argument is a rank n+2
-array describing the base area. This is called the input, even though of course
-the rhs is also an input. In a neural network, these are the input activations.
-The n+2 dimensions are, in this order:
-
-* `batch`: Each coordinate in this dimension represents an independent input
- for which convolution is carried out.
-* `z/depth/features`: Each (y,x) position in the base area has a vector
- associated to it, which goes into this dimension.
-* `spatial_dims`: Describes the `n` spatial dimensions that define the base
- area that the window moves across.
-
-The `rhs` argument is a rank n+2 array describing the convolutional
-filter/kernel/window. The dimensions are, in this order:
-
-* `output-z`: The `z` dimension of the output.
-* `input-z`: The size of this dimension times `feature_group_count` should
- equal the size of the `z` dimension in lhs.
-* `spatial_dims`: Describes the `n` spatial dimensions that define the n-d
- window that moves across the base area.
-
-The `window_strides` argument specifies the stride of the convolutional window
-in the spatial dimensions. For example, if the stride in the first spatial
-dimension is 3, then the window can only be placed at coordinates where the
-first spatial index is divisible by 3.
-
-The `padding` argument specifies the amount of zero padding to be applied to the
-base area. The amount of padding can be negative -- the absolute value of
-negative padding indicates the number of elements to remove from the specified
-dimension before doing the convolution. `padding[0]` specifies the padding for
-dimension `y` and `padding[1]` specifies the padding for dimension `x`. Each
-pair has the low padding as the first element and the high padding as the second
-element. The low padding is applied in the direction of lower indices while the
-high padding is applied in the direction of higher indices. For example, if
-`padding[1]` is `(2,3)` then there will be a padding by 2 zeroes on the left and
-by 3 zeroes on the right in the second spatial dimension. Using padding is
-equivalent to inserting those same zero values into the input (`lhs`) before
-doing the convolution.
-
-The `lhs_dilation` and `rhs_dilation` arguments specify the dilation factor to
-be applied to the lhs and rhs, respectively, in each spatial dimension. If the
-dilation factor in a spatial dimension is d, then d-1 holes are implicitly
-placed between each of the entries in that dimension, increasing the size of the
-array. The holes are filled with a no-op value, which for convolution means
-zeroes.
-
-Dilation of the rhs is also called atrous convolution. For more details, see
-`tf.nn.atrous_conv2d`. Dilation of the lhs is also called transposed
-convolution. For more details, see `tf.nn.conv2d_transpose`.
-
-The `feature_group_count` argument (default value 1) can be used for grouped
-convolutions. `feature_group_count` needs to be a divisor of both the input and
-the output feature dimension. If `feature_group_count` is greater than 1, it
-means that conceptually the input and output feature dimension and the `rhs`
-output feature dimension are split evenly into `feature_group_count` many
-groups, each group consisting of a consecutive subsequence of features. The
-input feature dimension of `rhs` needs to be equal to the `lhs` input feature
-dimension divided by `feature_group_count` (so it already has the size of a
-group of input features). The i-th groups are used together to compute
-`feature_group_count` many separate convolutions. The results of these
-convolutions are concatenated together in the output feature dimension.
-
-For depthwise convolution the `feature_group_count` argument would be set to the
-input feature dimension, and the filter would be reshaped from
-`[filter_height, filter_width, in_channels, channel_multiplier]` to
-`[filter_height, filter_width, 1, in_channels * channel_multiplier]`. For more
-details, see `tf.nn.depthwise_conv2d`.
-
-The output shape has these dimensions, in this order:
-
-* `batch`: Same size as `batch` on the input (`lhs`).
-* `z`: Same size as `output-z` on the kernel (`rhs`).
-* `spatial_dims`: One value for each valid placement of the convolutional
- window.
-
-The valid placements of the convolutional window are determined by the strides
-and the size of the base area after padding.
-
-To describe what a convolution does, consider a 2d convolution, and pick some
-fixed `batch`, `z`, `y`, `x` coordinates in the output. Then `(y,x)` is a
-position of a corner of the window within the base area (e.g. the upper left
-corner, depending on how you interpret the spatial dimensions). We now have a 2d
-window, taken from the base area, where each 2d point is associated to a 1d
-vector, so we get a 3d box. From the convolutional kernel, since we fixed the
-output coordinate `z`, we also have a 3d box. The two boxes have the same
-dimensions, so we can take the sum of the element-wise products between the two
-boxes (similar to a dot product). That is the output value.
-
-Note that if `output-z` is e.g., 5, then each position of the window produces 5
-values in the output into the `z` dimension of the output. These values differ
-in what part of the convolutional kernel is used - there is a separate 3d box of
-values used for each `output-z` coordinate. So you could think of it as 5
-separate convolutions with a different filter for each of them.
-
-Here is pseudo-code for a 2d convolution with padding and striding:
-
-```
-for (b, oz, oy, ox) { // output coordinates
- value = 0;
- for (iz, ky, kx) { // kernel coordinates and input z
- iy = oy*stride_y + ky - pad_low_y;
- ix = ox*stride_x + kx - pad_low_x;
- if ((iy, ix) inside the base area considered without padding) {
- value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx);
- }
- }
- output(b, oz, oy, ox) = value;
-}
-```
-
-## ConvertElementType
-
-See also
-[`XlaBuilder::ConvertElementType`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Similar to an element-wise `static_cast` in C++, performs an element-wise
-conversion operation from a data shape to a target shape. The dimensions must
-match, and the conversion is an element-wise one; e.g. `s32` elements become
-`f32` elements via an `s32`-to-`f32` conversion routine.
-
-<b> `ConvertElementType(operand, new_element_type)` </b>
-
-Arguments | Type | Semantics
------------------- | --------------- | ---------------------------
-`operand` | `XlaOp` | array of type T with dims D
-`new_element_type` | `PrimitiveType` | type U
-
-The dimensions of the operand and the target shape must match. The source and
-destination element types must not be tuples.
-
-A conversion such as `T=s32` to `U=f32` will perform a normalizing int-to-float
-conversion routine such as round-to-nearest-even.
-
-> Note: The precise float-to-int and visa-versa conversions are currently
-> unspecified, but may become additional arguments to the convert operation in
-> the future. Not all possible conversions have been implemented for all
->targets.
-
-```
-let a: s32[3] = {0, 1, 2};
-let b: f32[3] = convert(a, f32);
-then b == f32[3]{0.0, 1.0, 2.0}
-```
-
-## CrossReplicaSum
-
-See also
-[`XlaBuilder::CrossReplicaSum`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Computes a sum across replicas.
-
-<b> `CrossReplicaSum(operand)` </b>
-
-Arguments | Type | Semantics
---------- | ------- | -----------------------------
-`operand` | `XlaOp` | Array to sum across replicas.
-| `replica_group_ids` | `int64` vector | Group ID for each replica. |
-
-The output shape is the same as the input shape. For example, if there are two
-replicas and the operand has the value `(1.0, 2.5)` and `(3.0, 5.25)`
-respectively on the two replicas, then the output value from this op will be
-`(4.0, 7.75)` on both replicas.
-
-`replica_group_ids` identifies the group ID of each replica. The group ID must
-either be empty (all replicas belong to a single group), or contain the same
-number of elements as the number of replicas. For example, if
-`replica_group_ids` = {0, 1, 2, 3, 0, 1, 2, 3} has eight replicas, there are
-four subgroups of replica IDs: {0, 4}, {1, 5}, {2, 6}, and {3, 7}. The size of
-each subgroup *must* be identical, so, for example, using:
-`replica_group_ids` = {0, 1, 2, 0} for four replicas is invalid.
-
-Computing the result of CrossReplicaSum requires having one input from each
-replica, so if one replica executes a CrossReplicaSum node more times than
-another, then the former replica will wait forever. Since the replicas are all
-running the same program, there are not a lot of ways for that to happen, but it
-is possible when a while loop's condition depends on data from infeed and the
-data that is infed causes the while loop to iterate more times on one replica
-than another.
-
-## CustomCall
-
-See also
-[`XlaBuilder::CustomCall`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Call a user-provided function within a computation.
-
-<b> `CustomCall(target_name, args..., shape)` </b>
-
-| Arguments | Type | Semantics |
-| ------------- | ---------------------- | --------------------------------- |
-| `target_name` | `string` | Name of the function. A call |
-: : : instruction will be emitted which :
-: : : targets this symbol name. :
-| `args` | sequence of N `XlaOp`s | N arguments of arbitrary type, |
-: : : which will be passed to the :
-: : : function. :
-| `shape` | `Shape` | Output shape of the function |
-
-The function signature is the same, regardless of the arity or type of args:
-
-```
-extern "C" void target_name(void* out, void** in);
-```
-
-For example, if CustomCall is used as follows:
-
-```
-let x = f32[2] {1,2};
-let y = f32[2x3] {{10, 20, 30}, {40, 50, 60}};
-
-CustomCall("myfunc", {x, y}, f32[3x3])
-```
-
-Here is an example of an implementation of `myfunc`:
-
-```
-extern "C" void myfunc(void* out, void** in) {
- float (&x)[2] = *static_cast<float(*)[2]>(in[0]);
- float (&y)[2][3] = *static_cast<float(*)[2][3]>(in[1]);
- EXPECT_EQ(1, x[0]);
- EXPECT_EQ(2, x[1]);
- EXPECT_EQ(10, y[0][0]);
- EXPECT_EQ(20, y[0][1]);
- EXPECT_EQ(30, y[0][2]);
- EXPECT_EQ(40, y[1][0]);
- EXPECT_EQ(50, y[1][1]);
- EXPECT_EQ(60, y[1][2]);
- float (&z)[3][3] = *static_cast<float(*)[3][3]>(out);
- z[0][0] = x[1] + y[1][0];
- // ...
-}
-```
-
-The user-provided function must not have side-effects and its execution must be
-idempotent.
-
-> Note: The opaque nature of the user-provided function restricts optimization
-> opportunities for the compiler. Try to express your computation in terms of
-> native XLA ops whenever possible; only use CustomCall as a last resort.
-
-## Dot
-
-See also
-[`XlaBuilder::Dot`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-<b> `Dot(lhs, rhs)` </b>
-
-Arguments | Type | Semantics
---------- | ------- | ---------------
-`lhs` | `XlaOp` | array of type T
-`rhs` | `XlaOp` | array of type T
-
-The exact semantics of this operation depend on the ranks of the operands:
-
-| Input | Output | Semantics |
-| ----------------------- | --------------------- | ----------------------- |
-| vector [n] `dot` vector | scalar | vector dot product |
-: [n] : : :
-| matrix [m x k] `dot` | vector [m] | matrix-vector |
-: vector [k] : : multiplication :
-| matrix [m x k] `dot` | matrix [m x n] | matrix-matrix |
-: matrix [k x n] : : multiplication :
-
-The operation performs sum of products over the last dimension of `lhs` and the
-one-before-last dimension of `rhs`. These are the "contracted" dimensions. The
-contracted dimensions of `lhs` and `rhs` must be of the same size. In practice,
-it can be used to perform dot products between vectors, vector/matrix
-multiplications or matrix/matrix multiplications.
-
-## DotGeneral
-
-See also
-[`XlaBuilder::DotGeneral`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-<b> `DotGeneral(lhs, rhs, dimension_numbers)` </b>
-
-Arguments | Type | Semantics
-------------------- | --------------------- | ---------------
-`lhs` | `XlaOp` | array of type T
-`rhs` | `XlaOp` | array of type T
-`dimension_numbers` | `DotDimensionNumbers` | array of type T
-
-As Dot, but allows contracting and batch dimension numbers to be specified for
-both the 'lhs' and 'rhs'.
-
-| DotDimensionNumbers Fields | Type | Semantics
-| --------- | ----------------------- | ---------------
-| 'lhs_contracting_dimensions' | repeated int64 | 'lhs' contracting dimension numbers |
-| 'rhs_contracting_dimensions' | repeated int64 | 'rhs' contracting dimension numbers |
-| 'lhs_batch_dimensions' | repeated int64 | 'lhs' batch dimension numbers |
-| 'rhs_batch_dimensions' | repeated int64 | 'rhs' batch dimension numbers |
-
-DotGeneral performs the sum of products over contracting dimensions specified
-in 'dimension_numbers'.
-
-Associated contracting dimension numbers from the 'lhs' and 'rhs' do not need
-to be the same, but must be listed in the same order in both
-'lhs/rhs_contracting_dimensions' arrays and have the same dimension sizes.
-There must be exactly one contracting dimension on both 'lhs' and 'rhs'.
-
-Example with contracting dimension numbers:
-
-```
-lhs = { {1.0, 2.0, 3.0},
- {4.0, 5.0, 6.0} }
-
-rhs = { {1.0, 1.0, 1.0},
- {2.0, 2.0, 2.0} }
-
-DotDimensionNumbers dnums;
-dnums.add_lhs_contracting_dimensions(1);
-dnums.add_rhs_contracting_dimensions(1);
-
-DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0},
- {15.0, 30.0} }
-```
-
-Associated batch dimension numbers from the 'lhs' and 'rhs' must have the same
-dimension number, must be listed in the same order in both arrays, must
-have the same dimension sizes, and must be ordered before contracting and
-non-contracting/non-batch dimension numbers.
-
-Example with batch dimension numbers (batch size 2, 2x2 matrices):
-
-```
-lhs = { { {1.0, 2.0},
- {3.0, 4.0} },
- { {5.0, 6.0},
- {7.0, 8.0} } }
-
-rhs = { { {1.0, 0.0},
- {0.0, 1.0} },
- { {1.0, 0.0},
- {0.0, 1.0} } }
-
-DotDimensionNumbers dnums;
-dnums.add_lhs_contracting_dimensions(2);
-dnums.add_rhs_contracting_dimensions(1);
-dnums.add_lhs_batch_dimensions(0);
-dnums.add_rhs_batch_dimensions(0);
-
-DotGeneral(lhs, rhs, dnums) -> { { {1.0, 2.0},
- {3.0, 4.0} },
- { {5.0, 6.0},
- {7.0, 8.0} } }
-```
-
-| Input | Output | Semantics |
-| ----------------------------------- | ----------------- | ---------------- |
-| [b0, m, k] `dot` [b0, k, n] | [b0, m, n] | batch matmul |
-| [b0, b1, m, k] `dot` [b0, b1, k, n] | [b0, b1, m, n] | batch matmul |
-
-It follows that the resulting dimension number starts with the batch dimension,
-then the 'lhs' non-contracting/non-batch dimension, and finally the 'rhs'
-non-contracting/non-batch dimension.
-
-## DynamicSlice
-
-See also
-[`XlaBuilder::DynamicSlice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-DynamicSlice extracts a sub-array from the input array at dynamic
-`start_indices`. The size of the slice in each dimension is passed in
-`size_indices`, which specify the end point of exclusive slice intervals in each
-dimension: [start, start + size). The shape of `start_indices` must be rank ==
-1, with dimension size equal to the rank of `operand`.
-
-<b> `DynamicSlice(operand, start_indices, size_indices)` </b>
-
-| Arguments | Type | Semantics |
-| --------------- | ------------------- | ----------------------------------- |
-| `operand` | `XlaOp` | N dimensional array of type T |
-| `start_indices` | `XlaOp` | Rank 1 array of N integers |
-: : : containing the starting indices of :
-: : : the slice for each dimension. Value :
-: : : must be greater than or equal to :
-: : : zero. :
-| `size_indices` | `ArraySlice<int64>` | List of N integers containing the |
-: : : slice size for each dimension. Each :
-: : : value must be strictly greater than :
-: : : zero, and start + size must be less :
-: : : than or equal to the size of the :
-: : : dimension to avoid wrapping modulo :
-: : : dimension size. :
-
-The effective slice indices are computed by applying the following
-transformation for each index `i` in `[1, N)` before performing the slice:
-
-```
-start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - size_indices[i])
-```
-
-This ensures that the extracted slice is always in-bounds with respect to the
-operand array. If the slice is in-bounds before the transformation is applied,
-the transformation has no effect.
-
-1-dimensional example:
-
-```
-let a = {0.0, 1.0, 2.0, 3.0, 4.0}
-let s = {2}
-
-DynamicSlice(a, s, {2}) produces:
- {2.0, 3.0}
-```
-
-2-dimensional example:
-
-```
-let b =
- { {0.0, 1.0, 2.0},
- {3.0, 4.0, 5.0},
- {6.0, 7.0, 8.0},
- {9.0, 10.0, 11.0} }
-let s = {2, 1}
-
-DynamicSlice(b, s, {2, 2}) produces:
- { { 7.0, 8.0},
- {10.0, 11.0} }
-```
-## DynamicUpdateSlice
-
-See also
-[`XlaBuilder::DynamicUpdateSlice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-DynamicUpdateSlice generates a result which is the value of the input array
-`operand`, with a slice `update` overwritten at `start_indices`.
-The shape of `update` determines the shape of the sub-array of the result which
-is updated.
-The shape of `start_indices` must be rank == 1, with dimension size equal to
-the rank of `operand`.
-
-<b> `DynamicUpdateSlice(operand, update, start_indices)` </b>
-
-| Arguments | Type | Semantics |
-| --------------- | ------- | ------------------------------------------------ |
-| `operand` | `XlaOp` | N dimensional array of type T |
-| `update` | `XlaOp` | N dimensional array of type T containing the |
-: : : slice update. Each dimension of update shape :
-: : : must be strictly greater than zero, and start + :
-: : : update must be less than or equal to the operand :
-: : : size for each dimension to avoid generating :
-: : : out-of-bounds update indices. :
-| `start_indices` | `XlaOp` | Rank 1 array of N integers containing the |
-: : : starting indices of the slice for each :
-: : : dimension. Value must be greater than or equal :
-: : : to zero. :
-
-The effective slice indices are computed by applying the following
-transformation for each index `i` in `[1, N)` before performing the slice:
-
-```
-start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - update.dimension_size[i])
-```
-
-This ensures that the updated slice is always in-bounds with respect to the
-operand array. If the slice is in-bounds before the transformation is applied,
-the transformation has no effect.
-
-1-dimensional example:
-
-```
-let a = {0.0, 1.0, 2.0, 3.0, 4.0}
-let u = {5.0, 6.0}
-let s = {2}
-
-DynamicUpdateSlice(a, u, s) produces:
- {0.0, 1.0, 5.0, 6.0, 4.0}
-```
-
-2-dimensional example:
-
-```
-let b =
- { {0.0, 1.0, 2.0},
- {3.0, 4.0, 5.0},
- {6.0, 7.0, 8.0},
- {9.0, 10.0, 11.0} }
-let u =
- { {12.0, 13.0},
- {14.0, 15.0},
- {16.0, 17.0} }
-
-let s = {1, 1}
-
-DynamicUpdateSlice(b, u, s) produces:
- { {0.0, 1.0, 2.0},
- {3.0, 12.0, 13.0},
- {6.0, 14.0, 15.0},
- {9.0, 16.0, 17.0} }
-```
-
-## Element-wise binary arithmetic operations
-
-See also
-[`XlaBuilder::Add`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-A set of element-wise binary arithmetic operations is supported.
-
-<b> `Op(lhs, rhs)` </b>
-
-Where `Op` is one of `Add` (addition), `Sub` (subtraction), `Mul`
-(multiplication), `Div` (division), `Rem` (remainder), `Max` (maximum), `Min`
-(minimum), `LogicalAnd` (logical AND), or `LogicalOr` (logical OR).
-
-Arguments | Type | Semantics
---------- | ------- | ----------------------------------------
-`lhs` | `XlaOp` | left-hand-side operand: array of type T
-`rhs` | `XlaOp` | right-hand-side operand: array of type T
-
-The arguments' shapes have to be either similar or compatible. See the
-[broadcasting](../../performance/xla/broadcasting.md) documentation about what it means for shapes to
-be compatible. The result of an operation has a shape which is the result of
-broadcasting the two input arrays. In this variant, operations between arrays of
-different ranks are *not* supported, unless one of the operands is a scalar.
-
-When `Op` is `Rem`, the sign of the result is taken from the dividend, and the
-absolute value of the result is always less than the divisor's absolute value.
-
-Integer division overflow (signed/unsigned division/remainder by zero or signed
-divison/remainder of `INT_SMIN` with `-1`) produces an implementation defined
-value.
-
-An alternative variant with different-rank broadcasting support exists for these
-operations:
-
-<b> `Op(lhs, rhs, broadcast_dimensions)` </b>
-
-Where `Op` is the same as above. This variant of the operation should be used
-for arithmetic operations between arrays of different ranks (such as adding a
-matrix to a vector).
-
-The additional `broadcast_dimensions` operand is a slice of integers used to
-expand the rank of the lower-rank operand up to the rank of the higher-rank
-operand. `broadcast_dimensions` maps the dimensions of the lower-rank shape to
-the dimensions of the higher-rank shape. The unmapped dimensions of the expanded
-shape are filled with dimensions of size one. Degenerate-dimension broadcasting
-then broadcasts the shapes along these degenerate dimensions to equalize the
-shapes of both operands. The semantics are described in detail on the
-[broadcasting page](../../performance/xla/broadcasting.md).
-
-## Element-wise comparison operations
-
-See also
-[`XlaBuilder::Eq`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-A set of standard element-wise binary comparison operations is supported. Note
-that standard IEEE 754 floating-point comparison semantics apply when comparing
-floating-point types.
-
-<b> `Op(lhs, rhs)` </b>
-
-Where `Op` is one of `Eq` (equal-to), `Ne` (not equal-to), `Ge`
-(greater-or-equal-than), `Gt` (greater-than), `Le` (less-or-equal-than), `Lt`
-(less-than).
-
-Arguments | Type | Semantics
---------- | ------- | ----------------------------------------
-`lhs` | `XlaOp` | left-hand-side operand: array of type T
-`rhs` | `XlaOp` | right-hand-side operand: array of type T
-
-The arguments' shapes have to be either similar or compatible. See the
-[broadcasting](../../performance/xla/broadcasting.md) documentation about what it means for shapes to
-be compatible. The result of an operation has a shape which is the result of
-broadcasting the two input arrays with the element type `PRED`. In this variant,
-operations between arrays of different ranks are *not* supported, unless one of
-the operands is a scalar.
-
-An alternative variant with different-rank broadcasting support exists for these
-operations:
-
-<b> `Op(lhs, rhs, broadcast_dimensions)` </b>
-
-Where `Op` is the same as above. This variant of the operation should be used
-for comparison operations between arrays of different ranks (such as adding a
-matrix to a vector).
-
-The additional `broadcast_dimensions` operand is a slice of integers specifying
-the dimensions to use for broadcasting the operands. The semantics are described
-in detail on the [broadcasting page](../../performance/xla/broadcasting.md).
-
-## Element-wise unary functions
-
-XlaBuilder supports these element-wise unary functions:
-
-<b>`Abs(operand)`</b> Element-wise abs `x -> |x|`.
-
-<b>`Ceil(operand)`</b> Element-wise ceil `x -> ⌈x⌉`.
-
-<b>`Cos(operand)`</b> Element-wise cosine `x -> cos(x)`.
-
-<b>`Exp(operand)`</b> Element-wise natural exponential `x -> e^x`.
-
-<b>`Floor(operand)`</b> Element-wise floor `x -> ⌊x⌋`.
-
-<b>`IsFinite(operand)`</b> Tests whether each element of `operand` is finite,
-i.e., is not positive or negative infinity, and is not `NaN`. Returns an array
-of `PRED` values with the same shape as the input, where each element is `true`
-if and only if the corresponding input element is finite.
-
-<b>`Log(operand)`</b> Element-wise natural logarithm `x -> ln(x)`.
-
-<b>`LogicalNot(operand)`</b> Element-wise logical not `x -> !(x)`.
-
-<b>`Neg(operand)`</b> Element-wise negation `x -> -x`.
-
-<b>`Sign(operand)`</b> Element-wise sign operation `x -> sgn(x)` where
-
-$$\text{sgn}(x) = \begin{cases} -1 & x < 0\\ 0 & x = 0\\ 1 & x > 0 \end{cases}$$
-
-using the comparison operator of the element type of `operand`.
-
-<b>`Tanh(operand)`</b> Element-wise hyperbolic tangent `x -> tanh(x)`.
-
-
-Arguments | Type | Semantics
---------- | ------- | ---------------------------
-`operand` | `XlaOp` | The operand to the function
-
-The function is applied to each element in the `operand` array, resulting in an
-array with the same shape. It is allowed for `operand` to be a scalar (rank 0).
-
-## Gather
-
-The XLA gather operation stitches together several slices (each slice at a
-potentially different runtime offset) of an input array.
-
-### General Semantics
-
-See also
-[`XlaBuilder::Gather`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-For a more intuitive description, see the "Informal Description" section below.
-
-<b> `gather(operand, start_indices, offset_dims, collapsed_slice_dims, slice_sizes, start_index_map)` </b>
-
-|Arguments | Type | Semantics |
-|----------------- | ----------------------- | --------------------------------|
-|`operand` | `XlaOp` | The array we’re gathering |
-: : : from. :
-|`start_indices` | `XlaOp` | Array containing the starting |
-: : : indices of the slices we gather.:
-|`index_vector_dim` | `int64` | The dimension in |
-: : : `start_indices` that "contains" :
-: : : the starting indices. See :
-: : : below for a detailed :
-: : : description. :
-|`offset_dims` | `ArraySlice<int64>` | The set of dimensions in the :
-: : : output shape that offset into a :
-: : : array sliced from operand. :
-|`slice_sizes` | `ArraySlice<int64>` | `slice_sizes[i]` is the bounds |
-: : : for the slice on dimension `i`.:
-|`collapsed_slice_dims` | `ArraySlice<int64>` | The set of dimensions in each :
-| : | slice that are collapsed away. :
-| : | These dimensions must have size:
-| : | 1. |
-|`start_index_map` | `ArraySlice<int64>` | A map that describes how to map|
-: : : indices in `start_indices` to :
-: : : to legal indices into operand. :
-
-For convenience, we label dimensions in the output array not in `offset_dims`
-as `batch_dims`.
-
-The output is an array of rank `batch_dims.size` + `operand.rank` -
-`collapsed_slice_dims`.size.
-
-If `index_vector_dim` is equal to `start_indices.rank` we implicitly consider
-`start_indices` to have a trailing `1` dimension (i.e. if `start_indices` was of
-shape `[6,7]` and `index_vector_dim` is `2` then we implicitly consider the
-shape of `start_indices` to be `[6,7,1]`).
-
-The bounds for the output array along dimension `i` is computed as follows:
-
- 1. If `i` is present in `batch_dims` (i.e. is equal to `batch_dims[k]` for
- some `k`) then we pick the corresponding dimension bounds out of
- `start_indices.shape`, skipping `index_vector_dim` (i.e. pick
- `start_indices.shape.dims`[`k`] if `k` < `index_vector_dim` and
- `start_indices.shape.dims`[`k`+`1`] otherwise).
-
- 2. If `i` is present in `offset_dims` (i.e. equal to `offset_dims`[`k`] for
- some `k`) then we pick the corresponding bound out of `slice_sizes` after
- accounting for `collapsed_slice_dims` (i.e. we pick
- `adjusted_slice_sizes`[`k`] where `adjusted_slice_sizes` is `slice_sizes`
- with the bounds at indices `collapsed_slice_dims` removed).
-
-Formally, the operand index `In` corresponding to an output index `Out` is
-computed as follows:
-
- 1. Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out
- vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where
- Combine(A, b) inserts b at position `index_vector_dim` into A. Note that
- this is well defined even if `G` is empty -- if `G` is empty then `S` =
- `start_indices`.
-
- 2. Create a starting index, `S`<sub>`in`</sub>, into `operand` using `S` by
- scattering `S` using `start_index_map`. More precisely:
- 1. `S`<sub>`in`</sub>[`start_index_map`[`k`]] = `S`[`k`] if `k` <
- `start_index_map.size`.
- 2. `S`<sub>`in`</sub>[`_`] = `0` otherwise.
-
- 3. Create an index `O`<sub>`in`</sub> into `operand` by scattering the indices
- at the offset dimensions in `Out` according to the `collapsed_slice_dims`
- set. More precisely:
- 1. `O`<sub>`in`</sub>[`expand_offset_dims`(`k`)] =
- `Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size`
- (`expand_offset_dims` is defined below).
- 2. `O`<sub>`in`</sub>[`_`] = `0` otherwise.
- 4. `In` is `O`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise
- addition.
-
-`expand_offset_dims` is the monotonic function with domain [`0`, `offset.size`)
-and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So if, e.g.,
-`offset.size` is `4`, `operand.rank` is `6` and `collapsed_slice_dims` is {`0`,
-`2`} then `expand_offset_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}.
-
-### Informal Description and Examples
-
-Informally, every index `Out` in the output array corresponds to an element `E`
-in the operand array, computed as follows:
-
- - We use the batch dimensions in `Out` to look up a starting index from
- `start_indices`.
-
- - We use `start_index_map` to map the starting index (which may have size less
- than operand.rank) to a "full" starting index into operand.
-
- - We dynamic-slice out a slice with size `slice_sizes` using the full starting
- index.
-
- - We reshape the slice by collapsing the `collapsed_slice_dims` dimensions.
- Since all collapsed slice dimensions have to have bound 1 this reshape is
- always legal.
-
- - We use the offset dimensions in `Out` to index into this slice to get the
- input element, `E`, corresponding to output index `Out`.
-
-`index_vector_dim` is set to `start_indices.rank` - `1` in all of the
-examples that follow. More interesting values for `index_vector_dim` does not
-change the operation fundamentally, but makes the visual representation more
-cumbersome.
-
-To get an intuition on how all of the above fits together, let's look at an
-example that gathers 5 slices of shape `[8,6]` from a `[16,11]` array. The
-position of a slice into the `[16,11]` array can be represented as an index
-vector of shape `S64[2]`, so the set of 5 positions can be represented as a
-`S64[5,2]` array.
-
-The behavior of the gather operation can then be depicted as an index
-transformation that takes [`G`,`O`<sub>`0`</sub>,`O`<sub>`1`</sub>], an index in
-the output shape, and maps it to an element in the input array in the following
-way:
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:100%" src="../../images/ops_xla_gather_0.svg">
-</div>
-
-We first select an (`X`,`Y`) vector from the gather indices array using `G`.
-The element in the output array at index
-[`G`,`O`<sub>`0`</sub>,`O`<sub>`1`</sub>] is then the element in the input
-array at index [`X`+`O`<sub>`0`</sub>,`Y`+`O`<sub>`1`</sub>].
-
-`slice_sizes` is `[8,6]`, which decides the range of W<sub>`0`</sub> and
-W<sub>`1`</sub>, and this in turn decides the bounds of the slice.
-
-This gather operation acts as a batch dynamic slice with `G` as the batch
-dimension.
-
-The gather indices may be multidimensional. For instance, a more general
-version of the example above using a "gather indices" array of shape `[4,5,2]`
-would translate indices like this:
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:100%" src="../../images/ops_xla_gather_1.svg">
-</div>
-
-Again, this acts as a batch dynamic slice `G`<sub>`0`</sub> and
-`G`<sub>`1`</sub> as the batch dimensions. The slice size is still `[8,6]`.
-
-The gather operation in XLA generalizes the informal semantics outlined above in
-the following ways:
-
- 1. We can configure which dimensions in the output shape are the offset
- dimensions (dimensions containing `O`<sub>`0`</sub>, `O`<sub>`1`</sub> in
- the last example). The output batch dimensions (dimensions containing
- `G`<sub>`0`</sub>, `G`<sub>`1`</sub> in the last example) are defined to be
- the output dimensions that are not offset dimensions.
-
- 2. The number of output offset dimensions explicitly present in the output
- shape may be smaller than the input rank. These "missing" dimensions, which
- are listed explicitly as `collapsed_slice_dims`, must have a slice size of
- `1`. Since they have a slice size of `1` the only valid index for them is
- `0` and eliding them does not introduce ambiguity.
-
- 3. The slice extracted from the "Gather Indices" array ((`X`, `Y`) in the last
- example) may have fewer elements than the input array rank, and an explicit
- mapping dictates how the index should be expanded to have the same rank as
- the input.
-
-As a final example, we use (2) and (3) to implement `tf.gather_nd`:
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:100%" src="../../images/ops_xla_gather_2.svg">
-</div>
-
-`G`<sub>`0`</sub> and `G`<sub>`1`</sub> are used to slice out a starting index
-from the gather indices array as usual, except the starting index has only one
-element, `X`. Similarly, there is only one output offset index with the value
-`O`<sub>`0`</sub>. However, before being used as indices into the input array,
-these are expanded in accordance to "Gather Index Mapping" (`start_index_map` in
-the formal description) and "Offset Mapping" (`expand_offset_dims` in the formal
-description) into [`0`,`O`<sub>`0`</sub>] and [`X`,`0`] respectively, adding up
-to [`X`,`O`<sub>`0`</sub>]. In other words, the output index
-[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`O`<sub>`0`</sub>] maps to the input index
-[`GatherIndices`[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`0`],`X`] which gives us
-the semantics for `tf.gather_nd`.
-
-`slice_sizes` for this case is `[1,11]`. Intuitively this means that every
-index `X` in the gather indices array picks an entire row and the result is the
-concatenation of all these rows.
-
-## GetTupleElement
-
-See also
-[`XlaBuilder::GetTupleElement`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Indexes into a tuple with a compile-time-constant value.
-
-The value must be a compile-time-constant so that shape inference can determine
-the type of the resulting value.
-
-This is analogous to `std::get<int N>(t)` in C++. Conceptually:
-
-```
-let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
-let s: s32 = 5;
-let t: (f32[10], s32) = tuple(v, s);
-let element_1: s32 = gettupleelement(t, 1); // Inferred shape matches s32.
-```
-
-See also `tf.tuple`.
-
-## Infeed
-
-See also
-[`XlaBuilder::Infeed`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-<b> `Infeed(shape)` </b>
-
-| Argument | Type | Semantics |
-| -------- | ------- | ----------------------------------------------------- |
-| `shape` | `Shape` | Shape of the data read from the Infeed interface. The |
-: : : layout field of the shape must be set to match the :
-: : : layout of the data sent to the device; otherwise its :
-: : : behavior is undefined. :
-
-Reads a single data item from the implicit Infeed streaming interface of the
-device, interpreting the data as the given shape and its layout, and returns a
-`XlaOp` of the data. Multiple Infeed operations are allowed in a
-computation, but there must be a total order among the Infeed operations. For
-example, two Infeeds in the code below have a total order since there is a
-dependency between the while loops.
-
-```
-result1 = while (condition, init = init_value) {
- Infeed(shape)
-}
-
-result2 = while (condition, init = result1) {
- Infeed(shape)
-}
-```
-
-Nested tuple shapes are not supported. For an empty tuple shape, the Infeed
-operation is effectively a no-op and proceeds without reading any data from the
-Infeed of the device.
-
-> Note: We plan to allow multiple Infeed operations without a total order, in
-> which case the compiler will provide information about how the Infeed
-> operations are serialized in the compiled program.
-
-## Iota
-
-<b> `Iota()` </b>
-
-Builds a constant literal on device rather than a potentially large host
-transfer. Creates a rank 1 tensor of values starting at zero and incrementing
-by one.
-
-Arguments | Type | Semantics
------------------- | --------------- | ---------------------------
-`type` | `PrimitiveType` | type U
-`size` | `int64` | The number of elements in the tensor.
-
-## Map
-
-See also
-[`XlaBuilder::Map`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-<b> `Map(operands..., computation)` </b>
-
-| Arguments | Type | Semantics |
-| ----------------- | ---------------------- | ------------------------------ |
-| `operands` | sequence of N `XlaOp`s | N arrays of types T_0..T_{N-1} |
-| `computation` | `XlaComputation` | computation of type `T_0, T_1, |
-: : : ..., T_{N + M -1} -> S` with N :
-: : : parameters of type T and M of :
-: : : arbitrary type :
-| `dimensions` | `int64` array | array of map dimensions |
-
-Applies a scalar function over the given `operands` arrays, producing an array
-of the same dimensions where each element is the result of the mapped function
-applied to the corresponding elements in the input arrays.
-
-The mapped function is an arbitrary computation with the restriction that it has
-N inputs of scalar type `T` and a single output with type `S`. The output has
-the same dimensions as the operands except that the element type T is replaced
-with S.
-
-For example: `Map(op1, op2, op3, computation, par1)` maps `elem_out <-
-computation(elem1, elem2, elem3, par1)` at each (multi-dimensional) index in the
-input arrays to produce the output array.
-
-## Pad
-
-See also
-[`XlaBuilder::Pad`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-<b> `Pad(operand, padding_value, padding_config)` </b>
-
-| Arguments | Type | Semantics |
-| ---------------- | --------------- | --------------------------------------- |
-| `operand` | `XlaOp` | array of type `T` |
-| `padding_value` | `XlaOp` | scalar of type `T` to fill in the added |
-: : : padding :
-| `padding_config` | `PaddingConfig` | padding amount on both edges (low, |
-: : : high) and between the elements of each :
-: : : dimension :
-
-Expands the given `operand` array by padding around the array as well as between
-the elements of the array with the given `padding_value`. `padding_config`
-specifies the amount of edge padding and the interior padding for each
-dimension.
-
-`PaddingConfig` is a repeated field of `PaddingConfigDimension`, which contains
-three fields for each dimension: `edge_padding_low`, `edge_padding_high`, and
-`interior_padding`. `edge_padding_low` and `edge_padding_high` specify the
-amount of padding added at the low-end (next to index 0) and the high-end (next
-to the highest index) of each dimension respectively. The amount of edge padding
-can be negative -- the absolute value of negative padding indicates the number
-of elements to remove from the specified dimension. `interior_padding` specifies
-the amount of padding added between any two elements in each dimension. Interior
-padding occurs logically before edge padding, so in the case of negative edge
-padding elements are removed from the interior-padded operand. This operation is
-a no-op if the edge padding pairs are all (0, 0) and the interior padding values
-are all 0. The figure below shows examples of different `edge_padding` and
-`interior_padding` values for a two-dimensional array.
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:100%" src="https://www.tensorflow.org/images/ops_pad.png">
-</div>
-
-## Recv
-
-See also
-[`XlaBuilder::Recv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-<b> `Recv(shape, channel_handle)` </b>
-
-| Arguments | Type | Semantics |
-| ---------------- | --------------- | ------------------------------------ |
-| `shape` | `Shape` | shape of the data to receive |
-| `channel_handle` | `ChannelHandle` | unique identifier for each send/recv pair |
-
-Receives data of the given shape from a `Send` instruction in another
-computation that shares the same channel handle. Returns a
-XlaOp for the received data.
-
-The client API of `Recv` operation represents synchronous communication.
-However, the instruction is internally decomposed into 2 HLO instructions
-(`Recv` and `RecvDone`) to enable asynchronous data transfers. See also
-[`HloInstruction::CreateRecv` and `HloInstruction::CreateRecvDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/hlo_instruction.h).
-
-<b>`Recv(const Shape& shape, int64 channel_id)`</b>
-
-Allocates resources required to receive data from a `Send` instruction with the
-same channel_id. Returns a context for the allocated resources, which is used
-by a following `RecvDone` instruction to wait for the completion of the data
-transfer. The context is a tuple of {receive buffer (shape), request identifier
-(U32)} and it can only be used by a `RecvDone` instruction.
-
-<b> `RecvDone(HloInstruction context)` </b>
-
-Given a context created by a `Recv` instruction, waits for the data transfer to
-complete and returns the received data.
-
-## Reduce
-
-See also
-[`XlaBuilder::Reduce`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Applies a reduction function to one or more arrays in parallel.
-
-<b> `Reduce(operands..., init_values..., computation, dimensions)` </b>
-
-Arguments | Type | Semantics
-------------- | --------------------- | ---------------------------------------
-`operands` | Sequence of N `XlaOp` | N arrays of types `T_0, ..., T_N`.
-`init_values` | Sequence of N `XlaOp` | N scalars of types `T_0, ..., T_N`.
-`computation` | `XlaComputation` | computation of type
- : : `T_0, ..., T_N, T_0, ..., T_N -> Collate(T_0, ..., T_N)`
-`dimensions` | `int64` array | unordered array of dimensions to reduce
-
-Where:
-* N is required to be greater or equal to 1.
-* All input arrays must have the same dimensions.
-* If `N = 1`, `Collate(T)` is `T`.
-* If `N > 1`, `Collate(T_0, ..., T_N)` is a tuple of `N` elements of type `T`.
-
-The output of the op is `Collate(Q_0, ..., Q_N)` where `Q_i` is an array of type
-`T_i`, the dimensions of which are described below.
-
-This operation reduces one or more dimensions of each input array into scalars.
-The rank of each returned array is `rank(operand) - len(dimensions)`.
-`init_value` is the initial value used for every reduction and may be inserted
-anywhere during computation by the back-end. In most cases, `init_value` is an
-identity of the reduction function (for example, 0 for addition). The applied
-`computation` is always passed the `init_value` on the left-hand side.
-
-The evaluation order of the reduction function is arbitrary and may be
-non-deterministic. Therefore, the reduction function should not be overly
-sensitive to reassociation.
-
-Some reduction functions like addition are not strictly associative for floats.
-However, if the range of the data is limited, floating-point addition is close
-enough to being associative for most practical uses. It is possible to conceive
-of some completely non-associative reductions, however, and these will produce
-incorrect or unpredictable results in XLA reductions.
-
-As an example, when reducing across one dimension in a single 1D array with
-values [10, 11, 12, 13], with reduction function `f` (this is `computation`)
-then that could be computed as
-
-`f(10, f(11, f(12, f(init_value, 13)))`
-
-but there are also many other possibilities, e.g.
-
-`f(init_value, f(f(10, f(init_value, 11)), f(f(init_value, 12), f(init_value, 13))))`
-
-The following is a rough pseudo-code example of how reduction could be
-implemented, using summation as the reduction computation with an initial value
-of 0.
-
-```python
-result_shape <- remove all dims in dimensions from operand_shape
-
-# Iterate over all elements in result_shape. The number of r's here is equal
-# to the rank of the result
-for r0 in range(result_shape[0]), r1 in range(result_shape[1]), ...:
- # Initialize this result element
- result[r0, r1...] <- 0
-
- # Iterate over all the reduction dimensions
- for d0 in range(dimensions[0]), d1 in range(dimensions[1]), ...:
- # Increment the result element with the value of the operand's element.
- # The index of the operand's element is constructed from all ri's and di's
- # in the right order (by construction ri's and di's together index over the
- # whole operand shape).
- result[r0, r1...] += operand[ri... di]
-```
-
-Here's an example of reducing a 2D array (matrix). The shape has rank 2,
-dimension 0 of size 2 and dimension 1 of size 3:
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:35%" src="https://www.tensorflow.org/images/ops_2d_matrix.png">
-</div>
-
-Results of reducing dimensions 0 or 1 with an "add" function:
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:35%" src="https://www.tensorflow.org/images/ops_reduce_from_2d_matrix.png">
-</div>
-
-Note that both reduction results are 1D arrays. The diagram shows one as column
-and another as row just for visual convenience.
-
-For a more complex example, here is a 3D array. Its rank is 3, dimension 0 of
-size 4, dimension 1 of size 2 and dimension 2 of size 3. For simplicity, the
-values 1 to 6 are replicated across dimension 0.
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:35%" src="https://www.tensorflow.org/images/ops_reduce_from_3d_matrix.png">
-</div>
-
-Similarly to the 2D example, we can reduce just one dimension. If we reduce
-dimension 0, for example, we get a rank-2 array where all values across
-dimension 0 were folded into a scalar:
-
-```text
-| 4 8 12 |
-| 16 20 24 |
-```
-
-If we reduce dimension 2, we also get a rank-2 array where all values across
-dimension 2 were folded into a scalar:
-
-```text
-| 6 15 |
-| 6 15 |
-| 6 15 |
-| 6 15 |
-```
-
-Note that the relative order between the remaining dimensions in the input is
-preserved in the output, but some dimensions may get assigned new numbers (since
-the rank changes).
-
-We can also reduce multiple dimensions. Add-reducing dimensions 0 and 1 produces
-the 1D array `| 20 28 36 |`.
-
-Reducing the 3D array over all its dimensions produces the scalar `84`.
-
-When `N > 1`, reduce function application is slightly more complex, as it is
-applied simultaneously to all inputs. For example, consider the following
-reduction function, which can be used to compute the max and the argmax of a
-a 1-D tensor in parallel:
-
-```
-f: (Float, Int, Float, Int) -> Float, Int
-f(max, argmax, value, index):
- if value >= argmax:
- return (value, index)
- else:
- return (max, argmax)
-```
-
-For 1-D Input arrays `V = Float[N], K = Int[N]`, and init values
-`I_V = Float, I_K = Int`, the result `f_(N-1)` of reducing across the only
-input dimension is equivalent to the following recursive application:
-```
-f_0 = f(I_V, I_K, V_0, K_0)
-f_1 = f(f_0.first, f_0.second, V_1, K_1)
-...
-f_(N-1) = f(f_(N-2).first, f_(N-2).second, V_(N-1), K_(N-1))
-```
-
-Applying this reduction to an array of values, and an array of sequential
-indices (i.e. iota), will co-iterate over the arrays, and return a tuple
-containing the maximal value and the matching index.
-
-## ReducePrecision
-
-See also
-[`XlaBuilder::ReducePrecision`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Models the effect of converting floating-point values to a lower-precision
-format (such as IEEE-FP16) and back to the original format. The number of
-exponent and mantissa bits in the lower-precision format can be specified
-arbitrarily, although all bit sizes may not be supported on all hardware
-implementations.
-
-<b> `ReducePrecision(operand, mantissa_bits, exponent_bits)` </b>
-
-Arguments | Type | Semantics
---------------- | ------- | -------------------------------------------------
-`operand` | `XlaOp` | array of floating-point type `T`.
-`exponent_bits` | `int32` | number of exponent bits in lower-precision format
-`mantissa_bits` | `int32` | number of mantissa bits in lower-precision format
-
-The result is an array of type `T`. The input values are rounded to the nearest
-value representable with the given number of mantissa bits (using "ties to even"
-semantics), and any values that exceed the range specified by the number of
-exponent bits are clamped to positive or negative infinity. `NaN` values are
-retained, although they may be converted to canonical `NaN` values.
-
-The lower-precision format must have at least one exponent bit (in order to
-distinguish a zero value from an infinity, since both have a zero mantissa), and
-must have a non-negative number of mantissa bits. The number of exponent or
-mantissa bits may exceed the corresponding value for type `T`; the corresponding
-portion of the conversion is then simply a no-op.
-
-## ReduceWindow
-
-See also
-[`XlaBuilder::ReduceWindow`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Applies a reduction function to all elements in each window of the input
-multi-dimensional array, producing an output multi-dimensional array with the
-same number of elements as the number of valid positions of the window. A
-pooling layer can be expressed as a `ReduceWindow`. Similar to
-[`Reduce`](#reduce), the applied `computation` is always passed the `init_value`
-on the left-hand side.
-
-<b> `ReduceWindow(operand, init_value, computation, window_dimensions,
-window_strides, padding)` </b>
-
-| Arguments | Type | Semantics |
-| ------------------- | ------------------- | -------------------------------- |
-| `operand` | `XlaOp` | N dimensional array containing |
-: : : elements of type T. This is the :
-: : : base area on which the window is :
-: : : placed. :
-| `init_value` | `XlaOp` | Starting value for the |
-: : : reduction. See [Reduce](#reduce) :
-: : : for details. :
-| `computation` | `XlaComputation` | Reduction function of type `T, T |
-: : : -> T`, to apply to all elements :
-: : : in each window :
-| `window_dimensions` | `ArraySlice<int64>` | array of integers for window |
-: : : dimension values :
-| `window_strides` | `ArraySlice<int64>` | array of integers for window |
-: : : stride values :
-| `padding` | `Padding` | padding type for window |
-: : : (Padding\:\:kSame or :
-: : : Padding\:\:kValid) :
-
-Below code and figure shows an example of using `ReduceWindow`. Input is a
-matrix of size [4x6] and both window_dimensions and window_stride_dimensions are
-[2x3].
-
-```
-// Create a computation for the reduction (maximum).
-XlaComputation max;
-{
- XlaBuilder builder(client_, "max");
- auto y = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y");
- auto x = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "x");
- builder.Max(y, x);
- max = builder.Build().ConsumeValueOrDie();
-}
-
-// Create a ReduceWindow computation with the max reduction computation.
-XlaBuilder builder(client_, "reduce_window_2x3");
-auto shape = ShapeUtil::MakeShape(F32, {4, 6});
-auto input = builder.Parameter(0, shape, "input");
-builder.ReduceWindow(
- input, *max,
- /*init_val=*/builder.ConstantLiteral(LiteralUtil::MinValue(F32)),
- /*window_dimensions=*/{2, 3},
- /*window_stride_dimensions=*/{2, 3},
- Padding::kValid);
-```
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:35%" src="https://www.tensorflow.org/images/ops_reduce_window.png">
-</div>
-
-Stride of 1 in a dimension specifies that the position of a window in the
-dimension is 1 element away from its adjacent window. In order to specify that
-no windows overlap with each other, window_stride_dimensions should be equal to
-window_dimensions. The figure below illustrates the use of two different stride
-values. Padding is applied to each dimension of the input and the calculations
-are the same as though the input came in with the dimensions it has after
-padding.
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:75%" src="https://www.tensorflow.org/images/ops_reduce_window_stride.png">
-</div>
-
-The evaluation order of the reduction function is arbitrary and may be
-non-deterministic. Therefore, the reduction function should not be overly
-sensitive to reassociation. See the discussion about associativity in the
-context of [`Reduce`](#reduce) for more details.
-
-## Reshape
-
-See also
-[`XlaBuilder::Reshape`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
-and the [`Collapse`](#collapse) operation.
-
-Reshapes the dimensions of an array into a new configuration.
-
-<b> `Reshape(operand, new_sizes)` </b>
-<b> `Reshape(operand, dimensions, new_sizes)` </b>
-
-Arguments | Type | Semantics
------------- | -------------- | ---------------------------------------
-`operand` | `XlaOp` | array of type T
-`dimensions` | `int64` vector | order in which dimensions are collapsed
-`new_sizes` | `int64` vector | vector of sizes of new dimensions
-
-Conceptually, reshape first flattens an array into a one-dimensional vector of
-data values, and then refines this vector into a new shape. The input arguments
-are an arbitrary array of type T, a compile-time-constant vector of dimension
-indices, and a compile-time-constant vector of dimension sizes for the result.
-The values in the `dimension` vector, if given, must be a permutation of all of
-T's dimensions; the default if not given is `{0, ..., rank - 1}`. The order of
-the dimensions in `dimensions` is from slowest-varying dimension (most major) to
-fastest-varying dimension (most minor) in the loop nest which collapses the
-input array into a single dimension. The `new_sizes` vector determines the size
-of the output array. The value at index 0 in `new_sizes` is the size of
-dimension 0, the value at index 1 is the size of dimension 1, and so on. The
-product of the `new_size` dimensions must equal the product of the operand's
-dimension sizes. When refining the collapsed array into the multidimensional
-array defined by `new_sizes`, the dimensions in `new_sizes` are ordered from
-slowest varying (most major) and to fastest varying (most minor).
-
-For example, let v be an array of 24 elements:
-
-```
-let v = f32[4x2x3] {{{10, 11, 12}, {15, 16, 17}},
- {{20, 21, 22}, {25, 26, 27}},
- {{30, 31, 32}, {35, 36, 37}},
- {{40, 41, 42}, {45, 46, 47}}};
-
-In-order collapse:
-let v012_24 = Reshape(v, {0,1,2}, {24});
-then v012_24 == f32[24] {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27,
- 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47};
-
-let v012_83 = Reshape(v, {0,1,2}, {8,3});
-then v012_83 == f32[8x3] {{10, 11, 12}, {15, 16, 17},
- {20, 21, 22}, {25, 26, 27},
- {30, 31, 32}, {35, 36, 37},
- {40, 41, 42}, {45, 46, 47}};
-
-Out-of-order collapse:
-let v021_24 = Reshape(v, {1,2,0}, {24});
-then v012_24 == f32[24] {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42,
- 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47};
-
-let v021_83 = Reshape(v, {1,2,0}, {8,3});
-then v021_83 == f32[8x3] {{10, 20, 30}, {40, 11, 21},
- {31, 41, 12}, {22, 32, 42},
- {15, 25, 35}, {45, 16, 26},
- {36, 46, 17}, {27, 37, 47}};
-
-
-let v021_262 = Reshape(v, {1,2,0}, {2,6,2});
-then v021_262 == f32[2x6x2] {{{10, 20}, {30, 40},
- {11, 21}, {31, 41},
- {12, 22}, {32, 42}},
- {{15, 25}, {35, 45},
- {16, 26}, {36, 46},
- {17, 27}, {37, 47}}};
-```
-
-As a special case, reshape can transform a single-element array to a scalar and
-vice versa. For example,
-
-```
-Reshape(f32[1x1] {{5}}, {0,1}, {}) == 5;
-Reshape(5, {}, {1,1}) == f32[1x1] {{5}};
-```
-
-## Rev (reverse)
-
-See also
-[`XlaBuilder::Rev`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-<b>`Rev(operand, dimensions)`</b>
-
-Arguments | Type | Semantics
------------- | ------------------- | ---------------------
-`operand` | `XlaOp` | array of type T
-`dimensions` | `ArraySlice<int64>` | dimensions to reverse
-
-Reverses the order of elements in the `operand` array along the specified
-`dimensions`, generating an output array of the same shape. Each element of the
-operand array at a multidimensional index is stored into the output array at a
-transformed index. The multidimensional index is transformed by reversing the
-index in each dimension to be reversed (i.e., if a dimension of size N is one of
-the reversing dimensions, its index i is transformed into N - 1 - i).
-
-One use for the `Rev` operation is to reverse the convolution weight array along
-the two window dimensions during the gradient computation in neural networks.
-
-## RngNormal
-
-See also
-[`XlaBuilder::RngNormal`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Constructs an output of a given shape with random numbers generated following
-the $$N(\mu, \sigma)$$ normal distribution. The parameters $$\mu$$ and
-$$\sigma$$, and output shape have to have a floating point elemental type. The
-parameters furthermore have to be scalar valued.
-
-<b>`RngNormal(mu, sigma, shape)`</b>
-
-| Arguments | Type | Semantics |
-| --------- | ------- | --------------------------------------------------- |
-| `mu` | `XlaOp` | Scalar of type T specifying mean of generated |
-: : : numbers :
-| `sigma` | `XlaOp` | Scalar of type T specifying standard deviation of |
-: : : generated numbers :
-| `shape` | `Shape` | Output shape of type T |
-
-## RngUniform
-
-See also
-[`XlaBuilder::RngUniform`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Constructs an output of a given shape with random numbers generated following
-the uniform distribution over the interval $$[a,b)$$. The parameters and output
-element type have to be a boolean type, an integral type or a floating point
-types, and the types have to be consistent. The CPU and GPU backends currently
-only support F64, F32, F16, BF16, S64, U64, S32 and U32. Furthermore, the
-parameters need to be scalar valued. If $$b <= a$$ the result is
-implementation-defined.
-
-<b>`RngUniform(a, b, shape)`</b>
-
-| Arguments | Type | Semantics |
-| --------- | ----------------------- | --------------------------------- |
-| `a` | `XlaOp` | Scalar of type T specifying lower |
-: : : limit of interval :
-| `b` | `XlaOp` | Scalar of type T specifying upper |
-: : : limit of interval :
-| `shape` | `Shape` | Output shape of type T |
-
-## Scatter
-
-The XLA scatter operation generates a result which is the value of the input
-tensor `operand`, with several slices (at indices specified by
-`scatter_indices`) updated with the values in `updates` using
-`update_computation`.
-
-See also
-[`XlaBuilder::Scatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-<b> `scatter(operand, scatter_indices, updates, update_computation, index_vector_dim, update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims)` </b>
-
-|Arguments | Type | Semantics |
-|------------------|------------------------|----------------------------------|
-|`operand` | `XlaOp` | Tensor to be scattered into. |
-|`scatter_indices` | `XlaOp` | Tensor containing the starting |
-: : : indices of the slices that must :
-: : : be scattered to. :
-|`updates` | `XlaOp` | Tensor containing the values that|
-: : : must be used for scattering. :
-|`update_computation`| `XlaComputation` | Computation to be used for |
-: : : combining the existing values in :
-: : : the input tensor and the updates :
-: : : during scatter. This computation :
-: : : should be of type `T, T -> T`. :
-|`index_vector_dim`| `int64` | The dimension in |
-: : : `scatter_indices` that contains :
-: : : the starting indices. :
-|`update_window_dims`| `ArraySlice<int64>` | The set of dimensions in |
-: : : `updates` shape that are _window :
-: : : dimensions_. :
-|`inserted_window_dims`| `ArraySlice<int64>`| The set of _window dimensions_ |
-: : : that must be inserted into :
-: : : `updates` shape. :
-|`scatter_dims_to_operand_dims`| `ArraySlice<int64>` | A dimensions map from |
-: : : the scatter indices to the :
-: : : operand index space. This array :
-: : : is interpreted as mapping `i` to :
-: : : `scatter_dims_to_operand_dims[i]`:
-: : : . It has to be one-to-one and :
-: : : total. :
-
-If `index_vector_dim` is equal to `scatter_indices.rank` we implicitly consider
-`scatter_indices` to have a trailing `1` dimension.
-
-We define `update_scatter_dims` of type `ArraySlice<int64>` as the set of
-dimensions in `updates` shape that are not in `update_window_dims`, in ascending
-order.
-
-The arguments of scatter should follow these constraints:
-
- - `updates` tensor must be of rank `update_window_dims.size +
- scatter_indices.rank - 1`.
-
- - Bounds of dimension `i` in `updates` must conform to the following:
- - If `i` is present in `update_window_dims` (i.e. equal to
- `update_window_dims`[`k`] for some `k`), then the bound of dimension
- `i` in `updates` must not exceed the corresponding bound of `operand`
- after accounting for the `inserted_window_dims` (i.e.
- `adjusted_window_bounds`[`k`], where `adjusted_window_bounds` contains
- the bounds of `operand` with the bounds at indices
- `inserted_window_dims` removed).
- - If `i` is present in `update_scatter_dims` (i.e. equal to
- `update_scatter_dims`[`k`] for some `k`), then the bound of dimension
- `i` in `updates` must be equal to the corresponding bound of
- `scatter_indices`, skipping `index_vector_dim` (i.e.
- `scatter_indices.shape.dims`[`k`], if `k` < `index_vector_dim` and
- `scatter_indices.shape.dims`[`k+1`] otherwise).
-
- - `update_window_dims` must be in ascending order, not have any repeating
- dimension numbers, and be in the range `[0, updates.rank)`.
-
- - `inserted_window_dims` must be in ascending order, not have any
- repeating dimension numbers, and be in the range `[0, operand.rank)`.
-
- - `scatter_dims_to_operand_dims.size` must be equal to
- `scatter_indices`[`index_vector_dim`], and its values must be in the range
- `[0, operand.rank)`.
-
-For a given index `U` in the `updates` tensor, the corresponding index `I` in
-the `operand` tensor into which this update has to be applied is computed as
-follows:
-
- 1. Let `G` = { `U`[`k`] for `k` in `update_scatter_dims` }. Use `G` to look up
- an index vector `S` in the `scatter_indices` tensor such that `S`[`i`] =
- `scatter_indices`[Combine(`G`, `i`)] where Combine(A, b) inserts b at
- positions `index_vector_dim` into A.
- 2. Create an index `S`<sub>`in`</sub> into `operand` using `S` by scattering
- `S` using the `scatter_dims_to_operand_dims` map. More formally:
- 1. `S`<sub>`in`</sub>[`scatter_dims_to_operand_dims`[`k`]] = `S`[`k`] if
- `k` < `scatter_dims_to_operand_dims.size`.
- 2. `S`<sub>`in`</sub>[`_`] = `0` otherwise.
- 3. Create an index `W`<sub>`in`</sub> into `operand` by scattering the indices
- at `update_window_dims` in `U` according to `inserted_window_dims`.
- More formally:
- 1. `W`<sub>`in`</sub>[`window_dims_to_operand_dims`(`k`)] = `U`[`k`] if
- `k` < `update_window_dims.size`, where `window_dims_to_operand_dims`
- is the monotonic function with domain [`0`, `update_window_dims.size`)
- and range [`0`, `operand.rank`) \\ `inserted_window_dims`. (For
- example, if `update_window_dims.size` is `4`, `operand.rank` is `6`,
- and `inserted_window_dims` is {`0`, `2`} then
- `window_dims_to_operand_dims` is {`0`→`1`, `1`→`3`, `2`→`4`,
- `3`→`5`}).
- 2. `W`<sub>`in`</sub>[`_`] = `0` otherwise.
- 4. `I` is `W`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise
- addition.
-
-In summary, the scatter operation can be defined as follows.
-
- - Initialize `output` with `operand`, i.e. for all indices `O` in the
- `operand` tensor:\
- `output`[`O`] = `operand`[`O`]
- - For every index `U` in the `updates` tensor and the corresponding index `O`
- in the `operand` tensor:\
- `output`[`O`] = `update_computation`(`output`[`O`], `updates`[`U`])
-
-The order in which updates are applied is non-deterministic. So, when multiple
-indices in `updates` refer to the same index in `operand`, the corresponding
-value in `output` will be non-deterministic.
-
-Note that the first parameter that is passed into the `update_computation` will
-always be the current value from the `output` tensor and the second parameter
-will always be the value from the `updates` tensor. This is important
-specifically for cases when the `update_computation` is _not commutative_.
-
-Informally, the scatter op can be viewed as an _inverse_ of the gather op, i.e.
-the scatter op updates the elements in the input that are extracted by the
-corresponding gather op.
-
-For a detailed informal description and examples, refer to the
-"Informal Description" section under `Gather`.
-
-## Select
-
-See also
-[`XlaBuilder::Select`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Constructs an output array from elements of two input arrays, based on the
-values of a predicate array.
-
-<b> `Select(pred, on_true, on_false)` </b>
-
-Arguments | Type | Semantics
----------- | ------- | ------------------
-`pred` | `XlaOp` | array of type PRED
-`on_true` | `XlaOp` | array of type T
-`on_false` | `XlaOp` | array of type T
-
-The arrays `on_true` and `on_false` must have the same shape. This is also the
-shape of the output array. The array `pred` must have the same dimensionality as
-`on_true` and `on_false`, with the `PRED` element type.
-
-For each element `P` of `pred`, the corresponding element of the output array is
-taken from `on_true` if the value of `P` is `true`, and from `on_false` if the
-value of `P` is `false`. As a restricted form of [broadcasting]
-(broadcasting.md), `pred` can be a scalar of type `PRED`. In this case, the
-output array is taken wholly from `on_true` if `pred` is `true`, and from
-`on_false` if `pred` is `false`.
-
-Example with non-scalar `pred`:
-
-```
-let pred: PRED[4] = {true, false, false, true};
-let v1: s32[4] = {1, 2, 3, 4};
-let v2: s32[4] = {100, 200, 300, 400};
-==>
-Select(pred, v1, v2) = s32[4]{1, 200, 300, 4};
-```
-
-Example with scalar `pred`:
-
-```
-let pred: PRED = true;
-let v1: s32[4] = {1, 2, 3, 4};
-let v2: s32[4] = {100, 200, 300, 400};
-==>
-Select(pred, v1, v2) = s32[4]{1, 2, 3, 4};
-```
-
-Selections between tuples are supported. Tuples are considered to be scalar
-types for this purpose. If `on_true` and `on_false` are tuples (which must have
-the same shape!) then `pred` has to be a scalar of type `PRED`.
-
-## SelectAndScatter
-
-See also
-[`XlaBuilder::SelectAndScatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-This operation can be considered as a composite operation that first computes
-`ReduceWindow` on the `operand` array to select an element from each window, and
-then scatters the `source` array to the indices of the selected elements to
-construct an output array with the same shape as the operand array. The binary
-`select` function is used to select an element from each window by applying it
-across each window, and it is called with the property that the first
-parameter's index vector is lexicographically less than the second parameter's
-index vector. The `select` function returns `true` if the first parameter is
-selected and returns `false` if the second parameter is selected, and the
-function must hold transitivity (i.e., if `select(a, b)` and `select(b, c)` are
-`true`, then `select(a, c)` is also `true`) so that the selected element does
-not depend on the order of the elements traversed for a given window.
-
-The function `scatter` is applied at each selected index in the output array. It
-takes two scalar parameters:
-
-1. Current value at the selected index in the output array
-2. The scatter value from `source` that applies to the selected index
-
-It combines the two parameters and returns a scalar value that's used to update
-the value at the selected index in the output array. Initially, all indices of
-the output array are set to `init_value`.
-
-The output array has the same shape as the `operand` array and the `source`
-array must have the same shape as the result of applying a `ReduceWindow`
-operation on the `operand` array. `SelectAndScatter` can be used to
-backpropagate the gradient values for a pooling layer in a neural network.
-
-<b>`SelectAndScatter(operand, select, window_dimensions, window_strides,
-padding, source, init_value, scatter)`</b>
-
-| Arguments | Type | Semantics |
-| ------------------- | ------------------- | -------------------------------- |
-| `operand` | `XlaOp` | array of type T over which the |
-: : : windows slide :
-| `select` | `XlaComputation` | binary computation of type `T, T |
-: : : -> PRED`, to apply to all :
-: : : elements in each window; returns :
-: : : `true` if the first parameter is :
-: : : selected and returns `false` if :
-: : : the second parameter is selected :
-| `window_dimensions` | `ArraySlice<int64>` | array of integers for window |
-: : : dimension values :
-| `window_strides` | `ArraySlice<int64>` | array of integers for window |
-: : : stride values :
-| `padding` | `Padding` | padding type for window |
-: : : (Padding\:\:kSame or :
-: : : Padding\:\:kValid) :
-| `source` | `XlaOp` | array of type T with the values |
-: : : to scatter :
-| `init_value` | `XlaOp` | scalar value of type T for the |
-: : : initial value of the output :
-: : : array :
-| `scatter` | `XlaComputation` | binary computation of type `T, T |
-: : : -> T`, to apply each scatter :
-: : : source element with its :
-: : : destination element :
-
-The figure below shows examples of using `SelectAndScatter`, with the `select`
-function computing the maximal value among its parameters. Note that when the
-windows overlap, as in the figure (2) below, an index of the `operand` array may
-be selected multiple times by different windows. In the figure, the element of
-value 9 is selected by both of the top windows (blue and red) and the binary
-addition `scatter` function produces the output element of value 8 (2 + 6).
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:100%"
- src="https://www.tensorflow.org/images/ops_scatter_to_selected_window_element.png">
-</div>
-
-The evaluation order of the `scatter` function is arbitrary and may be
-non-deterministic. Therefore, the `scatter` function should not be overly
-sensitive to reassociation. See the discussion about associativity in the
-context of [`Reduce`](#reduce) for more details.
-
-## Send
-
-See also
-[`XlaBuilder::Send`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-<b> `Send(operand, channel_handle)` </b>
-
-Arguments | Type | Semantics
----------------- | --------------- | -----------------------------------------
-`operand` | `XlaOp` | data to send (array of type T)
-`channel_handle` | `ChannelHandle` | unique identifier for each send/recv pair
-
-Sends the given operand data to a `Recv` instruction in another computation
-that shares the same channel handle. Does not return any data.
-
-Similar to the `Recv` operation, the client API of `Send` operation represents
-synchronous communication, and is internally decomposed into 2 HLO instructions
-(`Send` and `SendDone`) to enable asynchronous data transfers. See also
-[`HloInstruction::CreateSend` and `HloInstruction::CreateSendDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/hlo_instruction.h).
-
-<b>`Send(HloInstruction operand, int64 channel_id)`</b>
-
-Initiates an asynchronous transfer of the operand to the resources allocated by
-the `Recv` instruction with the same channel id. Returns a context, which is
-used by a following `SendDone` instruction to wait for the completion of the
-data transfer. The context is a tuple of {operand (shape), request identifier
-(U32)} and it can only be used by a `SendDone` instruction.
-
-<b> `SendDone(HloInstruction context)` </b>
-
-Given a context created by a `Send` instruction, waits for the data transfer to
-complete. The instruction does not return any data.
-
-<b> Scheduling of channel instructions </b>
-
-The execution order of the 4 instructions for each channel (`Recv`, `RecvDone`,
-`Send`, `SendDone`) is as below.
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:70%" src="../../images/send_recv_order.png">
-</div>
-
-* `Recv` happens before `Send`
-* `Send` happens before `RecvDone`
-* `Recv` happens before `RecvDone`
-* `Send` happens before `SendDone`
-
-When the backend compilers generate a linear schedule for each computation that
-communicates via channel instructions, there must not be cycles across the
-computations. For example, below schedules lead to deadlocks.
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:100%" src="../../images/send_recv_schedule.png">
-</div>
-
-## Slice
-
-See also
-[`XlaBuilder::Slice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-Slicing extracts a sub-array from the input array. The sub-array is of the same
-rank as the input and contains the values inside a bounding box within the input
-array where the dimensions and indices of the bounding box are given as
-arguments to the slice operation.
-
-<b> `Slice(operand, start_indices, limit_indices)` </b>
-
-| Arguments | Type | Semantics |
-| --------------- | ------------------- | ------------------------------------ |
-| `operand` | `XlaOp` | N dimensional array of type T |
-| `start_indices` | `ArraySlice<int64>` | List of N integers containing the |
-: : : starting indices of the slice for :
-: : : each dimension. Values must be :
-: : : greater than or equal to zero. :
-| `limit_indices` | `ArraySlice<int64>` | List of N integers containing the |
-: : : ending indices (exclusive) for the :
-: : : slice for each dimension. Each value :
-: : : must be greater than or equal to the :
-: : : respective `start_indices` value for :
-: : : the dimension and less than or equal :
-: : : to the size of the dimension. :
-
-1-dimensional example:
-
-```
-let a = {0.0, 1.0, 2.0, 3.0, 4.0}
-Slice(a, {2}, {4}) produces:
- {2.0, 3.0}
-```
-
-2-dimensional example:
-
-```
-let b =
- { {0.0, 1.0, 2.0},
- {3.0, 4.0, 5.0},
- {6.0, 7.0, 8.0},
- {9.0, 10.0, 11.0} }
-
-Slice(b, {2, 1}, {4, 3}) produces:
- { { 7.0, 8.0},
- {10.0, 11.0} }
-```
-
-## Sort
-
-See also
-[`XlaBuilder::Sort`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-There are two versions of the Sort instruction: a single-operand and a
-two-operand version.
-
-<b>`Sort(operand)`</b>
-
-Arguments | Type | Semantics
------------ | ------- | --------------------
-`operand` | `XlaOp` | The operand to sort.
-`dimension` | `int64` | The dimension along which to sort.
-
-Sorts the elements in the operand in ascending order along the provided
-dimension. For example, for a rank-2 (matrix) operand, a `dimension` value of 0
-will sort each column independently, and a `dimension` value of 1 will sort each
-row independently. If the operand's elements have floating point type, and the
-operand contains NaN elements, the order of elements in the output is
-implementation-defined.
-
-<b>`Sort(key, value)`</b>
-
-Sorts both the key and the value operands. The keys are sorted as in the
-single-operand version. The values are sorted according to the order of their
-corresponding keys. For example, if the inputs are `keys = [3, 1]` and
-`values = [42, 50]`, then the output of the sort is the tuple
-`{[1, 3], [50, 42]}`.
-
-The sort is not guaranteed to be stable, that is, if the keys array contains
-duplicates, the order of their corresponding values may not be preserved.
-
-Arguments | Type | Semantics
------------ | ------- | -------------------
-`keys` | `XlaOp` | The sort keys.
-`values` | `XlaOp` | The values to sort.
-`dimension` | `int64` | The dimension along which to sort.
-
-The `keys` and `values` must have the same dimensions, but may have different
-element types.
-
-## Transpose
-
-See also the `tf.reshape` operation.
-
-<b>`Transpose(operand)`</b>
-
-Arguments | Type | Semantics
-------------- | ------------------- | ------------------------------
-`operand` | `XlaOp` | The operand to transpose.
-`permutation` | `ArraySlice<int64>` | How to permute the dimensions.
-
-
-Permutes the operand dimensions with the given permutation, so
-`∀ i . 0 ≤ i < rank ⇒ input_dimensions[permutation[i]] = output_dimensions[i]`.
-
-This is the same as Reshape(operand, permutation,
- Permute(permutation, operand.shape.dimensions)).
-
-## Tuple
-
-See also
-[`XlaBuilder::Tuple`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-A tuple containing a variable number of data handles, each of which has its own
-shape.
-
-This is analogous to `std::tuple` in C++. Conceptually:
-
-```
-let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
-let s: s32 = 5;
-let t: (f32[10], s32) = tuple(v, s);
-```
-
-Tuples can be deconstructed (accessed) via the [`GetTupleElement`]
-(#gettupleelement) operation.
-
-## While
-
-See also
-[`XlaBuilder::While`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
-
-<b> `While(condition, body, init)` </b>
-
-| Arguments | Type | Semantics |
-| ----------- | ---------------- | ---------------------------------------- |
-| `condition` | `XlaComputation` | XlaComputation of type `T -> PRED` which |
-: : : defines the termination condition of the :
-: : : loop. :
-| `body` | `XlaComputation` | XlaComputation of type `T -> T` which |
-: : : defines the body of the loop. :
-| `init` | `T` | Initial value for the parameter of |
-: : : `condition` and `body`. :
-
-Sequentially executes the `body` until the `condition` fails. This is similar to
-a typical while loop in many other languages except for the differences and
-restrictions listed below.
-
-* A `While` node returns a value of type `T`, which is the result from the
- last execution of the `body`.
-* The shape of the type `T` is statically determined and must be the same
- across all iterations.
-
-The T parameters of the computations are initialized with the `init` value in
-the first iteration and are automatically updated to the new result from `body`
-in each subsequent iteration.
-
-One main use case of the `While` node is to implement the repeated execution of
-training in neural networks. Simplified pseudocode is shown below with a graph
-that represents the computation. The code can be found in
-[`while_test.cc`](https://www.tensorflow.org/code/tensorflow/compiler/xla/tests/while_test.cc).
-The type `T` in this example is a `Tuple` consisting of an `int32` for the
-iteration count and a `vector[10]` for the accumulator. For 1000 iterations, the
-loop keeps adding a constant vector to the accumulator.
-
-```
-// Pseudocode for the computation.
-init = {0, zero_vector[10]} // Tuple of int32 and float[10].
-result = init;
-while (result(0) < 1000) {
- iteration = result(0) + 1;
- new_vector = result(1) + constant_vector[10];
- result = {iteration, new_vector};
-}
-```
-
-<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
- <img style="width:100%" src="https://www.tensorflow.org/images/ops_while.png">
-</div>
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index b4d4db3e4d..fe99915a6c 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -7221,6 +7221,45 @@ func MultiDeviceIteratorGetNextFromShard(scope *Scope, multi_device_iterator tf.
return components
}
+// Deprecated. Use TensorArrayGradV3
+//
+// DEPRECATED at GraphDef version 26: Use TensorArrayWriteV3
+func TensorArrayWriteV2(scope *Scope, handle tf.Output, index tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "TensorArrayWriteV2",
+ Input: []tf.Input{
+ handle, index, value, flow_in,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Writes the given dataset to the given file using the TFRecord format.
+//
+// Arguments:
+// input_dataset: A variant tensor representing the dataset to write.
+// filename: A scalar string tensor representing the filename to use.
+// compression_type: A scalar string tensor containing either (i) the empty string (no
+// compression), (ii) "ZLIB", or (iii) "GZIP".
+//
+// Returns the created operation.
+func DatasetToTFRecord(scope *Scope, input_dataset tf.Output, filename tf.Output, compression_type tf.Output) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "DatasetToTFRecord",
+ Input: []tf.Input{
+ input_dataset, filename, compression_type,
+ },
+ }
+ return scope.AddOperation(opspec)
+}
+
// Computes rectified linear 6: `min(max(features, 0), 6)`.
func Relu6(scope *Scope, features tf.Output) (activations tf.Output) {
if scope.Err() != nil {
@@ -8251,44 +8290,6 @@ func BiasAddGrad(scope *Scope, out_backprop tf.Output, optional ...BiasAddGradAt
return op.Output(0)
}
-// Bucketizes 'input' based on 'boundaries'.
-//
-// For example, if the inputs are
-// boundaries = [0, 10, 100]
-// input = [[-5, 10000]
-// [150, 10]
-// [5, 100]]
-//
-// then the output will be
-// output = [[0, 3]
-// [3, 2]
-// [1, 3]]
-//
-// Arguments:
-// input: Any shape of Tensor contains with int or float type.
-// boundaries: A sorted list of floats gives the boundary of the buckets.
-//
-// Returns Same shape with 'input', each value of input replaced with bucket index.
-//
-// @compatibility(numpy)
-// Equivalent to np.digitize.
-// @end_compatibility
-func Bucketize(scope *Scope, input tf.Output, boundaries []float32) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"boundaries": boundaries}
- opspec := tf.OpSpec{
- Type: "Bucketize",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// FusedBatchNormV2Attr is an optional argument to FusedBatchNormV2.
type FusedBatchNormV2Attr func(optionalAttr)
@@ -9640,36 +9641,6 @@ func DecodeRaw(scope *Scope, bytes tf.Output, out_type tf.DataType, optional ...
return op.Output(0)
}
-// Returns the element-wise sum of a list of tensors.
-//
-// `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not
-// wait for all of its inputs to be ready before beginning to sum. This can
-// save memory if inputs are ready at different times, since minimum temporary
-// storage is proportional to the output size rather than the inputs size.
-//
-// Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable.
-//
-// Returns a `Tensor` of same shape and type as the elements of `inputs`.
-//
-// Arguments:
-// inputs: A list of `Tensor` objects, each with same shape and type.
-// shape: Shape of elements of `inputs`.
-func AccumulateNV2(scope *Scope, inputs []tf.Output, shape tf.Shape) (sum tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"shape": shape}
- opspec := tf.OpSpec{
- Type: "AccumulateNV2",
- Input: []tf.Input{
- tf.OutputList(inputs),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// RandomShuffleAttr is an optional argument to RandomShuffle.
type RandomShuffleAttr func(optionalAttr)
@@ -10383,206 +10354,65 @@ func ResourceApplyFtrl(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.
return scope.AddOperation(opspec)
}
-// Encode audio data using the WAV file format.
-//
-// This operation will generate a string suitable to be saved out to create a .wav
-// audio file. It will be encoded in the 16-bit PCM format. It takes in float
-// values in the range -1.0f to 1.0f, and any outside that value will be clamped to
-// that range.
-//
-// `audio` is a 2-D float Tensor of shape `[length, channels]`.
-// `sample_rate` is a scalar Tensor holding the rate to use (e.g. 44100).
-//
-// Arguments:
-// audio: 2-D with shape `[length, channels]`.
-// sample_rate: Scalar containing the sample frequency.
-//
-// Returns 0-D. WAV-encoded file contents.
-func EncodeWav(scope *Scope, audio tf.Output, sample_rate tf.Output) (contents tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "EncodeWav",
- Input: []tf.Input{
- audio, sample_rate,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Computes atan of x element-wise.
-func Atan(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Atan",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// ResourceApplyAdaMaxAttr is an optional argument to ResourceApplyAdaMax.
-type ResourceApplyAdaMaxAttr func(optionalAttr)
-
-// ResourceApplyAdaMaxUseLocking sets the optional use_locking attribute to value.
-//
-// value: If `True`, updating of the var, m, and v tensors will be protected
-// by a lock; otherwise the behavior is undefined, but may exhibit less
-// contention.
-// If not specified, defaults to false
-func ResourceApplyAdaMaxUseLocking(value bool) ResourceApplyAdaMaxAttr {
- return func(m optionalAttr) {
- m["use_locking"] = value
- }
-}
-
-// Update '*var' according to the AdaMax algorithm.
+// Locks a mutex resource. The output is the lock. So long as the lock tensor
//
-// m_t <- beta1 * m_{t-1} + (1 - beta1) * g
-// v_t <- max(beta2 * v_{t-1}, abs(g))
-// variable <- variable - learning_rate / (1 - beta1^t) * m_t / (v_t + epsilon)
+// is alive, any other request to use `MutexLock` with this mutex will wait.
//
-// Arguments:
-// var_: Should be from a Variable().
-// m: Should be from a Variable().
-// v: Should be from a Variable().
-// beta1_power: Must be a scalar.
-// lr: Scaling factor. Must be a scalar.
-// beta1: Momentum factor. Must be a scalar.
-// beta2: Momentum factor. Must be a scalar.
-// epsilon: Ridge term. Must be a scalar.
-// grad: The gradient.
+// This is particularly useful for creating a critical section when used in
+// conjunction with `MutexLockIdentity`:
//
-// Returns the created operation.
-func ResourceApplyAdaMax(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, beta1_power tf.Output, lr tf.Output, beta1 tf.Output, beta2 tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdaMaxAttr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ResourceApplyAdaMax",
- Input: []tf.Input{
- var_, m, v, beta1_power, lr, beta1, beta2, epsilon, grad,
- },
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
-// AssertAttr is an optional argument to Assert.
-type AssertAttr func(optionalAttr)
-
-// AssertSummarize sets the optional summarize attribute to value.
+// ```python
//
-// value: Print this many entries of each tensor.
-// If not specified, defaults to 3
-func AssertSummarize(value int64) AssertAttr {
- return func(m optionalAttr) {
- m["summarize"] = value
- }
-}
-
-// Asserts that the given condition is true.
+// mutex = mutex_v2(
+// shared_name=handle_name, container=container, name=name)
//
-// If `condition` evaluates to false, print the list of tensors in `data`.
-// `summarize` determines how many entries of the tensors to print.
+// def execute_in_critical_section(fn, *args, **kwargs):
+// lock = gen_resource_variable_ops.mutex_lock(mutex)
//
-// Arguments:
-// condition: The condition to evaluate.
-// data: The tensors to print out when condition is false.
+// with ops.control_dependencies([lock]):
+// r = fn(*args, **kwargs)
//
-// Returns the created operation.
-func Assert(scope *Scope, condition tf.Output, data []tf.Output, optional ...AssertAttr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Assert",
- Input: []tf.Input{
- condition, tf.OutputList(data),
- },
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
-// Split a `SparseTensor` into `num_split` tensors along one dimension.
+// with ops.control_dependencies(nest.flatten(r)):
+// with ops.colocate_with(mutex):
+// ensure_lock_exists = mutex_lock_identity(lock)
//
-// If the `shape[split_dim]` is not an integer multiple of `num_split`. Slices
-// `[0 : shape[split_dim] % num_split]` gets one extra dimension.
-// For example, if `split_dim = 1` and `num_split = 2` and the input is
+// # Make sure that if any element of r is accessed, all of
+// # them are executed together.
+// r = nest.map_structure(tf.identity, r)
//
-// input_tensor = shape = [2, 7]
-// [ a d e ]
-// [b c ]
+// with ops.control_dependencies([ensure_lock_exists]):
+// return nest.map_structure(tf.identity, r)
+// ```
//
-// Graphically the output tensors are:
+// While `fn` is running in the critical section, no other functions which wish to
+// use this critical section may run.
//
-// output_tensor[0] = shape = [2, 4]
-// [ a ]
-// [b c ]
+// Often the use case is that two executions of the same graph, in parallel,
+// wish to run `fn`; and we wish to ensure that only one of them executes
+// at a time. This is especially important if `fn` modifies one or more
+// variables at a time.
//
-// output_tensor[1] = shape = [2, 3]
-// [ d e ]
-// [ ]
+// It is also useful if two separate functions must share a resource, but we
+// wish to ensure the usage is exclusive.
//
// Arguments:
-// split_dim: 0-D. The dimension along which to split. Must be in the range
-// `[0, rank(shape))`.
-// indices: 2-D tensor represents the indices of the sparse tensor.
-// values: 1-D tensor represents the values of the sparse tensor.
-// shape: 1-D. tensor represents the shape of the sparse tensor.
-// output indices: A list of 1-D tensors represents the indices of the output
-// sparse tensors.
-// num_split: The number of ways to split.
+// mutex: The mutex resource to lock.
//
-// Returns A list of 1-D tensors represents the values of the output sparse
-// tensors.A list of 1-D tensors represents the shape of the output sparse
-// tensors.
-func SparseSplit(scope *Scope, split_dim tf.Output, indices tf.Output, values tf.Output, shape tf.Output, num_split int64) (output_indices []tf.Output, output_values []tf.Output, output_shape []tf.Output) {
+// Returns A tensor that keeps a shared pointer to a lock on the mutex;
+// when the Tensor is destroyed, the use count on the shared pointer is decreased
+// by 1. When it reaches 0, the lock is released.
+func MutexLock(scope *Scope, mutex tf.Output) (mutex_lock tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{"num_split": num_split}
opspec := tf.OpSpec{
- Type: "SparseSplit",
+ Type: "MutexLock",
Input: []tf.Input{
- split_dim, indices, values, shape,
+ mutex,
},
- Attrs: attrs,
}
op := scope.AddOperation(opspec)
- if scope.Err() != nil {
- return
- }
- var idx int
- var err error
- if output_indices, idx, err = makeOutputList(op, idx, "output_indices"); err != nil {
- scope.UpdateErr("SparseSplit", err)
- return
- }
- if output_values, idx, err = makeOutputList(op, idx, "output_values"); err != nil {
- scope.UpdateErr("SparseSplit", err)
- return
- }
- if output_shape, idx, err = makeOutputList(op, idx, "output_shape"); err != nil {
- scope.UpdateErr("SparseSplit", err)
- return
- }
- return output_indices, output_values, output_shape
+ return op.Output(0)
}
// ResourceSparseApplyFtrlV2Attr is an optional argument to ResourceSparseApplyFtrlV2.
@@ -11151,6 +10981,44 @@ func Tan(scope *Scope, x tf.Output) (y tf.Output) {
return op.Output(0)
}
+// Bucketizes 'input' based on 'boundaries'.
+//
+// For example, if the inputs are
+// boundaries = [0, 10, 100]
+// input = [[-5, 10000]
+// [150, 10]
+// [5, 100]]
+//
+// then the output will be
+// output = [[0, 3]
+// [3, 2]
+// [1, 3]]
+//
+// Arguments:
+// input: Any shape of Tensor contains with int or float type.
+// boundaries: A sorted list of floats gives the boundary of the buckets.
+//
+// Returns Same shape with 'input', each value of input replaced with bucket index.
+//
+// @compatibility(numpy)
+// Equivalent to np.digitize.
+// @end_compatibility
+func Bucketize(scope *Scope, input tf.Output, boundaries []float32) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"boundaries": boundaries}
+ opspec := tf.OpSpec{
+ Type: "Bucketize",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// EncodeJpegAttr is an optional argument to EncodeJpeg.
type EncodeJpegAttr func(optionalAttr)
@@ -11699,6 +11567,238 @@ func SparseTensorDenseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output,
return op.Output(0)
}
+// Encode audio data using the WAV file format.
+//
+// This operation will generate a string suitable to be saved out to create a .wav
+// audio file. It will be encoded in the 16-bit PCM format. It takes in float
+// values in the range -1.0f to 1.0f, and any outside that value will be clamped to
+// that range.
+//
+// `audio` is a 2-D float Tensor of shape `[length, channels]`.
+// `sample_rate` is a scalar Tensor holding the rate to use (e.g. 44100).
+//
+// Arguments:
+// audio: 2-D with shape `[length, channels]`.
+// sample_rate: Scalar containing the sample frequency.
+//
+// Returns 0-D. WAV-encoded file contents.
+func EncodeWav(scope *Scope, audio tf.Output, sample_rate tf.Output) (contents tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "EncodeWav",
+ Input: []tf.Input{
+ audio, sample_rate,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Computes atan of x element-wise.
+func Atan(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Atan",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// ResourceApplyAdaMaxAttr is an optional argument to ResourceApplyAdaMax.
+type ResourceApplyAdaMaxAttr func(optionalAttr)
+
+// ResourceApplyAdaMaxUseLocking sets the optional use_locking attribute to value.
+//
+// value: If `True`, updating of the var, m, and v tensors will be protected
+// by a lock; otherwise the behavior is undefined, but may exhibit less
+// contention.
+// If not specified, defaults to false
+func ResourceApplyAdaMaxUseLocking(value bool) ResourceApplyAdaMaxAttr {
+ return func(m optionalAttr) {
+ m["use_locking"] = value
+ }
+}
+
+// Update '*var' according to the AdaMax algorithm.
+//
+// m_t <- beta1 * m_{t-1} + (1 - beta1) * g
+// v_t <- max(beta2 * v_{t-1}, abs(g))
+// variable <- variable - learning_rate / (1 - beta1^t) * m_t / (v_t + epsilon)
+//
+// Arguments:
+// var_: Should be from a Variable().
+// m: Should be from a Variable().
+// v: Should be from a Variable().
+// beta1_power: Must be a scalar.
+// lr: Scaling factor. Must be a scalar.
+// beta1: Momentum factor. Must be a scalar.
+// beta2: Momentum factor. Must be a scalar.
+// epsilon: Ridge term. Must be a scalar.
+// grad: The gradient.
+//
+// Returns the created operation.
+func ResourceApplyAdaMax(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, beta1_power tf.Output, lr tf.Output, beta1 tf.Output, beta2 tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdaMaxAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ResourceApplyAdaMax",
+ Input: []tf.Input{
+ var_, m, v, beta1_power, lr, beta1, beta2, epsilon, grad,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
+// AssertAttr is an optional argument to Assert.
+type AssertAttr func(optionalAttr)
+
+// AssertSummarize sets the optional summarize attribute to value.
+//
+// value: Print this many entries of each tensor.
+// If not specified, defaults to 3
+func AssertSummarize(value int64) AssertAttr {
+ return func(m optionalAttr) {
+ m["summarize"] = value
+ }
+}
+
+// Asserts that the given condition is true.
+//
+// If `condition` evaluates to false, print the list of tensors in `data`.
+// `summarize` determines how many entries of the tensors to print.
+//
+// Arguments:
+// condition: The condition to evaluate.
+// data: The tensors to print out when condition is false.
+//
+// Returns the created operation.
+func Assert(scope *Scope, condition tf.Output, data []tf.Output, optional ...AssertAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Assert",
+ Input: []tf.Input{
+ condition, tf.OutputList(data),
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
+// Split a `SparseTensor` into `num_split` tensors along one dimension.
+//
+// If the `shape[split_dim]` is not an integer multiple of `num_split`. Slices
+// `[0 : shape[split_dim] % num_split]` gets one extra dimension.
+// For example, if `split_dim = 1` and `num_split = 2` and the input is
+//
+// input_tensor = shape = [2, 7]
+// [ a d e ]
+// [b c ]
+//
+// Graphically the output tensors are:
+//
+// output_tensor[0] = shape = [2, 4]
+// [ a ]
+// [b c ]
+//
+// output_tensor[1] = shape = [2, 3]
+// [ d e ]
+// [ ]
+//
+// Arguments:
+// split_dim: 0-D. The dimension along which to split. Must be in the range
+// `[0, rank(shape))`.
+// indices: 2-D tensor represents the indices of the sparse tensor.
+// values: 1-D tensor represents the values of the sparse tensor.
+// shape: 1-D. tensor represents the shape of the sparse tensor.
+// output indices: A list of 1-D tensors represents the indices of the output
+// sparse tensors.
+// num_split: The number of ways to split.
+//
+// Returns A list of 1-D tensors represents the values of the output sparse
+// tensors.A list of 1-D tensors represents the shape of the output sparse
+// tensors.
+func SparseSplit(scope *Scope, split_dim tf.Output, indices tf.Output, values tf.Output, shape tf.Output, num_split int64) (output_indices []tf.Output, output_values []tf.Output, output_shape []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"num_split": num_split}
+ opspec := tf.OpSpec{
+ Type: "SparseSplit",
+ Input: []tf.Input{
+ split_dim, indices, values, shape,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if output_indices, idx, err = makeOutputList(op, idx, "output_indices"); err != nil {
+ scope.UpdateErr("SparseSplit", err)
+ return
+ }
+ if output_values, idx, err = makeOutputList(op, idx, "output_values"); err != nil {
+ scope.UpdateErr("SparseSplit", err)
+ return
+ }
+ if output_shape, idx, err = makeOutputList(op, idx, "output_shape"); err != nil {
+ scope.UpdateErr("SparseSplit", err)
+ return
+ }
+ return output_indices, output_values, output_shape
+}
+
+// Returns the element-wise sum of a list of tensors.
+//
+// `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not
+// wait for all of its inputs to be ready before beginning to sum. This can
+// save memory if inputs are ready at different times, since minimum temporary
+// storage is proportional to the output size rather than the inputs size.
+//
+// Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable.
+//
+// Returns a `Tensor` of same shape and type as the elements of `inputs`.
+//
+// Arguments:
+// inputs: A list of `Tensor` objects, each with same shape and type.
+// shape: Shape of elements of `inputs`.
+func AccumulateNV2(scope *Scope, inputs []tf.Output, shape tf.Shape) (sum tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"shape": shape}
+ opspec := tf.OpSpec{
+ Type: "AccumulateNV2",
+ Input: []tf.Input{
+ tf.OutputList(inputs),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// StatelessTruncatedNormalAttr is an optional argument to StatelessTruncatedNormal.
type StatelessTruncatedNormalAttr func(optionalAttr)
@@ -13925,67 +14025,6 @@ func CudnnRNNBackpropV2(scope *Scope, input tf.Output, input_h tf.Output, input_
return op.Output(0), op.Output(1), op.Output(2), op.Output(3)
}
-// Locks a mutex resource. The output is the lock. So long as the lock tensor
-//
-// is alive, any other request to use `MutexLock` with this mutex will wait.
-//
-// This is particularly useful for creating a critical section when used in
-// conjunction with `MutexLockIdentity`:
-//
-// ```python
-//
-// mutex = mutex_v2(
-// shared_name=handle_name, container=container, name=name)
-//
-// def execute_in_critical_section(fn, *args, **kwargs):
-// lock = gen_resource_variable_ops.mutex_lock(mutex)
-//
-// with ops.control_dependencies([lock]):
-// r = fn(*args, **kwargs)
-//
-// with ops.control_dependencies(nest.flatten(r)):
-// with ops.colocate_with(mutex):
-// ensure_lock_exists = mutex_lock_identity(lock)
-//
-// # Make sure that if any element of r is accessed, all of
-// # them are executed together.
-// r = nest.map_structure(tf.identity, r)
-//
-// with ops.control_dependencies([ensure_lock_exists]):
-// return nest.map_structure(tf.identity, r)
-// ```
-//
-// While `fn` is running in the critical section, no other functions which wish to
-// use this critical section may run.
-//
-// Often the use case is that two executions of the same graph, in parallel,
-// wish to run `fn`; and we wish to ensure that only one of them executes
-// at a time. This is especially important if `fn` modifies one or more
-// variables at a time.
-//
-// It is also useful if two separate functions must share a resource, but we
-// wish to ensure the usage is exclusive.
-//
-// Arguments:
-// mutex: The mutex resource to lock.
-//
-// Returns A tensor that keeps a shared pointer to a lock on the mutex;
-// when the Tensor is destroyed, the use count on the shared pointer is decreased
-// by 1. When it reaches 0, the lock is released.
-func MutexLock(scope *Scope, mutex tf.Output) (mutex_lock tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "MutexLock",
- Input: []tf.Input{
- mutex,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// StringFormatAttr is an optional argument to StringFormat.
type StringFormatAttr func(optionalAttr)
@@ -16807,26 +16846,6 @@ func TopK(scope *Scope, input tf.Output, k int64, optional ...TopKAttr) (values
return op.Output(0), op.Output(1)
}
-// Compute the Hurwitz zeta function \\(\zeta(x, q)\\).
-//
-// The Hurwitz zeta function is defined as:
-//
-//
-// \\(\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}\\)
-func Zeta(scope *Scope, x tf.Output, q tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Zeta",
- Input: []tf.Input{
- x, q,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Returns a list of tensors with the same shapes and contents as the input
//
// tensors.
@@ -18873,6 +18892,26 @@ func MutableDenseHashTableV2(scope *Scope, empty_key tf.Output, value_dtype tf.D
return op.Output(0)
}
+// Compute the Hurwitz zeta function \\(\zeta(x, q)\\).
+//
+// The Hurwitz zeta function is defined as:
+//
+//
+// \\(\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}\\)
+func Zeta(scope *Scope, x tf.Output, q tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Zeta",
+ Input: []tf.Input{
+ x, q,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Inverse fast Fourier transform.
//
// Computes the inverse 1-dimensional discrete Fourier transform over the
@@ -21413,43 +21452,6 @@ func QuantizedResizeBilinear(scope *Scope, images tf.Output, size tf.Output, min
return op.Output(0), op.Output(1), op.Output(2)
}
-// Computes the minimum along segments of a tensor.
-//
-// Read
-// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
-// for an explanation of segments.
-//
-// Computes a tensor such that
-// \\(output_i = \min_j(data_j)\\) where `min` is over `j` such
-// that `segment_ids[j] == i`.
-//
-// If the min is empty for a given segment ID `i`, `output[i] = 0`.
-//
-// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
-// <img style="width:100%" src="https://www.tensorflow.org/images/SegmentMin.png" alt>
-// </div>
-//
-// Arguments:
-//
-// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s
-// first dimension. Values should be sorted and can be repeated.
-//
-// Returns Has same shape as data, except for dimension 0 which
-// has size `k`, the number of segments.
-func SegmentMin(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "SegmentMin",
- Input: []tf.Input{
- data, segment_ids,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// SdcaOptimizerAttr is an optional argument to SdcaOptimizer.
type SdcaOptimizerAttr func(optionalAttr)
@@ -21924,6 +21926,43 @@ func QuantizeDownAndShrinkRange(scope *Scope, input tf.Output, input_min tf.Outp
return op.Output(0), op.Output(1), op.Output(2)
}
+// Computes the minimum along segments of a tensor.
+//
+// Read
+// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// for an explanation of segments.
+//
+// Computes a tensor such that
+// \\(output_i = \min_j(data_j)\\) where `min` is over `j` such
+// that `segment_ids[j] == i`.
+//
+// If the min is empty for a given segment ID `i`, `output[i] = 0`.
+//
+// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+// <img style="width:100%" src="https://www.tensorflow.org/images/SegmentMin.png" alt>
+// </div>
+//
+// Arguments:
+//
+// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s
+// first dimension. Values should be sorted and can be repeated.
+//
+// Returns Has same shape as data, except for dimension 0 which
+// has size `k`, the number of segments.
+func SegmentMin(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SegmentMin",
+ Input: []tf.Input{
+ data, segment_ids,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the sum along segments of a tensor.
//
// Read
@@ -22757,6 +22796,21 @@ func Imag(scope *Scope, input tf.Output, optional ...ImagAttr) (output tf.Output
return op.Output(0)
}
+// Computes hyperbolic tangent of `x` element-wise.
+func Tanh(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Tanh",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the maximum along segments of a tensor.
//
// Read
@@ -22794,21 +22848,6 @@ func SegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.
return op.Output(0)
}
-// Computes hyperbolic tangent of `x` element-wise.
-func Tanh(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Tanh",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Creates a dataset that skips `count` elements from the `input_dataset`.
//
// Arguments:
@@ -29094,6 +29133,17 @@ func TensorArrayGradV2(scope *Scope, handle tf.Output, flow_in tf.Output, source
return op.Output(0)
}
+// SubstrAttr is an optional argument to Substr.
+type SubstrAttr func(optionalAttr)
+
+// SubstrUnit sets the optional unit attribute to value.
+// If not specified, defaults to "BYTE"
+func SubstrUnit(value string) SubstrAttr {
+ return func(m optionalAttr) {
+ m["unit"] = value
+ }
+}
+
// Return substrings from `Tensor` of strings.
//
// For each string in the input `Tensor`, creates a substring starting at index
@@ -29178,15 +29228,20 @@ func TensorArrayGradV2(scope *Scope, handle tf.Output, flow_in tf.Output, source
// len: Scalar defining the number of characters to include in each substring
//
// Returns Tensor of substrings
-func Substr(scope *Scope, input tf.Output, pos tf.Output, len tf.Output) (output tf.Output) {
+func Substr(scope *Scope, input tf.Output, pos tf.Output, len tf.Output, optional ...SubstrAttr) (output tf.Output) {
if scope.Err() != nil {
return
}
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
Type: "Substr",
Input: []tf.Input{
input, pos, len,
},
+ Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
@@ -29862,28 +29917,6 @@ func Cross(scope *Scope, a tf.Output, b tf.Output) (product tf.Output) {
return op.Output(0)
}
-// Writes the given dataset to the given file using the TFRecord format.
-//
-// Arguments:
-// input_dataset: A variant tensor representing the dataset to write.
-// filename: A scalar string tensor representing the filename to use.
-// compression_type: A scalar string tensor containing either (i) the empty string (no
-// compression), (ii) "ZLIB", or (iii) "GZIP".
-//
-// Returns the created operation.
-func DatasetToTFRecord(scope *Scope, input_dataset tf.Output, filename tf.Output, compression_type tf.Output) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "DatasetToTFRecord",
- Input: []tf.Input{
- input_dataset, filename, compression_type,
- },
- }
- return scope.AddOperation(opspec)
-}
-
// AvgPool3DAttr is an optional argument to AvgPool3D.
type AvgPool3DAttr func(optionalAttr)
@@ -31676,23 +31709,6 @@ func TensorArraySizeV3(scope *Scope, handle tf.Output, flow_in tf.Output) (size
return op.Output(0)
}
-// Deprecated. Use TensorArrayGradV3
-//
-// DEPRECATED at GraphDef version 26: Use TensorArrayWriteV3
-func TensorArrayWriteV2(scope *Scope, handle tf.Output, index tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "TensorArrayWriteV2",
- Input: []tf.Input{
- handle, index, value, flow_in,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// SparseReduceMaxAttr is an optional argument to SparseReduceMax.
type SparseReduceMaxAttr func(optionalAttr)
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index fe81254ef7..822d596995 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -2152,6 +2152,7 @@ py_library(
":array_grad",
":array_ops",
":bitwise_ops",
+ ":check_ops",
":cond_v2_impl",
":control_flow_grad",
":control_flow_ops",
@@ -2172,8 +2173,11 @@ py_library(
":random_grad",
":resource_variable_ops",
":spectral_grad",
+ ":tensor_array_ops",
+ ":tensor_util",
":util",
":variable_scope",
+ "//tensorflow/core:protos_all_py",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:tape",
@@ -5193,6 +5197,19 @@ cuda_py_test(
)
cuda_py_test(
+ name = "control_flow_ops_benchmark",
+ srcs = ["ops/control_flow_ops_benchmark.py"],
+ additional_deps = [
+ ":client_testlib",
+ ":constant_op",
+ ":control_flow_ops",
+ ":framework_ops",
+ "//tensorflow/python/eager:function",
+ ],
+ main = "ops/control_flow_ops_benchmark.py",
+)
+
+cuda_py_test(
name = "conv2d_benchmark",
size = "large",
srcs = ["ops/conv2d_benchmark.py"],
diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py
index dc2d419d34..fcdbd0a82c 100644
--- a/tensorflow/python/autograph/core/converter_testing.py
+++ b/tensorflow/python/autograph/core/converter_testing.py
@@ -128,7 +128,13 @@ class TestCase(test.TestCase):
@contextlib.contextmanager
def converted(self, entity, converter_module, namespace, *tf_symbols):
node, ctx = self.prepare(entity, namespace)
- node = converter_module.transform(node, ctx)
+
+ if not isinstance(converter_module, (list, tuple)):
+ converter_module = (converter_module,)
+ for m in converter_module:
+ node = m.transform(node, ctx)
+ node = converter.standard_analysis(node, ctx, is_initial=True)
+
with self.compiled(node, namespace, *tf_symbols) as result:
yield result
diff --git a/tensorflow/python/autograph/operators/py_builtins.py b/tensorflow/python/autograph/operators/py_builtins.py
index 91a2a22cc2..70e59272a9 100644
--- a/tensorflow/python/autograph/operators/py_builtins.py
+++ b/tensorflow/python/autograph/operators/py_builtins.py
@@ -228,5 +228,6 @@ BUILTIN_FUINCTIONS_MAP = {
'len': len_,
'print': print_,
'range': range_,
+ # TODO(mdan): This might make more sense as tf.data.range.
'xrange': range_,
}
diff --git a/tensorflow/python/autograph/pyct/inspect_utils.py b/tensorflow/python/autograph/pyct/inspect_utils.py
index eef74599a7..29c406c248 100644
--- a/tensorflow/python/autograph/pyct/inspect_utils.py
+++ b/tensorflow/python/autograph/pyct/inspect_utils.py
@@ -30,10 +30,14 @@ from tensorflow.python.util import tf_inspect
def isbuiltin(f):
+ """Returns True if the argument is a built-in function."""
# Note these return false for isinstance(f, types.BuiltinFunctionType) so we
# need to specifically check for them.
if f in (range, int, float):
return True
+ if six.PY2:
+ if f in (xrange,):
+ return True
if isinstance(f, types.BuiltinFunctionType):
return True
if tf_inspect.isbuiltin(f):
@@ -63,6 +67,40 @@ def getnamespace(f):
return namespace
+def getqualifiedname(namespace, object_, max_depth=2):
+ """Returns the name by which a value can be referred to in a given namespace.
+
+ This function will recurse inside modules, but it will not search objects for
+ attributes. The recursion depth is controlled by max_depth.
+
+ Args:
+ namespace: Dict[str, Any], the namespace to search into.
+ object_: Any, the value to search.
+ max_depth: Optional[int], a limit to the recursion depth when searching
+ inside modules.
+ Returns: Union[str, None], the fully-qualified name that resolves to the value
+ o, or None if it couldn't be found.
+ """
+ for name, value in namespace.items():
+ # The value may be referenced by more than one symbol, case in which
+ # any symbol will be fine. If the program contains symbol aliases that
+ # change over time, this may capture a symbol that will later point to
+ # something else.
+ # TODO(mdan): Prefer the symbol that matches the value type name.
+ if object_ is value:
+ return name
+
+ # TODO(mdan): Use breadth-first search and avoid visiting modules twice.
+ if max_depth:
+ for name, value in namespace.items():
+ if tf_inspect.ismodule(value):
+ name_in_module = getqualifiedname(value.__dict__, object_,
+ max_depth - 1)
+ if name_in_module is not None:
+ return '{}.{}'.format(name, name_in_module)
+ return None
+
+
def _get_unbound_function(m):
# TODO(mdan): Figure out why six.get_unbound_function fails in some cases.
# The failure case is for tf.keras.Model.
diff --git a/tensorflow/python/autograph/pyct/inspect_utils_test.py b/tensorflow/python/autograph/pyct/inspect_utils_test.py
index f3eb027822..11074debfc 100644
--- a/tensorflow/python/autograph/pyct/inspect_utils_test.py
+++ b/tensorflow/python/autograph/pyct/inspect_utils_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from functools import wraps
+import imp
import six
@@ -127,6 +128,24 @@ class InspectUtilsTest(test.TestCase):
self.assertEqual(ns['closed_over_primitive'], closed_over_primitive)
self.assertTrue('local_var' not in ns)
+ def test_getqualifiedname(self):
+ foo = object()
+ qux = imp.new_module('quxmodule')
+ bar = imp.new_module('barmodule')
+ baz = object()
+ bar.baz = baz
+
+ ns = {
+ 'foo': foo,
+ 'bar': bar,
+ 'qux': qux,
+ }
+
+ self.assertIsNone(inspect_utils.getqualifiedname(ns, inspect_utils))
+ self.assertEqual(inspect_utils.getqualifiedname(ns, foo), 'foo')
+ self.assertEqual(inspect_utils.getqualifiedname(ns, bar), 'bar')
+ self.assertEqual(inspect_utils.getqualifiedname(ns, baz), 'bar.baz')
+
def test_getmethodclass(self):
self.assertEqual(
diff --git a/tensorflow/python/autograph/pyct/static_analysis/live_values.py b/tensorflow/python/autograph/pyct/static_analysis/live_values.py
index 36b9e7074d..4ceddce53b 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/live_values.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/live_values.py
@@ -24,6 +24,7 @@ from __future__ import division
from __future__ import print_function
import gast
+import six
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import transformer
@@ -35,6 +36,9 @@ from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
# These symbols are legal in Python, but don't appear in the namespace.
_SPECIAL_SYMBOLS = {'range': range, 'print': print}
+if six.PY2:
+ _SPECIAL_SYMBOLS['xrange'] = xrange
+
class LiveValueResolver(transformer.Base):
"""Annotates nodes with live values."""
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index d833defb8e..349c84e13c 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -26,7 +26,7 @@ import datetime
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 10, 3)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 10, 8)
@tf_export("compat.forward_compatible")
diff --git a/tensorflow/python/data/experimental/benchmarks/BUILD b/tensorflow/python/data/experimental/benchmarks/BUILD
new file mode 100644
index 0000000000..b9398aebe7
--- /dev/null
+++ b/tensorflow/python/data/experimental/benchmarks/BUILD
@@ -0,0 +1,25 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_test(
+ name = "map_benchmark",
+ size = "medium",
+ srcs = ["map_benchmark.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/python/data/experimental/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/experimental/benchmarks/map_benchmark.py
index 2f0bd1456b..ad253cffa5 100644
--- a/tensorflow/python/data/experimental/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/benchmarks/map_benchmark.py
@@ -19,7 +19,6 @@ from __future__ import print_function
import hashlib
import itertools
-import os
import time
import numpy as np
@@ -27,128 +26,15 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.data.experimental.ops import batching
-from tensorflow.python.data.experimental.ops import error_ops
from tensorflow.python.data.experimental.ops import optimization
-from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-from tensorflow.python.util import compat
_NUMPY_RANDOM_SEED = 42
-class MapDatasetTest(test_base.DatasetTestBase):
-
- def testMapIgnoreError(self):
- components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
-
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(components)
- .map(lambda x: array_ops.check_numerics(x, "message")).apply(
- error_ops.ignore_errors()))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- for x in [1., 2., 3., 5.]:
- self.assertEqual(x, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testParallelMapIgnoreError(self):
- components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
-
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(components).map(
- lambda x: array_ops.check_numerics(x, "message"),
- num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors()))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- for x in [1., 2., 3., 5.]:
- self.assertEqual(x, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testReadFileIgnoreError(self):
-
- def write_string_to_file(value, filename):
- with open(filename, "w") as f:
- f.write(value)
-
- filenames = [
- os.path.join(self.get_temp_dir(), "file_%d.txt" % i) for i in range(5)
- ]
- for filename in filenames:
- write_string_to_file(filename, filename)
-
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(filenames).map(
- io_ops.read_file,
- num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors()))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- # All of the files are present.
- sess.run(init_op)
- for filename in filenames:
- self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Delete one of the files.
- os.remove(filenames[0])
-
- # Attempting to read filenames[0] will fail, but ignore_errors()
- # will catch the error.
- sess.run(init_op)
- for filename in filenames[1:]:
- self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testCaptureResourceInMapFn(self):
-
- def _build_ds(iterator):
-
- def _map_fn(x):
- get_next = iterator.get_next()
- return x * get_next
-
- return dataset_ops.Dataset.range(10).map(_map_fn)
-
- def _build_graph():
- captured_iterator = dataset_ops.Dataset.range(
- 10).make_initializable_iterator()
- ds = _build_ds(captured_iterator)
- iterator = ds.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- return captured_iterator.initializer, init_op, get_next
-
- with ops.Graph().as_default() as g:
- captured_init_op, init_op, get_next = _build_graph()
- with self.session(graph=g) as sess:
- sess.run(captured_init_op)
- sess.run(init_op)
- for i in range(10):
- self.assertEquals(i * i, sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
-
class MapDatasetBenchmark(test.Benchmark):
# The purpose of this benchmark is to compare the performance of chaining vs
diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD
index f56127f3ef..4eef9580ad 100644
--- a/tensorflow/python/data/experimental/kernel_tests/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/BUILD
@@ -8,75 +8,62 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "py_test")
py_test(
- name = "batch_dataset_op_test",
+ name = "bucket_by_sequence_length_test",
size = "medium",
- srcs = ["batch_dataset_op_test.py"],
+ srcs = ["bucket_by_sequence_length_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss", # (b/79552534)
- "no_pip",
- "no_windows",
- ],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:script_ops",
- "//tensorflow/python:session",
"//tensorflow/python:sparse_tensor",
- "//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
- "//tensorflow/python:util",
- "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/experimental/ops:grouping",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
- "@absl_py//absl/testing:parameterized",
],
)
+cuda_py_test(
+ name = "copy_to_device_test",
+ size = "small",
+ srcs = ["copy_to_device_test.py"],
+ additional_deps = [
+ "//tensorflow/python/data/experimental/ops:prefetching_ops",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python/compat:compat",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ ],
+ tags = ["no_windows_gpu"],
+)
+
py_test(
- name = "bucketing_test",
- size = "medium",
- srcs = ["bucketing_test.py"],
+ name = "counter_test",
+ size = "small",
+ srcs = ["counter_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
deps = [
- "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:string_ops",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/experimental/ops:grouping",
+ "//tensorflow/python/data/experimental/ops:counter",
"//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
],
)
py_test(
- name = "csv_dataset_op_test",
+ name = "csv_dataset_test",
size = "medium",
- srcs = ["csv_dataset_op_test.py"],
+ srcs = ["csv_dataset_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
+ tags = ["no_pip"],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
@@ -97,25 +84,18 @@ py_test(
)
py_test(
- name = "dataset_constructor_op_test",
- size = "medium",
- srcs = ["dataset_constructor_op_test.py"],
+ name = "dense_to_sparse_batch_test",
+ srcs = ["dense_to_sparse_batch_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "manual",
- "no_oss",
- "no_pip",
- "no_windows",
- "nomac", # b/62040583
- ],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
"//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
+ "//third_party/py/numpy",
],
)
@@ -124,11 +104,6 @@ py_test(
size = "medium",
srcs = ["directed_interleave_dataset_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
deps = [
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
@@ -141,14 +116,67 @@ py_test(
)
py_test(
+ name = "enumerate_dataset_test",
+ size = "small",
+ srcs = ["enumerate_dataset_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/experimental/ops:enumerate_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "filter_dataset_op_test",
+ size = "medium",
+ srcs = ["filter_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+cuda_py_test(
+ name = "function_buffering_resource_test",
+ size = "small",
+ srcs = ["function_buffering_resource_test.py"],
+ additional_deps = [
+ "//tensorflow/python/data/experimental/ops:prefetching_ops",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:function",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ ],
+ tags = ["no_windows_gpu"],
+)
+
+py_test(
name = "get_single_element_test",
size = "small",
srcs = ["get_single_element_test.py"],
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@@ -165,19 +193,20 @@ py_test(
)
py_test(
- name = "indexed_dataset_ops_test",
- srcs = ["indexed_dataset_ops_test.py"],
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
+ name = "group_by_reducer_test",
+ size = "medium",
+ srcs = ["group_by_reducer_test.py"],
+ srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
- "//tensorflow/python:experimental_dataset_ops_gen",
- "//tensorflow/python/data/experimental/ops:indexed_dataset_ops",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/experimental/ops:grouping",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
@@ -185,107 +214,134 @@ py_test(
)
py_test(
- name = "interleave_dataset_op_test",
+ name = "group_by_window_test",
size = "medium",
- srcs = ["interleave_dataset_op_test.py"],
+ srcs = ["group_by_window_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- "notap",
- ],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
- "//tensorflow/python:script_ops",
- "//tensorflow/python:sparse_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/experimental/ops:interleave_ops",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/experimental/ops:grouping",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
- "@six_archive//:six",
+ "//third_party/py/numpy",
],
)
py_test(
- name = "iterator_ops_test",
- size = "small",
- srcs = ["iterator_ops_test.py"],
+ name = "ignore_errors_test",
+ srcs = ["ignore_errors_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:error_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
],
+)
+
+py_test(
+ name = "indexed_dataset_ops_test",
+ srcs = ["indexed_dataset_ops_test.py"],
deps = [
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variables",
- "//tensorflow/python/data/experimental/ops:iterator_ops",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python/data/experimental/ops:indexed_dataset_ops",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/estimator:estimator_py",
+ "//third_party/py/numpy",
],
)
py_test(
- name = "map_dataset_op_test",
+ name = "make_batched_features_dataset_test",
size = "medium",
- srcs = ["map_dataset_op_test.py"],
+ srcs = ["make_batched_features_dataset_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- "noasan", # times out
- "optonly",
+ tags = ["no_pip"],
+ deps = [
+ ":reader_dataset_ops_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python/data/experimental/ops:readers",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/data/util:nest",
+ "//third_party/py/numpy",
],
+)
+
+py_test(
+ name = "make_csv_dataset_test",
+ size = "medium",
+ srcs = ["make_csv_dataset_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
deps = [
- "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
- "//tensorflow/python:io_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:util",
- "//tensorflow/python/data/experimental/ops:batching",
- "//tensorflow/python/data/experimental/ops:error_ops",
- "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/experimental/ops:readers",
"//tensorflow/python/data/kernel_tests:test_base",
- "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
"//third_party/py/numpy",
],
)
py_test(
- name = "filter_dataset_op_test",
+ name = "make_tf_record_dataset_test",
size = "medium",
- srcs = ["filter_dataset_op_test.py"],
+ srcs = ["make_tf_record_dataset_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
+ tags = ["no_pip"],
+ deps = [
+ ":reader_dataset_ops_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python/data/experimental/ops:readers",
+ "//tensorflow/python/data/util:nest",
],
+)
+
+py_test(
+ name = "map_and_batch_test",
+ size = "medium",
+ srcs = ["map_and_batch_test.py"],
+ srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:io_ops",
"//tensorflow/python:math_ops",
- "//tensorflow/python:util",
- "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python:script_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
@@ -294,11 +350,7 @@ py_test(
size = "small",
srcs = ["map_defun_op_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
+ tags = ["no_pip"],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
@@ -317,16 +369,57 @@ py_test(
)
py_test(
- name = "parsing_ops_test",
+ name = "override_threadpool_test",
size = "small",
- srcs = ["parsing_ops_test.py"],
+ srcs = ["override_threadpool_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:script_ops",
+ "//tensorflow/python/data/experimental/ops:threadpool",
+ "//tensorflow/python/data/experimental/ops:unique",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "parallel_interleave_test",
+ size = "medium",
+ srcs = ["parallel_interleave_test.py"],
srcs_version = "PY2AND3",
tags = [
"no_oss",
"no_pip",
- "no_windows",
+ "notap",
],
deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:script_ops",
+ "//tensorflow/python:sparse_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "@six_archive//:six",
+ ],
+)
+
+py_test(
+ name = "parse_example_dataset_test",
+ size = "small",
+ srcs = ["parse_example_dataset_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
@@ -344,53 +437,20 @@ py_test(
)
cuda_py_test(
- name = "prefetching_ops_test",
+ name = "prefetch_to_device_test",
size = "small",
- srcs = ["prefetching_ops_test.py"],
+ srcs = ["prefetch_to_device_test.py"],
additional_deps = [
"//tensorflow/python/data/experimental/ops:prefetching_ops",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
- "//tensorflow/python:function",
- "//tensorflow/python:resource_variable_ops",
- "//tensorflow/python/compat:compat",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- ],
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- "no_windows_gpu",
- ],
-)
-
-py_test(
- name = "range_dataset_op_test",
- size = "small",
- srcs = ["range_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
- deps = [
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/experimental/ops:counter",
- "//tensorflow/python/data/experimental/ops:enumerate_ops",
- "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
],
+ tags = ["no_windows_gpu"],
)
py_library(
@@ -421,41 +481,12 @@ py_library(
)
py_test(
- name = "reader_dataset_ops_test",
- size = "medium",
- srcs = ["reader_dataset_ops_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
- deps = [
- ":reader_dataset_ops_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:string_ops",
- "//tensorflow/python/data/experimental/ops:readers",
- "//tensorflow/python/data/ops:readers",
- "//tensorflow/python/data/util:nest",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "resample_test",
+ name = "rejection_resample_test",
size = "medium",
- srcs = ["resample_test.py"],
+ srcs = ["rejection_resample_test.py"],
shard_count = 2,
srcs_version = "PY2AND3",
tags = [
- "no_oss",
- "no_pip",
- "no_windows",
"noasan",
"optonly",
],
@@ -477,15 +508,27 @@ py_test(
)
py_test(
- name = "scan_dataset_op_test",
- size = "small",
- srcs = ["scan_dataset_op_test.py"],
+ name = "restructured_dataset_test",
+ size = "medium",
+ srcs = ["restructured_dataset_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
],
+)
+
+py_test(
+ name = "scan_test",
+ size = "small",
+ srcs = ["scan_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@@ -503,14 +546,12 @@ py_test(
)
py_test(
- name = "shuffle_dataset_op_test",
+ name = "shuffle_and_repeat_test",
size = "medium",
- srcs = ["shuffle_dataset_op_test.py"],
+ srcs = ["shuffle_and_repeat_test.py"],
srcs_version = "PY2AND3",
tags = [
- "no_oss",
"no_pip",
- "no_windows",
"optonly",
],
deps = [
@@ -525,8 +566,8 @@ py_test(
)
py_library(
- name = "sql_dataset_op_test_base",
- srcs = ["sql_dataset_op_test_base.py"],
+ name = "sql_dataset_test_base",
+ srcs = ["sql_dataset_test_base.py"],
srcs_version = "PY2AND3",
visibility = [
"//tensorflow/python/data/experimental/kernel_tests:__pkg__",
@@ -543,17 +584,13 @@ py_library(
)
py_test(
- name = "sql_dataset_op_test",
+ name = "sql_dataset_test",
size = "small",
- srcs = ["sql_dataset_op_test.py"],
+ srcs = ["sql_dataset_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
+ tags = ["no_pip"],
deps = [
- ":sql_dataset_op_test_base",
+ ":sql_dataset_test_base",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
@@ -565,11 +602,7 @@ py_test(
size = "medium",
srcs = ["stats_dataset_ops_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
+ tags = ["no_pip"],
deps = [
":reader_dataset_ops_test_base",
":stats_dataset_test_base",
@@ -595,68 +628,60 @@ py_library(
)
py_test(
- name = "threadpool_dataset_ops_test",
+ name = "tf_record_writer_test",
size = "small",
- srcs = ["threadpool_dataset_ops_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
+ srcs = ["tf_record_writer_test.py"],
deps = [
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:script_ops",
- "//tensorflow/python/data/experimental/ops:threadpool",
- "//tensorflow/python/data/experimental/ops:unique",
+ "//tensorflow/python:lib",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:writers",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- "@absl_py//absl/testing:parameterized",
+ "//tensorflow/python/data/ops:readers",
],
)
py_test(
- name = "unique_dataset_op_test",
- size = "small",
- srcs = ["unique_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
+ name = "unbatch_test",
+ size = "medium",
+ srcs = ["unbatch_test.py"],
deps = [
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
- "//tensorflow/python/data/experimental/ops:unique",
+ "//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
py_test(
- name = "writer_ops_test",
+ name = "unique_test",
size = "small",
- srcs = ["writer_ops_test.py"],
- tags = [
- "no_oss",
- "no_pip",
- "no_windows",
- ],
+ srcs = ["unique_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
deps = [
- "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
- "//tensorflow/python:lib",
+ "//tensorflow/python:errors",
"//tensorflow/python:util",
- "//tensorflow/python/data/experimental/ops:writers",
+ "//tensorflow/python/data/experimental/ops:unique",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:readers",
],
)
diff --git a/tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py
deleted file mode 100644
index 956b4518f6..0000000000
--- a/tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py
+++ /dev/null
@@ -1,686 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import math
-import time
-
-from absl.testing import parameterized
-import numpy as np
-
-from tensorflow.python.client import session
-from tensorflow.python.data.experimental.ops import batching
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import script_ops
-from tensorflow.python.ops import string_ops
-from tensorflow.python.platform import test
-from tensorflow.python.util import compat
-
-
-class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
-
- def testDenseToSparseBatchDataset(self):
- components = np.random.randint(12, size=(100,)).astype(np.int32)
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components)
- .map(lambda x: array_ops.fill([x], x)).apply(
- batching.dense_to_sparse_batch(4, [12]))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
-
- for start in range(0, len(components), 4):
- results = sess.run(get_next)
- self.assertAllEqual([[i, j]
- for i, c in enumerate(components[start:start + 4])
- for j in range(c)], results.indices)
- self.assertAllEqual(
- [c for c in components[start:start + 4] for _ in range(c)],
- results.values)
- self.assertAllEqual([min(4,
- len(components) - start), 12],
- results.dense_shape)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testDenseToSparseBatchDatasetWithUnknownShape(self):
- components = np.random.randint(5, size=(40,)).astype(np.int32)
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components)
- .map(lambda x: array_ops.fill([x, x], x)).apply(
- batching.dense_to_sparse_batch(
- 4, [5, None])).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
-
- for start in range(0, len(components), 4):
- results = sess.run(get_next)
- self.assertAllEqual([[i, j, z]
- for i, c in enumerate(components[start:start + 4])
- for j in range(c)
- for z in range(c)], results.indices)
- self.assertAllEqual([
- c
- for c in components[start:start + 4] for _ in range(c)
- for _ in range(c)
- ], results.values)
- self.assertAllEqual([
- min(4,
- len(components) - start), 5,
- np.max(components[start:start + 4])
- ], results.dense_shape)
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testDenseToSparseBatchDatasetWithInvalidShape(self):
- input_tensor = array_ops.constant([[1]])
- with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"):
- dataset_ops.Dataset.from_tensors(input_tensor).apply(
- batching.dense_to_sparse_batch(4, [-2])).make_initializable_iterator()
-
- def testDenseToSparseBatchDatasetShapeErrors(self):
- input_tensor = array_ops.placeholder(dtypes.int32)
- iterator = (
- dataset_ops.Dataset.from_tensors(input_tensor).apply(
- batching.dense_to_sparse_batch(4, [12]))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- # Initialize with an input tensor of incompatible rank.
- sess.run(init_op, feed_dict={input_tensor: [[1]]})
- with self.assertRaisesRegexp(errors.InvalidArgumentError,
- "incompatible with the row shape"):
- sess.run(get_next)
-
- # Initialize with an input tensor that is larger than `row_shape`.
- sess.run(init_op, feed_dict={input_tensor: range(13)})
- with self.assertRaisesRegexp(errors.DataLossError,
- "larger than the row shape"):
- sess.run(get_next)
-
- def testUnbatchWithUnknownRankInput(self):
- placeholder = array_ops.placeholder(dtypes.int32)
- dataset = dataset_ops.Dataset.from_tensors(placeholder).apply(
- batching.unbatch())
- iterator = dataset.make_initializable_iterator()
- next_elem = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer, feed_dict={placeholder: [0, 1, 2, 3]})
- for i in range(4):
- self.assertEqual(i, sess.run(next_elem))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_elem)
-
- def testUnbatchScalarDataset(self):
- data = tuple([math_ops.range(10) for _ in range(3)])
- data = dataset_ops.Dataset.from_tensor_slices(data)
- expected_types = (dtypes.int32,) * 3
- data = data.batch(2)
- self.assertEqual(expected_types, data.output_types)
- data = data.apply(batching.unbatch())
- self.assertEqual(expected_types, data.output_types)
-
- iterator = data.make_one_shot_iterator()
- op = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual((i,) * 3, sess.run(op))
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(op)
-
- def testUnbatchDatasetWithStrings(self):
- data = tuple([math_ops.range(10) for _ in range(3)])
- data = dataset_ops.Dataset.from_tensor_slices(data)
- data = data.map(lambda x, y, z: (x, string_ops.as_string(y), z))
- expected_types = (dtypes.int32, dtypes.string, dtypes.int32)
- data = data.batch(2)
- self.assertEqual(expected_types, data.output_types)
- data = data.apply(batching.unbatch())
- self.assertEqual(expected_types, data.output_types)
-
- iterator = data.make_one_shot_iterator()
- op = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(op)
-
- def testUnbatchDatasetWithSparseTensor(self):
- st = sparse_tensor.SparseTensorValue(
- indices=[[i, i] for i in range(10)],
- values=list(range(10)),
- dense_shape=[10, 10])
- data = dataset_ops.Dataset.from_tensors(st)
- data = data.apply(batching.unbatch())
- data = data.batch(5)
- data = data.apply(batching.unbatch())
- iterator = data.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- st_row = sess.run(next_element)
- self.assertEqual([i], st_row.indices)
- self.assertEqual([i], st_row.values)
- self.assertEqual([10], st_row.dense_shape)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testUnbatchDatasetWithDenseAndSparseTensor(self):
- st = sparse_tensor.SparseTensorValue(
- indices=[[i, i] for i in range(10)],
- values=list(range(10)),
- dense_shape=[10, 10])
- data = dataset_ops.Dataset.from_tensors((list(range(10)), st))
- data = data.apply(batching.unbatch())
- data = data.batch(5)
- data = data.apply(batching.unbatch())
- iterator = data.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- dense_elem, st_row = sess.run(next_element)
- self.assertEqual(i, dense_elem)
- self.assertEqual([i], st_row.indices)
- self.assertEqual([i], st_row.values)
- self.assertEqual([10], st_row.dense_shape)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testUnbatchSingleElementTupleDataset(self):
- data = tuple([(math_ops.range(10),) for _ in range(3)])
- data = dataset_ops.Dataset.from_tensor_slices(data)
- expected_types = ((dtypes.int32,),) * 3
- data = data.batch(2)
- self.assertEqual(expected_types, data.output_types)
- data = data.apply(batching.unbatch())
- self.assertEqual(expected_types, data.output_types)
-
- iterator = data.make_one_shot_iterator()
- op = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual(((i,),) * 3, sess.run(op))
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(op)
-
- def testUnbatchMultiElementTupleDataset(self):
- data = tuple([(math_ops.range(10 * i, 10 * i + 10),
- array_ops.fill([10], "hi")) for i in range(3)])
- data = dataset_ops.Dataset.from_tensor_slices(data)
- expected_types = ((dtypes.int32, dtypes.string),) * 3
- data = data.batch(2)
- self.assertAllEqual(expected_types, data.output_types)
- data = data.apply(batching.unbatch())
- self.assertAllEqual(expected_types, data.output_types)
-
- iterator = data.make_one_shot_iterator()
- op = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")),
- sess.run(op))
-
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(op)
-
- def testUnbatchEmpty(self):
- data = dataset_ops.Dataset.from_tensors(
- (constant_op.constant([]), constant_op.constant([], shape=[0, 4]),
- constant_op.constant([], shape=[0, 4, 0])))
- data = data.apply(batching.unbatch())
- iterator = data.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testUnbatchStaticShapeMismatch(self):
- data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8),
- np.arange(9)))
- with self.assertRaises(ValueError):
- data.apply(batching.unbatch())
-
- def testUnbatchDynamicShapeMismatch(self):
- ph1 = array_ops.placeholder(dtypes.int32, shape=[None])
- ph2 = array_ops.placeholder(dtypes.int32, shape=None)
- data = dataset_ops.Dataset.from_tensors((ph1, ph2))
- data = data.apply(batching.unbatch())
- iterator = data.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- # Mismatch in the 0th dimension.
- sess.run(
- iterator.initializer,
- feed_dict={
- ph1: np.arange(7).astype(np.int32),
- ph2: np.arange(8).astype(np.int32)
- })
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(next_element)
-
- # No 0th dimension (i.e. scalar value) for one component.
- sess.run(
- iterator.initializer,
- feed_dict={
- ph1: np.arange(7).astype(np.int32),
- ph2: 7
- })
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(next_element)
-
- @parameterized.named_parameters(
- ("Default", None, None),
- ("SequentialCalls", 1, None),
- ("ParallelCalls", 2, None),
- ("ParallelBatches", None, 10),
- )
- def testMapAndBatch(self, num_parallel_calls, num_parallel_batches):
- """Test a dataset that maps a TF function across its input elements."""
- # The pipeline is TensorSliceDataset ->
- # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size).
- components = (np.arange(7),
- np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
- np.array(37.0) * np.arange(7))
-
- count = array_ops.placeholder(dtypes.int64, shape=[])
- batch_size = array_ops.placeholder(dtypes.int64, shape=[])
-
- def _map_fn(x, y, z):
- return math_ops.square(x), math_ops.square(y), math_ops.square(z)
-
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components).repeat(count).apply(
- batching.map_and_batch(
- map_func=_map_fn,
- batch_size=batch_size,
- num_parallel_calls=num_parallel_calls,
- num_parallel_batches=num_parallel_batches))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- self.assertEqual([[None] + list(c.shape[1:]) for c in components],
- [t.shape.as_list() for t in get_next])
-
- with self.cached_session() as sess:
- # Batch of a finite input, where the batch_size divides the
- # total number of elements.
- sess.run(init_op, feed_dict={count: 28, batch_size: 14})
- num_batches = (28 * 7) // 14
- for i in range(num_batches):
- result = sess.run(get_next)
- for component, result_component in zip(components, result):
- for j in range(14):
- self.assertAllEqual(component[(i * 14 + j) % 7]**2,
- result_component[j])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Batch of a finite input, where the batch_size does not
- # divide the total number of elements.
- sess.run(init_op, feed_dict={count: 14, batch_size: 8})
-
- # We expect (num_batches - 1) full-sized batches.
- num_batches = int(math.ceil((14 * 7) / 8))
- for i in range(num_batches - 1):
- result = sess.run(get_next)
- for component, result_component in zip(components, result):
- for j in range(8):
- self.assertAllEqual(component[(i * 8 + j) % 7]**2,
- result_component[j])
- result = sess.run(get_next)
- for component, result_component in zip(components, result):
- for j in range((14 * 7) % 8):
- self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
- result_component[j])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Batch of an empty input should fail straight away.
- sess.run(init_op, feed_dict={count: 0, batch_size: 8})
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- # Empty batch should be an initialization time error.
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(init_op, feed_dict={count: 14, batch_size: 0})
-
- @parameterized.named_parameters(
- ("Even", False),
- ("Uneven", True),
- )
- def testMapAndBatchPartialBatch(self, drop_remainder):
- iterator = (
- dataset_ops.Dataset.range(10).apply(
- batching.map_and_batch(
- lambda x: array_ops.reshape(x * x, [1]),
- batch_size=4,
- drop_remainder=drop_remainder)).make_one_shot_iterator())
- if drop_remainder:
- self.assertEqual([4, 1], iterator.output_shapes.as_list())
- else:
- self.assertEqual([None, 1], iterator.output_shapes.as_list())
- next_element = iterator.get_next()
- with self.cached_session() as sess:
- self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
- self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
- if not drop_remainder:
- self.assertAllEqual([[64], [81]], sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testMapAndBatchYieldsPartialBatch(self):
- iterator = (dataset_ops.Dataset.range(10)
- .apply(batching.map_and_batch(
- lambda x: array_ops.reshape(x * x, [1]), 4))
- .make_one_shot_iterator())
- self.assertEqual([None, 1], iterator.output_shapes.as_list())
- next_element = iterator.get_next()
- with self.cached_session() as sess:
- self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
- self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
- self.assertAllEqual([[64], [81]], sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testMapAndBatchParallelGetNext(self):
- iterator = (dataset_ops.Dataset.range(50000)
- .apply(batching.map_and_batch(lambda x: x, batch_size=100))
- .make_one_shot_iterator())
- elements = []
- for _ in range(100):
- elements.append(iterator.get_next())
- with self.cached_session() as sess:
- for i in range(5):
- got = sess.run(elements)
- got.sort(key=lambda x: x[0])
- expected = []
- for j in range(100):
- expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
- self.assertAllEqual(got, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elements)
-
- def testMapAndBatchParallelGetNextDropRemainder(self):
- iterator = (
- dataset_ops.Dataset.range(49999).apply(
- batching.map_and_batch(
- lambda x: x, batch_size=100, drop_remainder=True))
- .make_one_shot_iterator())
- elements = []
- for _ in range(100):
- elements.append(iterator.get_next())
- with self.cached_session() as sess:
- for i in range(4):
- got = sess.run(elements)
- got.sort(key=lambda x: x[0])
- expected = []
- for j in range(100):
- expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
- self.assertAllEqual(got, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(elements)
-
- def testMapAndBatchSparse(self):
-
- def _sparse(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0]], values=(i * [1]), dense_shape=[1])
-
- iterator = dataset_ops.Dataset.range(10).apply(
- batching.map_and_batch(_sparse, 5)).make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- for i in range(2):
- actual = sess.run(get_next)
- expected = sparse_tensor.SparseTensorValue(
- indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
- values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
- dense_shape=[5, 1])
- self.assertTrue(sparse_tensor.is_sparse(actual))
- self.assertSparseValuesEqual(actual, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testMapAndBatchFails(self):
- """Test a dataset that maps a TF function across its input elements."""
- dataset = dataset_ops.Dataset.from_tensors(
- array_ops.check_numerics(
- constant_op.constant(1.0) / constant_op.constant(0.0), "oops"))
- batch_size = array_ops.placeholder(dtypes.int64, shape=[])
- iterator = (
- dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
- .make_initializable_iterator())
- init_op = iterator.initializer
- with self.cached_session() as sess:
- with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
- sess.run(init_op, feed_dict={batch_size: 14})
-
- def testMapAndBatchShapeMismatch(self):
- """Test a dataset that maps a TF function across its input elements."""
-
- def generator():
- yield [1]
- yield [2]
- yield [3]
- yield [[4, 5, 6]]
-
- dataset = dataset_ops.Dataset.from_generator(
- generator, output_types=dtypes.int32)
- batch_size = 4
- iterator = (
- dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- with self.assertRaisesRegexp(errors.InvalidArgumentError,
- "number of elements does not match"):
- sess.run(get_next)
-
- def testMapAndBatchImplicitDispose(self):
- # Tests whether a map and batch dataset will be cleaned up correctly when
- # the pipeline does not run it until exhaustion.
- # The pipeline is TensorSliceDataset -> RepeatDataset(1000) ->
- # MapAndBatchDataset(f=square_3, batch_size=100).
- components = (np.arange(1000),
- np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
- np.array(37.0) * np.arange(1000))
-
- def _map_fn(x, y, z):
- return math_ops.square(x), math_ops.square(y), math_ops.square(z)
-
- dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat(
- 1000).apply(batching.map_and_batch(_map_fn, batch_size=100))
- dataset = dataset.prefetch(5)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- for _ in range(3):
- sess.run(get_next)
-
- @parameterized.named_parameters(
- ("1", 0),
- ("2", 5),
- ("3", 10),
- ("4", 90),
- ("5", 95),
- ("6", 99),
- )
- def testMapAndBatchOutOfRangeError(self, threshold):
-
- def raising_py_fn(i):
- if i >= threshold:
- raise StopIteration()
- else:
- return i
-
- iterator = (
- dataset_ops.Dataset.range(100).apply(
- batching.map_and_batch(
- lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64),
- batch_size=10)).make_one_shot_iterator())
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(threshold // 10):
- self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
- if threshold % 10 != 0:
- self.assertAllEqual(
- [threshold // 10 * 10 + j for j in range(threshold % 10)],
- sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- @parameterized.named_parameters(
- ("1", False, dtypes.bool),
- ("2", -42, dtypes.int8),
- ("3", -42, dtypes.int16),
- ("4", -42, dtypes.int32),
- ("5", -42, dtypes.int64),
- ("6", 42, dtypes.uint8),
- ("7", 42, dtypes.uint16),
- ("8", 42.0, dtypes.float16),
- ("9", 42.0, dtypes.float32),
- ("10", 42.0, dtypes.float64),
- ("11", b"hello", dtypes.string),
- )
- def testMapAndBatchTypes(self, element, dtype):
- def gen():
- yield element
-
- dataset = dataset_ops.Dataset.from_generator(gen, dtype).repeat(100).apply(
- batching.map_and_batch(lambda x: x, batch_size=10))
-
- get_next = dataset.make_one_shot_iterator().get_next()
-
- with self.cached_session() as sess:
- for _ in range(10):
- self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
-
-
-class UnbatchDatasetBenchmark(test.Benchmark):
-
- def benchmarkNativeUnbatch(self):
- batch_sizes = [1, 2, 5, 10, 20, 50]
- elems_per_trial = 10000
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
- batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
- dataset = dataset.batch(batch_size_placeholder)
- dataset = dataset.apply(batching.unbatch())
- dataset = dataset.skip(elems_per_trial)
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for batch_size in batch_sizes:
- deltas = []
- for _ in range(5):
- sess.run(
- iterator.initializer,
- feed_dict={batch_size_placeholder: batch_size})
- start = time.time()
- sess.run(next_element.op)
- end = time.time()
- deltas.append((end - start) / elems_per_trial)
-
- median_wall_time = np.median(deltas)
- print("Unbatch (native) batch size: %d Median wall time per element:"
- " %f microseconds" % (batch_size, median_wall_time * 1e6))
- self.report_benchmark(
- iters=10000,
- wall_time=median_wall_time,
- name="benchmark_unbatch_dataset_native_batch_size_%d" %
- batch_size)
-
- # Include a benchmark of the previous `unbatch()` implementation that uses
- # a composition of more primitive ops. Eventually we'd hope to generate code
- # that is as good in both cases.
- def benchmarkOldUnbatchImplementation(self):
- batch_sizes = [1, 2, 5, 10, 20, 50]
- elems_per_trial = 10000
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
- batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
- dataset = dataset.batch(batch_size_placeholder)
- dataset = dataset.flat_map(dataset_ops.Dataset.from_tensor_slices)
- dataset = dataset.skip(elems_per_trial)
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for batch_size in batch_sizes:
- deltas = []
- for _ in range(5):
- sess.run(
- iterator.initializer,
- feed_dict={batch_size_placeholder: batch_size})
- start = time.time()
- sess.run(next_element.op)
- end = time.time()
- deltas.append((end - start) / elems_per_trial)
-
- median_wall_time = np.median(deltas)
- print("Unbatch (unfused) batch size: %d Median wall time per element:"
- " %f microseconds" % (batch_size, median_wall_time * 1e6))
- self.report_benchmark(
- iters=10000,
- wall_time=median_wall_time,
- name="benchmark_unbatch_dataset_unfused_batch_size_%d" %
- batch_size)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py b/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py
new file mode 100644
index 0000000000..3903ec49b9
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py
@@ -0,0 +1,322 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.bucket_by_sequence_length()."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import random
+
+from tensorflow.python.data.experimental.ops import grouping
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+def _element_length_fn(x, y=None):
+ del y
+ return array_ops.shape(x)[0]
+
+
+def _to_sparse_tensor(record):
+ return sparse_tensor.SparseTensor(**record)
+
+
+def _format_record(array, sparse):
+ if sparse:
+ return {
+ "values": array,
+ "indices": [[i] for i in range(len(array))],
+ "dense_shape": (len(array),)
+ }
+ return array
+
+
+def _get_record_type(sparse):
+ if sparse:
+ return {
+ "values": dtypes.int64,
+ "indices": dtypes.int64,
+ "dense_shape": dtypes.int64
+ }
+ return dtypes.int32
+
+
+def _get_record_shape(sparse):
+ if sparse:
+ return {
+ "values": tensor_shape.TensorShape([None,]),
+ "indices": tensor_shape.TensorShape([None, 1]),
+ "dense_shape": tensor_shape.TensorShape([1,])
+ }
+ return tensor_shape.TensorShape([None])
+
+
+class BucketBySequenceLengthTest(test_base.DatasetTestBase):
+
+ def testBucket(self):
+
+ boundaries = [10, 20, 30]
+ batch_sizes = [10, 8, 4, 2]
+ lengths = [8, 13, 25, 35]
+
+ def build_dataset(sparse):
+ def _generator():
+ # Produce 1 batch for each bucket
+ elements = []
+ for batch_size, length in zip(batch_sizes, lengths):
+ record_len = length - 1
+ for _ in range(batch_size):
+ elements.append([1] * record_len)
+ record_len = length
+ random.shuffle(elements)
+ for el in elements:
+ yield (_format_record(el, sparse),)
+ dataset = dataset_ops.Dataset.from_generator(
+ _generator,
+ (_get_record_type(sparse),),
+ (_get_record_shape(sparse),))
+ if sparse:
+ dataset = dataset.map(lambda x: (_to_sparse_tensor(x),))
+ return dataset
+
+ def _test_bucket_by_padding(no_padding):
+ dataset = build_dataset(sparse=no_padding)
+ dataset = dataset.apply(
+ grouping.bucket_by_sequence_length(
+ _element_length_fn,
+ boundaries,
+ batch_sizes,
+ no_padding=no_padding))
+ batch, = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ batches = []
+ for _ in range(4):
+ batches.append(sess.run(batch))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(batch)
+ batch_sizes_val = []
+ lengths_val = []
+ for batch in batches:
+ shape = batch.dense_shape if no_padding else batch.shape
+ batch_size = shape[0]
+ length = shape[1]
+ batch_sizes_val.append(batch_size)
+ lengths_val.append(length)
+ sum_check = batch.values.sum() if no_padding else batch.sum()
+ self.assertEqual(sum_check, batch_size * length - 1)
+ self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
+ self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
+ self.assertEqual(sorted(lengths), sorted(lengths_val))
+
+ for no_padding in (True, False):
+ _test_bucket_by_padding(no_padding)
+
+ def testPadToBoundary(self):
+
+ boundaries = [10, 20, 30]
+ batch_sizes = [10, 8, 4, 2]
+ lengths = [8, 13, 25]
+
+ def element_gen():
+ # Produce 1 batch for each bucket
+ elements = []
+ for batch_size, length in zip(batch_sizes[:-1], lengths):
+ for _ in range(batch_size):
+ elements.append([1] * length)
+ random.shuffle(elements)
+ for el in elements:
+ yield (el,)
+ for _ in range(batch_sizes[-1]):
+ el = [1] * (boundaries[-1] + 5)
+ yield (el,)
+
+ element_len = lambda el: array_ops.shape(el)[0]
+ dataset = dataset_ops.Dataset.from_generator(
+ element_gen, (dtypes.int64,), ([None],)).apply(
+ grouping.bucket_by_sequence_length(
+ element_len, boundaries, batch_sizes,
+ pad_to_bucket_boundary=True))
+ batch, = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ batches = []
+ for _ in range(3):
+ batches.append(sess.run(batch))
+ with self.assertRaisesOpError("bucket_boundaries"):
+ sess.run(batch)
+ batch_sizes_val = []
+ lengths_val = []
+ for batch in batches:
+ batch_size = batch.shape[0]
+ length = batch.shape[1]
+ batch_sizes_val.append(batch_size)
+ lengths_val.append(length)
+ batch_sizes = batch_sizes[:-1]
+ self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
+ self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
+ self.assertEqual([boundary - 1 for boundary in sorted(boundaries)],
+ sorted(lengths_val))
+
+ def testPadToBoundaryNoExtraneousPadding(self):
+
+ boundaries = [3, 7, 11]
+ batch_sizes = [2, 2, 2, 2]
+ lengths = range(1, 11)
+
+ def element_gen():
+ for length in lengths:
+ yield ([1] * length,)
+
+ element_len = lambda element: array_ops.shape(element)[0]
+ dataset = dataset_ops.Dataset.from_generator(
+ element_gen, (dtypes.int64,), ([None],)).apply(
+ grouping.bucket_by_sequence_length(
+ element_len, boundaries, batch_sizes,
+ pad_to_bucket_boundary=True))
+ batch, = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ batches = []
+ for _ in range(5):
+ batches.append(sess.run(batch))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(batch)
+
+ self.assertAllEqual(batches[0], [[1, 0],
+ [1, 1]])
+ self.assertAllEqual(batches[1], [[1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 1, 0, 0]])
+ self.assertAllEqual(batches[2], [[1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 1]])
+ self.assertAllEqual(batches[3], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
+ self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
+
+ def testTupleElements(self):
+
+ def build_dataset(sparse):
+ def _generator():
+ text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]]
+ label = [1, 2, 1, 2]
+ for x, y in zip(text, label):
+ yield (_format_record(x, sparse), y)
+ dataset = dataset_ops.Dataset.from_generator(
+ generator=_generator,
+ output_types=(_get_record_type(sparse), dtypes.int32),
+ output_shapes=(_get_record_shape(sparse),
+ tensor_shape.TensorShape([])))
+ if sparse:
+ dataset = dataset.map(lambda x, y: (_to_sparse_tensor(x), y))
+ return dataset
+
+ def _test_tuple_elements_by_padding(no_padding):
+ dataset = build_dataset(sparse=no_padding)
+ dataset = dataset.apply(grouping.bucket_by_sequence_length(
+ element_length_func=_element_length_fn,
+ bucket_batch_sizes=[2, 2, 2],
+ bucket_boundaries=[0, 8],
+ no_padding=no_padding))
+ shapes = dataset.output_shapes
+ self.assertEqual([None, None], shapes[0].as_list())
+ self.assertEqual([None], shapes[1].as_list())
+
+ for no_padding in (True, False):
+ _test_tuple_elements_by_padding(no_padding)
+
+ def testBucketSparse(self):
+ """Tests bucketing of sparse tensors (case where `no_padding` == True).
+
+ Test runs on following dataset:
+ [
+ [0],
+ [0, 1],
+ [0, 1, 2]
+ ...
+ [0, ..., max_len - 1]
+ ]
+ Sequences are bucketed by length and batched with
+ `batch_size` < `bucket_size`.
+ """
+
+ min_len = 0
+ max_len = 100
+ batch_size = 7
+ bucket_size = 10
+
+ def _build_dataset():
+ input_data = [range(i+1) for i in range(min_len, max_len)]
+ def generator_fn():
+ for record in input_data:
+ yield _format_record(record, sparse=True)
+ dataset = dataset_ops.Dataset.from_generator(
+ generator=generator_fn,
+ output_types=_get_record_type(sparse=True))
+ dataset = dataset.map(_to_sparse_tensor)
+ return dataset
+
+ def _compute_expected_batches():
+ """Computes expected batch outputs and stores in a set."""
+ all_expected_sparse_tensors = set()
+ for bucket_start_len in range(min_len, max_len, bucket_size):
+ for batch_offset in range(0, bucket_size, batch_size):
+ batch_start_len = bucket_start_len + batch_offset
+ batch_end_len = min(batch_start_len + batch_size,
+ bucket_start_len + bucket_size)
+ expected_indices = []
+ expected_values = []
+ for length in range(batch_start_len, batch_end_len):
+ for val in range(length + 1):
+ expected_indices.append((length - batch_start_len, val))
+ expected_values.append(val)
+ expected_sprs_tensor = (tuple(expected_indices),
+ tuple(expected_values))
+ all_expected_sparse_tensors.add(expected_sprs_tensor)
+ return all_expected_sparse_tensors
+
+ def _compute_batches(dataset):
+ """Computes actual batch outputs of dataset and stores in a set."""
+ batch = dataset.make_one_shot_iterator().get_next()
+ all_sparse_tensors = set()
+ with self.cached_session() as sess:
+ with self.assertRaises(errors.OutOfRangeError):
+ while True:
+ output = sess.run(batch)
+ sprs_tensor = (tuple([tuple(idx) for idx in output.indices]),
+ tuple(output.values))
+ all_sparse_tensors.add(sprs_tensor)
+ return all_sparse_tensors
+
+ dataset = _build_dataset()
+ boundaries = range(min_len + bucket_size + 1, max_len, bucket_size)
+ dataset = dataset.apply(grouping.bucket_by_sequence_length(
+ _element_length_fn,
+ boundaries,
+ [batch_size] * (len(boundaries) + 1),
+ no_padding=True))
+ batches = _compute_batches(dataset)
+ expected_batches = _compute_expected_batches()
+ self.assertEqual(batches, expected_batches)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/bucketing_test.py b/tensorflow/python/data/experimental/kernel_tests/bucketing_test.py
deleted file mode 100644
index 153a03989b..0000000000
--- a/tensorflow/python/data/experimental/kernel_tests/bucketing_test.py
+++ /dev/null
@@ -1,824 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import random
-
-import numpy as np
-
-from tensorflow.python.data.experimental.ops import grouping
-from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import string_ops
-from tensorflow.python.platform import test
-
-
-class GroupByReducerTest(test_base.DatasetTestBase):
-
- def checkResults(self, dataset, shapes, values):
- self.assertEqual(shapes, dataset.output_shapes)
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- for expected in values:
- got = sess.run(get_next)
- self.assertEqual(got, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testSum(self):
- reducer = grouping.Reducer(
- init_func=lambda _: np.int64(0),
- reduce_func=lambda x, y: x + y,
- finalize_func=lambda x: x)
- for i in range(1, 11):
- dataset = dataset_ops.Dataset.range(2 * i).apply(
- grouping.group_by_reducer(lambda x: x % 2, reducer))
- self.checkResults(
- dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
-
- def testAverage(self):
-
- def reduce_fn(x, y):
- return (x[0] * x[1] + math_ops.cast(y, dtypes.float32)) / (
- x[1] + 1), x[1] + 1
-
- reducer = grouping.Reducer(
- init_func=lambda _: (0.0, 0.0),
- reduce_func=reduce_fn,
- finalize_func=lambda x, _: x)
- for i in range(1, 11):
- dataset = dataset_ops.Dataset.range(2 * i).apply(
- grouping.group_by_reducer(
- lambda x: math_ops.cast(x, dtypes.int64) % 2, reducer))
- self.checkResults(
- dataset, shapes=tensor_shape.scalar(), values=[i - 1, i])
-
- def testConcat(self):
- components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray)
- reducer = grouping.Reducer(
- init_func=lambda x: "",
- reduce_func=lambda x, y: x + y[0],
- finalize_func=lambda x: x)
- for i in range(1, 11):
- dataset = dataset_ops.Dataset.zip(
- (dataset_ops.Dataset.from_tensor_slices(components),
- dataset_ops.Dataset.range(2 * i))).apply(
- grouping.group_by_reducer(lambda x, y: y % 2, reducer))
- self.checkResults(
- dataset,
- shapes=tensor_shape.scalar(),
- values=[b"acegikmoqs" [:i], b"bdfhjlnprt" [:i]])
-
- def testSparseSum(self):
- def _sparse(i):
- return sparse_tensor.SparseTensorValue(
- indices=np.array([[0, 0]]),
- values=(i * np.array([1], dtype=np.int64)),
- dense_shape=np.array([1, 1]))
-
- reducer = grouping.Reducer(
- init_func=lambda _: _sparse(np.int64(0)),
- reduce_func=lambda x, y: _sparse(x.values[0] + y.values[0]),
- finalize_func=lambda x: x.values[0])
- for i in range(1, 11):
- dataset = dataset_ops.Dataset.range(2 * i).map(_sparse).apply(
- grouping.group_by_reducer(lambda x: x.values[0] % 2, reducer))
- self.checkResults(
- dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
-
- def testChangingStateShape(self):
-
- def reduce_fn(x, _):
- # Statically known rank, but dynamic length.
- larger_dim = array_ops.concat([x[0], x[0]], 0)
- # Statically unknown rank.
- larger_rank = array_ops.expand_dims(x[1], 0)
- return larger_dim, larger_rank
-
- reducer = grouping.Reducer(
- init_func=lambda x: ([0], 1),
- reduce_func=reduce_fn,
- finalize_func=lambda x, y: (x, y))
-
- for i in range(1, 11):
- dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply(
- grouping.group_by_reducer(lambda x: x, reducer))
- self.assertEqual([None], dataset.output_shapes[0].as_list())
- self.assertIs(None, dataset.output_shapes[1].ndims)
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- x, y = sess.run(get_next)
- self.assertAllEqual([0] * (2**i), x)
- self.assertAllEqual(np.array(1, ndmin=i), y)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testTypeMismatch(self):
- reducer = grouping.Reducer(
- init_func=lambda x: constant_op.constant(1, dtype=dtypes.int32),
- reduce_func=lambda x, y: constant_op.constant(1, dtype=dtypes.int64),
- finalize_func=lambda x: x)
-
- dataset = dataset_ops.Dataset.range(10)
- with self.assertRaisesRegexp(
- TypeError,
- "The element types for the new state must match the initial state."):
- dataset.apply(
- grouping.group_by_reducer(lambda _: np.int64(0), reducer))
-
- # TODO(b/78665031): Remove once non-scalar keys are supported.
- def testInvalidKeyShape(self):
- reducer = grouping.Reducer(
- init_func=lambda x: np.int64(0),
- reduce_func=lambda x, y: x + y,
- finalize_func=lambda x: x)
-
- dataset = dataset_ops.Dataset.range(10)
- with self.assertRaisesRegexp(
- ValueError, "`key_func` must return a single tf.int64 tensor."):
- dataset.apply(
- grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer))
-
- # TODO(b/78665031): Remove once non-int64 keys are supported.
- def testInvalidKeyType(self):
- reducer = grouping.Reducer(
- init_func=lambda x: np.int64(0),
- reduce_func=lambda x, y: x + y,
- finalize_func=lambda x: x)
-
- dataset = dataset_ops.Dataset.range(10)
- with self.assertRaisesRegexp(
- ValueError, "`key_func` must return a single tf.int64 tensor."):
- dataset.apply(
- grouping.group_by_reducer(lambda _: "wrong", reducer))
-
- def testTuple(self):
- def init_fn(_):
- return np.array([], dtype=np.int64), np.int64(0)
-
- def reduce_fn(state, value):
- s1, s2 = state
- v1, v2 = value
- return array_ops.concat([s1, [v1]], 0), s2 + v2
-
- def finalize_fn(s1, s2):
- return s1, s2
-
- reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn)
- dataset = dataset_ops.Dataset.zip(
- (dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply(
- grouping.group_by_reducer(lambda x, y: np.int64(0), reducer))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- x, y = sess.run(get_next)
- self.assertAllEqual(x, np.asarray([x for x in range(10)]))
- self.assertEqual(y, 45)
-
-
-class GroupByWindowTest(test_base.DatasetTestBase):
-
- def testSimple(self):
- components = np.random.randint(100, size=(200,)).astype(np.int64)
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components).map(lambda x: x * x)
- .apply(
- grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4),
- 4)).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- counts = []
- with self.assertRaises(errors.OutOfRangeError):
- while True:
- result = sess.run(get_next)
- self.assertTrue(
- all(x % 2 == 0
- for x in result) or all(x % 2 == 1)
- for x in result)
- counts.append(result.shape[0])
-
- self.assertEqual(len(components), sum(counts))
- num_full_batches = len([c for c in counts if c == 4])
- self.assertGreaterEqual(num_full_batches, 24)
- self.assertTrue(all(c == 4 for c in counts[:num_full_batches]))
-
- def testImmediateOutput(self):
- components = np.array(
- [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply(
- grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4),
- 4)).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- # The input is infinite, so this test demonstrates that:
- # 1. We produce output without having to consume the entire input,
- # 2. Different buckets can produce output at different rates, and
- # 3. For deterministic input, the output is deterministic.
- for _ in range(3):
- self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
- self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
- self.assertAllEqual([2, 2, 2, 2], sess.run(get_next))
- self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
-
- def testSmallGroups(self):
- components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components).apply(
- grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4),
- 4)).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
- self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
- # The small outputs at the end are deterministically produced in key
- # order.
- self.assertAllEqual([0, 0, 0], sess.run(get_next))
- self.assertAllEqual([1], sess.run(get_next))
-
- def testEmpty(self):
- iterator = (
- dataset_ops.Dataset.range(4).apply(
- grouping.group_by_window(lambda _: 0, lambda _, xs: xs, 0))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- "Window size must be greater than zero, but got 0."):
- print(sess.run(get_next))
-
- def testReduceFuncError(self):
- components = np.random.randint(100, size=(200,)).astype(np.int64)
-
- def reduce_func(_, xs):
- # Introduce an incorrect padded shape that cannot (currently) be
- # detected at graph construction time.
- return xs.padded_batch(
- 4,
- padded_shapes=(tensor_shape.TensorShape([]),
- constant_op.constant([5], dtype=dtypes.int64) * -1))
-
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components)
- .map(lambda x: (x, ops.convert_to_tensor([x * x]))).apply(
- grouping.group_by_window(lambda x, _: x % 2, reduce_func,
- 32)).make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
-
- def testConsumeWindowDatasetMoreThanOnce(self):
- components = np.random.randint(50, size=(200,)).astype(np.int64)
-
- def reduce_func(key, window):
- # Apply two different kinds of padding to the input: tight
- # padding, and quantized (to a multiple of 10) padding.
- return dataset_ops.Dataset.zip((
- window.padded_batch(
- 4, padded_shapes=tensor_shape.TensorShape([None])),
- window.padded_batch(
- 4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])),
- ))
-
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components)
- .map(lambda x: array_ops.fill([math_ops.cast(x, dtypes.int32)], x))
- .apply(grouping.group_by_window(
- lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64),
- reduce_func, 4))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- counts = []
- with self.assertRaises(errors.OutOfRangeError):
- while True:
- tight_result, multiple_of_10_result = sess.run(get_next)
- self.assertEqual(0, multiple_of_10_result.shape[1] % 10)
- self.assertAllEqual(tight_result,
- multiple_of_10_result[:, :tight_result.shape[1]])
- counts.append(tight_result.shape[0])
- self.assertEqual(len(components), sum(counts))
-
-
-# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py.
-# Currently, they use a constant batch size, though should be made to use a
-# different batch size per key.
-class BucketTest(test_base.DatasetTestBase):
-
- def _dynamicPad(self, bucket, window, window_size):
- # TODO(mrry): To match `tf.contrib.training.bucket()`, implement a
- # generic form of padded_batch that pads every component
- # dynamically and does not rely on static shape information about
- # the arguments.
- return dataset_ops.Dataset.zip(
- (dataset_ops.Dataset.from_tensors(bucket),
- window.padded_batch(
- 32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape(
- [None]), tensor_shape.TensorShape([3])))))
-
- def testSingleBucket(self):
-
- def _map_fn(v):
- return (v, array_ops.fill([v], v),
- array_ops.fill([3], string_ops.as_string(v)))
-
- input_dataset = (
- dataset_ops.Dataset.from_tensor_slices(math_ops.range(32)).map(_map_fn))
-
- bucketed_dataset = input_dataset.apply(
- grouping.group_by_window(
- lambda x, y, z: 0,
- lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
-
- iterator = bucketed_dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
-
- which_bucket, bucketed_values = sess.run(get_next)
-
- self.assertEqual(0, which_bucket)
-
- expected_scalar_int = np.arange(32, dtype=np.int64)
- expected_unk_int64 = np.zeros((32, 31)).astype(np.int64)
- for i in range(32):
- expected_unk_int64[i, :i] = i
- expected_vec3_str = np.vstack(3 * [np.arange(32).astype(bytes)]).T
-
- self.assertAllEqual(expected_scalar_int, bucketed_values[0])
- self.assertAllEqual(expected_unk_int64, bucketed_values[1])
- self.assertAllEqual(expected_vec3_str, bucketed_values[2])
-
- def testEvenOddBuckets(self):
-
- def _map_fn(v):
- return (v, array_ops.fill([v], v),
- array_ops.fill([3], string_ops.as_string(v)))
-
- input_dataset = (
- dataset_ops.Dataset.from_tensor_slices(math_ops.range(64)).map(_map_fn))
-
- bucketed_dataset = input_dataset.apply(
- grouping.group_by_window(
- lambda x, y, z: math_ops.cast(x % 2, dtypes.int64),
- lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
-
- iterator = bucketed_dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
-
- # Get two minibatches (one containing even values, one containing odds)
- which_bucket_even, bucketed_values_even = sess.run(get_next)
- which_bucket_odd, bucketed_values_odd = sess.run(get_next)
-
- # Count number of bucket_tensors.
- self.assertEqual(3, len(bucketed_values_even))
- self.assertEqual(3, len(bucketed_values_odd))
-
- # Ensure bucket 0 was used for all minibatch entries.
- self.assertAllEqual(0, which_bucket_even)
- self.assertAllEqual(1, which_bucket_odd)
-
- # Test the first bucket outputted, the events starting at 0
- expected_scalar_int = np.arange(0, 32 * 2, 2, dtype=np.int64)
- expected_unk_int64 = np.zeros((32, 31 * 2)).astype(np.int64)
- for i in range(0, 32):
- expected_unk_int64[i, :2 * i] = 2 * i
- expected_vec3_str = np.vstack(
- 3 * [np.arange(0, 32 * 2, 2).astype(bytes)]).T
-
- self.assertAllEqual(expected_scalar_int, bucketed_values_even[0])
- self.assertAllEqual(expected_unk_int64, bucketed_values_even[1])
- self.assertAllEqual(expected_vec3_str, bucketed_values_even[2])
-
- # Test the second bucket outputted, the odds starting at 1
- expected_scalar_int = np.arange(1, 32 * 2 + 1, 2, dtype=np.int64)
- expected_unk_int64 = np.zeros((32, 31 * 2 + 1)).astype(np.int64)
- for i in range(0, 32):
- expected_unk_int64[i, :2 * i + 1] = 2 * i + 1
- expected_vec3_str = np.vstack(
- 3 * [np.arange(1, 32 * 2 + 1, 2).astype(bytes)]).T
-
- self.assertAllEqual(expected_scalar_int, bucketed_values_odd[0])
- self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1])
- self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2])
-
- def testEvenOddBucketsFilterOutAllOdd(self):
-
- def _map_fn(v):
- return {
- "x": v,
- "y": array_ops.fill([v], v),
- "z": array_ops.fill([3], string_ops.as_string(v))
- }
-
- def _dynamic_pad_fn(bucket, window, _):
- return dataset_ops.Dataset.zip(
- (dataset_ops.Dataset.from_tensors(bucket),
- window.padded_batch(
- 32, {
- "x": tensor_shape.TensorShape([]),
- "y": tensor_shape.TensorShape([None]),
- "z": tensor_shape.TensorShape([3])
- })))
-
- input_dataset = (
- dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn)
- .filter(lambda d: math_ops.equal(d["x"] % 2, 0)))
-
- bucketed_dataset = input_dataset.apply(
- grouping.group_by_window(
- lambda d: math_ops.cast(d["x"] % 2, dtypes.int64),
- lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32))
-
- iterator = bucketed_dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
-
- # Get two minibatches ([0, 2, ...] and [64, 66, ...])
- which_bucket0, bucketed_values_even0 = sess.run(get_next)
- which_bucket1, bucketed_values_even1 = sess.run(get_next)
-
- # Ensure that bucket 1 was completely filtered out
- self.assertAllEqual(0, which_bucket0)
- self.assertAllEqual(0, which_bucket1)
- self.assertAllEqual(
- np.arange(0, 64, 2, dtype=np.int64), bucketed_values_even0["x"])
- self.assertAllEqual(
- np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"])
-
- def testDynamicWindowSize(self):
- components = np.arange(100).astype(np.int64)
-
- # Key fn: even/odd
- # Reduce fn: batches of 5
- # Window size fn: even=5, odd=10
-
- def window_size_func(key):
- window_sizes = constant_op.constant([5, 10], dtype=dtypes.int64)
- return window_sizes[key]
-
- dataset = dataset_ops.Dataset.from_tensor_slices(components).apply(
- grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(20),
- None, window_size_func))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- with self.assertRaises(errors.OutOfRangeError):
- batches = 0
- while True:
- result = sess.run(get_next)
- is_even = all(x % 2 == 0 for x in result)
- is_odd = all(x % 2 == 1 for x in result)
- self.assertTrue(is_even or is_odd)
- expected_batch_size = 5 if is_even else 10
- self.assertEqual(expected_batch_size, result.shape[0])
- batches += 1
-
- self.assertEqual(batches, 15)
-
-
-def _element_length_fn(x, y=None):
- del y
- return array_ops.shape(x)[0]
-
-
-def _to_sparse_tensor(record):
- return sparse_tensor.SparseTensor(**record)
-
-
-def _format_record(array, sparse):
- if sparse:
- return {
- "values": array,
- "indices": [[i] for i in range(len(array))],
- "dense_shape": (len(array),)
- }
- return array
-
-
-def _get_record_type(sparse):
- if sparse:
- return {
- "values": dtypes.int64,
- "indices": dtypes.int64,
- "dense_shape": dtypes.int64
- }
- return dtypes.int32
-
-
-def _get_record_shape(sparse):
- if sparse:
- return {
- "values": tensor_shape.TensorShape([None,]),
- "indices": tensor_shape.TensorShape([None, 1]),
- "dense_shape": tensor_shape.TensorShape([1,])
- }
- return tensor_shape.TensorShape([None])
-
-
-class BucketBySequenceLength(test_base.DatasetTestBase):
-
- def testBucket(self):
-
- boundaries = [10, 20, 30]
- batch_sizes = [10, 8, 4, 2]
- lengths = [8, 13, 25, 35]
-
- def build_dataset(sparse):
- def _generator():
- # Produce 1 batch for each bucket
- elements = []
- for batch_size, length in zip(batch_sizes, lengths):
- record_len = length - 1
- for _ in range(batch_size):
- elements.append([1] * record_len)
- record_len = length
- random.shuffle(elements)
- for el in elements:
- yield (_format_record(el, sparse),)
- dataset = dataset_ops.Dataset.from_generator(
- _generator,
- (_get_record_type(sparse),),
- (_get_record_shape(sparse),))
- if sparse:
- dataset = dataset.map(lambda x: (_to_sparse_tensor(x),))
- return dataset
-
- def _test_bucket_by_padding(no_padding):
- dataset = build_dataset(sparse=no_padding)
- dataset = dataset.apply(
- grouping.bucket_by_sequence_length(
- _element_length_fn,
- boundaries,
- batch_sizes,
- no_padding=no_padding))
- batch, = dataset.make_one_shot_iterator().get_next()
-
- with self.cached_session() as sess:
- batches = []
- for _ in range(4):
- batches.append(sess.run(batch))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(batch)
- batch_sizes_val = []
- lengths_val = []
- for batch in batches:
- shape = batch.dense_shape if no_padding else batch.shape
- batch_size = shape[0]
- length = shape[1]
- batch_sizes_val.append(batch_size)
- lengths_val.append(length)
- sum_check = batch.values.sum() if no_padding else batch.sum()
- self.assertEqual(sum_check, batch_size * length - 1)
- self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
- self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
- self.assertEqual(sorted(lengths), sorted(lengths_val))
-
- for no_padding in (True, False):
- _test_bucket_by_padding(no_padding)
-
- def testPadToBoundary(self):
-
- boundaries = [10, 20, 30]
- batch_sizes = [10, 8, 4, 2]
- lengths = [8, 13, 25]
-
- def element_gen():
- # Produce 1 batch for each bucket
- elements = []
- for batch_size, length in zip(batch_sizes[:-1], lengths):
- for _ in range(batch_size):
- elements.append([1] * length)
- random.shuffle(elements)
- for el in elements:
- yield (el,)
- for _ in range(batch_sizes[-1]):
- el = [1] * (boundaries[-1] + 5)
- yield (el,)
-
- element_len = lambda el: array_ops.shape(el)[0]
- dataset = dataset_ops.Dataset.from_generator(
- element_gen, (dtypes.int64,), ([None],)).apply(
- grouping.bucket_by_sequence_length(
- element_len, boundaries, batch_sizes,
- pad_to_bucket_boundary=True))
- batch, = dataset.make_one_shot_iterator().get_next()
-
- with self.cached_session() as sess:
- batches = []
- for _ in range(3):
- batches.append(sess.run(batch))
- with self.assertRaisesOpError("bucket_boundaries"):
- sess.run(batch)
- batch_sizes_val = []
- lengths_val = []
- for batch in batches:
- batch_size = batch.shape[0]
- length = batch.shape[1]
- batch_sizes_val.append(batch_size)
- lengths_val.append(length)
- batch_sizes = batch_sizes[:-1]
- self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
- self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
- self.assertEqual([boundary - 1 for boundary in sorted(boundaries)],
- sorted(lengths_val))
-
- def testPadToBoundaryNoExtraneousPadding(self):
-
- boundaries = [3, 7, 11]
- batch_sizes = [2, 2, 2, 2]
- lengths = range(1, 11)
-
- def element_gen():
- for length in lengths:
- yield ([1] * length,)
-
- element_len = lambda element: array_ops.shape(element)[0]
- dataset = dataset_ops.Dataset.from_generator(
- element_gen, (dtypes.int64,), ([None],)).apply(
- grouping.bucket_by_sequence_length(
- element_len, boundaries, batch_sizes,
- pad_to_bucket_boundary=True))
- batch, = dataset.make_one_shot_iterator().get_next()
-
- with self.cached_session() as sess:
- batches = []
- for _ in range(5):
- batches.append(sess.run(batch))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(batch)
-
- self.assertAllEqual(batches[0], [[1, 0],
- [1, 1]])
- self.assertAllEqual(batches[1], [[1, 1, 1, 0, 0, 0],
- [1, 1, 1, 1, 0, 0]])
- self.assertAllEqual(batches[2], [[1, 1, 1, 1, 1, 0],
- [1, 1, 1, 1, 1, 1]])
- self.assertAllEqual(batches[3], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
- self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
- [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
-
- def testTupleElements(self):
-
- def build_dataset(sparse):
- def _generator():
- text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]]
- label = [1, 2, 1, 2]
- for x, y in zip(text, label):
- yield (_format_record(x, sparse), y)
- dataset = dataset_ops.Dataset.from_generator(
- generator=_generator,
- output_types=(_get_record_type(sparse), dtypes.int32),
- output_shapes=(_get_record_shape(sparse),
- tensor_shape.TensorShape([])))
- if sparse:
- dataset = dataset.map(lambda x, y: (_to_sparse_tensor(x), y))
- return dataset
-
- def _test_tuple_elements_by_padding(no_padding):
- dataset = build_dataset(sparse=no_padding)
- dataset = dataset.apply(grouping.bucket_by_sequence_length(
- element_length_func=_element_length_fn,
- bucket_batch_sizes=[2, 2, 2],
- bucket_boundaries=[0, 8],
- no_padding=no_padding))
- shapes = dataset.output_shapes
- self.assertEqual([None, None], shapes[0].as_list())
- self.assertEqual([None], shapes[1].as_list())
-
- for no_padding in (True, False):
- _test_tuple_elements_by_padding(no_padding)
-
- def testBucketSparse(self):
- """Tests bucketing of sparse tensors (case where `no_padding` == True).
-
- Test runs on following dataset:
- [
- [0],
- [0, 1],
- [0, 1, 2]
- ...
- [0, ..., max_len - 1]
- ]
- Sequences are bucketed by length and batched with
- `batch_size` < `bucket_size`.
- """
-
- min_len = 0
- max_len = 100
- batch_size = 7
- bucket_size = 10
-
- def _build_dataset():
- input_data = [range(i+1) for i in range(min_len, max_len)]
- def generator_fn():
- for record in input_data:
- yield _format_record(record, sparse=True)
- dataset = dataset_ops.Dataset.from_generator(
- generator=generator_fn,
- output_types=_get_record_type(sparse=True))
- dataset = dataset.map(_to_sparse_tensor)
- return dataset
-
- def _compute_expected_batches():
- """Computes expected batch outputs and stores in a set."""
- all_expected_sparse_tensors = set()
- for bucket_start_len in range(min_len, max_len, bucket_size):
- for batch_offset in range(0, bucket_size, batch_size):
- batch_start_len = bucket_start_len + batch_offset
- batch_end_len = min(batch_start_len + batch_size,
- bucket_start_len + bucket_size)
- expected_indices = []
- expected_values = []
- for length in range(batch_start_len, batch_end_len):
- for val in range(length + 1):
- expected_indices.append((length - batch_start_len, val))
- expected_values.append(val)
- expected_sprs_tensor = (tuple(expected_indices),
- tuple(expected_values))
- all_expected_sparse_tensors.add(expected_sprs_tensor)
- return all_expected_sparse_tensors
-
- def _compute_batches(dataset):
- """Computes actual batch outputs of dataset and stores in a set."""
- batch = dataset.make_one_shot_iterator().get_next()
- all_sparse_tensors = set()
- with self.cached_session() as sess:
- with self.assertRaises(errors.OutOfRangeError):
- while True:
- output = sess.run(batch)
- sprs_tensor = (tuple([tuple(idx) for idx in output.indices]),
- tuple(output.values))
- all_sparse_tensors.add(sprs_tensor)
- return all_sparse_tensors
-
- dataset = _build_dataset()
- boundaries = range(min_len + bucket_size + 1, max_len, bucket_size)
- dataset = dataset.apply(grouping.bucket_by_sequence_length(
- _element_length_fn,
- boundaries,
- [batch_size] * (len(boundaries) + 1),
- no_padding=True))
- batches = _compute_batches(dataset)
- expected_batches = _compute_expected_batches()
- self.assertEqual(batches, expected_batches)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/prefetching_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/copy_to_device_test.py
index 7d7b842c17..adfacf1c9f 100644
--- a/tensorflow/python/data/experimental/kernel_tests/prefetching_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/copy_to_device_test.py
@@ -12,440 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for prefetching_ops."""
+"""Tests for `tf.data.experimental.copy_to_device()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import threading
-
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.compat import compat
from tensorflow.python.data.experimental.ops import prefetching_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
-from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
-from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
-class PrefetchingKernelsOpsTest(test_base.DatasetTestBase):
-
- def setUp(self):
- self._event = threading.Event()
-
- def _create_ds_and_iterator(self, device0, initializable=False):
-
- def gen():
- for i in range(1, 10):
- yield [float(i)]
- if i == 6:
- self._event.set()
-
- with ops.device(device0):
- ds = dataset_ops.Dataset.from_generator(gen, (dtypes.float32))
- if initializable:
- ds_iterator = ds.make_initializable_iterator()
- else:
- ds_iterator = ds.make_one_shot_iterator()
- return (ds, ds_iterator)
-
- def _create_ops(self, ds, ds_iterator, buffer_name, device0, device1):
- ds_iterator_handle = ds_iterator.string_handle()
-
- @function.Defun(dtypes.string)
- def _remote_fn(h):
- remote_iterator = iterator_ops.Iterator.from_string_handle(
- h, ds.output_types, ds.output_shapes)
- return remote_iterator.get_next()
-
- target = constant_op.constant(device0)
- with ops.device(device1):
- buffer_resource_handle = prefetching_ops.function_buffering_resource(
- f=_remote_fn,
- output_types=[dtypes.float32],
- target_device=target,
- string_arg=ds_iterator_handle,
- buffer_size=3,
- shared_name=buffer_name)
-
- with ops.device(device1):
- prefetch_op = prefetching_ops.function_buffering_resource_get_next(
- function_buffer_resource=buffer_resource_handle,
- output_types=[dtypes.float32])
- reset_op = prefetching_ops.function_buffering_resource_reset(
- function_buffer_resource=buffer_resource_handle)
- destroy_op = resource_variable_ops.destroy_resource_op(
- buffer_resource_handle, ignore_lookup_error=True)
-
- return (prefetch_op, reset_op, destroy_op)
-
- def _prefetch_fn_helper_one_shot(self, buffer_name, device0, device1):
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
-
- ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=False)
- prefetch_op, _, destroy_op = self._create_ops(ds, ds_iterator, buffer_name,
- device0, device1)
-
- with self.test_session(config=worker_config) as sess:
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [1.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [2.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [3.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [4.0])
- self._event.wait()
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [5.0])
- sess.run(destroy_op)
-
- def testSameDeviceCPU(self):
- self._prefetch_fn_helper_one_shot("same_device_cpu",
- "/job:localhost/replica:0/task:0/cpu:0",
- "/job:localhost/replica:0/task:0/cpu:0")
-
- def testDifferentDeviceCPU(self):
- self._prefetch_fn_helper_one_shot("diff_device_cpu",
- "/job:localhost/replica:0/task:0/cpu:0",
- "/job:localhost/replica:0/task:0/cpu:1")
-
- def testDifferentDeviceCPUGPU(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- self._prefetch_fn_helper_one_shot("cpu_gpu",
- "/job:localhost/replica:0/task:0/cpu:0",
- "/job:localhost/replica:0/task:0/gpu:0")
-
- def testReinitialization(self):
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
-
- device0 = "/job:localhost/replica:0/task:0/cpu:0"
- device1 = "/job:localhost/replica:0/task:0/cpu:1"
- ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True)
- prefetch_op, reset_op, destroy_op = self._create_ops(
- ds, ds_iterator, "reinit", device0, device1)
-
- with self.test_session(config=worker_config) as sess:
- sess.run(ds_iterator.initializer)
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [1.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [2.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [3.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [4.0])
- self._event.wait()
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [5.0])
- # Lets reset the function buffering resource and reinitialize the
- # iterator. Should be able to go through this again.
- self._event.clear()
- sess.run(reset_op)
- sess.run(ds_iterator.initializer)
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [1.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [2.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [3.0])
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [4.0])
- self._event.wait()
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [5.0])
- sess.run(destroy_op)
-
- def testReinitializationOutOfRange(self):
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
-
- device0 = "/job:localhost/replica:0/task:0/cpu:0"
- device1 = "/job:localhost/replica:0/task:0/cpu:1"
- ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True)
- prefetch_op, reset_op, destroy_op = self._create_ops(
- ds, ds_iterator, "reinit", device0, device1)
-
- with self.test_session(config=worker_config) as sess:
- sess.run(ds_iterator.initializer)
- for i in range(1, 10):
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [float(i)])
- # Try fetching after its over twice to test out end of sequence.
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(prefetch_op)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(prefetch_op)
-
- # Now reset everything and try it out again.
- self._event.clear()
- sess.run(reset_op)
- sess.run(ds_iterator.initializer)
- for i in range(1, 10):
- elem = sess.run(prefetch_op)
- self.assertEqual(elem, [float(i)])
- # Try fetching after its over twice to test out end of sequence.
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(prefetch_op)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(prefetch_op)
-
- sess.run(destroy_op)
-
- def testStringsGPU(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- device0 = "/job:localhost/replica:0/task:0/cpu:0"
- device1 = "/job:localhost/replica:0/task:0/gpu:0"
-
- ds = dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"])
- ds_iterator = ds.make_one_shot_iterator()
- ds_iterator_handle = ds_iterator.string_handle()
-
- @function.Defun(dtypes.string)
- def _remote_fn(h):
- remote_iterator = iterator_ops.Iterator.from_string_handle(
- h, ds.output_types, ds.output_shapes)
- return remote_iterator.get_next()
-
- target = constant_op.constant(device0)
- with ops.device(device1):
- buffer_resource_handle = prefetching_ops.function_buffering_resource(
- f=_remote_fn,
- output_types=[dtypes.string],
- target_device=target,
- string_arg=ds_iterator_handle,
- buffer_size=3,
- shared_name="strings")
-
- with ops.device(device1):
- prefetch_op = prefetching_ops.function_buffering_resource_get_next(
- function_buffer_resource=buffer_resource_handle,
- output_types=[dtypes.string])
- destroy_op = resource_variable_ops.destroy_resource_op(
- buffer_resource_handle, ignore_lookup_error=True)
-
- with self.cached_session() as sess:
- self.assertEqual([b"a"], sess.run(prefetch_op))
- self.assertEqual([b"b"], sess.run(prefetch_op))
- self.assertEqual([b"c"], sess.run(prefetch_op))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(prefetch_op)
-
- sess.run(destroy_op)
-
-
-class PrefetchToDeviceTest(test_base.DatasetTestBase):
-
- def testPrefetchToDevice(self):
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device("/cpu:1"))
-
- # NOTE(mrry): This device block creates the "host" dataset and iterator on
- # /cpu:0, and ensures that the prefetching is across devices. In typical use
- # this would not be necessary, because the GPU device would not support any
- # of the dataset-related ops.
- with ops.device("/cpu:0"):
- iterator = device_dataset.make_one_shot_iterator()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- next_element = iterator.get_next()
- self.assertEqual(dtypes.int64, next_element.dtype)
- self.assertEqual([], next_element.shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchToSameDevice(self):
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device(
- "/job:localhost/replica:0/task:0/device:CPU:0"))
-
- # NOTE(mrry): This device block creates the "host" dataset and iterator on
- # /cpu:0, and ensures that the prefetching is across devices. In typical use
- # this would not be necessary, because the GPU device would not support any
- # of the dataset-related ops.
- with ops.device("/cpu:0"):
- iterator = device_dataset.make_one_shot_iterator()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- next_element = iterator.get_next()
- self.assertEqual(dtypes.int64, next_element.dtype)
- self.assertEqual([], next_element.shape)
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchDictToDevice(self):
- host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device("/cpu:1"))
-
- # NOTE(mrry): This device block creates the "host" dataset and iterator on
- # /cpu:0, and ensures that the prefetching is across devices. In typical use
- # this would not be necessary, because the GPU device would not support any
- # of the dataset-related ops.
- with ops.device("/cpu:0"):
- iterator = device_dataset.make_one_shot_iterator()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- next_element = iterator.get_next()
- self.assertEqual(dtypes.int64, next_element["a"].dtype)
- self.assertEqual([], next_element["a"].shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- for i in range(10):
- self.assertEqual({"a": i}, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchSparseTensorsToDevice(self):
- def make_tensor(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0, 0]], values=(i*[1]), dense_shape=[2, 2])
- host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)
-
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device("/cpu:1"))
-
- # NOTE(mrry): This device block creates the "host" dataset and iterator on
- # /cpu:0, and ensures that the prefetching is across devices. In typical use
- # this would not be necessary, because the GPU device would not support any
- # of the dataset-related ops.
- with ops.device("/cpu:0"):
- iterator = device_dataset.make_one_shot_iterator()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- next_element = iterator.get_next()
- self.assertEqual(dtypes.int64, next_element.dtype)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- for i in range(10):
- actual = sess.run(next_element)
- self.assertAllEqual([i], actual.values)
- self.assertAllEqual([[0, 0]], actual.indices)
- self.assertAllEqual([2, 2], actual.dense_shape)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchToDeviceGpu(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device("/gpu:0"))
-
- iterator = device_dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchToDeviceWithReInit(self):
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device("/cpu:1"))
-
- # NOTE(mrry): This device block creates the "host" dataset and iterator on
- # /cpu:0, and ensures that the prefetching is across devices. In typical use
- # this would not be necessary, because the GPU device would not support any
- # of the dataset-related ops.
- with ops.device("/cpu:0"):
- iterator = device_dataset.make_initializable_iterator()
-
- self.assertEqual(host_dataset.output_types, device_dataset.output_types)
- self.assertEqual(host_dataset.output_types, iterator.output_types)
- self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
- self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
- self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
- self.assertEqual(host_dataset.output_classes, iterator.output_classes)
-
- next_element = iterator.get_next()
- self.assertEqual(dtypes.int64, next_element.dtype)
- self.assertEqual([], next_element.shape)
-
- worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
- with self.test_session(config=worker_config) as sess:
- sess.run(iterator.initializer)
- for i in range(5):
- self.assertEqual(i, sess.run(next_element))
- sess.run(iterator.initializer)
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPrefetchToDeviceGpuWithReInit(self):
- if not test_util.is_gpu_available():
- self.skipTest("No GPU available")
-
- host_dataset = dataset_ops.Dataset.range(10)
- device_dataset = host_dataset.apply(
- prefetching_ops.prefetch_to_device("/gpu:0"))
-
- iterator = device_dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(iterator.initializer)
- for i in range(5):
- self.assertEqual(i, sess.run(next_element))
- sess.run(iterator.initializer)
- for i in range(10):
- self.assertEqual(i, sess.run(next_element))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
-
class CopyToDeviceTest(test_base.DatasetTestBase):
def testCopyToDevice(self):
diff --git a/tensorflow/python/data/experimental/kernel_tests/counter_test.py b/tensorflow/python/data/experimental/kernel_tests/counter_test.py
new file mode 100644
index 0000000000..4e114ac479
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/counter_test.py
@@ -0,0 +1,51 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.Counter`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.ops import counter
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.framework import dtypes
+from tensorflow.python.platform import test
+
+
+class CounterTest(test_base.DatasetTestBase):
+
+ def testCounter(self):
+ """Test dataset construction using `count`."""
+ iterator = (counter.Counter(start=3, step=4)
+ .make_one_shot_iterator())
+ get_next = iterator.get_next()
+ self.assertEqual([], get_next.shape.as_list())
+ self.assertEqual(dtypes.int64, get_next.dtype)
+
+ negative_iterator = (counter.Counter(start=0, step=-1)
+ .make_one_shot_iterator())
+ negative_get_next = negative_iterator.get_next()
+
+ with self.cached_session() as sess:
+ self.assertEqual(3, sess.run(get_next))
+ self.assertEqual(3 + 4, sess.run(get_next))
+ self.assertEqual(3 + 2 * 4, sess.run(get_next))
+
+ self.assertEqual(0, sess.run(negative_get_next))
+ self.assertEqual(-1, sess.run(negative_get_next))
+ self.assertEqual(-2, sess.run(negative_get_next))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/csv_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py
index 4ee1779710..fb75be1fbc 100644
--- a/tensorflow/python/data/experimental/kernel_tests/csv_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for CsvDatasetOp."""
+"""Tests for `tf.data.experimental.CsvDataset`."""
from __future__ import absolute_import
from __future__ import division
@@ -44,7 +44,7 @@ from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
-class CsvDatasetOpTest(test_base.DatasetTestBase):
+class CsvDatasetTest(test_base.DatasetTestBase):
def _setup_files(self, inputs, linebreak='\n', compression_type=None):
filenames = []
diff --git a/tensorflow/python/data/experimental/kernel_tests/dataset_serialization_test_base.py b/tensorflow/python/data/experimental/kernel_tests/dataset_serialization_test_base.py
deleted file mode 100644
index 7f435b8239..0000000000
--- a/tensorflow/python/data/experimental/kernel_tests/dataset_serialization_test_base.py
+++ /dev/null
@@ -1,692 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Base class for testing serializable datasets."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-
-import numpy as np
-
-from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
-from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import lookup_ops
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import gfile
-from tensorflow.python.platform import test
-from tensorflow.python.training import checkpoint_management
-from tensorflow.python.training import saver as saver_lib
-from tensorflow.python.util import nest
-
-
-def remove_variants(get_next_op):
- # TODO(b/72408568): Remove this once session.run can get
- # variant tensors.
- """Remove variants from a nest structure, so sess.run will execute."""
-
- def _remove_variant(x):
- if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant:
- return ()
- else:
- return x
-
- return nest.map_structure(_remove_variant, get_next_op)
-
-
-class DatasetSerializationTestBase(test.TestCase):
- """Base class for testing serializable datasets."""
-
- def tearDown(self):
- self._delete_ckpt()
-
- # TODO(b/72657739): Remove sparse_tensor argument, which is to test the
- # (deprecated) saveable `SparseTensorSliceDataset`, once the API
- # `from_sparse_tensor_slices()`and related tests are deleted.
- def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False):
- """Runs the core tests.
-
- Args:
- ds_fn1: 0-argument function that returns a Dataset.
- ds_fn2: 0-argument function that returns a Dataset different from
- ds_fn1. If None, verify_restore_in_modified_graph test is not run.
- num_outputs: Total number of outputs expected from this Dataset.
- sparse_tensors: Whether dataset is built from SparseTensor(s).
-
- Raises:
- AssertionError if any test fails.
- """
- self.verify_unused_iterator(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- self.verify_fully_used_iterator(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- self.verify_exhausted_iterator(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- self.verify_init_before_restore(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- self.verify_multiple_breaks(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- self.verify_reset_restored_iterator(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- self.verify_restore_in_empty_graph(
- ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
- if ds_fn2:
- self.verify_restore_in_modified_graph(
- ds_fn1, ds_fn2, num_outputs, sparse_tensors=sparse_tensors)
-
- def verify_unused_iterator(self,
- ds_fn,
- num_outputs,
- sparse_tensors=False,
- verify_exhausted=True):
- """Verifies that saving and restoring an unused iterator works.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- self.verify_run_with_breaks(
- ds_fn, [0],
- num_outputs,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- def verify_fully_used_iterator(self, ds_fn, num_outputs,
- sparse_tensors=False):
- """Verifies that saving and restoring a fully used iterator works.
-
- Note that this only checks saving and restoring an iterator from which
- `num_outputs` items have been produced but does not check for an
- exhausted iterator, i.e., one from which an OutOfRange error has been
- returned.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- sparse_tensors: See `run_core_tests`.
-
- Raises:
- AssertionError if test fails.
- """
- self.verify_run_with_breaks(
- ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors)
-
- def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False):
- """Verifies that saving and restoring an exhausted iterator works.
-
- An exhausted iterator is one which has returned an OutOfRange error.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- sparse_tensors: See `run_core_tests`.
-
- Raises:
- AssertionError if any test fails.
- """
- self.gen_outputs(
- ds_fn, [],
- num_outputs,
- verify_exhausted=True,
- sparse_tensors=sparse_tensors)
- actual = self.gen_outputs(
- ds_fn, [],
- 0,
- ckpt_saved=True,
- verify_exhausted=True,
- sparse_tensors=sparse_tensors)
- self.assertEqual(len(actual), 0)
-
- def verify_init_before_restore(self,
- ds_fn,
- num_outputs,
- sparse_tensors=False,
- verify_exhausted=True):
- """Verifies that restoring into an already initialized iterator works.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- self.verify_run_with_breaks(
- ds_fn,
- self.gen_break_points(num_outputs),
- num_outputs,
- init_before_restore=True,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- def verify_multiple_breaks(self,
- ds_fn,
- num_outputs,
- num_breaks=10,
- sparse_tensors=False,
- verify_exhausted=True):
- """Attempts to save/restore at multiple break points.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- num_breaks: The number of break points. These are uniformly spread in
- [0, num_outputs] both inclusive.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- self.verify_run_with_breaks(
- ds_fn,
- self.gen_break_points(num_outputs, num_breaks),
- num_outputs,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- def verify_reset_restored_iterator(self,
- ds_fn,
- num_outputs,
- break_point=None,
- sparse_tensors=False,
- verify_exhausted=True):
- """Attempts to re-initialize a restored iterator.
-
- This is useful when restoring a training checkpoint during validation.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- break_point: Break point. Optional. Defaults to num_outputs/2.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- break_point = num_outputs // 2 if not break_point else break_point
-
- # Collect ground truth containing all outputs.
- expected = self.gen_outputs(
- ds_fn, [],
- num_outputs,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- # Skip some items and save checkpoint.
- self.gen_outputs(
- ds_fn, [],
- break_point,
- sparse_tensors=sparse_tensors,
- verify_exhausted=False)
-
- actual = []
- # Restore from checkpoint and then run init_op.
- with ops.Graph().as_default() as g:
- saver = self._import_meta_graph()
- init_op, get_next_op = self._get_iterator_ops_from_collection(
- ds_fn, sparse_tensors=sparse_tensors)
- get_next_op = remove_variants(get_next_op)
- with self.session(graph=g) as sess:
- self._restore(saver, sess)
- self._initialize(init_op, sess)
- for _ in range(num_outputs):
- actual.append(sess.run(get_next_op))
- if verify_exhausted:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
- self.match(expected, actual)
-
- def verify_restore_in_modified_graph(self,
- ds_fn1,
- ds_fn2,
- num_outputs,
- break_point=None,
- sparse_tensors=False,
- verify_exhausted=True):
- """Attempts to restore an iterator in a modified graph.
-
- Builds an input pipeline using ds_fn1, runs it for `break_point` steps
- and saves a checkpoint. Then builds a new graph using ds_fn2, restores
- the checkpoint from ds_fn1 and verifies that the restore is successful.
-
- Args:
- ds_fn1: See `run_core_tests`.
- ds_fn2: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- break_point: Break point. Optional. Defaults to num_outputs/2.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- break_point = num_outputs // 2 if not break_point else break_point
-
- # Skip `break_point` items and store the remaining produced from ds_fn1
- # in `expected`.
- self.gen_outputs(
- ds_fn1, [],
- break_point,
- sparse_tensors=sparse_tensors,
- verify_exhausted=False)
- expected = self.gen_outputs(
- ds_fn1, [],
- num_outputs - break_point,
- ckpt_saved=True,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- # Generate `break_point` items from ds_fn1 and save checkpoint.
- self.gen_outputs(
- ds_fn1, [],
- break_point,
- sparse_tensors=sparse_tensors,
- verify_exhausted=False)
-
- actual = []
- # Build graph for ds_fn2 but load checkpoint for ds_fn1.
- with ops.Graph().as_default() as g:
- _, get_next_op, saver = self._build_graph(
- ds_fn2, sparse_tensors=sparse_tensors)
- get_next_op = remove_variants(get_next_op)
- with self.session(graph=g) as sess:
- self._restore(saver, sess)
- for _ in range(num_outputs - break_point):
- actual.append(sess.run(get_next_op))
- if verify_exhausted:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
- self.match(expected, actual)
-
- def verify_restore_in_empty_graph(self,
- ds_fn,
- num_outputs,
- break_point=None,
- sparse_tensors=False,
- verify_exhausted=True):
- """Attempts to restore an iterator in an empty graph.
-
- Builds an input pipeline using ds_fn, runs it for `break_point` steps
- and saves a checkpoint. Then builds a new empty graph, restores
- the checkpoint from ds_fn and verifies that the restore is successful.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- break_point: Break point. Optional. Defaults to num_outputs/2.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- break_point = num_outputs // 2 if not break_point else break_point
-
- # Skip `break_point` items and store the remaining produced from ds_fn
- # in `expected`.
- self.gen_outputs(
- ds_fn, [],
- break_point,
- sparse_tensors=sparse_tensors,
- verify_exhausted=False)
- expected = self.gen_outputs(
- ds_fn, [],
- num_outputs - break_point,
- ckpt_saved=True,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- # Generate `break_point` items from ds_fn and save checkpoint.
- self.gen_outputs(
- ds_fn, [],
- break_point,
- sparse_tensors=sparse_tensors,
- verify_exhausted=False)
-
- actual = []
- # Build an empty graph but load checkpoint for ds_fn.
- with ops.Graph().as_default() as g:
- get_next_op, saver = self._build_empty_graph(
- ds_fn, sparse_tensors=sparse_tensors)
- get_next_op = remove_variants(get_next_op)
- with self.session(graph=g) as sess:
- self._restore(saver, sess)
- for _ in range(num_outputs - break_point):
- actual.append(sess.run(get_next_op))
- if verify_exhausted:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
-
- self.match(expected, actual)
-
- def verify_error_on_save(self,
- ds_fn,
- num_outputs,
- error,
- break_point=None,
- sparse_tensors=False):
- """Attempts to save a non-saveable iterator.
-
- Args:
- ds_fn: See `run_core_tests`.
- num_outputs: See `run_core_tests`.
- error: Declared error when trying to save iterator.
- break_point: Break point. Optional. Defaults to num_outputs/2.
- sparse_tensors: See `run_core_tests`.
-
- Raises:
- AssertionError if any test fails.
- """
-
- break_point = num_outputs // 2 if not break_point else break_point
- with ops.Graph().as_default() as g:
- init_op, get_next_op, saver = self._build_graph(
- ds_fn, sparse_tensors=sparse_tensors)
- get_next_op = remove_variants(get_next_op)
- with self.session(graph=g) as sess:
- self._initialize(init_op, sess)
- for _ in range(break_point):
- sess.run(get_next_op)
- with self.assertRaises(error):
- self._save(sess, saver)
-
- def verify_run_with_breaks(self,
- ds_fn,
- break_points,
- num_outputs,
- init_before_restore=False,
- sparse_tensors=False,
- verify_exhausted=True):
- """Verifies that ds_fn() produces the same outputs with and without breaks.
-
- 1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
- *without* stopping at break points.
- 2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
- with stopping at break points.
-
- Deep matches outputs from 1 and 2.
-
- Args:
- ds_fn: See `gen_outputs`.
- break_points: See `gen_outputs`.
- num_outputs: See `gen_outputs`.
- init_before_restore: See `gen_outputs`.
- sparse_tensors: See `run_core_tests`.
- verify_exhausted: See `gen_outputs`.
-
- Raises:
- AssertionError if any test fails.
- """
- expected = self.gen_outputs(
- ds_fn, [],
- num_outputs,
- init_before_restore=init_before_restore,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- actual = self.gen_outputs(
- ds_fn,
- break_points,
- num_outputs,
- init_before_restore=init_before_restore,
- sparse_tensors=sparse_tensors,
- verify_exhausted=verify_exhausted)
-
- self.match(expected, actual)
-
- def gen_outputs(self,
- ds_fn,
- break_points,
- num_outputs,
- ckpt_saved=False,
- init_before_restore=False,
- sparse_tensors=False,
- verify_exhausted=True,
- save_checkpoint_at_end=True):
- """Generates elements from input dataset while stopping at break points.
-
- Produces `num_outputs` outputs and saves the state of the iterator in the
- Saver checkpoint.
-
- Args:
- ds_fn: 0-argument function that returns the dataset.
- break_points: A list of integers. For each `break_point` in
- `break_points`, we produce outputs till `break_point` number of items
- have been produced and then checkpoint the state. The current graph
- and session are destroyed and a new graph and session are used to
- produce outputs till next checkpoint or till `num_outputs` elements
- have been produced. `break_point` must be <= `num_outputs`.
- num_outputs: The total number of outputs to produce from the iterator.
- ckpt_saved: Whether a checkpoint already exists. If False, we build the
- graph from ds_fn.
- init_before_restore: Whether init should be called before saver.restore.
- This is just so that we can verify that restoring an already initialized
- iterator works.
- sparse_tensors: Whether dataset is built from SparseTensor(s).
- verify_exhausted: Whether to verify that the iterator has been exhausted
- after producing `num_outputs` elements.
- save_checkpoint_at_end: Whether to save a checkpoint after producing all
- outputs. If False, checkpoints are saved each break point but not at the
- end. Note that checkpoints overwrite each other so there is always only
- a single checkpoint available. Defaults to True.
-
- Returns:
- A list of `num_outputs` items.
- """
- outputs = []
-
- def get_ops():
- if ckpt_saved:
- saver = self._import_meta_graph()
- init_op, get_next_op = self._get_iterator_ops_from_collection(
- ds_fn, sparse_tensors=sparse_tensors)
- else:
- init_op, get_next_op, saver = self._build_graph(
- ds_fn, sparse_tensors=sparse_tensors)
- return init_op, get_next_op, saver
-
- for i in range(len(break_points) + 1):
- with ops.Graph().as_default() as g:
- init_op, get_next_op, saver = get_ops()
- get_next_op = remove_variants(get_next_op)
- with self.session(graph=g) as sess:
- if ckpt_saved:
- if init_before_restore:
- self._initialize(init_op, sess)
- self._restore(saver, sess)
- else:
- self._initialize(init_op, sess)
- start = break_points[i - 1] if i > 0 else 0
- end = break_points[i] if i < len(break_points) else num_outputs
- num_iters = end - start
- for _ in range(num_iters):
- outputs.append(sess.run(get_next_op))
- if i == len(break_points) and verify_exhausted:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next_op)
- if save_checkpoint_at_end or i < len(break_points):
- self._save(sess, saver)
- ckpt_saved = True
-
- return outputs
-
- def match(self, expected, actual):
- """Matches nested structures.
-
- Recursively matches shape and values of `expected` and `actual`.
- Handles scalars, numpy arrays and other python sequence containers
- e.g. list, dict.
-
- Args:
- expected: Nested structure 1.
- actual: Nested structure 2.
-
- Raises:
- AssertionError if matching fails.
- """
- if isinstance(expected, np.ndarray):
- expected = expected.tolist()
- if isinstance(actual, np.ndarray):
- actual = actual.tolist()
- self.assertEqual(type(expected), type(actual))
-
- if nest.is_sequence(expected):
- self.assertEqual(len(expected), len(actual))
- if isinstance(expected, dict):
- for key1, key2 in zip(sorted(expected), sorted(actual)):
- self.assertEqual(key1, key2)
- self.match(expected[key1], actual[key2])
- else:
- for item1, item2 in zip(expected, actual):
- self.match(item1, item2)
- else:
- self.assertEqual(expected, actual)
-
- def does_not_match(self, expected, actual):
- with self.assertRaises(AssertionError):
- self.match(expected, actual)
-
- def gen_break_points(self, num_outputs, num_samples=10):
- """Generates `num_samples` breaks points in [0, num_outputs]."""
- return np.linspace(0, num_outputs, num_samples, dtype=int)
-
- def _build_graph(self, ds_fn, sparse_tensors=False):
- iterator = ds_fn().make_initializable_iterator()
-
- saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
- ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
- init_op = iterator.initializer
- if sparse_tensors:
- get_next = sparse_tensor.SparseTensor(*iterator.get_next())
- else:
- get_next = iterator.get_next()
- self._add_iterator_ops_to_collection(init_op, get_next, ds_fn,
- sparse_tensors)
- saver = saver_lib.Saver(allow_empty=True)
- return init_op, get_next, saver
-
- def _build_empty_graph(self, ds_fn, sparse_tensors=False):
- iterator = iterator_ops.Iterator.from_structure(
- self._get_output_types(ds_fn),
- output_shapes=self._get_output_shapes(ds_fn),
- output_classes=self._get_output_classes(ds_fn))
- saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
- ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
- if sparse_tensors:
- get_next = sparse_tensor.SparseTensor(*iterator.get_next())
- else:
- get_next = iterator.get_next()
- saver = saver_lib.Saver(allow_empty=True)
- return get_next, saver
-
- def _add_iterator_ops_to_collection(self,
- init_op,
- get_next,
- ds_fn,
- sparse_tensors=False):
- ops.add_to_collection("iterator_ops", init_op)
- # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections
- # do not support tuples we flatten the tensors and restore the shape in
- # `_get_iterator_ops_from_collection`.
- if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
- ops.add_to_collection("iterator_ops", get_next.indices)
- ops.add_to_collection("iterator_ops", get_next.values)
- ops.add_to_collection("iterator_ops", get_next.dense_shape)
- return
-
- get_next_list = nest.flatten(get_next)
- for i, output_class in enumerate(
- nest.flatten(self._get_output_classes(ds_fn))):
- if output_class is sparse_tensor.SparseTensor:
- ops.add_to_collection("iterator_ops", get_next_list[i].indices)
- ops.add_to_collection("iterator_ops", get_next_list[i].values)
- ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape)
- else:
- ops.add_to_collection("iterator_ops", get_next_list[i])
-
- def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False):
- all_ops = ops.get_collection("iterator_ops")
- if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
- init_op, indices, values, dense_shape = all_ops
- return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape)
- get_next_list = []
- i = 1
- for output_class in nest.flatten(self._get_output_classes(ds_fn)):
- if output_class is sparse_tensor.SparseTensor:
- indices, values, dense_shape = all_ops[i:i + 3]
- i += 3
- get_next_list.append(
- sparse_tensor.SparseTensor(indices, values, dense_shape))
- else:
- get_next_list.append(all_ops[i])
- i += 1
- return all_ops[0], nest.pack_sequence_as(
- self._get_output_types(ds_fn), get_next_list)
-
- def _get_output_types(self, ds_fn):
- with ops.Graph().as_default():
- return ds_fn().output_types
-
- def _get_output_shapes(self, ds_fn):
- with ops.Graph().as_default():
- return ds_fn().output_shapes
-
- def _get_output_classes(self, ds_fn):
- with ops.Graph().as_default():
- return ds_fn().output_classes
-
- def _ckpt_path(self):
- return os.path.join(self.get_temp_dir(), "iterator")
-
- def _latest_ckpt(self):
- return checkpoint_management.latest_checkpoint(self.get_temp_dir())
-
- def _save(self, sess, saver):
- saver.save(sess, self._ckpt_path())
-
- def _restore(self, saver, sess):
- sess.run(lookup_ops.tables_initializer())
- saver.restore(sess, self._latest_ckpt())
-
- def _initialize(self, init_op, sess):
- sess.run(variables.global_variables_initializer())
- sess.run(lookup_ops.tables_initializer())
- sess.run(init_op)
-
- def _import_meta_graph(self):
- meta_file_path = self._ckpt_path() + ".meta"
- return saver_lib.import_meta_graph(meta_file_path)
-
- def _delete_ckpt(self):
- # Remove all checkpoint files.
- prefix = self._ckpt_path()
- pattern = prefix + "*"
- files = gfile.Glob(pattern)
- map(gfile.Remove, files)
diff --git a/tensorflow/python/data/experimental/kernel_tests/dense_to_sparse_batch_test.py b/tensorflow/python/data/experimental/kernel_tests/dense_to_sparse_batch_test.py
new file mode 100644
index 0000000000..73be6cbcca
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/dense_to_sparse_batch_test.py
@@ -0,0 +1,124 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.dense_to_sparse_batch()."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class DenseToSparseBatchTest(test_base.DatasetTestBase):
+
+ def testDenseToSparseBatchDataset(self):
+ components = np.random.randint(12, size=(100,)).astype(np.int32)
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components)
+ .map(lambda x: array_ops.fill([x], x)).apply(
+ batching.dense_to_sparse_batch(4, [12]))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+
+ for start in range(0, len(components), 4):
+ results = sess.run(get_next)
+ self.assertAllEqual([[i, j]
+ for i, c in enumerate(components[start:start + 4])
+ for j in range(c)], results.indices)
+ self.assertAllEqual(
+ [c for c in components[start:start + 4] for _ in range(c)],
+ results.values)
+ self.assertAllEqual([min(4,
+ len(components) - start), 12],
+ results.dense_shape)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testDenseToSparseBatchDatasetWithUnknownShape(self):
+ components = np.random.randint(5, size=(40,)).astype(np.int32)
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components)
+ .map(lambda x: array_ops.fill([x, x], x)).apply(
+ batching.dense_to_sparse_batch(
+ 4, [5, None])).make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+
+ for start in range(0, len(components), 4):
+ results = sess.run(get_next)
+ self.assertAllEqual([[i, j, z]
+ for i, c in enumerate(components[start:start + 4])
+ for j in range(c)
+ for z in range(c)], results.indices)
+ self.assertAllEqual([
+ c
+ for c in components[start:start + 4] for _ in range(c)
+ for _ in range(c)
+ ], results.values)
+ self.assertAllEqual([
+ min(4,
+ len(components) - start), 5,
+ np.max(components[start:start + 4])
+ ], results.dense_shape)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testDenseToSparseBatchDatasetWithInvalidShape(self):
+ input_tensor = array_ops.constant([[1]])
+ with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"):
+ dataset_ops.Dataset.from_tensors(input_tensor).apply(
+ batching.dense_to_sparse_batch(4, [-2])).make_initializable_iterator()
+
+ def testDenseToSparseBatchDatasetShapeErrors(self):
+ input_tensor = array_ops.placeholder(dtypes.int32)
+ iterator = (
+ dataset_ops.Dataset.from_tensors(input_tensor).apply(
+ batching.dense_to_sparse_batch(4, [12]))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ # Initialize with an input tensor of incompatible rank.
+ sess.run(init_op, feed_dict={input_tensor: [[1]]})
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "incompatible with the row shape"):
+ sess.run(get_next)
+
+ # Initialize with an input tensor that is larger than `row_shape`.
+ sess.run(init_op, feed_dict={input_tensor: range(13)})
+ with self.assertRaisesRegexp(errors.DataLossError,
+ "larger than the row shape"):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/range_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/enumerate_dataset_test.py
index 22412c3965..e54235d9f8 100644
--- a/tensorflow/python/data/experimental/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/enumerate_dataset_test.py
@@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Test RangeDataset."""
+"""Tests for `tf.data.experimental.enumerate_dataset()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.experimental.ops import counter
from tensorflow.python.data.experimental.ops import enumerate_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
@@ -28,7 +27,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import test
-class RangeDatasetTest(test_base.DatasetTestBase):
+class EnumerateDatasetTest(test_base.DatasetTestBase):
def testEnumerateDataset(self):
components = (["a", "b"], [1, 2], [37.0, 38])
@@ -52,27 +51,6 @@ class RangeDatasetTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def testCounter(self):
- """Test dataset construction using `count`."""
- iterator = (counter.Counter(start=3, step=4)
- .make_one_shot_iterator())
- get_next = iterator.get_next()
- self.assertEqual([], get_next.shape.as_list())
- self.assertEqual(dtypes.int64, get_next.dtype)
-
- negative_iterator = (counter.Counter(start=0, step=-1)
- .make_one_shot_iterator())
- negative_get_next = negative_iterator.get_next()
-
- with self.cached_session() as sess:
- self.assertEqual(3, sess.run(get_next))
- self.assertEqual(3 + 4, sess.run(get_next))
- self.assertEqual(3 + 2 * 4, sess.run(get_next))
-
- self.assertEqual(0, sess.run(negative_get_next))
- self.assertEqual(-1, sess.run(negative_get_next))
- self.assertEqual(-2, sess.run(negative_get_next))
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/function_buffering_resource_test.py b/tensorflow/python/data/experimental/kernel_tests/function_buffering_resource_test.py
new file mode 100644
index 0000000000..399fd284f4
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/function_buffering_resource_test.py
@@ -0,0 +1,247 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the private `FunctionBufferingResource` used in prefetching."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import threading
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.data.experimental.ops import prefetching_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.platform import test
+
+
+class FunctionBufferingResourceTest(test_base.DatasetTestBase):
+
+ def setUp(self):
+ self._event = threading.Event()
+
+ def _create_ds_and_iterator(self, device0, initializable=False):
+
+ def gen():
+ for i in range(1, 10):
+ yield [float(i)]
+ if i == 6:
+ self._event.set()
+
+ with ops.device(device0):
+ ds = dataset_ops.Dataset.from_generator(gen, (dtypes.float32))
+ if initializable:
+ ds_iterator = ds.make_initializable_iterator()
+ else:
+ ds_iterator = ds.make_one_shot_iterator()
+ return (ds, ds_iterator)
+
+ def _create_ops(self, ds, ds_iterator, buffer_name, device0, device1):
+ ds_iterator_handle = ds_iterator.string_handle()
+
+ @function.Defun(dtypes.string)
+ def _remote_fn(h):
+ remote_iterator = iterator_ops.Iterator.from_string_handle(
+ h, ds.output_types, ds.output_shapes)
+ return remote_iterator.get_next()
+
+ target = constant_op.constant(device0)
+ with ops.device(device1):
+ buffer_resource_handle = prefetching_ops.function_buffering_resource(
+ f=_remote_fn,
+ output_types=[dtypes.float32],
+ target_device=target,
+ string_arg=ds_iterator_handle,
+ buffer_size=3,
+ shared_name=buffer_name)
+
+ with ops.device(device1):
+ prefetch_op = prefetching_ops.function_buffering_resource_get_next(
+ function_buffer_resource=buffer_resource_handle,
+ output_types=[dtypes.float32])
+ reset_op = prefetching_ops.function_buffering_resource_reset(
+ function_buffer_resource=buffer_resource_handle)
+ destroy_op = resource_variable_ops.destroy_resource_op(
+ buffer_resource_handle, ignore_lookup_error=True)
+
+ return (prefetch_op, reset_op, destroy_op)
+
+ def _prefetch_fn_helper_one_shot(self, buffer_name, device0, device1):
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+
+ ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=False)
+ prefetch_op, _, destroy_op = self._create_ops(ds, ds_iterator, buffer_name,
+ device0, device1)
+
+ with self.test_session(config=worker_config) as sess:
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [1.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [2.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [3.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [4.0])
+ self._event.wait()
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [5.0])
+ sess.run(destroy_op)
+
+ def testSameDeviceCPU(self):
+ self._prefetch_fn_helper_one_shot("same_device_cpu",
+ "/job:localhost/replica:0/task:0/cpu:0",
+ "/job:localhost/replica:0/task:0/cpu:0")
+
+ def testDifferentDeviceCPU(self):
+ self._prefetch_fn_helper_one_shot("diff_device_cpu",
+ "/job:localhost/replica:0/task:0/cpu:0",
+ "/job:localhost/replica:0/task:0/cpu:1")
+
+ def testDifferentDeviceCPUGPU(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ self._prefetch_fn_helper_one_shot("cpu_gpu",
+ "/job:localhost/replica:0/task:0/cpu:0",
+ "/job:localhost/replica:0/task:0/gpu:0")
+
+ def testReinitialization(self):
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+
+ device0 = "/job:localhost/replica:0/task:0/cpu:0"
+ device1 = "/job:localhost/replica:0/task:0/cpu:1"
+ ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True)
+ prefetch_op, reset_op, destroy_op = self._create_ops(
+ ds, ds_iterator, "reinit", device0, device1)
+
+ with self.test_session(config=worker_config) as sess:
+ sess.run(ds_iterator.initializer)
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [1.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [2.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [3.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [4.0])
+ self._event.wait()
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [5.0])
+ # Lets reset the function buffering resource and reinitialize the
+ # iterator. Should be able to go through this again.
+ self._event.clear()
+ sess.run(reset_op)
+ sess.run(ds_iterator.initializer)
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [1.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [2.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [3.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [4.0])
+ self._event.wait()
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [5.0])
+ sess.run(destroy_op)
+
+ def testReinitializationOutOfRange(self):
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+
+ device0 = "/job:localhost/replica:0/task:0/cpu:0"
+ device1 = "/job:localhost/replica:0/task:0/cpu:1"
+ ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True)
+ prefetch_op, reset_op, destroy_op = self._create_ops(
+ ds, ds_iterator, "reinit", device0, device1)
+
+ with self.test_session(config=worker_config) as sess:
+ sess.run(ds_iterator.initializer)
+ for i in range(1, 10):
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [float(i)])
+ # Try fetching after its over twice to test out end of sequence.
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(prefetch_op)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(prefetch_op)
+
+ # Now reset everything and try it out again.
+ self._event.clear()
+ sess.run(reset_op)
+ sess.run(ds_iterator.initializer)
+ for i in range(1, 10):
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [float(i)])
+ # Try fetching after its over twice to test out end of sequence.
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(prefetch_op)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(prefetch_op)
+
+ sess.run(destroy_op)
+
+ def testStringsGPU(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ device0 = "/job:localhost/replica:0/task:0/cpu:0"
+ device1 = "/job:localhost/replica:0/task:0/gpu:0"
+
+ ds = dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"])
+ ds_iterator = ds.make_one_shot_iterator()
+ ds_iterator_handle = ds_iterator.string_handle()
+
+ @function.Defun(dtypes.string)
+ def _remote_fn(h):
+ remote_iterator = iterator_ops.Iterator.from_string_handle(
+ h, ds.output_types, ds.output_shapes)
+ return remote_iterator.get_next()
+
+ target = constant_op.constant(device0)
+ with ops.device(device1):
+ buffer_resource_handle = prefetching_ops.function_buffering_resource(
+ f=_remote_fn,
+ output_types=[dtypes.string],
+ target_device=target,
+ string_arg=ds_iterator_handle,
+ buffer_size=3,
+ shared_name="strings")
+
+ with ops.device(device1):
+ prefetch_op = prefetching_ops.function_buffering_resource_get_next(
+ function_buffer_resource=buffer_resource_handle,
+ output_types=[dtypes.string])
+ destroy_op = resource_variable_ops.destroy_resource_op(
+ buffer_resource_handle, ignore_lookup_error=True)
+
+ with self.cached_session() as sess:
+ self.assertEqual([b"a"], sess.run(prefetch_op))
+ self.assertEqual([b"b"], sess.run(prefetch_op))
+ self.assertEqual([b"c"], sess.run(prefetch_op))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(prefetch_op)
+
+ sess.run(destroy_op)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/group_by_reducer_test.py b/tensorflow/python/data/experimental/kernel_tests/group_by_reducer_test.py
new file mode 100644
index 0000000000..9030328593
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/group_by_reducer_test.py
@@ -0,0 +1,199 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.group_by_reducer()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import grouping
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class GroupByReducerTest(test_base.DatasetTestBase):
+
+ def checkResults(self, dataset, shapes, values):
+ self.assertEqual(shapes, dataset.output_shapes)
+ get_next = dataset.make_one_shot_iterator().get_next()
+ with self.cached_session() as sess:
+ for expected in values:
+ got = sess.run(get_next)
+ self.assertEqual(got, expected)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testSum(self):
+ reducer = grouping.Reducer(
+ init_func=lambda _: np.int64(0),
+ reduce_func=lambda x, y: x + y,
+ finalize_func=lambda x: x)
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.range(2 * i).apply(
+ grouping.group_by_reducer(lambda x: x % 2, reducer))
+ self.checkResults(
+ dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
+
+ def testAverage(self):
+
+ def reduce_fn(x, y):
+ return (x[0] * x[1] + math_ops.cast(y, dtypes.float32)) / (
+ x[1] + 1), x[1] + 1
+
+ reducer = grouping.Reducer(
+ init_func=lambda _: (0.0, 0.0),
+ reduce_func=reduce_fn,
+ finalize_func=lambda x, _: x)
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.range(2 * i).apply(
+ grouping.group_by_reducer(
+ lambda x: math_ops.cast(x, dtypes.int64) % 2, reducer))
+ self.checkResults(
+ dataset, shapes=tensor_shape.scalar(), values=[i - 1, i])
+
+ def testConcat(self):
+ components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray)
+ reducer = grouping.Reducer(
+ init_func=lambda x: "",
+ reduce_func=lambda x, y: x + y[0],
+ finalize_func=lambda x: x)
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.from_tensor_slices(components),
+ dataset_ops.Dataset.range(2 * i))).apply(
+ grouping.group_by_reducer(lambda x, y: y % 2, reducer))
+ self.checkResults(
+ dataset,
+ shapes=tensor_shape.scalar(),
+ values=[b"acegikmoqs" [:i], b"bdfhjlnprt" [:i]])
+
+ def testSparseSum(self):
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0]]),
+ values=(i * np.array([1], dtype=np.int64)),
+ dense_shape=np.array([1, 1]))
+
+ reducer = grouping.Reducer(
+ init_func=lambda _: _sparse(np.int64(0)),
+ reduce_func=lambda x, y: _sparse(x.values[0] + y.values[0]),
+ finalize_func=lambda x: x.values[0])
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.range(2 * i).map(_sparse).apply(
+ grouping.group_by_reducer(lambda x: x.values[0] % 2, reducer))
+ self.checkResults(
+ dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
+
+ def testChangingStateShape(self):
+
+ def reduce_fn(x, _):
+ # Statically known rank, but dynamic length.
+ larger_dim = array_ops.concat([x[0], x[0]], 0)
+ # Statically unknown rank.
+ larger_rank = array_ops.expand_dims(x[1], 0)
+ return larger_dim, larger_rank
+
+ reducer = grouping.Reducer(
+ init_func=lambda x: ([0], 1),
+ reduce_func=reduce_fn,
+ finalize_func=lambda x, y: (x, y))
+
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply(
+ grouping.group_by_reducer(lambda x: x, reducer))
+ self.assertEqual([None], dataset.output_shapes[0].as_list())
+ self.assertIs(None, dataset.output_shapes[1].ndims)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ x, y = sess.run(get_next)
+ self.assertAllEqual([0] * (2**i), x)
+ self.assertAllEqual(np.array(1, ndmin=i), y)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testTypeMismatch(self):
+ reducer = grouping.Reducer(
+ init_func=lambda x: constant_op.constant(1, dtype=dtypes.int32),
+ reduce_func=lambda x, y: constant_op.constant(1, dtype=dtypes.int64),
+ finalize_func=lambda x: x)
+
+ dataset = dataset_ops.Dataset.range(10)
+ with self.assertRaisesRegexp(
+ TypeError,
+ "The element types for the new state must match the initial state."):
+ dataset.apply(
+ grouping.group_by_reducer(lambda _: np.int64(0), reducer))
+
+ # TODO(b/78665031): Remove once non-scalar keys are supported.
+ def testInvalidKeyShape(self):
+ reducer = grouping.Reducer(
+ init_func=lambda x: np.int64(0),
+ reduce_func=lambda x, y: x + y,
+ finalize_func=lambda x: x)
+
+ dataset = dataset_ops.Dataset.range(10)
+ with self.assertRaisesRegexp(
+ ValueError, "`key_func` must return a single tf.int64 tensor."):
+ dataset.apply(
+ grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer))
+
+ # TODO(b/78665031): Remove once non-int64 keys are supported.
+ def testInvalidKeyType(self):
+ reducer = grouping.Reducer(
+ init_func=lambda x: np.int64(0),
+ reduce_func=lambda x, y: x + y,
+ finalize_func=lambda x: x)
+
+ dataset = dataset_ops.Dataset.range(10)
+ with self.assertRaisesRegexp(
+ ValueError, "`key_func` must return a single tf.int64 tensor."):
+ dataset.apply(
+ grouping.group_by_reducer(lambda _: "wrong", reducer))
+
+ def testTuple(self):
+ def init_fn(_):
+ return np.array([], dtype=np.int64), np.int64(0)
+
+ def reduce_fn(state, value):
+ s1, s2 = state
+ v1, v2 = value
+ return array_ops.concat([s1, [v1]], 0), s2 + v2
+
+ def finalize_fn(s1, s2):
+ return s1, s2
+
+ reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn)
+ dataset = dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply(
+ grouping.group_by_reducer(lambda x, y: np.int64(0), reducer))
+ get_next = dataset.make_one_shot_iterator().get_next()
+ with self.cached_session() as sess:
+ x, y = sess.run(get_next)
+ self.assertAllEqual(x, np.asarray([x for x in range(10)]))
+ self.assertEqual(y, 45)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/group_by_window_test.py b/tensorflow/python/data/experimental/kernel_tests/group_by_window_test.py
new file mode 100644
index 0000000000..557d56e8b9
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/group_by_window_test.py
@@ -0,0 +1,367 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.group_by_window()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import grouping
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+
+
+# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py.
+# Currently, they use a constant batch size, though should be made to use a
+# different batch size per key.
+class GroupByWindowTest(test_base.DatasetTestBase):
+
+ def _dynamicPad(self, bucket, window, window_size):
+ # TODO(mrry): To match `tf.contrib.training.bucket()`, implement a
+ # generic form of padded_batch that pads every component
+ # dynamically and does not rely on static shape information about
+ # the arguments.
+ return dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.from_tensors(bucket),
+ window.padded_batch(
+ 32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape(
+ [None]), tensor_shape.TensorShape([3])))))
+
+ def testSingleBucket(self):
+
+ def _map_fn(v):
+ return (v, array_ops.fill([v], v),
+ array_ops.fill([3], string_ops.as_string(v)))
+
+ input_dataset = (
+ dataset_ops.Dataset.from_tensor_slices(math_ops.range(32)).map(_map_fn))
+
+ bucketed_dataset = input_dataset.apply(
+ grouping.group_by_window(
+ lambda x, y, z: 0,
+ lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
+
+ iterator = bucketed_dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+
+ which_bucket, bucketed_values = sess.run(get_next)
+
+ self.assertEqual(0, which_bucket)
+
+ expected_scalar_int = np.arange(32, dtype=np.int64)
+ expected_unk_int64 = np.zeros((32, 31)).astype(np.int64)
+ for i in range(32):
+ expected_unk_int64[i, :i] = i
+ expected_vec3_str = np.vstack(3 * [np.arange(32).astype(bytes)]).T
+
+ self.assertAllEqual(expected_scalar_int, bucketed_values[0])
+ self.assertAllEqual(expected_unk_int64, bucketed_values[1])
+ self.assertAllEqual(expected_vec3_str, bucketed_values[2])
+
+ def testEvenOddBuckets(self):
+
+ def _map_fn(v):
+ return (v, array_ops.fill([v], v),
+ array_ops.fill([3], string_ops.as_string(v)))
+
+ input_dataset = (
+ dataset_ops.Dataset.from_tensor_slices(math_ops.range(64)).map(_map_fn))
+
+ bucketed_dataset = input_dataset.apply(
+ grouping.group_by_window(
+ lambda x, y, z: math_ops.cast(x % 2, dtypes.int64),
+ lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
+
+ iterator = bucketed_dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+
+ # Get two minibatches (one containing even values, one containing odds)
+ which_bucket_even, bucketed_values_even = sess.run(get_next)
+ which_bucket_odd, bucketed_values_odd = sess.run(get_next)
+
+ # Count number of bucket_tensors.
+ self.assertEqual(3, len(bucketed_values_even))
+ self.assertEqual(3, len(bucketed_values_odd))
+
+ # Ensure bucket 0 was used for all minibatch entries.
+ self.assertAllEqual(0, which_bucket_even)
+ self.assertAllEqual(1, which_bucket_odd)
+
+ # Test the first bucket outputted, the events starting at 0
+ expected_scalar_int = np.arange(0, 32 * 2, 2, dtype=np.int64)
+ expected_unk_int64 = np.zeros((32, 31 * 2)).astype(np.int64)
+ for i in range(0, 32):
+ expected_unk_int64[i, :2 * i] = 2 * i
+ expected_vec3_str = np.vstack(
+ 3 * [np.arange(0, 32 * 2, 2).astype(bytes)]).T
+
+ self.assertAllEqual(expected_scalar_int, bucketed_values_even[0])
+ self.assertAllEqual(expected_unk_int64, bucketed_values_even[1])
+ self.assertAllEqual(expected_vec3_str, bucketed_values_even[2])
+
+ # Test the second bucket outputted, the odds starting at 1
+ expected_scalar_int = np.arange(1, 32 * 2 + 1, 2, dtype=np.int64)
+ expected_unk_int64 = np.zeros((32, 31 * 2 + 1)).astype(np.int64)
+ for i in range(0, 32):
+ expected_unk_int64[i, :2 * i + 1] = 2 * i + 1
+ expected_vec3_str = np.vstack(
+ 3 * [np.arange(1, 32 * 2 + 1, 2).astype(bytes)]).T
+
+ self.assertAllEqual(expected_scalar_int, bucketed_values_odd[0])
+ self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1])
+ self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2])
+
+ def testEvenOddBucketsFilterOutAllOdd(self):
+
+ def _map_fn(v):
+ return {
+ "x": v,
+ "y": array_ops.fill([v], v),
+ "z": array_ops.fill([3], string_ops.as_string(v))
+ }
+
+ def _dynamic_pad_fn(bucket, window, _):
+ return dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.from_tensors(bucket),
+ window.padded_batch(
+ 32, {
+ "x": tensor_shape.TensorShape([]),
+ "y": tensor_shape.TensorShape([None]),
+ "z": tensor_shape.TensorShape([3])
+ })))
+
+ input_dataset = (
+ dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn)
+ .filter(lambda d: math_ops.equal(d["x"] % 2, 0)))
+
+ bucketed_dataset = input_dataset.apply(
+ grouping.group_by_window(
+ lambda d: math_ops.cast(d["x"] % 2, dtypes.int64),
+ lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32))
+
+ iterator = bucketed_dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+
+ # Get two minibatches ([0, 2, ...] and [64, 66, ...])
+ which_bucket0, bucketed_values_even0 = sess.run(get_next)
+ which_bucket1, bucketed_values_even1 = sess.run(get_next)
+
+ # Ensure that bucket 1 was completely filtered out
+ self.assertAllEqual(0, which_bucket0)
+ self.assertAllEqual(0, which_bucket1)
+ self.assertAllEqual(
+ np.arange(0, 64, 2, dtype=np.int64), bucketed_values_even0["x"])
+ self.assertAllEqual(
+ np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"])
+
+ def testDynamicWindowSize(self):
+ components = np.arange(100).astype(np.int64)
+
+ # Key fn: even/odd
+ # Reduce fn: batches of 5
+ # Window size fn: even=5, odd=10
+
+ def window_size_func(key):
+ window_sizes = constant_op.constant([5, 10], dtype=dtypes.int64)
+ return window_sizes[key]
+
+ dataset = dataset_ops.Dataset.from_tensor_slices(components).apply(
+ grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(20),
+ None, window_size_func))
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ with self.assertRaises(errors.OutOfRangeError):
+ batches = 0
+ while True:
+ result = sess.run(get_next)
+ is_even = all(x % 2 == 0 for x in result)
+ is_odd = all(x % 2 == 1 for x in result)
+ self.assertTrue(is_even or is_odd)
+ expected_batch_size = 5 if is_even else 10
+ self.assertEqual(expected_batch_size, result.shape[0])
+ batches += 1
+
+ self.assertEqual(batches, 15)
+
+ def testSimple(self):
+ components = np.random.randint(100, size=(200,)).astype(np.int64)
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components).map(lambda x: x * x)
+ .apply(
+ grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4),
+ 4)).make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ counts = []
+ with self.assertRaises(errors.OutOfRangeError):
+ while True:
+ result = sess.run(get_next)
+ self.assertTrue(
+ all(x % 2 == 0
+ for x in result) or all(x % 2 == 1)
+ for x in result)
+ counts.append(result.shape[0])
+
+ self.assertEqual(len(components), sum(counts))
+ num_full_batches = len([c for c in counts if c == 4])
+ self.assertGreaterEqual(num_full_batches, 24)
+ self.assertTrue(all(c == 4 for c in counts[:num_full_batches]))
+
+ def testImmediateOutput(self):
+ components = np.array(
+ [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply(
+ grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4),
+ 4)).make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ # The input is infinite, so this test demonstrates that:
+ # 1. We produce output without having to consume the entire input,
+ # 2. Different buckets can produce output at different rates, and
+ # 3. For deterministic input, the output is deterministic.
+ for _ in range(3):
+ self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
+ self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
+ self.assertAllEqual([2, 2, 2, 2], sess.run(get_next))
+ self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
+
+ def testSmallGroups(self):
+ components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components).apply(
+ grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4),
+ 4)).make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
+ self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
+ # The small outputs at the end are deterministically produced in key
+ # order.
+ self.assertAllEqual([0, 0, 0], sess.run(get_next))
+ self.assertAllEqual([1], sess.run(get_next))
+
+ def testEmpty(self):
+ iterator = (
+ dataset_ops.Dataset.range(4).apply(
+ grouping.group_by_window(lambda _: 0, lambda _, xs: xs, 0))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Window size must be greater than zero, but got 0."):
+ print(sess.run(get_next))
+
+ def testReduceFuncError(self):
+ components = np.random.randint(100, size=(200,)).astype(np.int64)
+
+ def reduce_func(_, xs):
+ # Introduce an incorrect padded shape that cannot (currently) be
+ # detected at graph construction time.
+ return xs.padded_batch(
+ 4,
+ padded_shapes=(tensor_shape.TensorShape([]),
+ constant_op.constant([5], dtype=dtypes.int64) * -1))
+
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components)
+ .map(lambda x: (x, ops.convert_to_tensor([x * x]))).apply(
+ grouping.group_by_window(lambda x, _: x % 2, reduce_func,
+ 32)).make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(get_next)
+
+ def testConsumeWindowDatasetMoreThanOnce(self):
+ components = np.random.randint(50, size=(200,)).astype(np.int64)
+
+ def reduce_func(key, window):
+ # Apply two different kinds of padding to the input: tight
+ # padding, and quantized (to a multiple of 10) padding.
+ return dataset_ops.Dataset.zip((
+ window.padded_batch(
+ 4, padded_shapes=tensor_shape.TensorShape([None])),
+ window.padded_batch(
+ 4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])),
+ ))
+
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components)
+ .map(lambda x: array_ops.fill([math_ops.cast(x, dtypes.int32)], x))
+ .apply(grouping.group_by_window(
+ lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64),
+ reduce_func, 4))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ counts = []
+ with self.assertRaises(errors.OutOfRangeError):
+ while True:
+ tight_result, multiple_of_10_result = sess.run(get_next)
+ self.assertEqual(0, multiple_of_10_result.shape[1] % 10)
+ self.assertAllEqual(tight_result,
+ multiple_of_10_result[:, :tight_result.shape[1]])
+ counts.append(tight_result.shape[0])
+ self.assertEqual(len(components), sum(counts))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/ignore_errors_test.py b/tensorflow/python/data/experimental/kernel_tests/ignore_errors_test.py
new file mode 100644
index 0000000000..c0ec1486ab
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/ignore_errors_test.py
@@ -0,0 +1,115 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.ignore_errors()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import error_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+_NUMPY_RANDOM_SEED = 42
+
+
+class IgnoreErrorsTest(test_base.DatasetTestBase):
+
+ def testMapIgnoreError(self):
+ components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
+
+ dataset = (
+ dataset_ops.Dataset.from_tensor_slices(components)
+ .map(lambda x: array_ops.check_numerics(x, "message")).apply(
+ error_ops.ignore_errors()))
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for x in [1., 2., 3., 5.]:
+ self.assertEqual(x, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testParallelMapIgnoreError(self):
+ components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
+
+ dataset = (
+ dataset_ops.Dataset.from_tensor_slices(components).map(
+ lambda x: array_ops.check_numerics(x, "message"),
+ num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors()))
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for x in [1., 2., 3., 5.]:
+ self.assertEqual(x, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testReadFileIgnoreError(self):
+
+ def write_string_to_file(value, filename):
+ with open(filename, "w") as f:
+ f.write(value)
+
+ filenames = [
+ os.path.join(self.get_temp_dir(), "file_%d.txt" % i) for i in range(5)
+ ]
+ for filename in filenames:
+ write_string_to_file(filename, filename)
+
+ dataset = (
+ dataset_ops.Dataset.from_tensor_slices(filenames).map(
+ io_ops.read_file,
+ num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors()))
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ # All of the files are present.
+ sess.run(init_op)
+ for filename in filenames:
+ self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Delete one of the files.
+ os.remove(filenames[0])
+
+ # Attempting to read filenames[0] will fail, but ignore_errors()
+ # will catch the error.
+ sess.run(init_op)
+ for filename in filenames[1:]:
+ self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py
new file mode 100644
index 0000000000..5ee94e14dc
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py
@@ -0,0 +1,239 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.make_batched_features_dataset()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.platform import test
+
+
+class MakeBatchedFeaturesDatasetTest(
+ reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase):
+
+ def testRead(self):
+ for batch_size in [1, 2]:
+ for num_epochs in [1, 10]:
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ # Basic test: read from file 0.
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size).make_one_shot_iterator().get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ 0,
+ num_epochs=num_epochs,
+ label_key_provided=True)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess, label_key_provided=True)
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ # Basic test: read from file 1.
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames[1],
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size).make_one_shot_iterator().get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ 1,
+ num_epochs=num_epochs,
+ label_key_provided=True)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess, label_key_provided=True)
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ # Basic test: read from both files.
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames,
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size).make_one_shot_iterator().get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ num_epochs=num_epochs,
+ label_key_provided=True)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess, label_key_provided=True)
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ # Basic test: read from both files.
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames,
+ num_epochs=num_epochs,
+ batch_size=batch_size).make_one_shot_iterator().get_next()
+ self.verify_records(sess, batch_size, num_epochs=num_epochs)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess)
+
+ def testReadWithEquivalentDataset(self):
+ features = {
+ "file": parsing_ops.FixedLenFeature([], dtypes.int64),
+ "record": parsing_ops.FixedLenFeature([], dtypes.int64),
+ }
+ dataset = (
+ core_readers.TFRecordDataset(self.test_filenames)
+ .map(lambda x: parsing_ops.parse_single_example(x, features))
+ .repeat(10).batch(2))
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
+ range(self._num_files), 2, 10):
+ actual_batch = sess.run(next_element)
+ self.assertAllEqual(file_batch, actual_batch["file"])
+ self.assertAllEqual(record_batch, actual_batch["record"])
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testReadWithFusedShuffleRepeatDataset(self):
+ num_epochs = 5
+ total_records = num_epochs * self._num_records
+ for batch_size in [1, 2]:
+ # Test that shuffling with same seed produces the same result.
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ outputs1 = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ shuffle=True,
+ shuffle_seed=5).make_one_shot_iterator().get_next()
+ outputs2 = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ shuffle=True,
+ shuffle_seed=5).make_one_shot_iterator().get_next()
+ for _ in range(total_records // batch_size):
+ batch1 = self._run_actual_batch(outputs1, sess)
+ batch2 = self._run_actual_batch(outputs2, sess)
+ for i in range(len(batch1)):
+ self.assertAllEqual(batch1[i], batch2[i])
+
+ # Test that shuffling with different seeds produces a different order.
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ outputs1 = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ shuffle=True,
+ shuffle_seed=5).make_one_shot_iterator().get_next()
+ outputs2 = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ shuffle=True,
+ shuffle_seed=15).make_one_shot_iterator().get_next()
+ all_equal = True
+ for _ in range(total_records // batch_size):
+ batch1 = self._run_actual_batch(outputs1, sess)
+ batch2 = self._run_actual_batch(outputs2, sess)
+ for i in range(len(batch1)):
+ all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
+ self.assertFalse(all_equal)
+
+ def testParallelReadersAndParsers(self):
+ num_epochs = 5
+ for batch_size in [1, 2]:
+ for reader_num_threads in [2, 4]:
+ for parser_num_threads in [2, 4]:
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames,
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ reader_num_threads=reader_num_threads,
+ parser_num_threads=parser_num_threads).make_one_shot_iterator(
+ ).get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ num_epochs=num_epochs,
+ label_key_provided=True,
+ interleave_cycle_length=reader_num_threads)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess, label_key_provided=True)
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames,
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ reader_num_threads=reader_num_threads,
+ parser_num_threads=parser_num_threads).make_one_shot_iterator(
+ ).get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ num_epochs=num_epochs,
+ interleave_cycle_length=reader_num_threads)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess)
+
+ def testDropFinalBatch(self):
+ for batch_size in [1, 2]:
+ for num_epochs in [1, 10]:
+ with ops.Graph().as_default():
+ # Basic test: read from file 0.
+ outputs = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ drop_final_batch=True).make_one_shot_iterator().get_next()
+ for tensor in nest.flatten(outputs):
+ if isinstance(tensor, ops.Tensor): # Guard against SparseTensor.
+ self.assertEqual(tensor.shape[0], batch_size)
+
+ def testIndefiniteRepeatShapeInference(self):
+ dataset = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ label_key="label",
+ num_epochs=None,
+ batch_size=32)
+ for shape, clazz in zip(nest.flatten(dataset.output_shapes),
+ nest.flatten(dataset.output_classes)):
+ if issubclass(clazz, ops.Tensor):
+ self.assertEqual(32, shape[0])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/make_csv_dataset_test.py
index a02f4bd14f..e4bf089184 100644
--- a/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/make_csv_dataset_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for `tf.data.experimental.make_csv_dataset()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -23,226 +23,16 @@ import zlib
import numpy as np
-from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.kernel_tests import test_base
-from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.data.util import nest
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
-from tensorflow.python.ops import parsing_ops
-from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
-class ReadBatchFeaturesTest(
- reader_dataset_ops_test_base.ReadBatchFeaturesTestBase):
-
- def testRead(self):
- for batch_size in [1, 2]:
- for num_epochs in [1, 10]:
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- # Basic test: read from file 0.
- self.outputs = self.make_batch_feature(
- filenames=self.test_filenames[0],
- label_key="label",
- num_epochs=num_epochs,
- batch_size=batch_size).make_one_shot_iterator().get_next()
- self.verify_records(
- sess,
- batch_size,
- 0,
- num_epochs=num_epochs,
- label_key_provided=True)
- with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess, label_key_provided=True)
-
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- # Basic test: read from file 1.
- self.outputs = self.make_batch_feature(
- filenames=self.test_filenames[1],
- label_key="label",
- num_epochs=num_epochs,
- batch_size=batch_size).make_one_shot_iterator().get_next()
- self.verify_records(
- sess,
- batch_size,
- 1,
- num_epochs=num_epochs,
- label_key_provided=True)
- with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess, label_key_provided=True)
-
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- # Basic test: read from both files.
- self.outputs = self.make_batch_feature(
- filenames=self.test_filenames,
- label_key="label",
- num_epochs=num_epochs,
- batch_size=batch_size).make_one_shot_iterator().get_next()
- self.verify_records(
- sess,
- batch_size,
- num_epochs=num_epochs,
- label_key_provided=True)
- with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess, label_key_provided=True)
-
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- # Basic test: read from both files.
- self.outputs = self.make_batch_feature(
- filenames=self.test_filenames,
- num_epochs=num_epochs,
- batch_size=batch_size).make_one_shot_iterator().get_next()
- self.verify_records(sess, batch_size, num_epochs=num_epochs)
- with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess)
-
- def testReadWithEquivalentDataset(self):
- features = {
- "file": parsing_ops.FixedLenFeature([], dtypes.int64),
- "record": parsing_ops.FixedLenFeature([], dtypes.int64),
- }
- dataset = (
- core_readers.TFRecordDataset(self.test_filenames)
- .map(lambda x: parsing_ops.parse_single_example(x, features))
- .repeat(10).batch(2))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
- range(self._num_files), 2, 10):
- actual_batch = sess.run(next_element)
- self.assertAllEqual(file_batch, actual_batch["file"])
- self.assertAllEqual(record_batch, actual_batch["record"])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testReadWithFusedShuffleRepeatDataset(self):
- num_epochs = 5
- total_records = num_epochs * self._num_records
- for batch_size in [1, 2]:
- # Test that shuffling with same seed produces the same result.
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- outputs1 = self.make_batch_feature(
- filenames=self.test_filenames[0],
- num_epochs=num_epochs,
- batch_size=batch_size,
- shuffle=True,
- shuffle_seed=5).make_one_shot_iterator().get_next()
- outputs2 = self.make_batch_feature(
- filenames=self.test_filenames[0],
- num_epochs=num_epochs,
- batch_size=batch_size,
- shuffle=True,
- shuffle_seed=5).make_one_shot_iterator().get_next()
- for _ in range(total_records // batch_size):
- batch1 = self._run_actual_batch(outputs1, sess)
- batch2 = self._run_actual_batch(outputs2, sess)
- for i in range(len(batch1)):
- self.assertAllEqual(batch1[i], batch2[i])
-
- # Test that shuffling with different seeds produces a different order.
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- outputs1 = self.make_batch_feature(
- filenames=self.test_filenames[0],
- num_epochs=num_epochs,
- batch_size=batch_size,
- shuffle=True,
- shuffle_seed=5).make_one_shot_iterator().get_next()
- outputs2 = self.make_batch_feature(
- filenames=self.test_filenames[0],
- num_epochs=num_epochs,
- batch_size=batch_size,
- shuffle=True,
- shuffle_seed=15).make_one_shot_iterator().get_next()
- all_equal = True
- for _ in range(total_records // batch_size):
- batch1 = self._run_actual_batch(outputs1, sess)
- batch2 = self._run_actual_batch(outputs2, sess)
- for i in range(len(batch1)):
- all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
- self.assertFalse(all_equal)
-
- def testParallelReadersAndParsers(self):
- num_epochs = 5
- for batch_size in [1, 2]:
- for reader_num_threads in [2, 4]:
- for parser_num_threads in [2, 4]:
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- self.outputs = self.make_batch_feature(
- filenames=self.test_filenames,
- label_key="label",
- num_epochs=num_epochs,
- batch_size=batch_size,
- reader_num_threads=reader_num_threads,
- parser_num_threads=parser_num_threads).make_one_shot_iterator(
- ).get_next()
- self.verify_records(
- sess,
- batch_size,
- num_epochs=num_epochs,
- label_key_provided=True,
- interleave_cycle_length=reader_num_threads)
- with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess, label_key_provided=True)
-
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- self.outputs = self.make_batch_feature(
- filenames=self.test_filenames,
- num_epochs=num_epochs,
- batch_size=batch_size,
- reader_num_threads=reader_num_threads,
- parser_num_threads=parser_num_threads).make_one_shot_iterator(
- ).get_next()
- self.verify_records(
- sess,
- batch_size,
- num_epochs=num_epochs,
- interleave_cycle_length=reader_num_threads)
- with self.assertRaises(errors.OutOfRangeError):
- self._next_actual_batch(sess)
-
- def testDropFinalBatch(self):
- for batch_size in [1, 2]:
- for num_epochs in [1, 10]:
- with ops.Graph().as_default():
- # Basic test: read from file 0.
- outputs = self.make_batch_feature(
- filenames=self.test_filenames[0],
- label_key="label",
- num_epochs=num_epochs,
- batch_size=batch_size,
- drop_final_batch=True).make_one_shot_iterator().get_next()
- for tensor in nest.flatten(outputs):
- if isinstance(tensor, ops.Tensor): # Guard against SparseTensor.
- self.assertEqual(tensor.shape[0], batch_size)
-
- def testIndefiniteRepeatShapeInference(self):
- dataset = self.make_batch_feature(
- filenames=self.test_filenames[0],
- label_key="label",
- num_epochs=None,
- batch_size=32)
- for shape, clazz in zip(nest.flatten(dataset.output_shapes),
- nest.flatten(dataset.output_classes)):
- if issubclass(clazz, ops.Tensor):
- self.assertEqual(32, shape[0])
-
-
class MakeCsvDatasetTest(test_base.DatasetTestBase):
def _make_csv_dataset(self, filenames, batch_size, num_epochs=1, **kwargs):
@@ -866,218 +656,5 @@ class MakeCsvDatasetTest(test_base.DatasetTestBase):
self.assertEqual(32, shape[0])
-class MakeTFRecordDatasetTest(
- reader_dataset_ops_test_base.TFRecordDatasetTestBase):
-
- def _interleave(self, iterators, cycle_length):
- pending_iterators = iterators
- open_iterators = []
- num_open = 0
- for i in range(cycle_length):
- if pending_iterators:
- open_iterators.append(pending_iterators.pop(0))
- num_open += 1
-
- while num_open:
- for i in range(min(cycle_length, len(open_iterators))):
- if open_iterators[i] is None:
- continue
- try:
- yield next(open_iterators[i])
- except StopIteration:
- if pending_iterators:
- open_iterators[i] = pending_iterators.pop(0)
- else:
- open_iterators[i] = None
- num_open -= 1
-
- def _next_expected_batch(self,
- file_indices,
- batch_size,
- num_epochs,
- cycle_length,
- drop_final_batch,
- use_parser_fn):
-
- def _next_record(file_indices):
- for j in file_indices:
- for i in range(self._num_records):
- yield j, i
-
- def _next_record_interleaved(file_indices, cycle_length):
- return self._interleave([_next_record([i]) for i in file_indices],
- cycle_length)
-
- record_batch = []
- batch_index = 0
- for _ in range(num_epochs):
- if cycle_length == 1:
- next_records = _next_record(file_indices)
- else:
- next_records = _next_record_interleaved(file_indices, cycle_length)
- for f, r in next_records:
- record = self._record(f, r)
- if use_parser_fn:
- record = record[1:]
- record_batch.append(record)
- batch_index += 1
- if len(record_batch) == batch_size:
- yield record_batch
- record_batch = []
- batch_index = 0
- if record_batch and not drop_final_batch:
- yield record_batch
-
- def _verify_records(self,
- sess,
- outputs,
- batch_size,
- file_index,
- num_epochs,
- interleave_cycle_length,
- drop_final_batch,
- use_parser_fn):
- if file_index is not None:
- file_indices = [file_index]
- else:
- file_indices = range(self._num_files)
-
- for expected_batch in self._next_expected_batch(
- file_indices, batch_size, num_epochs, interleave_cycle_length,
- drop_final_batch, use_parser_fn):
- actual_batch = sess.run(outputs)
- self.assertAllEqual(expected_batch, actual_batch)
-
- def _read_test(self, batch_size, num_epochs, file_index=None,
- num_parallel_reads=1, drop_final_batch=False, parser_fn=False):
- if file_index is None:
- file_pattern = self.test_filenames
- else:
- file_pattern = self.test_filenames[file_index]
-
- if parser_fn:
- fn = lambda x: string_ops.substr(x, 1, 999)
- else:
- fn = None
-
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- outputs = readers.make_tf_record_dataset(
- file_pattern=file_pattern,
- num_epochs=num_epochs,
- batch_size=batch_size,
- parser_fn=fn,
- num_parallel_reads=num_parallel_reads,
- drop_final_batch=drop_final_batch,
- shuffle=False).make_one_shot_iterator().get_next()
- self._verify_records(
- sess, outputs, batch_size, file_index, num_epochs=num_epochs,
- interleave_cycle_length=num_parallel_reads,
- drop_final_batch=drop_final_batch, use_parser_fn=parser_fn)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(outputs)
-
- def testRead(self):
- for batch_size in [1, 2]:
- for num_epochs in [1, 3]:
- # Basic test: read from file 0.
- self._read_test(batch_size, num_epochs, 0)
-
- # Basic test: read from file 1.
- self._read_test(batch_size, num_epochs, 1)
-
- # Basic test: read from both files.
- self._read_test(batch_size, num_epochs)
-
- # Basic test: read from both files, with parallel reads.
- self._read_test(batch_size, num_epochs, num_parallel_reads=8)
-
- def testDropFinalBatch(self):
- for batch_size in [1, 2, 10]:
- for num_epochs in [1, 3]:
- # Read from file 0.
- self._read_test(batch_size, num_epochs, 0, drop_final_batch=True)
-
- # Read from both files.
- self._read_test(batch_size, num_epochs, drop_final_batch=True)
-
- # Read from both files, with parallel reads.
- self._read_test(batch_size, num_epochs, num_parallel_reads=8,
- drop_final_batch=True)
-
- def testParserFn(self):
- for batch_size in [1, 2]:
- for num_epochs in [1, 3]:
- for drop_final_batch in [False, True]:
- self._read_test(batch_size, num_epochs, parser_fn=True,
- drop_final_batch=drop_final_batch)
- self._read_test(batch_size, num_epochs, num_parallel_reads=8,
- parser_fn=True, drop_final_batch=drop_final_batch)
-
- def _shuffle_test(self, batch_size, num_epochs, num_parallel_reads=1,
- seed=None):
- with ops.Graph().as_default() as g:
- with self.session(graph=g) as sess:
- dataset = readers.make_tf_record_dataset(
- file_pattern=self.test_filenames,
- num_epochs=num_epochs,
- batch_size=batch_size,
- num_parallel_reads=num_parallel_reads,
- shuffle=True,
- shuffle_seed=seed)
- iterator = dataset.make_initializable_iterator()
- next_element = iterator.get_next()
-
- sess.run(iterator.initializer)
- first_batches = []
- try:
- while True:
- first_batches.append(sess.run(next_element))
- except errors.OutOfRangeError:
- pass
-
- sess.run(iterator.initializer)
- second_batches = []
- try:
- while True:
- second_batches.append(sess.run(next_element))
- except errors.OutOfRangeError:
- pass
-
- self.assertEqual(len(first_batches), len(second_batches))
- if seed is not None:
- # if you set a seed, should get the same results
- for i in range(len(first_batches)):
- self.assertAllEqual(first_batches[i], second_batches[i])
-
- expected = []
- for f in range(self._num_files):
- for r in range(self._num_records):
- expected.extend([self._record(f, r)] * num_epochs)
-
- for batches in (first_batches, second_batches):
- actual = []
- for b in batches:
- actual.extend(b)
- self.assertAllEqual(sorted(expected), sorted(actual))
-
- def testShuffle(self):
- for batch_size in [1, 2]:
- for num_epochs in [1, 3]:
- for num_parallel_reads in [1, 2]:
- # Test that all expected elements are produced
- self._shuffle_test(batch_size, num_epochs, num_parallel_reads)
- # Test that elements are produced in a consistent order if
- # you specify a seed.
- self._shuffle_test(batch_size, num_epochs, num_parallel_reads,
- seed=21345)
-
- def testIndefiniteRepeatShapeInference(self):
- dataset = readers.make_tf_record_dataset(
- file_pattern=self.test_filenames, num_epochs=None, batch_size=32)
- for shape in nest.flatten(dataset.output_shapes):
- self.assertEqual(32, shape[0])
-
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/make_tf_record_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/make_tf_record_dataset_test.py
new file mode 100644
index 0000000000..657cf3c00e
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/make_tf_record_dataset_test.py
@@ -0,0 +1,243 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.make_tf_record_dataset()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.python.data.experimental.ops import readers
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+
+
+class MakeTFRecordDatasetTest(
+ reader_dataset_ops_test_base.TFRecordDatasetTestBase):
+
+ def _interleave(self, iterators, cycle_length):
+ pending_iterators = iterators
+ open_iterators = []
+ num_open = 0
+ for i in range(cycle_length):
+ if pending_iterators:
+ open_iterators.append(pending_iterators.pop(0))
+ num_open += 1
+
+ while num_open:
+ for i in range(min(cycle_length, len(open_iterators))):
+ if open_iterators[i] is None:
+ continue
+ try:
+ yield next(open_iterators[i])
+ except StopIteration:
+ if pending_iterators:
+ open_iterators[i] = pending_iterators.pop(0)
+ else:
+ open_iterators[i] = None
+ num_open -= 1
+
+ def _next_expected_batch(self,
+ file_indices,
+ batch_size,
+ num_epochs,
+ cycle_length,
+ drop_final_batch,
+ use_parser_fn):
+
+ def _next_record(file_indices):
+ for j in file_indices:
+ for i in range(self._num_records):
+ yield j, i
+
+ def _next_record_interleaved(file_indices, cycle_length):
+ return self._interleave([_next_record([i]) for i in file_indices],
+ cycle_length)
+
+ record_batch = []
+ batch_index = 0
+ for _ in range(num_epochs):
+ if cycle_length == 1:
+ next_records = _next_record(file_indices)
+ else:
+ next_records = _next_record_interleaved(file_indices, cycle_length)
+ for f, r in next_records:
+ record = self._record(f, r)
+ if use_parser_fn:
+ record = record[1:]
+ record_batch.append(record)
+ batch_index += 1
+ if len(record_batch) == batch_size:
+ yield record_batch
+ record_batch = []
+ batch_index = 0
+ if record_batch and not drop_final_batch:
+ yield record_batch
+
+ def _verify_records(self,
+ sess,
+ outputs,
+ batch_size,
+ file_index,
+ num_epochs,
+ interleave_cycle_length,
+ drop_final_batch,
+ use_parser_fn):
+ if file_index is not None:
+ file_indices = [file_index]
+ else:
+ file_indices = range(self._num_files)
+
+ for expected_batch in self._next_expected_batch(
+ file_indices, batch_size, num_epochs, interleave_cycle_length,
+ drop_final_batch, use_parser_fn):
+ actual_batch = sess.run(outputs)
+ self.assertAllEqual(expected_batch, actual_batch)
+
+ def _read_test(self, batch_size, num_epochs, file_index=None,
+ num_parallel_reads=1, drop_final_batch=False, parser_fn=False):
+ if file_index is None:
+ file_pattern = self.test_filenames
+ else:
+ file_pattern = self.test_filenames[file_index]
+
+ if parser_fn:
+ fn = lambda x: string_ops.substr(x, 1, 999)
+ else:
+ fn = None
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ outputs = readers.make_tf_record_dataset(
+ file_pattern=file_pattern,
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ parser_fn=fn,
+ num_parallel_reads=num_parallel_reads,
+ drop_final_batch=drop_final_batch,
+ shuffle=False).make_one_shot_iterator().get_next()
+ self._verify_records(
+ sess, outputs, batch_size, file_index, num_epochs=num_epochs,
+ interleave_cycle_length=num_parallel_reads,
+ drop_final_batch=drop_final_batch, use_parser_fn=parser_fn)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(outputs)
+
+ def testRead(self):
+ for batch_size in [1, 2]:
+ for num_epochs in [1, 3]:
+ # Basic test: read from file 0.
+ self._read_test(batch_size, num_epochs, 0)
+
+ # Basic test: read from file 1.
+ self._read_test(batch_size, num_epochs, 1)
+
+ # Basic test: read from both files.
+ self._read_test(batch_size, num_epochs)
+
+ # Basic test: read from both files, with parallel reads.
+ self._read_test(batch_size, num_epochs, num_parallel_reads=8)
+
+ def testDropFinalBatch(self):
+ for batch_size in [1, 2, 10]:
+ for num_epochs in [1, 3]:
+ # Read from file 0.
+ self._read_test(batch_size, num_epochs, 0, drop_final_batch=True)
+
+ # Read from both files.
+ self._read_test(batch_size, num_epochs, drop_final_batch=True)
+
+ # Read from both files, with parallel reads.
+ self._read_test(batch_size, num_epochs, num_parallel_reads=8,
+ drop_final_batch=True)
+
+ def testParserFn(self):
+ for batch_size in [1, 2]:
+ for num_epochs in [1, 3]:
+ for drop_final_batch in [False, True]:
+ self._read_test(batch_size, num_epochs, parser_fn=True,
+ drop_final_batch=drop_final_batch)
+ self._read_test(batch_size, num_epochs, num_parallel_reads=8,
+ parser_fn=True, drop_final_batch=drop_final_batch)
+
+ def _shuffle_test(self, batch_size, num_epochs, num_parallel_reads=1,
+ seed=None):
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ dataset = readers.make_tf_record_dataset(
+ file_pattern=self.test_filenames,
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ num_parallel_reads=num_parallel_reads,
+ shuffle=True,
+ shuffle_seed=seed)
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ sess.run(iterator.initializer)
+ first_batches = []
+ try:
+ while True:
+ first_batches.append(sess.run(next_element))
+ except errors.OutOfRangeError:
+ pass
+
+ sess.run(iterator.initializer)
+ second_batches = []
+ try:
+ while True:
+ second_batches.append(sess.run(next_element))
+ except errors.OutOfRangeError:
+ pass
+
+ self.assertEqual(len(first_batches), len(second_batches))
+ if seed is not None:
+ # if you set a seed, should get the same results
+ for i in range(len(first_batches)):
+ self.assertAllEqual(first_batches[i], second_batches[i])
+
+ expected = []
+ for f in range(self._num_files):
+ for r in range(self._num_records):
+ expected.extend([self._record(f, r)] * num_epochs)
+
+ for batches in (first_batches, second_batches):
+ actual = []
+ for b in batches:
+ actual.extend(b)
+ self.assertAllEqual(sorted(expected), sorted(actual))
+
+ def testShuffle(self):
+ for batch_size in [1, 2]:
+ for num_epochs in [1, 3]:
+ for num_parallel_reads in [1, 2]:
+ # Test that all expected elements are produced
+ self._shuffle_test(batch_size, num_epochs, num_parallel_reads)
+ # Test that elements are produced in a consistent order if
+ # you specify a seed.
+ self._shuffle_test(batch_size, num_epochs, num_parallel_reads,
+ seed=21345)
+
+ def testIndefiniteRepeatShapeInference(self):
+ dataset = readers.make_tf_record_dataset(
+ file_pattern=self.test_filenames, num_epochs=None, batch_size=32)
+ for shape in nest.flatten(dataset.output_shapes):
+ self.assertEqual(32, shape[0])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
new file mode 100644
index 0000000000..d444c4082e
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
@@ -0,0 +1,368 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.map_and_batch()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import script_ops
+from tensorflow.python.platform import test
+
+
+class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ("Default", None, None),
+ ("SequentialCalls", 1, None),
+ ("ParallelCalls", 2, None),
+ ("ParallelBatches", None, 10),
+ )
+ def testMapAndBatch(self, num_parallel_calls, num_parallel_batches):
+ """Test a dataset that maps a TF function across its input elements."""
+ # The pipeline is TensorSliceDataset ->
+ # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size).
+ components = (np.arange(7),
+ np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
+ np.array(37.0) * np.arange(7))
+
+ count = array_ops.placeholder(dtypes.int64, shape=[])
+ batch_size = array_ops.placeholder(dtypes.int64, shape=[])
+
+ def _map_fn(x, y, z):
+ return math_ops.square(x), math_ops.square(y), math_ops.square(z)
+
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components).repeat(count).apply(
+ batching.map_and_batch(
+ map_func=_map_fn,
+ batch_size=batch_size,
+ num_parallel_calls=num_parallel_calls,
+ num_parallel_batches=num_parallel_batches))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ self.assertEqual([[None] + list(c.shape[1:]) for c in components],
+ [t.shape.as_list() for t in get_next])
+
+ with self.cached_session() as sess:
+ # Batch of a finite input, where the batch_size divides the
+ # total number of elements.
+ sess.run(init_op, feed_dict={count: 28, batch_size: 14})
+ num_batches = (28 * 7) // 14
+ for i in range(num_batches):
+ result = sess.run(get_next)
+ for component, result_component in zip(components, result):
+ for j in range(14):
+ self.assertAllEqual(component[(i * 14 + j) % 7]**2,
+ result_component[j])
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Batch of a finite input, where the batch_size does not
+ # divide the total number of elements.
+ sess.run(init_op, feed_dict={count: 14, batch_size: 8})
+
+ # We expect (num_batches - 1) full-sized batches.
+ num_batches = int(math.ceil((14 * 7) / 8))
+ for i in range(num_batches - 1):
+ result = sess.run(get_next)
+ for component, result_component in zip(components, result):
+ for j in range(8):
+ self.assertAllEqual(component[(i * 8 + j) % 7]**2,
+ result_component[j])
+ result = sess.run(get_next)
+ for component, result_component in zip(components, result):
+ for j in range((14 * 7) % 8):
+ self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
+ result_component[j])
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Batch of an empty input should fail straight away.
+ sess.run(init_op, feed_dict={count: 0, batch_size: 8})
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Empty batch should be an initialization time error.
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(init_op, feed_dict={count: 14, batch_size: 0})
+
+ @parameterized.named_parameters(
+ ("Even", False),
+ ("Uneven", True),
+ )
+ def testMapAndBatchPartialBatch(self, drop_remainder):
+ iterator = (
+ dataset_ops.Dataset.range(10).apply(
+ batching.map_and_batch(
+ lambda x: array_ops.reshape(x * x, [1]),
+ batch_size=4,
+ drop_remainder=drop_remainder)).make_one_shot_iterator())
+ if drop_remainder:
+ self.assertEqual([4, 1], iterator.output_shapes.as_list())
+ else:
+ self.assertEqual([None, 1], iterator.output_shapes.as_list())
+ next_element = iterator.get_next()
+ with self.cached_session() as sess:
+ self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
+ self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
+ if not drop_remainder:
+ self.assertAllEqual([[64], [81]], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testMapAndBatchYieldsPartialBatch(self):
+ iterator = (dataset_ops.Dataset.range(10)
+ .apply(batching.map_and_batch(
+ lambda x: array_ops.reshape(x * x, [1]), 4))
+ .make_one_shot_iterator())
+ self.assertEqual([None, 1], iterator.output_shapes.as_list())
+ next_element = iterator.get_next()
+ with self.cached_session() as sess:
+ self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
+ self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
+ self.assertAllEqual([[64], [81]], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testMapAndBatchParallelGetNext(self):
+ iterator = (dataset_ops.Dataset.range(50000)
+ .apply(batching.map_and_batch(lambda x: x, batch_size=100))
+ .make_one_shot_iterator())
+ elements = []
+ for _ in range(100):
+ elements.append(iterator.get_next())
+ with self.cached_session() as sess:
+ for i in range(5):
+ got = sess.run(elements)
+ got.sort(key=lambda x: x[0])
+ expected = []
+ for j in range(100):
+ expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
+ self.assertAllEqual(got, expected)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elements)
+
+ def testMapAndBatchParallelGetNextDropRemainder(self):
+ iterator = (
+ dataset_ops.Dataset.range(49999).apply(
+ batching.map_and_batch(
+ lambda x: x, batch_size=100, drop_remainder=True))
+ .make_one_shot_iterator())
+ elements = []
+ for _ in range(100):
+ elements.append(iterator.get_next())
+ with self.cached_session() as sess:
+ for i in range(4):
+ got = sess.run(elements)
+ got.sort(key=lambda x: x[0])
+ expected = []
+ for j in range(100):
+ expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
+ self.assertAllEqual(got, expected)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elements)
+
+ def testMapAndBatchSparse(self):
+
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0]], values=(i * [1]), dense_shape=[1])
+
+ iterator = dataset_ops.Dataset.range(10).apply(
+ batching.map_and_batch(_sparse, 5)).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for i in range(2):
+ actual = sess.run(get_next)
+ expected = sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
+ values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
+ dense_shape=[5, 1])
+ self.assertTrue(sparse_tensor.is_sparse(actual))
+ self.assertSparseValuesEqual(actual, expected)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testMapAndBatchFails(self):
+ """Test a dataset that maps a TF function across its input elements."""
+ dataset = dataset_ops.Dataset.from_tensors(
+ array_ops.check_numerics(
+ constant_op.constant(1.0) / constant_op.constant(0.0), "oops"))
+ batch_size = array_ops.placeholder(dtypes.int64, shape=[])
+ iterator = (
+ dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ with self.cached_session() as sess:
+ with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
+ sess.run(init_op, feed_dict={batch_size: 14})
+
+ def testMapAndBatchShapeMismatch(self):
+ """Test a dataset that maps a TF function across its input elements."""
+
+ def generator():
+ yield [1]
+ yield [2]
+ yield [3]
+ yield [[4, 5, 6]]
+
+ dataset = dataset_ops.Dataset.from_generator(
+ generator, output_types=dtypes.int32)
+ batch_size = 4
+ iterator = (
+ dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "number of elements does not match"):
+ sess.run(get_next)
+
+ def testMapAndBatchImplicitDispose(self):
+ # Tests whether a map and batch dataset will be cleaned up correctly when
+ # the pipeline does not run it until exhaustion.
+ # The pipeline is TensorSliceDataset -> RepeatDataset(1000) ->
+ # MapAndBatchDataset(f=square_3, batch_size=100).
+ components = (np.arange(1000),
+ np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
+ np.array(37.0) * np.arange(1000))
+
+ def _map_fn(x, y, z):
+ return math_ops.square(x), math_ops.square(y), math_ops.square(z)
+
+ dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat(
+ 1000).apply(batching.map_and_batch(_map_fn, batch_size=100))
+ dataset = dataset.prefetch(5)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for _ in range(3):
+ sess.run(get_next)
+
+ @parameterized.named_parameters(
+ ("1", 0),
+ ("2", 5),
+ ("3", 10),
+ ("4", 90),
+ ("5", 95),
+ ("6", 99),
+ )
+ def testMapAndBatchOutOfRangeError(self, threshold):
+
+ def raising_py_fn(i):
+ if i >= threshold:
+ raise StopIteration()
+ else:
+ return i
+
+ iterator = (
+ dataset_ops.Dataset.range(100).apply(
+ batching.map_and_batch(
+ lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64),
+ batch_size=10)).make_one_shot_iterator())
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(threshold // 10):
+ self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
+ if threshold % 10 != 0:
+ self.assertAllEqual(
+ [threshold // 10 * 10 + j for j in range(threshold % 10)],
+ sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ @parameterized.named_parameters(
+ ("1", False, dtypes.bool),
+ ("2", -42, dtypes.int8),
+ ("3", -42, dtypes.int16),
+ ("4", -42, dtypes.int32),
+ ("5", -42, dtypes.int64),
+ ("6", 42, dtypes.uint8),
+ ("7", 42, dtypes.uint16),
+ ("8", 42.0, dtypes.float16),
+ ("9", 42.0, dtypes.float32),
+ ("10", 42.0, dtypes.float64),
+ ("11", b"hello", dtypes.string),
+ )
+ def testMapAndBatchTypes(self, element, dtype):
+ def gen():
+ yield element
+
+ dataset = dataset_ops.Dataset.from_generator(gen, dtype).repeat(100).apply(
+ batching.map_and_batch(lambda x: x, batch_size=10))
+
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ for _ in range(10):
+ self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
+
+ @parameterized.named_parameters(
+ ("Identity", None, lambda x: x, None),
+ ("Replicate", None, lambda x: (x, x), None),
+ ("Swap", (None, None), lambda x, y: (y, x), None),
+ ("Project", (None, None), lambda x, y: x, None),
+ )
+ def testShortCircuit(self, structure, map_fn, num_parallel_calls):
+ dataset = self.structuredDataset(structure).repeat().apply(
+ batching.map_and_batch(map_fn, batch_size=10))
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ if isinstance(structure, tuple):
+ expected = map_fn(
+ *sess.run(self.structuredElement(structure, shape=[10])))
+ else:
+ expected = map_fn(
+ sess.run(self.structuredElement(structure, shape=[10])))
+ self.assertAllEqual(expected, sess.run(get_next))
+
+ def testShortCircuitCapturedInput(self):
+ captured_t = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset = self.structuredDataset(None).repeat().apply(
+ batching.map_and_batch(lambda x: captured_t, batch_size=10))
+ iterator = dataset.make_initializable_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer, feed_dict={captured_t: 42})
+ self.assertAllEqual([42] * 10, sess.run(get_next))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
index 612ee332c4..ae9dedb0ab 100644
--- a/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
@@ -235,6 +235,18 @@ class MapDefunTest(test_base.DatasetTestBase):
sess.close()
thread.join()
+ def testMapDefunWithCapturedInputs(self):
+ c = constant_op.constant(2)
+
+ @function.Defun(dtypes.int32)
+ def fn(x):
+ return x + c
+
+ x = constant_op.constant([1, 2, 3, 4])
+ map_defun_op = map_defun.map_defun(fn, [x], [dtypes.int32], [()])[0]
+ expected = x + c
+ self.assertAllEqual(self.evaluate(expected), self.evaluate(map_defun_op))
+
class MapDefunBenchmark(test.Benchmark):
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
index 32ebc49c40..971a2d94b9 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
@@ -78,6 +78,7 @@ class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("Basic", lambda x: (x, x + 1), None),
+ ("Const", lambda x: 2, 12),
("Parallel", lambda x: (x, x + 1), 12),
("Gather", lambda x: array_ops.gather(x, 0), 12),
)
@@ -207,6 +208,9 @@ class MapVectorizationBenchmark(test.Benchmark):
def benchmarkAddConst(self):
self._benchmark_helper(lambda *args: [x + 1 for x in args], "add_const")
+ def benchmarkReturnConst(self):
+ self._benchmark_helper(lambda *args: [constant_op.constant(2)], "ret_const")
+
def benchmarkSelect(self):
self._benchmark_helper(lambda *args: args[0], "select")
diff --git a/tensorflow/python/data/experimental/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/override_threadpool_test.py
index 4432dcb05a..5e419a9b2f 100644
--- a/tensorflow/python/data/experimental/kernel_tests/threadpool_dataset_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/override_threadpool_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline statistics gathering ops."""
+"""Tests for the private `override_threadpool()` transformation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -32,8 +32,8 @@ from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
-class OverrideThreadpoolDatasetTest(test_base.DatasetTestBase,
- parameterized.TestCase):
+class OverrideThreadpoolTest(test_base.DatasetTestBase,
+ parameterized.TestCase):
@parameterized.named_parameters(
("1", 1, None),
diff --git a/tensorflow/python/data/experimental/kernel_tests/interleave_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/parallel_interleave_test.py
index 560902caad..90ac250df7 100644
--- a/tensorflow/python/data/experimental/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/parallel_interleave_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for `tf.data.experimental.parallel_interleave()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -37,7 +37,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
-class ParallelInterleaveDatasetTest(test_base.DatasetTestBase):
+class ParallelInterleaveTest(test_base.DatasetTestBase):
def setUp(self):
diff --git a/tensorflow/python/data/experimental/kernel_tests/parsing_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/parse_example_dataset_test.py
index 13f924b656..723e709ae8 100644
--- a/tensorflow/python/data/experimental/kernel_tests/parsing_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/parse_example_dataset_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for tensorflow.ops.parsing_ops."""
+"""Tests for `tf.data.experimental.parse_example_dataset()."""
from __future__ import absolute_import
from __future__ import division
@@ -73,7 +73,7 @@ def _compare_output_to_expected(tester, dict_tensors, expected_tensors,
i += 1
-class ParseExampleTest(test_base.DatasetTestBase):
+class ParseExampleDatasetTest(test_base.DatasetTestBase):
def _test(self,
input_tensor,
diff --git a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py
new file mode 100644
index 0000000000..f73725366c
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py
@@ -0,0 +1,234 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.prefetch_to_device()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.data.experimental.ops import prefetching_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+class PrefetchToDeviceTest(test_base.DatasetTestBase):
+
+ def testPrefetchToDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/cpu:1"))
+
+ # NOTE(mrry): This device block creates the "host" dataset and iterator on
+ # /cpu:0, and ensures that the prefetching is across devices. In typical use
+ # this would not be necessary, because the GPU device would not support any
+ # of the dataset-related ops.
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_one_shot_iterator()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ next_element = iterator.get_next()
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchToSameDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device(
+ "/job:localhost/replica:0/task:0/device:CPU:0"))
+
+ # NOTE(mrry): This device block creates the "host" dataset and iterator on
+ # /cpu:0, and ensures that the prefetching is across devices. In typical use
+ # this would not be necessary, because the GPU device would not support any
+ # of the dataset-related ops.
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_one_shot_iterator()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ next_element = iterator.get_next()
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchDictToDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/cpu:1"))
+
+ # NOTE(mrry): This device block creates the "host" dataset and iterator on
+ # /cpu:0, and ensures that the prefetching is across devices. In typical use
+ # this would not be necessary, because the GPU device would not support any
+ # of the dataset-related ops.
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_one_shot_iterator()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ next_element = iterator.get_next()
+ self.assertEqual(dtypes.int64, next_element["a"].dtype)
+ self.assertEqual([], next_element["a"].shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual({"a": i}, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchSparseTensorsToDevice(self):
+ def make_tensor(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0]], values=(i*[1]), dense_shape=[2, 2])
+ host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)
+
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/cpu:1"))
+
+ # NOTE(mrry): This device block creates the "host" dataset and iterator on
+ # /cpu:0, and ensures that the prefetching is across devices. In typical use
+ # this would not be necessary, because the GPU device would not support any
+ # of the dataset-related ops.
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_one_shot_iterator()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ next_element = iterator.get_next()
+ self.assertEqual(dtypes.int64, next_element.dtype)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ actual = sess.run(next_element)
+ self.assertAllEqual([i], actual.values)
+ self.assertAllEqual([[0, 0]], actual.indices)
+ self.assertAllEqual([2, 2], actual.dense_shape)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchToDeviceGpu(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/gpu:0"))
+
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchToDeviceWithReInit(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/cpu:1"))
+
+ # NOTE(mrry): This device block creates the "host" dataset and iterator on
+ # /cpu:0, and ensures that the prefetching is across devices. In typical use
+ # this would not be necessary, because the GPU device would not support any
+ # of the dataset-related ops.
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ next_element = iterator.get_next()
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchToDeviceGpuWithReInit(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/gpu:0"))
+
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py
index b6ab80d132..77df8310d4 100644
--- a/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py
+++ b/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py
@@ -63,11 +63,11 @@ class FixedLengthRecordDatasetTestBase(test_base.DatasetTestBase):
return filenames
-class ReadBatchFeaturesTestBase(test_base.DatasetTestBase):
- """Base class for setting up and testing `make_batched_feature_dataset`."""
+class MakeBatchedFeaturesDatasetTestBase(test_base.DatasetTestBase):
+ """Base class for setting up and testing `make_batched_features_dataset`."""
def setUp(self):
- super(ReadBatchFeaturesTestBase, self).setUp()
+ super(MakeBatchedFeaturesDatasetTestBase, self).setUp()
self._num_files = 2
self._num_records = 7
self.test_filenames = self._createFiles()
diff --git a/tensorflow/python/data/experimental/kernel_tests/resample_test.py b/tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py
index 775648c943..4c879dbae6 100644
--- a/tensorflow/python/data/experimental/kernel_tests/resample_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for `tf.data.experimental.rejection_resample()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -58,7 +58,7 @@ def _time_resampling(
return end_time - start_time
-class ResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
+class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("InitialDistributionKnown", True),
diff --git a/tensorflow/python/data/experimental/kernel_tests/dataset_constructor_op_test.py b/tensorflow/python/data/experimental/kernel_tests/restructured_dataset_test.py
index 3fc7157bc5..516e489d04 100644
--- a/tensorflow/python/data/experimental/kernel_tests/dataset_constructor_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/restructured_dataset_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for the private `_RestructuredDataset` transformation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -26,7 +26,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class DatasetConstructorTest(test_base.DatasetTestBase):
+class RestructuredDatasetTest(test_base.DatasetTestBase):
def testRestructureDataset(self):
components = (array_ops.placeholder(dtypes.int32),
diff --git a/tensorflow/python/data/experimental/kernel_tests/scan_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/scan_test.py
index 78ec80de23..0730455431 100644
--- a/tensorflow/python/data/experimental/kernel_tests/scan_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/scan_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for `tf.data.experimental.scan()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -34,7 +34,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class ScanDatasetTest(test_base.DatasetTestBase):
+class ScanTest(test_base.DatasetTestBase):
def _counting_dataset(self, start, scan_fn):
return dataset_ops.Dataset.from_tensors(0).repeat().apply(
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
index 58a335ae4f..e556b65b7c 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
@@ -70,6 +70,26 @@ py_test(
)
py_test(
+ name = "checkpoint_input_pipeline_hook_test",
+ size = "small",
+ srcs = ["checkpoint_input_pipeline_hook_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/data/experimental/ops:iterator_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/estimator:estimator_py",
+ ],
+)
+
+py_test(
name = "concatenate_dataset_serialization_test",
size = "small",
srcs = ["concatenate_dataset_serialization_test.py"],
@@ -580,7 +600,7 @@ py_test(
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
- "//tensorflow/python/data/experimental/kernel_tests:sql_dataset_op_test_base",
+ "//tensorflow/python/data/experimental/kernel_tests:sql_dataset_test_base",
"//tensorflow/python/data/experimental/ops:readers",
],
)
diff --git a/tensorflow/python/data/experimental/kernel_tests/iterator_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/checkpoint_input_pipeline_hook_test.py
index 94393d6d4b..94393d6d4b 100644
--- a/tensorflow/python/data/experimental/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/checkpoint_input_pipeline_hook_test.py
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py
index a0dd6960b0..b3dfe21486 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py
@@ -23,7 +23,7 @@ from tensorflow.python.platform import test
class ParseExampleDatasetSerializationTest(
- reader_dataset_ops_test_base.ReadBatchFeaturesTestBase,
+ reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase,
dataset_serialization_test_base.DatasetSerializationTestBase):
def ParseExampleDataset(self, num_repeat, batch_size):
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py
index b179770ce3..006279bbe1 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import os
-from tensorflow.python.data.experimental.kernel_tests import sql_dataset_op_test_base
+from tensorflow.python.data.experimental.kernel_tests import sql_dataset_test_base
from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.framework import dtypes
@@ -28,7 +28,7 @@ from tensorflow.python.platform import test
class SqlDatasetSerializationTest(
- sql_dataset_op_test_base.SqlDatasetTestBase,
+ sql_dataset_test_base.SqlDatasetTestBase,
dataset_serialization_test_base.DatasetSerializationTestBase):
def _build_dataset(self, num_repeats):
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization_integration_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization_integration_test.py
deleted file mode 100644
index 88d5c896c9..0000000000
--- a/tensorflow/python/data/experimental/kernel_tests/serialization_integration_test.py
+++ /dev/null
@@ -1,85 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Integration test for dataset serialization."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os
-
-from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import ops
-from tensorflow.python.platform import test
-from tensorflow.python.training import saver as saver_lib
-
-
-class SerializationIntegrationTest(test.TestCase):
-
- def _build_input_pipeline(self, name, num_outputs):
- with ops.name_scope(name):
- ds = dataset_ops.Dataset.range(num_outputs).shuffle(
- 10, reshuffle_each_iteration=False).prefetch(10)
- iterator = ds.make_initializable_iterator()
- saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
- ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
- return iterator.initializer, iterator.get_next()
-
- def _build_graph(self, num_pipelines, num_outputs):
- init_ops = []
- get_next_ops = []
- for i in range(num_pipelines):
- name = "input_pipeline_%d" % i
- init_op, get_next_op = self._build_input_pipeline(name, num_outputs)
- init_ops.append(init_op)
- get_next_ops.append(get_next_op)
- saver = saver_lib.Saver()
- return init_ops, get_next_ops, saver
-
- def _ckpt_path(self):
- return os.path.join(self.get_temp_dir(), "iterator")
-
- def testConcurrentSaves(self):
- num_pipelines = 100
- num_outputs = 100
- break_point = 10
- all_outputs = [[] for _ in range(num_pipelines)]
- with ops.Graph().as_default() as g:
- init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
- num_outputs)
- with self.session(graph=g) as sess:
- sess.run(init_ops)
- for _ in range(break_point):
- output = sess.run(get_next_ops)
- for i in range(num_pipelines):
- all_outputs[i].append(output[i])
- saver.save(sess, self._ckpt_path())
-
- with ops.Graph().as_default() as g:
- init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
- num_outputs)
- with self.session(graph=g) as sess:
- saver.restore(sess, self._ckpt_path())
- for _ in range(num_outputs - break_point):
- output = sess.run(get_next_ops)
- for i in range(num_pipelines):
- all_outputs[i].append(output[i])
-
- for output in all_outputs:
- self.assertSequenceEqual(sorted(output), range(num_outputs))
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/shuffle_and_repeat_test.py
index 50895b5945..c208963a86 100644
--- a/tensorflow/python/data/experimental/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/shuffle_and_repeat_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for `tf.data.experimental.shuffle_and_repeat()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
diff --git a/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_test.py
index 301f75488a..a2c1169638 100644
--- a/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_test.py
@@ -12,19 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for experimental sql input op."""
+"""Tests for `tf.data.experimental.SqlDataset`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.experimental.kernel_tests import sql_dataset_op_test_base
+from tensorflow.python.data.experimental.kernel_tests import sql_dataset_test_base
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
-class SqlDatasetTest(sql_dataset_op_test_base.SqlDatasetTestBase):
+class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
# Test that SqlDataset can read from a database table.
def testReadResultSet(self):
diff --git a/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test_base.py b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_test_base.py
index a135c357f0..6aaaa90c65 100644
--- a/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test_base.py
+++ b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_test_base.py
@@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Base class for testing SqlDataset."""
-
+"""Base class for testing `tf.data.experimental.SqlDataset`."""
from __future__ import absolute_import
from __future__ import division
diff --git a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
index 19f5a62d45..427654cd76 100644
--- a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
@@ -280,7 +280,7 @@ class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
class FeatureStatsDatasetTest(
stats_dataset_test_base.StatsDatasetTestBase,
- reader_dataset_ops_test_base.ReadBatchFeaturesTestBase):
+ reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase):
def testFeaturesStats(self):
num_epochs = 5
diff --git a/tensorflow/python/data/experimental/kernel_tests/writer_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/tf_record_writer_test.py
index 25a2e63ba1..8fd0ad50c4 100644
--- a/tensorflow/python/data/experimental/kernel_tests/writer_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/tf_record_writer_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for `tf.data.experimental.TFRecordWriter`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
diff --git a/tensorflow/python/data/experimental/kernel_tests/unbatch_test.py b/tensorflow/python/data/experimental/kernel_tests/unbatch_test.py
new file mode 100644
index 0000000000..0278a208cb
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/unbatch_test.py
@@ -0,0 +1,300 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.unbatch()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+
+class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ def testUnbatchWithUnknownRankInput(self):
+ placeholder = array_ops.placeholder(dtypes.int32)
+ dataset = dataset_ops.Dataset.from_tensors(placeholder).apply(
+ batching.unbatch())
+ iterator = dataset.make_initializable_iterator()
+ next_elem = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer, feed_dict={placeholder: [0, 1, 2, 3]})
+ for i in range(4):
+ self.assertEqual(i, sess.run(next_elem))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_elem)
+
+ def testUnbatchScalarDataset(self):
+ data = tuple([math_ops.range(10) for _ in range(3)])
+ data = dataset_ops.Dataset.from_tensor_slices(data)
+ expected_types = (dtypes.int32,) * 3
+ data = data.batch(2)
+ self.assertEqual(expected_types, data.output_types)
+ data = data.apply(batching.unbatch())
+ self.assertEqual(expected_types, data.output_types)
+
+ iterator = data.make_one_shot_iterator()
+ op = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual((i,) * 3, sess.run(op))
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(op)
+
+ def testUnbatchDatasetWithStrings(self):
+ data = tuple([math_ops.range(10) for _ in range(3)])
+ data = dataset_ops.Dataset.from_tensor_slices(data)
+ data = data.map(lambda x, y, z: (x, string_ops.as_string(y), z))
+ expected_types = (dtypes.int32, dtypes.string, dtypes.int32)
+ data = data.batch(2)
+ self.assertEqual(expected_types, data.output_types)
+ data = data.apply(batching.unbatch())
+ self.assertEqual(expected_types, data.output_types)
+
+ iterator = data.make_one_shot_iterator()
+ op = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(op)
+
+ def testUnbatchDatasetWithSparseTensor(self):
+ st = sparse_tensor.SparseTensorValue(
+ indices=[[i, i] for i in range(10)],
+ values=list(range(10)),
+ dense_shape=[10, 10])
+ data = dataset_ops.Dataset.from_tensors(st)
+ data = data.apply(batching.unbatch())
+ data = data.batch(5)
+ data = data.apply(batching.unbatch())
+ iterator = data.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ st_row = sess.run(next_element)
+ self.assertEqual([i], st_row.indices)
+ self.assertEqual([i], st_row.values)
+ self.assertEqual([10], st_row.dense_shape)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testUnbatchDatasetWithDenseAndSparseTensor(self):
+ st = sparse_tensor.SparseTensorValue(
+ indices=[[i, i] for i in range(10)],
+ values=list(range(10)),
+ dense_shape=[10, 10])
+ data = dataset_ops.Dataset.from_tensors((list(range(10)), st))
+ data = data.apply(batching.unbatch())
+ data = data.batch(5)
+ data = data.apply(batching.unbatch())
+ iterator = data.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ dense_elem, st_row = sess.run(next_element)
+ self.assertEqual(i, dense_elem)
+ self.assertEqual([i], st_row.indices)
+ self.assertEqual([i], st_row.values)
+ self.assertEqual([10], st_row.dense_shape)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testUnbatchSingleElementTupleDataset(self):
+ data = tuple([(math_ops.range(10),) for _ in range(3)])
+ data = dataset_ops.Dataset.from_tensor_slices(data)
+ expected_types = ((dtypes.int32,),) * 3
+ data = data.batch(2)
+ self.assertEqual(expected_types, data.output_types)
+ data = data.apply(batching.unbatch())
+ self.assertEqual(expected_types, data.output_types)
+
+ iterator = data.make_one_shot_iterator()
+ op = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual(((i,),) * 3, sess.run(op))
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(op)
+
+ def testUnbatchMultiElementTupleDataset(self):
+ data = tuple([(math_ops.range(10 * i, 10 * i + 10),
+ array_ops.fill([10], "hi")) for i in range(3)])
+ data = dataset_ops.Dataset.from_tensor_slices(data)
+ expected_types = ((dtypes.int32, dtypes.string),) * 3
+ data = data.batch(2)
+ self.assertAllEqual(expected_types, data.output_types)
+ data = data.apply(batching.unbatch())
+ self.assertAllEqual(expected_types, data.output_types)
+
+ iterator = data.make_one_shot_iterator()
+ op = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")),
+ sess.run(op))
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(op)
+
+ def testUnbatchEmpty(self):
+ data = dataset_ops.Dataset.from_tensors(
+ (constant_op.constant([]), constant_op.constant([], shape=[0, 4]),
+ constant_op.constant([], shape=[0, 4, 0])))
+ data = data.apply(batching.unbatch())
+ iterator = data.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testUnbatchStaticShapeMismatch(self):
+ data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8),
+ np.arange(9)))
+ with self.assertRaises(ValueError):
+ data.apply(batching.unbatch())
+
+ def testUnbatchDynamicShapeMismatch(self):
+ ph1 = array_ops.placeholder(dtypes.int32, shape=[None])
+ ph2 = array_ops.placeholder(dtypes.int32, shape=None)
+ data = dataset_ops.Dataset.from_tensors((ph1, ph2))
+ data = data.apply(batching.unbatch())
+ iterator = data.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ # Mismatch in the 0th dimension.
+ sess.run(
+ iterator.initializer,
+ feed_dict={
+ ph1: np.arange(7).astype(np.int32),
+ ph2: np.arange(8).astype(np.int32)
+ })
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(next_element)
+
+ # No 0th dimension (i.e. scalar value) for one component.
+ sess.run(
+ iterator.initializer,
+ feed_dict={
+ ph1: np.arange(7).astype(np.int32),
+ ph2: 7
+ })
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(next_element)
+
+
+class UnbatchBenchmark(test.Benchmark):
+
+ def benchmarkNativeUnbatch(self):
+ batch_sizes = [1, 2, 5, 10, 20, 50]
+ elems_per_trial = 10000
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
+ batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset = dataset.batch(batch_size_placeholder)
+ dataset = dataset.apply(batching.unbatch())
+ dataset = dataset.skip(elems_per_trial)
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for batch_size in batch_sizes:
+ deltas = []
+ for _ in range(5):
+ sess.run(
+ iterator.initializer,
+ feed_dict={batch_size_placeholder: batch_size})
+ start = time.time()
+ sess.run(next_element.op)
+ end = time.time()
+ deltas.append((end - start) / elems_per_trial)
+
+ median_wall_time = np.median(deltas)
+ print("Unbatch (native) batch size: %d Median wall time per element:"
+ " %f microseconds" % (batch_size, median_wall_time * 1e6))
+ self.report_benchmark(
+ iters=10000,
+ wall_time=median_wall_time,
+ name="benchmark_unbatch_dataset_native_batch_size_%d" %
+ batch_size)
+
+ # Include a benchmark of the previous `unbatch()` implementation that uses
+ # a composition of more primitive ops. Eventually we'd hope to generate code
+ # that is as good in both cases.
+ def benchmarkOldUnbatchImplementation(self):
+ batch_sizes = [1, 2, 5, 10, 20, 50]
+ elems_per_trial = 10000
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
+ batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset = dataset.batch(batch_size_placeholder)
+ dataset = dataset.flat_map(dataset_ops.Dataset.from_tensor_slices)
+ dataset = dataset.skip(elems_per_trial)
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for batch_size in batch_sizes:
+ deltas = []
+ for _ in range(5):
+ sess.run(
+ iterator.initializer,
+ feed_dict={batch_size_placeholder: batch_size})
+ start = time.time()
+ sess.run(next_element.op)
+ end = time.time()
+ deltas.append((end - start) / elems_per_trial)
+
+ median_wall_time = np.median(deltas)
+ print("Unbatch (unfused) batch size: %d Median wall time per element:"
+ " %f microseconds" % (batch_size, median_wall_time * 1e6))
+ self.report_benchmark(
+ iters=10000,
+ wall_time=median_wall_time,
+ name="benchmark_unbatch_dataset_unfused_batch_size_%d" %
+ batch_size)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/unique_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/unique_test.py
index b5a0b20f3f..847cff26b0 100644
--- a/tensorflow/python/data/experimental/kernel_tests/unique_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/unique_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for `tf.data.experimental.unique()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -26,7 +26,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class UniqueDatasetTest(test_base.DatasetTestBase):
+class UniqueTest(test_base.DatasetTestBase):
def _testSimpleHelper(self, dtype, test_cases):
"""Test the `unique()` transformation on a list of test cases.
diff --git a/tensorflow/python/data/experimental/ops/map_defun.py b/tensorflow/python/data/experimental/ops/map_defun.py
index 3d0d0993c9..3ac1158d8b 100644
--- a/tensorflow/python/data/experimental/ops/map_defun.py
+++ b/tensorflow/python/data/experimental/ops/map_defun.py
@@ -47,10 +47,12 @@ def map_defun(fn, elems, output_dtypes, output_shapes):
if not isinstance(elems, list):
raise ValueError("`elems` must be a list of tensors.")
if not isinstance(output_dtypes, list):
- raise ValueError("`output_dtypes` must be a list of tensors.")
+ raise ValueError("`output_dtypes` must be a list of `tf.DType` objects.")
if not isinstance(output_shapes, list):
- raise ValueError("`output_shapes` must be a list of tensors.")
+ raise ValueError("`output_shapes` must be a list of `tf.TensorShape` "
+ "objects.")
elems = [ops.convert_to_tensor(e) for e in elems]
output_shapes = [tensor_shape.TensorShape(s) for s in output_shapes]
- return gen_dataset_ops.map_defun(elems, output_dtypes, output_shapes, fn)
+ return gen_dataset_ops.map_defun(elems, fn.captured_inputs, output_dtypes,
+ output_shapes, fn)
diff --git a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
index 6b7afafa5d..a0c6b37a6d 100644
--- a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
@@ -156,7 +156,7 @@ class FilterDatasetTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def testReturnComponent(self):
+ def testShortCircuit(self):
iterator = (
dataset_ops.Dataset.zip(
(dataset_ops.Dataset.range(10),
diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
index 230ae3f3fd..4683b1db91 100644
--- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for `tf.data.Dataset.map()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -267,6 +267,35 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
+ def testCaptureIterator(self):
+
+ def _build_ds(iterator):
+
+ def _map_fn(x):
+ get_next = iterator.get_next()
+ return x * get_next
+
+ return dataset_ops.Dataset.range(10).map(_map_fn)
+
+ def _build_graph():
+ captured_iterator = dataset_ops.Dataset.range(
+ 10).make_initializable_iterator()
+ ds = _build_ds(captured_iterator)
+ iterator = ds.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ return captured_iterator.initializer, init_op, get_next
+
+ with ops.Graph().as_default() as g:
+ captured_init_op, init_op, get_next = _build_graph()
+ with self.session(graph=g) as sess:
+ sess.run(captured_init_op)
+ sess.run(init_op)
+ for i in range(10):
+ self.assertEqual(i * i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
def testCaptureHashTable(self):
# NOTE(mrry): We must use the V2 variants of `HashTable`
# etc. because these produce a `tf.resource`-typed output that is
@@ -593,7 +622,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.run(init_op)
for i in range(10):
actual = sess.run(get_next)
- self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue))
+ self.assertIsInstance(actual, sparse_tensor.SparseTensorValue)
self.assertSparseValuesEqual(actual, _sparse(i))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -620,7 +649,7 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.run(init_op)
for i in range(10):
actual = sess.run(get_next)
- self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue))
+ self.assertIsInstance(actual, sparse_tensor.SparseTensorValue)
self.assertSparseValuesEqual(actual, _check(_sparse(i)).eval())
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -754,19 +783,72 @@ class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertTrue(all(tids[0] == tid for tid in tids))
# pylint: enable=g-long-lambda
+ @parameterized.named_parameters(
+ ("SequentialIdentity", None, lambda x: x, None),
+ ("SequentialReplicate", None, lambda x: (x, x), None),
+ ("SequentialSwap", (None, None), lambda x, y: (y, x), None),
+ ("SequentialProject", (None, None), lambda x, y: x, None),
+ ("ParallelIdentity", None, lambda x: x, 10),
+ ("ParallelReplicate", None, lambda x: (x, x), 10),
+ ("ParallelSwap", (None, None), lambda x, y: (y, x), 10),
+ ("ParallelProject", (None, None), lambda x, y: x, 10),
+ )
+ def testShortCircuit(self, structure, map_fn, num_parallel_calls):
+ dataset = self.structuredDataset(structure).repeat().map(
+ map_fn, num_parallel_calls=num_parallel_calls)
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ if isinstance(structure, tuple):
+ expected = map_fn(*sess.run(self.structuredElement(structure)))
+ else:
+ expected = map_fn(sess.run(self.structuredElement(structure)))
+ self.assertEqual(expected, sess.run(get_next))
+
+ @parameterized.named_parameters(
+ ("Sequential", None),
+ ("Parallel", 10),
+ )
+ def testShortCircuitCapturedInput(self, num_parallel_calls):
+ captured_t = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset = self.structuredDataset(None).repeat().map(
+ lambda x: captured_t, num_parallel_calls=num_parallel_calls)
+ iterator = dataset.make_initializable_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer, feed_dict={captured_t: 42})
+ self.assertEqual(42, sess.run(get_next))
+
class MapDatasetBenchmark(test.Benchmark):
def benchmarkChainOfMaps(self):
chain_lengths = [0, 1, 2, 5, 10, 20, 50]
for chain_length in chain_lengths:
- for use_inter_op_parallelism in [False, True]:
+ for mode in ["general", "single-threaded", "short-circuit"]:
+ if mode == "general":
+ map_fn = lambda x: x + 1
+ use_inter_op_parallelism = True
+ print_label = ""
+ benchmark_label = ""
+ if mode == "single-threaded":
+ map_fn = lambda x: x + 1
+ use_inter_op_parallelism = False
+ print_label = " (single threaded mode)"
+ benchmark_label = "_single_threaded"
+ if mode == "short-circuit":
+ map_fn = lambda x: x
+ use_inter_op_parallelism = True # should not have any significance
+ print_label = " (short circuit mode)"
+ benchmark_label = "_short_circuit"
+
with ops.Graph().as_default():
dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
for _ in range(chain_length):
dataset = dataset_ops.MapDataset(
dataset,
- lambda x: x,
+ map_fn,
use_inter_op_parallelism=use_inter_op_parallelism)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
@@ -784,25 +866,39 @@ class MapDatasetBenchmark(test.Benchmark):
median_wall_time = np.median(deltas) / 100
print("Map dataset chain length%s: %d Median wall time: %f" %
- (" (single threaded mode)" if not use_inter_op_parallelism
- else "", chain_length, median_wall_time))
+ (print_label, chain_length, median_wall_time))
self.report_benchmark(
iters=1000,
wall_time=median_wall_time,
name="benchmark_map_dataset_chain_latency_%d%s" %
- (chain_length, "_single_threaded"
- if not use_inter_op_parallelism else ""))
+ (chain_length, benchmark_label))
def benchmarkMapFanOut(self):
fan_outs = [1, 2, 5, 10, 20, 50, 100]
for fan_out in fan_outs:
- for use_inter_op_parallelism in [False, True]:
+ for mode in ["general", "single-threaded", "short-circuit"]:
+ if mode == "general":
+ map_fn = lambda *xs: [x + 1 for x in xs]
+ use_inter_op_parallelism = True
+ print_label = ""
+ benchmark_label = ""
+ if mode == "single-threaded":
+ map_fn = lambda *xs: [x + 1 for x in xs]
+ use_inter_op_parallelism = False
+ print_label = " (single threaded mode)"
+ benchmark_label = "_single_threaded"
+ if mode == "short-circuit":
+ map_fn = lambda *xs: xs
+ use_inter_op_parallelism = True # should not have any significance
+ print_label = " (short circuit mode)"
+ benchmark_label = "_short_circuit"
+
with ops.Graph().as_default():
dataset = dataset_ops.Dataset.from_tensors(
tuple(0 for _ in range(fan_out))).repeat(None)
dataset = dataset_ops.MapDataset(
dataset,
- lambda *xs: xs,
+ map_fn,
use_inter_op_parallelism=use_inter_op_parallelism)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
@@ -820,14 +916,12 @@ class MapDatasetBenchmark(test.Benchmark):
median_wall_time = np.median(deltas) / 100
print("Map dataset fan out%s: %d Median wall time: %f" %
- (" (single threaded mode)" if not use_inter_op_parallelism
- else "", fan_out, median_wall_time))
+ (print_label, fan_out, median_wall_time))
self.report_benchmark(
iters=1000,
wall_time=median_wall_time,
- name="benchmark_map_dataset_fan_out_%d%s" %
- (fan_out, "_single_threaded"
- if not use_inter_op_parallelism else ""))
+ name="benchmark_map_dataset_fan_out_%d%s" % (fan_out,
+ benchmark_label))
if __name__ == "__main__":
diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py
index b730e10949..b73a94e683 100644
--- a/tensorflow/python/data/kernel_tests/test_base.py
+++ b/tensorflow/python/data/kernel_tests/test_base.py
@@ -19,10 +19,13 @@ from __future__ import print_function
import re
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -107,3 +110,29 @@ class DatasetTestBase(test.TestCase):
with self.assertRaisesRegexp(exception_class,
re.escape(expected_message)):
self.evaluate(next2())
+
+ def structuredDataset(self, structure, shape=None, dtype=dtypes.int64):
+ """Returns a singleton dataset with the given structure."""
+ if shape is None:
+ shape = []
+ if structure is None:
+ return dataset_ops.Dataset.from_tensors(
+ array_ops.zeros(shape, dtype=dtype))
+ else:
+ return dataset_ops.Dataset.zip(
+ tuple([
+ self.structuredDataset(substructure, shape, dtype)
+ for substructure in structure
+ ]))
+
+ def structuredElement(self, structure, shape=None, dtype=dtypes.int64):
+ """Returns an element with the given structure."""
+ if shape is None:
+ shape = []
+ if structure is None:
+ return array_ops.zeros(shape, dtype=dtype)
+ else:
+ return tuple([
+ self.structuredElement(substructure, shape, dtype)
+ for substructure in structure
+ ])
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index d0c1a93118..cae809a7c3 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -251,6 +251,7 @@ py_library(
"//tensorflow/python:gradients_impl",
"//tensorflow/python:graph_to_function_def",
"//tensorflow/python:util",
+ "//tensorflow/python/autograph",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:core",
"//tensorflow/python/eager:execute",
diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py
index 3fe79ef244..2b0118c07f 100644
--- a/tensorflow/python/eager/benchmarks_test.py
+++ b/tensorflow/python/eager/benchmarks_test.py
@@ -353,7 +353,7 @@ class MicroBenchmarks(test.Benchmark):
num_iters,
execution_mode=None):
f = function.defun(math_ops.matmul)
- func = lambda: f(m, m, transpose_b)
+ func = lambda: f(m, m, transpose_b=transpose_b)
self._run(func, num_iters, execution_mode=execution_mode)
def _benchmark_defun_matmul_forward_backward(self,
@@ -366,7 +366,7 @@ class MicroBenchmarks(test.Benchmark):
def func():
with backprop.GradientTape() as gt:
gt.watch(m)
- y = f(m, m, transpose_b)
+ y = f(m, m, transpose_b=transpose_b)
_ = gt.gradient(y, m)
self._run(func, num_iters, execution_mode=execution_mode)
diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py
index fb5442b646..e601aa376f 100644
--- a/tensorflow/python/eager/core_test.py
+++ b/tensorflow/python/eager/core_test.py
@@ -631,6 +631,34 @@ class TFETest(test_util.TensorFlowTestCase):
for t in tensors:
self.assertIsInstance(t, ops.EagerTensor)
+ def testSmallIntegerOpsForcedToCPU(self):
+ if not context.context().num_gpus():
+ self.skipTest('No GPUs found')
+
+ a = constant_op.constant((1, 2, 3, 4, 5), dtype=dtypes.int64)
+ b = constant_op.constant((2, 3, 4, 5, 6), dtype=dtypes.int64)
+ with context.device('gpu:0'):
+ c = a + b
+
+ # Op forced to CPU since all constants are integers and small.
+ self.assertEqual(c.device, '/job:localhost/replica:0/task:0/device:CPU:0')
+
+ a = array_ops.zeros((8, 10), dtype=dtypes.int64)
+ b = array_ops.ones((8, 10), dtype=dtypes.int64)
+
+ with context.device('gpu:0'):
+ c = a + b
+
+ # Op not forced to CPU since the tensors are larger than 64 elements.
+ self.assertEqual(c.device, '/job:localhost/replica:0/task:0/device:GPU:0')
+
+ a = constant_op.constant((1, 2, 3, 4, 5), dtype=dtypes.float32)
+ b = constant_op.constant((2, 3, 4, 5, 6), dtype=dtypes.float32)
+ with context.device('gpu:0'):
+ c = a + b
+
+ # Op not forced to CPU since the constants are not integers.
+ self.assertEqual(c.device, '/job:localhost/replica:0/task:0/device:GPU:0')
class SendRecvTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index dd9f5e233c..99bf375ea7 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -31,6 +31,7 @@ import six
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
+from tensorflow.python import autograph
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import execute
@@ -45,6 +46,7 @@ from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2_impl
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import resource_variable_ops
@@ -80,49 +82,10 @@ def _create_substitute_placeholder(value, name=None, dtype=None):
with ops.control_dependencies(None):
placeholder = graph_placeholder(
dtype=dtype or value.dtype, shape=value.shape, name=name)
- _copy_handle_data(value, placeholder)
+ custom_gradient.copy_handle_data(value, placeholder)
return placeholder
-def _copy_handle_data(source_t, target_t):
- """Copies HandleData for variant and resource type tensors if available.
-
- The CppShapeInferenceResult::HandleData proto contains information about the
- shapes and types of the element tensors of resource/variant type tensors.
- We need to copy this across function boundaries, i.e., when capturing a
- placeholder or when returning a function tensor as output. If we don't do this
- the element tensors will have unknown shapes, e.g., if a TensorList variant
- tensor is captured as a placeholder, elements popped from that list would have
- unknown shape.
-
- Args:
- source_t: The tensor to copy HandleData from.
- target_t: The tensor to copy HandleData to.
- """
- if (target_t.dtype == dtypes_module.resource or
- target_t.dtype == dtypes_module.variant):
- if isinstance(source_t, ops.EagerTensor):
- handle_data = source_t._handle_data # pylint: disable=protected-access
- else:
- handle_data = resource_variable_ops.get_resource_handle_data(source_t)
- if handle_data is not None and handle_data.is_set:
- # pylint: disable=protected-access
- pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph,
- target_t._as_tf_output(),
- handle_data.SerializeToString())
- # pylint: enable=protected-access
- # Ensure that shapes and dtypes are propagated.
- shapes, types = zip(*[(pair.shape, pair.dtype)
- for pair in handle_data.shape_and_type])
- ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
- shapes = [[d.size for d in s.dim]
- if not s.unknown_rank else None for s in shapes]
- pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
- target_t._op._graph._c_graph, # pylint: disable=protected-access
- target_t._as_tf_output(), # pylint: disable=protected-access
- shapes, ranks, types)
-
-
def _get_device_functions(ctx, graph):
"""Returns a tuple of device functions representing the device stack."""
if ctx.executing_eagerly():
@@ -269,15 +232,6 @@ class FuncGraph(ops.Graph):
def variables(self, var_list):
self._weak_variables = [weakref.ref(v) for v in var_list]
- def control_dependencies(self, control_inputs):
- # Drop control dependencies to outside of the graph. TODO(b/117109273)
- # unclear how to capture an op, not a tensor.
- if not control_inputs:
- return super(FuncGraph, self).control_dependencies(control_inputs)
- return super(FuncGraph, self).control_dependencies(
- [c for c in control_inputs
- if getattr(c, "graph", None) is self])
-
def create_op(
self,
op_type,
@@ -503,6 +457,9 @@ class _EagerDefinedFunction(object):
Returns:
The outputs of the function call.
+
+ Raises:
+ ValueError: if the number of arguments is incorrect.
"""
executing_eagerly = ctx.executing_eagerly()
@@ -536,6 +493,10 @@ class _EagerDefinedFunction(object):
# TODO(akshayka): Either remove this if the FunctionLibraryRuntime
# creates `PartitionedCallOp` kernels by default, or remove the previous
# branch if a TPU kernel is registered for `PartitionedCall`.
+ if len(args) != len(self.signature.input_arg):
+ raise ValueError(
+ "Arguments and signature arguments do not match: %s %s " %
+ (len(args), len(list(self.signature.input_arg))))
outputs = functional_ops.partitioned_call(
args=args,
f=self,
@@ -548,7 +509,7 @@ class _EagerDefinedFunction(object):
for i, shape in enumerate(self._output_shapes):
outputs[i].set_shape(shape)
for i, func_graph_output in enumerate(self._func_graph_outputs):
- _copy_handle_data(func_graph_output, outputs[i])
+ custom_gradient.copy_handle_data(func_graph_output, outputs[i])
return outputs
@@ -659,7 +620,48 @@ class Function(object):
if tape.should_record(tensor_inputs) or tape.should_record(captures):
return self._backprop_call(args)
- outputs = self._inference_function.call(ctx, args)
+ # Only need to override the gradient in graph mode and when we have outputs.
+ if context.executing_eagerly() or not self.outputs:
+ outputs = self._inference_function.call(ctx, args)
+ else:
+ name = "PartitionedCall-%s" % ops.uid()
+
+ @ops.RegisterGradient(name)
+ def grad_fn(op, *doutputs): # pylint: disable=unused-variable
+ """Gradients of this function."""
+ if op.graph is not ops.get_default_graph():
+ # TODO(apassos) this will still emit SymbolicGradient ops when
+ # nested defuns are being differentiated. We need to somehow figure
+ # out a way to update the FunctionDef corresponding to the calling
+ # function when mutating a call to the forward pass.
+ return gradients_impl._SymGrad(op, list(doutputs)) # pylint: disable=protected-access
+ if self._backward_graph_function is None:
+ self._construct_backprop_function()
+ self._forward_function.add_to_graph(op.graph)
+ func = attr_value_pb2.AttrValue(
+ func=attr_value_pb2.NameAttrList(
+ name=self._forward_function.name))
+ # pylint: disable=protected-access
+ op._set_attr("f", func)
+ types = attr_value_pb2.AttrValue.ListValue(
+ type=self._forward_function._output_types)
+ op._set_attr("Tout", attr_value_pb2.AttrValue(list=types))
+ for i in range(
+ len(outputs), len(self._forward_function._output_types)):
+ t = ops.Tensor(op, i, self._forward_function._output_types[i])
+ t.set_shape(self._forward_function._output_shapes[i])
+ func_graph_output = self._forward_function._func_graph_outputs[i]
+ custom_gradient.copy_handle_data(func_graph_output, t)
+ op._outputs.append(t)
+ # pylint: enable=protected-access
+ side_outputs = op.outputs[len(outputs):]
+ return self._backward_graph_function(
+ *(list(doutputs) + list(side_outputs)))
+
+ with ops.get_default_graph().gradient_override_map(
+ {"PartitionedCall": name}):
+ outputs = self._inference_function.call(ctx, args)
+
return self._build_call_outputs(outputs)
@property
@@ -756,7 +758,6 @@ class Function(object):
BACKWARD_FUNCTION_ATTRIBUTE_NAME:
self._backward_graph_function._inference_function.name}) # pylint: disable=protected-access
forward_function_attr.update(self._attrs)
-
self._forward_function = _EagerDefinedFunction(
forward_function_name, self._func_graph, self._func_graph.inputs,
self._func_graph.outputs + backwards_graph_captures,
@@ -857,20 +858,12 @@ class Function(object):
return ret
-def _get_defun_inputs_from_signature(signature):
- """Maps a signature to graph-construction inputs."""
- function_inputs = [
- graph_placeholder(spec.dtype, spec.shape)
- for spec in nest.flatten(signature)
- ]
- return nest.pack_sequence_as(signature, function_inputs)
-
-
def _get_defun_inputs_from_args(args):
"""Maps python function args to graph-construction inputs."""
function_inputs = [
graph_placeholder(arg.dtype, arg.shape)
- if isinstance(arg, ops.Tensor) else arg for arg in nest.flatten(args)
+ if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec))
+ else arg for arg in nest.flatten(args)
]
return nest.pack_sequence_as(args, function_inputs)
@@ -880,7 +873,8 @@ def func_graph_from_py_func(name,
args,
kwargs,
signature=None,
- func_graph=None):
+ func_graph=None,
+ experimental_autograph=False):
"""Returns a `FuncGraph` generated from `python_func`.
Args:
@@ -897,6 +891,8 @@ def func_graph_from_py_func(name,
inputs.
func_graph: Optional. An instance of FuncGraph. If provided, we will use
this graph else a new one is built and returned.
+ experimental_autograph: whether to use autograph to compile `python_func`.
+ See https://www.tensorflow.org/guide/autograph for more information.
Returns:
A FuncGraph.
@@ -911,12 +907,12 @@ def func_graph_from_py_func(name,
with func_graph.as_default(), AutomaticControlDependencies() as a:
variable_scope.get_variable_scope().set_use_resource(True)
- if signature is None:
- func_args = _get_defun_inputs_from_args(args)
- func_kwargs = _get_defun_inputs_from_args(kwargs)
- else:
- func_args = _get_defun_inputs_from_signature(signature)
- func_kwargs = {}
+ if signature is not None:
+ args = signature
+ kwargs = {}
+
+ func_args = _get_defun_inputs_from_args(args)
+ func_kwargs = _get_defun_inputs_from_args(kwargs)
# Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
# Variables to help check whether mutation happens in calling the function
@@ -942,7 +938,17 @@ def func_graph_from_py_func(name,
this_tape = tape.push_new_tape()
try:
- func_outputs = python_func(*func_args, **func_kwargs)
+ if experimental_autograph:
+ func_outputs = autograph.converted_call(
+ python_func,
+ autograph.ConversionOptions(
+ verbose=True,
+ recursive=True,
+ force_conversion=False,
+ strip_decorators=(defun,),
+ arg_types={}), *func_args, **func_kwargs)
+ else:
+ func_outputs = python_func(*func_args, **func_kwargs)
# invariant: `func_outputs` contains only Tensors and `None`s.
func_outputs = nest.map_structure(convert, func_outputs)
@@ -1038,7 +1044,8 @@ class PolymorphicFunction(object):
python_function,
name,
input_signature=None,
- attributes=None):
+ attributes=None,
+ experimental_autograph=False):
"""Initializes a polymorphic function.
Args:
@@ -1048,7 +1055,10 @@ class PolymorphicFunction(object):
specifying the input signature of this function. If `None`, a separate
function is instantiated for each inferred input signature.
attributes: dict, extra keyword arguments that will be added as attribute
- of the function.
+ of the function.
+ experimental_autograph: whether to use autograph to compile
+ `python_function`. See https://www.tensorflow.org/guide/autograph for
+ more information.
Raises:
ValueError: if `input_signature` is not None and the `python_function`'s
@@ -1064,6 +1074,7 @@ class PolymorphicFunction(object):
self._args_to_prepend = tuple()
self._kwargs_to_include = {}
self._name = name
+ self._experimental_autograph = experimental_autograph
self._function_cache = collections.OrderedDict()
self._function_attributes = attributes or {}
@@ -1289,8 +1300,13 @@ class PolymorphicFunction(object):
if graph_function is None:
graph_function = Function(
- func_graph_from_py_func(self._name, self._python_function, args,
- kwargs, self._input_signature),
+ func_graph_from_py_func(
+ self._name,
+ self._python_function,
+ args,
+ kwargs,
+ self._input_signature,
+ experimental_autograph=self._experimental_autograph),
self._function_attributes)
self._function_cache[cache_key] = graph_function
return graph_function, [
@@ -1351,7 +1367,7 @@ def _validate_signature(signature):
"a possibly nested sequence of TensorSpec objects.")
-def defun(func=None, input_signature=None):
+def defun(func=None, input_signature=None, experimental_autograph=False):
"""Compiles a Python function into a callable TensorFlow graph.
`defun` (short for "define function") trace-compiles a Python function
@@ -1660,6 +1676,10 @@ def defun(func=None, input_signature=None):
function is instantiated for each inferred input signature. If a
signature is specified, every input to `func` must be a `Tensor`, and
`func` cannot accept `**kwargs`.
+ experimental_autograph: Whether `func` should be compiled before
+ constructing the graph. See https://www.tensorflow.org/guide/autograph
+ for more information.
+
Returns:
If `func` is not None, returns a callable that will execute the compiled
@@ -1671,10 +1691,16 @@ def defun(func=None, input_signature=None):
TypeError: If `input_signature` is neither `None` nor a sequence of
`tf.contrib.eager.TensorSpec` objects.
"""
- return defun_with_attributes(func=func, input_signature=input_signature)
+ return defun_with_attributes(
+ func=func,
+ input_signature=input_signature,
+ experimental_autograph=experimental_autograph)
-def defun_with_attributes(func=None, input_signature=None, attributes=None):
+def defun_with_attributes(func=None,
+ input_signature=None,
+ attributes=None,
+ experimental_autograph=False):
"""Compiles a Python function into a callable TensorFlow graph.
This function supports adding extra function attributes. See detailed
@@ -1689,6 +1715,7 @@ def defun_with_attributes(func=None, input_signature=None, attributes=None):
attributes. Currently only support primitive types as value, and only
whitelisted attribute name is allowed. Unwhitelisted attribute name or
unsupported value will result into ValueError.
+ experimental_autograph: same as defun()'s experimental_autograph.
Returns:
Same as the return value of defun, with attributes added to the function in
@@ -1705,8 +1732,12 @@ def defun_with_attributes(func=None, input_signature=None, attributes=None):
name = "function"
return tf_decorator.make_decorator(
function,
- PolymorphicFunction(function, name, input_signature=input_signature,
- attributes=attributes))
+ PolymorphicFunction(
+ function,
+ name,
+ input_signature=input_signature,
+ attributes=attributes,
+ experimental_autograph=experimental_autograph))
# This code path is for the `foo = tfe.defun(foo, ...)` use case
if func is not None:
@@ -1909,8 +1940,10 @@ class AutomaticControlDependencies(object):
last_op_using_resource_tensor[inp] = op
ops_which_must_run = set([op])
continue
+ found_resource = False
for inp in op.inputs:
if inp.dtype == dtypes_module.resource:
+ found_resource = True
# Deal with switches, finally.
if inp.op.type == "Switch":
self._process_switch(inp.op, ops_which_must_run,
@@ -1925,6 +1958,11 @@ class AutomaticControlDependencies(object):
if inp in merge_for_resource:
merge_for_resource[inp]._add_control_input(op) # pylint: disable=protected-access
last_op_using_resource_tensor[inp] = op
+ if (op.op_def.is_stateful and not found_resource
+ and op._control_flow_context is None): # pylint: disable=protected-access
+ if None in last_op_using_resource_tensor:
+ op._add_control_input(last_op_using_resource_tensor[None]) # pylint: disable=protected-access
+ last_op_using_resource_tensor[None] = op
control_inputs = [c for c in control_inputs
if c._control_flow_context is op._control_flow_context] # pylint: disable=protected-access
op._add_control_inputs(control_inputs) # pylint: disable=protected-access
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 9ce367a837..e46bde098b 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -172,6 +172,43 @@ class FunctionTest(test.TestCase):
out = sq_op(t)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
+ def testInputSpecGraphFunction(self):
+ matmul = function.defun(math_ops.matmul)
+
+ @function.defun
+ def sq(a):
+ return matmul(a, a)
+
+ sq_op = sq.get_concrete_function(
+ tensor_spec.TensorSpec((None, None), dtypes.float32))
+ self.assertEqual([None, None], sq_op.output_shapes.as_list())
+
+ t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ out1 = sq_op(t1)
+ self.assertAllEqual(out1, math_ops.matmul(t1, t1).numpy())
+
+ t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ out2 = sq_op(t2)
+ self.assertAllEqual(out2, math_ops.matmul(t2, t2).numpy())
+
+ def testNestedInputSpecGraphFunction(self):
+ matmul = function.defun(math_ops.matmul)
+
+ @function.defun
+ def sq(mats):
+ ((a, b),) = mats
+ return matmul(a, b)
+
+ sq_op = sq.get_concrete_function(
+ [(tensor_spec.TensorSpec((None, None), dtypes.float32),
+ tensor_spec.TensorSpec((None, None), dtypes.float32))])
+ self.assertEqual([None, None], sq_op.output_shapes.as_list())
+
+ t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ t2 = constant_op.constant([[1.4, 2.4], [3.4, 4.4]])
+ out = sq_op(t1, t2) # Flattened structure for inputs to the graph function
+ self.assertAllEqual(out, math_ops.matmul(t1, t2).numpy())
+
def testExecutingStatelessDefunConcurrently(self):
@function.defun
@@ -249,7 +286,23 @@ class FunctionTest(test.TestCase):
c = constant_op.constant([[2.]])
f_c = f(c)
g, = gradients_impl.gradients(f_c, c)
- self.assertAllEqual(sess.run(g), [[1.0]])
+ self.assertAllEqual(sess.run(g).values, [[1.0]])
+
+ def testNoSymGradNestedDefun(self):
+
+ @function.defun
+ def outer():
+
+ @function.defun
+ def f(x):
+ return array_ops.gather_nd(x, [[0]])
+
+ c = constant_op.constant([[2.]])
+ f_c = f(c)
+ g, = gradients_impl.gradients(f_c, c)
+ self.assertTrue(isinstance(g, ops.IndexedSlices))
+
+ outer()
def testNestedInputsGraphFunction(self):
matmul = function.defun(math_ops.matmul)
@@ -1255,6 +1308,44 @@ class FunctionTest(test.TestCase):
defined(Foo())
self.assertEqual(len(defined._function_cache), 2)
+ def testCacheTensorShapeDtypeCollision(self):
+
+ def func(t):
+ return t + t
+
+ defined = function.defun(func)
+ t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
+ defined(t)
+ self.assertEqual(len(defined._function_cache), 1)
+
+ t = constant_op.constant([1.0], dtype=dtypes.complex128)
+ defined(t)
+ self.assertEqual(len(defined._function_cache), 2)
+
+ def testCacheTensorUnknownShapesCollision(self):
+
+ def func(t):
+ return t + t
+
+ with context.graph_mode(), self.cached_session():
+ defined = function.defun(func)
+
+ p = array_ops.placeholder(dtype=dtypes.float32, shape=None)
+ defined(p)
+ self.assertEqual(len(defined._function_cache), 1)
+
+ p = array_ops.placeholder(dtype=dtypes.float32, shape=[None])
+ defined(p)
+ self.assertEqual(len(defined._function_cache), 2)
+
+ p = array_ops.placeholder(dtype=dtypes.float32, shape=[None, None])
+ defined(p)
+ self.assertEqual(len(defined._function_cache), 3)
+
+ t = constant_op.constant(1.0, dtype=dtypes.float32)
+ defined(t)
+ self.assertEqual(len(defined._function_cache), 4)
+
def testPythonFunctionWithDefaultArgs(self):
def func(foo, bar=1, baz=2):
@@ -1271,17 +1362,17 @@ class FunctionTest(test.TestCase):
return tuple(key[0] for key in defined._function_cache)
# `True` corresponds to the fact that we're executing eagerly
- self.assertIn(('tRRR', (0, 1, 20)), cache_keys())
+ self.assertIn(('URRR', (0, 1, 20)), cache_keys())
defined(1) # bar=1, baz=2
- self.assertIn(('tRRR', (1, 1, 2)), cache_keys())
+ self.assertIn(('URRR', (1, 1, 2)), cache_keys())
# This matches the previous call.
defined(foo=1)
self.assertEqual(len(defined._function_cache), 2)
defined(1, 2, 3)
- self.assertIn(('tRRR', (1, 2, 3)), cache_keys())
+ self.assertIn(('URRR', (1, 2, 3)), cache_keys())
# This matches the previous call.
defined(1, bar=2, baz=3)
diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py
index 5f5af4ab6c..5c35860e9d 100644
--- a/tensorflow/python/eager/imperative_grad.py
+++ b/tensorflow/python/eager/imperative_grad.py
@@ -51,11 +51,6 @@ def imperative_grad(
Raises:
RuntimeError: if something goes wrong.
- ValueError: if there is no sequence of differentiable operations connecting
- a source and any target Tensor. This can happen either if the target is
- not computed based on the source, if the tracing was set up incorrectly,
- or if only non-differentiable functions of the source were used in the
- computation of target.
"""
return pywrap_tensorflow.TFE_Py_TapeGradient(
tape._tape, # pylint: disable=protected-access
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index ae1e12f9c3..9789dbadee 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -1228,8 +1228,9 @@ static PyTypeObject TFE_Py_Tape_Type = {
// GIL, which is always held when any TFE_Py_* methods are called. We should
// revisit this if/when decide to not hold the GIL while manipulating the tape
// stack.
-static tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* tape_set = nullptr;
tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() {
+ thread_local tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* tape_set{
+ nullptr};
if (tape_set == nullptr) {
tape_set = new tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>;
}
@@ -1264,27 +1265,10 @@ class SafeTapeSet {
tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*> tape_set_;
};
-// xcode 7 doesn't define thread_local, so for compatibility we implement our
-// own. TODO(apassos) remove once we can deprecate xcode 7.
-#ifndef __APPLE__
bool* ThreadTapeIsStopped() {
thread_local bool thread_tape_is_stopped{false};
return &thread_tape_is_stopped;
}
-#else
-static std::unordered_map<std::thread::id, bool>* tape_is_stopped = nullptr;
-bool* ThreadTapeIsStopped() {
- if (tape_is_stopped == nullptr) {
- tape_is_stopped = new std::unordered_map<std::thread::id, bool>;
- }
- auto it = tape_is_stopped->find(std::this_thread::get_id());
- if (it != tape_is_stopped->end()) {
- return &(it->second);
- }
- return &(tape_is_stopped->emplace(std::this_thread::get_id(), false)
- .first->second);
-}
-#endif
void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; }
@@ -1852,6 +1836,8 @@ bool OpGradientDoesntRequireOutputIndices(
{"SoftplusGrad", {true, {}}},
{"Softsign", {true, {}}},
{"ReluGrad", {true, {}}},
+ {"LeakyRelu", {true, {}}},
+ {"LeakyReluGrad", {true, {}}},
{"Conv2D", {true, {}}},
{"DepthwiseConv2dNative", {true, {}}},
{"Dilation2D", {true, {}}},
@@ -2747,11 +2733,15 @@ PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
}
namespace {
-
-tensorflow::int64 GetPyNoneHash() {
- tensorflow::int64 py_none_hash = PyObject_Hash(Py_None);
- return py_none_hash;
-}
+const char kTensor[] = "T";
+const char kIndexedSlices[] = "I";
+const char kList[] = "L";
+const char kTuple[] = "U";
+const char kDict[] = "D";
+const char kRaw[] = "R";
+const char kShape[] = "s";
+const char kDType[] = "d";
+const char kNone[] = "n";
struct EncodeResult {
string str;
@@ -2784,8 +2774,10 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) {
TFE_TensorHandle* t = EagerTensor_Handle(arg);
tensorflow::TensorShape tensor_shape;
TF_RETURN_IF_ERROR(t->handle->Shape(&tensor_shape));
- absl::StrAppend(&result->str, t->handle->dtype);
+ absl::StrAppend(&result->str, kDType, t->handle->dtype);
+
+ absl::StrAppend(&result->str, kShape);
for (tensorflow::int64 dim_size : tensor_shape.dim_sizes()) {
absl::StrAppend(&result->str, dim_size);
}
@@ -2812,7 +2804,7 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) {
tensorflow::DataType dtype =
static_cast<tensorflow::DataType>(MakeInt(dtype_enum.get()));
- absl::StrAppend(&result->str, dtype);
+ absl::StrAppend(&result->str, kDType, dtype);
static char _shape_tuple[] = "_shape_tuple";
tensorflow::Safe_PyObjectPtr shape_tuple(
PyObject_CallMethod(arg, _shape_tuple, nullptr));
@@ -2824,10 +2816,11 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) {
if (shape_tuple.get() == Py_None) {
// Unknown shape, encode that directly.
- absl::StrAppend(&result->str, GetPyNoneHash());
+ absl::StrAppend(&result->str, kNone);
return tensorflow::Status::OK();
}
+ absl::StrAppend(&result->str, kShape);
tensorflow::Safe_PyObjectPtr shape_seq(PySequence_Fast(
shape_tuple.get(), "shape_tuple didn't return a sequence"));
@@ -2835,7 +2828,7 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) {
for (int i = 0; i < len; ++i) {
PyObject* item = PySequence_Fast_GET_ITEM(shape_seq.get(), i);
if (item == Py_None) {
- absl::StrAppend(&result->str, GetPyNoneHash());
+ absl::StrAppend(&result->str, kNone);
} else {
absl::StrAppend(&result->str, MakeInt(item));
}
@@ -2844,13 +2837,6 @@ tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) {
return tensorflow::Status::OK();
}
-const char kTensor[] = "T";
-const char kIndexedSlices[] = "I";
-const char kList[] = "L";
-const char kTuple[] = "t";
-const char kDict[] = "D";
-const char kRaw[] = "R";
-
tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result);
// This function doesn't set the type of sequence before
@@ -2864,7 +2850,7 @@ tensorflow::Status TFE_Py_EncodeSequence(PyObject* arg, const char* type,
for (int i = 0; i < len; ++i) {
PyObject* item = PySequence_Fast_GET_ITEM(arg_seq.get(), i);
if (item == Py_None) {
- absl::StrAppend(&result->str, GetPyNoneHash());
+ absl::StrAppend(&result->str, kNone);
} else {
TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(item, result));
}
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 5352796174..28a8286544 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -2660,6 +2660,7 @@ class _EmbeddingColumn(
inputs=inputs,
weight_collections=weight_collections,
trainable=trainable)
+
sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
sequence_length = _sequence_length_from_sparse_tensor(
sparse_tensors.id_tensor)
@@ -3383,6 +3384,16 @@ class _IndicatorColumn(_DenseColumn, _SequenceDenseColumn,
def _verify_static_batch_size_equality(tensors, columns):
+ """Validates that the first dim (batch size) of all tensors are equal or None.
+
+ Args:
+ tensors: list of tensors to check.
+ columns: list of feature columns matching tensors. Will be used for error
+ messaging.
+
+ Raises:
+ ValueError: if one of the tensors has a variant batch size
+ """
# bath_size is a tf.Dimension object.
expected_batch_size = None
for i in range(0, len(tensors)):
@@ -3403,9 +3414,18 @@ def _sequence_length_from_sparse_tensor(sp_tensor, num_elements=1):
with ops.name_scope(None, 'sequence_length') as name_scope:
row_ids = sp_tensor.indices[:, 0]
column_ids = sp_tensor.indices[:, 1]
+ # Add one to convert column indices to element length
column_ids += array_ops.ones_like(column_ids)
- seq_length = math_ops.to_int64(
- math_ops.segment_max(column_ids, segment_ids=row_ids) / num_elements)
+ # Get the number of elements we will have per example/row
+ seq_length = math_ops.segment_max(column_ids, segment_ids=row_ids)
+
+ # The raw values are grouped according to num_elements;
+ # how many entities will we have after grouping?
+ # Example: orig tensor [[1, 2], [3]], col_ids = (0, 1, 1),
+ # row_ids = (0, 0, 1), seq_length = [2, 1]. If num_elements = 2,
+ # these will get grouped, and the final seq_length is [1, 1]
+ seq_length = math_ops.to_int64(math_ops.ceil(seq_length / num_elements))
+
# If the last n rows do not have ids, seq_length will have shape
# [batch_size - n]. Pad the remaining values with zeros.
n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1]
@@ -3439,25 +3459,14 @@ class _SequenceCategoricalColumn(
sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
id_tensor = sparse_tensors.id_tensor
weight_tensor = sparse_tensors.weight_tensor
- # Expands final dimension, so that embeddings are not combined during
- # embedding lookup.
- check_id_rank = check_ops.assert_equal(
- array_ops.rank(id_tensor), 2,
- data=[
- 'Column {} expected ID tensor of rank 2. '.format(self.name),
- 'id_tensor shape: ', array_ops.shape(id_tensor)])
- with ops.control_dependencies([check_id_rank]):
- id_tensor = sparse_ops.sparse_reshape(
- id_tensor,
- shape=array_ops.concat([id_tensor.dense_shape, [1]], axis=0))
+
+ # Expands third dimension, if necessary so that embeddings are not
+ # combined during embedding lookup. If the tensor is already 3D, leave
+ # as-is.
+ shape = array_ops.shape(id_tensor)
+ target_shape = [shape[0], shape[1], -1]
+ id_tensor = sparse_ops.sparse_reshape(id_tensor, target_shape)
if weight_tensor is not None:
- check_weight_rank = check_ops.assert_equal(
- array_ops.rank(weight_tensor), 2,
- data=[
- 'Column {} expected weight tensor of rank 2.'.format(self.name),
- 'weight_tensor shape:', array_ops.shape(weight_tensor)])
- with ops.control_dependencies([check_weight_rank]):
- weight_tensor = sparse_ops.sparse_reshape(
- weight_tensor,
- shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0))
+ weight_tensor = sparse_ops.sparse_reshape(weight_tensor, target_shape)
+
return _CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py
index c6595918ae..c9ac27e788 100644
--- a/tensorflow/python/framework/importer.py
+++ b/tensorflow/python/framework/importer.py
@@ -370,7 +370,8 @@ def import_graph_def(graph_def,
Returns:
A list of `Operation` and/or `Tensor` objects from the imported graph,
- corresponding to the names in `return_elements`.
+ corresponding to the names in `return_elements`,
+ and None if `returns_elements` is None.
Raises:
TypeError: If `graph_def` is not a `GraphDef` proto,
diff --git a/tensorflow/python/framework/op_def_library.py b/tensorflow/python/framework/op_def_library.py
index e85bba11cd..9955a9a2cd 100644
--- a/tensorflow/python/framework/op_def_library.py
+++ b/tensorflow/python/framework/op_def_library.py
@@ -482,7 +482,8 @@ class OpDefLibrary(object):
else:
raise TypeError("%s that don't all match." % prefix)
else:
- raise TypeError("%s that are invalid." % prefix)
+ raise TypeError(
+ "%s that are invalid. Tensors: %s" % (prefix, values))
types = [x.dtype for x in values]
inputs.extend(values)
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 8bb177939e..77c2bc930e 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -4140,10 +4140,7 @@ class Graph(object):
if op is None and not ignore_existing:
raise ValueError("Trying to reset colocation (op is None) but "
"ignore_existing is not True")
-
- if op is not None and not isinstance(op, Operation):
- # We always want to colocate with the reference op.
- op = internal_convert_to_tensor_or_indexed_slices(op, as_ref=True).op
+ op = _op_to_colocate_with(op)
# By default, colocate_with resets the device function stack,
# since colocate_with is typically used in specific internal
@@ -6168,4 +6165,27 @@ def _operation_conversion_error(op, dtype=None, name=None, as_ref=False):
name, as_ref))
+def _op_to_colocate_with(v):
+ """Operation object corresponding to v to use for colocation constraints."""
+ if v is None:
+ return None
+ if isinstance(v, Operation):
+ return v
+ # We always want to colocate with the reference op.
+ # When 'v' is a ResourceVariable, the reference op is the handle creating op.
+ #
+ # What this should be is:
+ # if isinstance(v, ResourceVariable):
+ # return v.handle.op
+ # However, that would require a circular import dependency.
+ # As of October 2018, there were attempts underway to remove
+ # colocation constraints altogether. Assuming that will
+ # happen soon, perhaps this hack to work around the circular
+ # import dependency is acceptable.
+ if hasattr(v, "handle") and hasattr(v.handle, "op") and isinstance(
+ v.handle.op, Operation):
+ return v.handle.op
+ return internal_convert_to_tensor_or_indexed_slices(v, as_ref=True).op
+
+
register_tensor_conversion_function(Operation, _operation_conversion_error)
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 4ec4b41b5e..95925bb471 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -506,9 +506,9 @@ def disable_control_flow_v2(unused_msg):
def assert_no_new_pyobjects_executing_eagerly(f):
"""Decorator for asserting that no new Python objects persist after a test.
- Runs the test multiple times executing eagerly, first as a warmup and then
- several times to let objects accumulate. The warmup helps ignore caches which
- do not grow as the test is run repeatedly.
+ Runs the test multiple times executing eagerly, first as a warmup and then to
+ let objects accumulate. The warmup helps ignore caches which do not grow as
+ the test is run repeatedly.
Useful for checking that there are no missing Py_DECREFs in the C exercised by
a bit of Python.
@@ -518,7 +518,14 @@ def assert_no_new_pyobjects_executing_eagerly(f):
"""Warms up, gets an object count, runs the test, checks for new objects."""
with context.eager_mode():
gc.disable()
- f(self, **kwargs)
+ # Run the test 2 times as warmup, in an attempt to fill up caches, which
+ # should not grow as the test is run repeatedly below.
+ #
+ # TODO(b/117156879): Running warmup twice is black magic; we have seen
+ # tests that fail with 1 warmup run, and pass with 2, on various versions
+ # of python2.7.x.
+ for _ in range(2):
+ f(self, **kwargs)
gc.collect()
previous_count = len(gc.get_objects())
if ops.has_default_graph():
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 4a72c4b3f3..c4d23f117f 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -62,6 +62,7 @@ py_library(
":backend",
":engine",
":layers",
+ ":optimizer_v2",
"//tensorflow/python/saved_model",
"//tensorflow/python:training",
],
@@ -189,6 +190,30 @@ py_library(
],
)
+py_library(
+ name = "optimizer_v2",
+ srcs = [
+ "optimizer_v2/adadelta.py",
+ "optimizer_v2/adagrad.py",
+ "optimizer_v2/adam.py",
+ "optimizer_v2/optimizer_v2.py",
+ "optimizer_v2/rmsprop.py",
+ "optimizer_v2/sgd.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:distribute",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ ],
+)
+
py_test(
name = "integration_test",
size = "medium",
@@ -827,3 +852,133 @@ py_library(
"//third_party/py/numpy",
],
)
+
+cuda_py_test(
+ name = "adadelta_test",
+ size = "medium",
+ srcs = ["optimizer_v2/adadelta_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:embedding_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ ],
+)
+
+cuda_py_test(
+ name = "adagrad_test",
+ size = "small",
+ srcs = ["optimizer_v2/adagrad_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "//tensorflow/python:embedding_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
+ ],
+)
+
+cuda_py_test(
+ name = "adam_test",
+ size = "small",
+ srcs = ["optimizer_v2/adam_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
+ ],
+)
+
+cuda_py_test(
+ name = "checkpointable_utils_test",
+ srcs = ["optimizer_v2/checkpointable_utils_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "@six_archive//:six",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:layers_base",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:test",
+ "//tensorflow/python/keras",
+ ],
+ tags = ["notsan"],
+)
+
+cuda_py_test(
+ name = "sgd_test",
+ size = "medium",
+ srcs = ["optimizer_v2/sgd_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:embedding_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:resources",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
+cuda_py_test(
+ name = "optimizer_v2_test",
+ size = "medium",
+ srcs = ["optimizer_v2/optimizer_v2_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:clip_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:variables",
+ ],
+)
+
+cuda_py_test(
+ name = "rmsprop_test",
+ size = "small",
+ srcs = ["optimizer_v2/rmsprop_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "@absl_py//absl/testing:parameterized",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:embedding_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
+ ],
+ tags = ["optonly"],
+)
diff --git a/tensorflow/python/keras/activations.py b/tensorflow/python/keras/activations.py
index 99645de736..d69791ce8d 100644
--- a/tensorflow/python/keras/activations.py
+++ b/tensorflow/python/keras/activations.py
@@ -160,6 +160,11 @@ def sigmoid(x):
return nn.sigmoid(x)
+@tf_export('keras.activations.exponential')
+def exponential(x):
+ return math_ops.exp(x)
+
+
@tf_export('keras.activations.hard_sigmoid')
def hard_sigmoid(x):
"""Hard sigmoid activation function.
diff --git a/tensorflow/python/keras/activations_test.py b/tensorflow/python/keras/activations_test.py
index dd0bbcff39..ad238cb0a9 100644
--- a/tensorflow/python/keras/activations_test.py
+++ b/tensorflow/python/keras/activations_test.py
@@ -169,6 +169,16 @@ class KerasActivationsTest(test.TestCase):
expected = np.tanh(test_values)
self.assertAllClose(result, expected, rtol=1e-05)
+ def test_exponential(self):
+ with self.cached_session():
+ test_values = np.random.random((2, 5))
+ x = keras.backend.placeholder(ndim=2)
+ exp = keras.activations.exponential(x)
+ f = keras.backend.function([x], [exp])
+ result = f([test_values])[0]
+ expected = np.exp(test_values)
+ self.assertAllClose(result, expected, rtol=1e-05)
+
def test_linear(self):
x = np.random.random((10, 5))
self.assertAllClose(x, keras.activations.linear(x))
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 0d6877e4a1..13f52fbae7 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -653,6 +653,7 @@ def variable(value, dtype=None, name=None, constraint=None):
Examples:
```python
+ >>> import numpy as np
>>> from keras import backend as K
>>> val = np.array([[1, 2], [3, 4]])
>>> kvar = K.variable(value=val, dtype='float64', name='example_var')
@@ -773,6 +774,8 @@ def is_keras_tensor(x):
Examples:
```python
+ >>> import tensorflow as tf
+ >>> import numpy
>>> from keras import backend as K
>>> from keras.layers import Input, Dense
>>> np_var = numpy.array([1, 2])
@@ -2220,7 +2223,7 @@ def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3):
@tf_export('keras.backend.batch_normalization')
-def batch_normalization(x, mean, var, beta, gamma, epsilon=1e-3):
+def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
"""Applies batch normalization on x given mean, var, beta and gamma.
I.e. returns:
@@ -2232,11 +2235,49 @@ def batch_normalization(x, mean, var, beta, gamma, epsilon=1e-3):
var: Variance of batch.
beta: Tensor with which to center the input.
gamma: Tensor by which to scale the input.
+ axis: Integer, the axis that should be normalized.
+ (typically the features axis).
epsilon: Fuzz factor.
Returns:
A tensor.
"""
+ if ndim(x) == 4:
+ # The CPU implementation of `fused_batch_norm` only supports NHWC
+ if axis == 1 or axis == -3:
+ tf_data_format = 'NCHW'
+ elif axis == 3 or axis == -1:
+ tf_data_format = 'NHWC'
+ else:
+ tf_data_format = None
+
+ if (tf_data_format == 'NHWC' or
+ tf_data_format == 'NCHW' and _has_nchw_support()):
+ # The mean / var / beta / gamma tensors may be broadcasted
+ # so they may have extra axes of size 1, which should be squeezed.
+ if ndim(mean) > 1:
+ mean = array_ops.reshape(mean, [-1])
+ if ndim(var) > 1:
+ var = array_ops.reshape(var, [-1])
+ if beta is None:
+ beta = zeros_like(mean)
+ elif ndim(beta) > 1:
+ beta = array_ops.reshape(beta, [-1])
+ if gamma is None:
+ gamma = ones_like(mean)
+ elif ndim(gamma) > 1:
+ gamma = array_ops.reshape(gamma, [-1])
+ y, _, _ = nn.fused_batch_norm(
+ x,
+ gamma,
+ beta,
+ epsilon=epsilon,
+ mean=mean,
+ variance=var,
+ data_format=tf_data_format,
+ is_training=False
+ )
+ return y
return nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
@@ -2877,7 +2918,7 @@ class Function(object):
if session_kwargs:
raise ValueError('Some keys in session_kwargs are not supported at this '
- 'time: %s', session_kwargs.keys())
+ 'time: %s', (session_kwargs.keys(),))
self._callable_fn = None
self._feed_arrays = None
@@ -3795,19 +3836,23 @@ def _preprocess_conv1d_input(x, data_format):
return x, tf_data_format
-def _preprocess_conv2d_input(x, data_format):
+def _preprocess_conv2d_input(x, data_format, force_transpose=False):
"""Transpose and cast the input before the conv2d.
Arguments:
x: input tensor.
data_format: string, `"channels_last"` or `"channels_first"`.
+ force_transpose: Boolean. If True, the input will always be transposed
+ from NCHW to NHWC if `data_format` is `"channels_first"`.
+ If False, the transposition only occurs on CPU (GPU ops are
+ assumed to support NCHW).
Returns:
A tensor.
"""
tf_data_format = 'NHWC'
if data_format == 'channels_first':
- if not _has_nchw_support():
+ if not _has_nchw_support() or force_transpose:
x = array_ops.transpose(x, (0, 2, 3, 1)) # NCHW -> NHWC
else:
tf_data_format = 'NCHW'
@@ -3955,7 +4000,8 @@ def conv2d_transpose(x,
output_shape,
strides=(1, 1),
padding='valid',
- data_format=None):
+ data_format=None,
+ dilation_rate=(1, 1)):
"""2D deconvolution (i.e.
transposed convolution).
@@ -3969,6 +4015,7 @@ def conv2d_transpose(x,
data_format: string, `"channels_last"` or `"channels_first"`.
Whether to use Theano or TensorFlow/CNTK data format
for inputs/kernels/outputs.
+ dilation_rate: Tuple of 2 integers.
Returns:
A tensor, result of transposed 2D convolution.
@@ -3984,7 +4031,13 @@ def conv2d_transpose(x,
if isinstance(output_shape, (tuple, list)):
output_shape = array_ops.stack(output_shape)
- x, tf_data_format = _preprocess_conv2d_input(x, data_format)
+ # `atrous_conv2d_transpose` only supports NHWC format, even on GPU.
+ if data_format == 'channels_first' and dilation_rate != (1, 1):
+ force_transpose = True
+ else:
+ force_transpose = False
+
+ x, tf_data_format = _preprocess_conv2d_input(x, data_format, force_transpose)
if data_format == 'channels_first' and tf_data_format == 'NHWC':
output_shape = (output_shape[0], output_shape[2], output_shape[3],
@@ -3999,13 +4052,18 @@ def conv2d_transpose(x,
else:
strides = (1, 1) + strides
- x = nn.conv2d_transpose(
- x,
- kernel,
- output_shape,
- strides,
- padding=padding,
- data_format=tf_data_format)
+ if dilation_rate == (1, 1):
+ x = nn.conv2d_transpose(x, kernel, output_shape, strides,
+ padding=padding,
+ data_format=tf_data_format)
+ else:
+ assert dilation_rate[0] == dilation_rate[1]
+ x = nn.atrous_conv2d_transpose(
+ x,
+ kernel,
+ output_shape,
+ rate=dilation_rate[0],
+ padding=padding)
if data_format == 'channels_first' and tf_data_format == 'NHWC':
x = array_ops.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
return x
diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py
index ab71589940..0834448699 100644
--- a/tensorflow/python/keras/backend_test.py
+++ b/tensorflow/python/keras/backend_test.py
@@ -26,6 +26,7 @@ from tensorflow.python import keras
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import nn
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.util import tf_inspect
@@ -1381,6 +1382,36 @@ class BackendNNOpsTest(test.TestCase, parameterized.TestCase):
self.assertEqual(mean.get_shape().as_list(), [3,])
self.assertEqual(var.get_shape().as_list(), [3,])
+ def test_batch_normalization(self):
+ g_val = np.random.random((3,))
+ b_val = np.random.random((3,))
+ gamma = keras.backend.variable(g_val)
+ beta = keras.backend.variable(b_val)
+
+ # 3D NHC case
+ val = np.random.random((10, 5, 3))
+ x = keras.backend.variable(val)
+ mean, var = nn.moments(x, (0, 1), None, None, False)
+ normed = keras.backend.batch_normalization(
+ x, mean, var, beta, gamma, axis=-1, epsilon=1e-3)
+ self.assertEqual(normed.shape.as_list(), [10, 5, 3])
+
+ # 4D NHWC case
+ val = np.random.random((10, 5, 5, 3))
+ x = keras.backend.variable(val)
+ mean, var = nn.moments(x, (0, 1, 2), None, None, False)
+ normed = keras.backend.batch_normalization(
+ x, mean, var, beta, gamma, axis=-1, epsilon=1e-3)
+ self.assertEqual(normed.shape.as_list(), [10, 5, 5, 3])
+
+ # 4D NCHW case
+ val = np.random.random((10, 3, 5, 5))
+ x = keras.backend.variable(val)
+ mean, var = nn.moments(x, (0, 2, 3), None, None, False)
+ normed = keras.backend.batch_normalization(
+ x, mean, var, beta, gamma, axis=1, epsilon=1e-3)
+ self.assertEqual(normed.shape.as_list(), [10, 3, 5, 5])
+
class TestCTC(test.TestCase):
@@ -1506,12 +1537,13 @@ class TestRandomOps(test.TestCase):
self.assertAllClose(np.min(y), -2., atol=0.1)
def test_string_input(self):
- seq = keras.Sequential([
- keras.layers.InputLayer(input_shape=(1,), dtype=dtypes.string),
- keras.layers.Lambda(lambda x: x[0])
- ])
- preds = seq.predict([['tensorflow eager']])
- self.assertEqual(preds.shape, (1,))
+ with self.cached_session():
+ seq = keras.Sequential([
+ keras.layers.InputLayer(input_shape=(1,), dtype=dtypes.string),
+ keras.layers.Lambda(lambda x: x[0])
+ ])
+ preds = seq.predict([['tensorflow eager']])
+ self.assertEqual(preds.shape, (1,))
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index 6dfbbf3694..3d6000f223 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -781,6 +781,10 @@ class LearningRateScheduler(Callback):
print('\nEpoch %05d: LearningRateScheduler reducing learning '
'rate to %s.' % (epoch + 1, lr))
+ def on_epoch_end(self, epoch, logs=None):
+ logs = logs or {}
+ logs['lr'] = K.get_value(self.model.optimizer.lr)
+
@tf_export('keras.callbacks.TensorBoard')
class TensorBoard(Callback):
diff --git a/tensorflow/python/keras/engine/input_layer.py b/tensorflow/python/keras/engine/input_layer.py
index 8a4018a0df..6a69d0ed90 100644
--- a/tensorflow/python/keras/engine/input_layer.py
+++ b/tensorflow/python/keras/engine/input_layer.py
@@ -82,6 +82,7 @@ class InputLayer(base_layer.Layer):
self.built = True
self.sparse = sparse
self.batch_size = batch_size
+ self.supports_masking = True
if isinstance(input_shape, tensor_shape.TensorShape):
input_shape = tuple(input_shape.as_list())
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index 8d34006967..5969fea2b2 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -1028,7 +1028,10 @@ class Network(base_layer.Layer):
output_tensors, output_masks = layer._call_and_compute_mask(
computed_tensor, **kwargs)
else:
- output_tensors = layer.call(computed_tensor, **kwargs)
+ if context.executing_eagerly():
+ output_tensors = layer(computed_tensor, **kwargs)
+ else:
+ output_tensors = layer.call(computed_tensor, **kwargs)
if hasattr(layer, 'compute_mask'):
output_masks = layer.compute_mask(computed_tensor,
computed_mask)
@@ -1049,7 +1052,10 @@ class Network(base_layer.Layer):
output_tensors, output_masks = layer._call_and_compute_mask(
computed_tensors, **kwargs)
else:
- output_tensors = layer.call(computed_tensors, **kwargs)
+ if context.executing_eagerly():
+ output_tensors = layer(computed_tensors, **kwargs)
+ else:
+ output_tensors = layer.call(computed_tensors, **kwargs)
if hasattr(layer, 'compute_mask'):
output_masks = layer.compute_mask(computed_tensors,
computed_masks)
@@ -1635,10 +1641,11 @@ class Network(base_layer.Layer):
ValueError: if `summary()` is called before the model is built.
"""
if not self.built:
- raise ValueError('This model has never been called, thus its weights '
- 'have not yet been created, so no summary can be '
- 'displayed. Build the model first '
- '(e.g. by calling it on some data).')
+ raise ValueError('This model has not yet been built. '
+ 'Build the model first by calling `build()` or calling '
+ '`fit()` with some data, or specify '
+ 'an `input_shape` argument in the first layer(s) for '
+ 'automatic build.')
layer_utils.print_summary(self,
line_length=line_length,
positions=positions,
diff --git a/tensorflow/python/keras/engine/topology_test.py b/tensorflow/python/keras/engine/topology_test.py
index a0da96334b..b4488033cd 100644
--- a/tensorflow/python/keras/engine/topology_test.py
+++ b/tensorflow/python/keras/engine/topology_test.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import test
+from tensorflow.python.training import rmsprop
try:
import yaml # pylint:disable=g-import-not-at-top
@@ -1182,6 +1183,36 @@ class DefaultShapeInferenceBehaviorTest(test.TestCase):
output = model(sample_input)
self.assertEqual(output.shape, (1, 3))
+ @test_util.run_in_graph_and_eager_modes()
+ def test_sequential_as_downstream_of_masking_layer(self):
+ inputs = keras.layers.Input(shape=(3, 4))
+ x = keras.layers.Masking(mask_value=0., input_shape=(3, 4))(inputs)
+
+ s = keras.Sequential()
+ s.add(keras.layers.Dense(5, input_shape=(4,)))
+
+ x = keras.layers.wrappers.TimeDistributed(s)(x)
+ model = keras.Model(inputs=inputs, outputs=x)
+ model.compile(optimizer=rmsprop.RMSPropOptimizer(1e-3), loss='mse')
+
+ model_input = np.random.randint(
+ low=1, high=5, size=(10, 3, 4)).astype('float32')
+ for i in range(4):
+ model_input[i, i:, :] = 0.
+ model.fit(model_input,
+ np.random.random((10, 3, 5)), epochs=1, batch_size=6)
+
+ if not context.executing_eagerly():
+ # Note: this doesn't work in eager due to DeferredTensor/ops compatibility
+ # issue.
+ mask_outputs = [model.layers[1].compute_mask(model.layers[1].input)]
+ mask_outputs += [model.layers[2].compute_mask(
+ model.layers[2].input, mask_outputs[-1])]
+ func = keras.backend.function([model.input], mask_outputs)
+ mask_outputs_val = func([model_input])
+ self.assertAllClose(mask_outputs_val[0], np.any(model_input, axis=-1))
+ self.assertAllClose(mask_outputs_val[1], np.any(model_input, axis=-1))
+
class GraphUtilsTest(test.TestCase):
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 2ebb4cf99f..ff2ae54ad4 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -563,9 +563,11 @@ class Model(Network):
for name in self.output_names:
tmp_target_tensors.append(target_tensors.get(name, None))
target_tensors = tmp_target_tensors
+ elif tensor_util.is_tensor(target_tensors):
+ target_tensors = [target_tensors]
else:
- raise TypeError('Expected `target_tensors` to be '
- 'a list or dict, but got:', target_tensors)
+ raise TypeError('Expected `target_tensors` to be a list or tuple or '
+ 'dict or a single tensor, but got:', target_tensors)
for i in range(len(self.outputs)):
if i in skip_target_indices:
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index 04e8d079c0..ac759ef3aa 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -820,10 +820,6 @@ def _clone_and_build_model(model, inputs=None, targets=None):
optimizer_config = model.optimizer.get_config()
optimizer = model.optimizer.__class__.from_config(optimizer_config)
- # TODO(priyag): Is there a cleaner way to do this? The API doc suggests a
- # single tensor should be OK but it throws an error in that case.
- if targets is not None and not isinstance(targets, (list, dict, tuple)):
- targets = [targets]
if isinstance(targets, tuple):
targets = nest.flatten(targets)
cloned_model.compile(
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 54ad74c08b..868fd1dc69 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -1865,6 +1865,10 @@ class TestTrainingWithDataTensors(test.TestCase):
model.compile(optimizer='rmsprop', loss='mse', target_tensors=[target])
model.train_on_batch(input_val, None)
+ # single-output, as single tensor
+ model.compile(optimizer='rmsprop', loss='mse', target_tensors=target)
+ model.train_on_batch(input_val, None)
+
# single-output, as dict
model.compile(optimizer='rmsprop', loss='mse',
target_tensors={'dense': target})
diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py
index d00def07bb..8f5872385c 100644
--- a/tensorflow/python/keras/layers/convolutional.py
+++ b/tensorflow/python/keras/layers/convolutional.py
@@ -645,6 +645,14 @@ class Conv2DTranspose(Conv2D):
Specifying any stride value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
padding: one of `"valid"` or `"same"` (case-insensitive).
+ output_padding: An integer or tuple/list of 2 integers,
+ specifying the amount of padding along the height and width
+ of the output tensor.
+ Can be a single integer to specify the same value for all
+ spatial dimensions.
+ The amount of output padding along a given dimension must be
+ lower than the stride along that same dimension.
+ If set to `None` (default), the output shape is inferred.
data_format: A string,
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
@@ -700,7 +708,9 @@ class Conv2DTranspose(Conv2D):
kernel_size,
strides=(1, 1),
padding='valid',
+ output_padding=None,
data_format=None,
+ dilation_rate=(1, 1),
activation=None,
use_bias=True,
kernel_initializer='glorot_uniform',
@@ -717,6 +727,7 @@ class Conv2DTranspose(Conv2D):
strides=strides,
padding=padding,
data_format=data_format,
+ dilation_rate=dilation_rate,
activation=activations.get(activation),
use_bias=use_bias,
kernel_initializer=initializers.get(kernel_initializer),
@@ -728,6 +739,16 @@ class Conv2DTranspose(Conv2D):
bias_constraint=constraints.get(bias_constraint),
**kwargs)
+ self.output_padding = output_padding
+ if self.output_padding is not None:
+ self.output_padding = conv_utils.normalize_tuple(
+ self.output_padding, 2, 'output_padding')
+ for stride, out_pad in zip(self.strides, self.output_padding):
+ if out_pad >= stride:
+ raise ValueError('Stride ' + str(self.strides) + ' must be '
+ 'greater than output padding ' +
+ str(self.output_padding))
+
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
if len(input_shape) != 4:
@@ -769,51 +790,50 @@ class Conv2DTranspose(Conv2D):
inputs_shape = array_ops.shape(inputs)
batch_size = inputs_shape[0]
if self.data_format == 'channels_first':
- c_axis, h_axis, w_axis = 1, 2, 3
+ h_axis, w_axis = 2, 3
else:
- c_axis, h_axis, w_axis = 3, 1, 2
+ h_axis, w_axis = 1, 2
height, width = inputs_shape[h_axis], inputs_shape[w_axis]
kernel_h, kernel_w = self.kernel_size
stride_h, stride_w = self.strides
+ if self.output_padding is None:
+ out_pad_h = out_pad_w = None
+ else:
+ out_pad_h, out_pad_w = self.output_padding
+
# Infer the dynamic output shape:
out_height = conv_utils.deconv_output_length(height,
kernel_h,
- self.padding,
- stride_h)
+ padding=self.padding,
+ output_padding=out_pad_h,
+ stride=stride_h,
+ dilation=self.dilation_rate[0])
out_width = conv_utils.deconv_output_length(width,
kernel_w,
- self.padding,
- stride_w)
+ padding=self.padding,
+ output_padding=out_pad_w,
+ stride=stride_w,
+ dilation=self.dilation_rate[1])
if self.data_format == 'channels_first':
output_shape = (batch_size, self.filters, out_height, out_width)
- strides = (1, 1, stride_h, stride_w)
else:
output_shape = (batch_size, out_height, out_width, self.filters)
- strides = (1, stride_h, stride_w, 1)
output_shape_tensor = array_ops.stack(output_shape)
- outputs = nn.conv2d_transpose(
+ outputs = backend.conv2d_transpose(
inputs,
self.kernel,
output_shape_tensor,
- strides,
- padding=self.padding.upper(),
- data_format=conv_utils.convert_data_format(self.data_format, ndim=4))
+ strides=self.strides,
+ padding=self.padding,
+ data_format=self.data_format,
+ dilation_rate=self.dilation_rate)
if not context.executing_eagerly():
# Infer the static output shape:
- out_shape = inputs.get_shape().as_list()
- out_shape[c_axis] = self.filters
- out_shape[h_axis] = conv_utils.deconv_output_length(out_shape[h_axis],
- kernel_h,
- self.padding,
- stride_h)
- out_shape[w_axis] = conv_utils.deconv_output_length(out_shape[w_axis],
- kernel_w,
- self.padding,
- stride_w)
+ out_shape = self.compute_output_shape(inputs.shape)
outputs.set_shape(out_shape)
if self.use_bias:
@@ -837,13 +857,33 @@ class Conv2DTranspose(Conv2D):
kernel_h, kernel_w = self.kernel_size
stride_h, stride_w = self.strides
+ if self.output_padding is None:
+ out_pad_h = out_pad_w = None
+ else:
+ out_pad_h, out_pad_w = self.output_padding
+
output_shape[c_axis] = self.filters
output_shape[h_axis] = conv_utils.deconv_output_length(
- output_shape[h_axis], kernel_h, self.padding, stride_h)
+ output_shape[h_axis],
+ kernel_h,
+ padding=self.padding,
+ output_padding=out_pad_h,
+ stride=stride_h,
+ dilation=self.dilation_rate[0])
output_shape[w_axis] = conv_utils.deconv_output_length(
- output_shape[w_axis], kernel_w, self.padding, stride_w)
+ output_shape[w_axis],
+ kernel_w,
+ padding=self.padding,
+ output_padding=out_pad_w,
+ stride=stride_w,
+ dilation=self.dilation_rate[1])
return tensor_shape.TensorShape(output_shape)
+ def get_config(self):
+ config = super(Conv2DTranspose, self).get_config()
+ config['output_padding'] = self.output_padding
+ return config
+
@tf_export('keras.layers.Conv3DTranspose',
'keras.layers.Convolution3DTranspose')
@@ -878,6 +918,14 @@ class Conv3DTranspose(Conv3D):
Specifying any stride value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
padding: one of `"valid"` or `"same"` (case-insensitive).
+ output_padding: An integer or tuple/list of 3 integers,
+ specifying the amount of padding along the depth, height, and
+ width.
+ Can be a single integer to specify the same value for all
+ spatial dimensions.
+ The amount of output padding along a given dimension must be
+ lower than the stride along that same dimension.
+ If set to `None` (default), the output shape is inferred.
data_format: A string,
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
@@ -943,6 +991,7 @@ class Conv3DTranspose(Conv3D):
kernel_size,
strides=(1, 1, 1),
padding='valid',
+ output_padding=None,
data_format=None,
activation=None,
use_bias=True,
@@ -971,6 +1020,16 @@ class Conv3DTranspose(Conv3D):
bias_constraint=constraints.get(bias_constraint),
**kwargs)
+ self.output_padding = output_padding
+ if self.output_padding is not None:
+ self.output_padding = conv_utils.normalize_tuple(
+ self.output_padding, 3, 'output_padding')
+ for stride, out_pad in zip(self.strides, self.output_padding):
+ if out_pad >= stride:
+ raise ValueError('Stride ' + str(self.strides) + ' must be '
+ 'greater than output padding ' +
+ str(self.output_padding))
+
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
if len(input_shape) != 5:
@@ -1012,11 +1071,9 @@ class Conv3DTranspose(Conv3D):
inputs_shape = array_ops.shape(inputs)
batch_size = inputs_shape[0]
if self.data_format == 'channels_first':
- c_axis, d_axis, h_axis, w_axis = 1, 2, 3, 4
+ d_axis, h_axis, w_axis = 2, 3, 4
else:
- c_axis, d_axis, h_axis, w_axis = 4, 1, 2, 3
-
- self.input_spec = InputSpec(ndim=5, axes={c_axis: inputs_shape[c_axis]})
+ d_axis, h_axis, w_axis = 1, 2, 3
depth = inputs_shape[d_axis]
height = inputs_shape[h_axis]
@@ -1025,19 +1082,27 @@ class Conv3DTranspose(Conv3D):
kernel_d, kernel_h, kernel_w = self.kernel_size
stride_d, stride_h, stride_w = self.strides
+ if self.output_padding is None:
+ out_pad_d = out_pad_h = out_pad_w = None
+ else:
+ out_pad_d, out_pad_h, out_pad_w = self.output_padding
+
# Infer the dynamic output shape:
out_depth = conv_utils.deconv_output_length(depth,
kernel_d,
- self.padding,
- stride_d)
+ padding=self.padding,
+ output_padding=out_pad_d,
+ stride=stride_d)
out_height = conv_utils.deconv_output_length(height,
kernel_h,
- self.padding,
- stride_h)
+ padding=self.padding,
+ output_padding=out_pad_h,
+ stride=stride_h)
out_width = conv_utils.deconv_output_length(width,
kernel_w,
- self.padding,
- stride_w)
+ padding=self.padding,
+ output_padding=out_pad_w,
+ stride=stride_w)
if self.data_format == 'channels_first':
output_shape = (batch_size, self.filters, out_depth, out_height,
out_width)
@@ -1058,20 +1123,7 @@ class Conv3DTranspose(Conv3D):
if not context.executing_eagerly():
# Infer the static output shape:
- out_shape = inputs.get_shape().as_list()
- out_shape[c_axis] = self.filters
- out_shape[d_axis] = conv_utils.deconv_output_length(out_shape[d_axis],
- kernel_d,
- self.padding,
- stride_d)
- out_shape[h_axis] = conv_utils.deconv_output_length(out_shape[h_axis],
- kernel_h,
- self.padding,
- stride_h)
- out_shape[w_axis] = conv_utils.deconv_output_length(out_shape[w_axis],
- kernel_w,
- self.padding,
- stride_w)
+ out_shape = self.compute_output_shape(inputs.shape)
outputs.set_shape(out_shape)
if self.use_bias:
@@ -1109,15 +1161,38 @@ class Conv3DTranspose(Conv3D):
kernel_d, kernel_h, kernel_w = self.kernel_size
stride_d, stride_h, stride_w = self.strides
+ if self.output_padding is None:
+ out_pad_d = out_pad_h = out_pad_w = None
+ else:
+ out_pad_d, out_pad_h, out_pad_w = self.output_padding
+
output_shape[c_axis] = self.filters
output_shape[d_axis] = conv_utils.deconv_output_length(
- output_shape[d_axis], kernel_d, self.padding, stride_d)
+ output_shape[d_axis],
+ kernel_d,
+ padding=self.padding,
+ output_padding=out_pad_d,
+ stride=stride_d)
output_shape[h_axis] = conv_utils.deconv_output_length(
- output_shape[h_axis], kernel_h, self.padding, stride_h)
+ output_shape[h_axis],
+ kernel_h,
+ padding=self.padding,
+ output_padding=out_pad_h,
+ stride=stride_h)
output_shape[w_axis] = conv_utils.deconv_output_length(
- output_shape[w_axis], kernel_w, self.padding, stride_w)
+ output_shape[w_axis],
+ kernel_w,
+ padding=self.padding,
+ output_padding=out_pad_w,
+ stride=stride_w)
return tensor_shape.TensorShape(output_shape)
+ def get_config(self):
+ config = super(Conv3DTranspose, self).get_config()
+ config.pop('dilation_rate')
+ config['output_padding'] = self.output_padding
+ return config
+
class SeparableConv(Conv):
"""Abstract base layer for separable nD convolution.
diff --git a/tensorflow/python/keras/layers/convolutional_test.py b/tensorflow/python/keras/layers/convolutional_test.py
index cad5e4c8bd..f88d632ab5 100644
--- a/tensorflow/python/keras/layers/convolutional_test.py
+++ b/tensorflow/python/keras/layers/convolutional_test.py
@@ -204,6 +204,9 @@ class Conv2DTransposeTest(test.TestCase):
if test.is_gpu_available(cuda_only=True):
self._run_test(kwargs, 'data_format', ['channels_first'])
+ kwargs['strides'] = (2, 2)
+ self._run_test(kwargs, 'output_padding', [(1, 1)])
+
def test_conv2dtranspose_regularizers(self):
kwargs = {
'filters': 3,
@@ -239,6 +242,31 @@ class Conv2DTransposeTest(test.TestCase):
self.assertEqual(layer.kernel.constraint, k_constraint)
self.assertEqual(layer.bias.constraint, b_constraint)
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_conv2d_transpose_dilation(self):
+ testing_utils.layer_test(keras.layers.Conv2DTranspose,
+ kwargs={'filters': 2,
+ 'kernel_size': 3,
+ 'padding': 'same',
+ 'data_format': 'channels_last',
+ 'dilation_rate': (2, 2)},
+ input_shape=(2, 5, 6, 3))
+
+ input_data = np.arange(48).reshape((1, 4, 4, 3)).astype(np.float32)
+ expected_output = np.float32([[192, 228, 192, 228],
+ [336, 372, 336, 372],
+ [192, 228, 192, 228],
+ [336, 372, 336, 372]]).reshape((1, 4, 4, 1))
+ testing_utils.layer_test(keras.layers.Conv2DTranspose,
+ input_data=input_data,
+ kwargs={'filters': 1,
+ 'kernel_size': 3,
+ 'padding': 'same',
+ 'data_format': 'channels_last',
+ 'dilation_rate': (2, 2),
+ 'kernel_initializer': 'ones'},
+ expected_output=expected_output)
+
class Conv3DTransposeTest(test.TestCase):
@@ -270,6 +298,9 @@ class Conv3DTransposeTest(test.TestCase):
if test.is_gpu_available(cuda_only=True):
self._run_test(kwargs, 'data_format', ['channels_first'])
+ kwargs['strides'] = (2, 2, 2)
+ self._run_test(kwargs, 'output_padding', [(1, 1, 1)])
+
def test_conv3dtranspose_regularizers(self):
kwargs = {
'filters': 3,
diff --git a/tensorflow/python/keras/layers/pooling.py b/tensorflow/python/keras/layers/pooling.py
index 912e8bd619..72a9c1d629 100644
--- a/tensorflow/python/keras/layers/pooling.py
+++ b/tensorflow/python/keras/layers/pooling.py
@@ -18,12 +18,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
+
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import backend
from tensorflow.python.keras.engine.base_layer import InputSpec
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.utils import conv_utils
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.util.tf_export import tf_export
@@ -41,16 +44,18 @@ class Pooling1D(Layer):
strides of the pooling operation.
padding: A string. The padding method, either 'valid' or 'same'.
Case-insensitive.
- data_format: A string, one of `channels_last` (default) or `channels_first`.
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
- `(batch, length, channels)` while `channels_first` corresponds to
- inputs with shape `(batch, channels, length)`.
+ `(batch, steps, features)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, features, steps)`.
name: A string, the name of the layer.
"""
def __init__(self, pool_function, pool_size, strides,
- padding='valid', data_format=None,
+ padding='valid', data_format='channels_last',
name=None, **kwargs):
super(Pooling1D, self).__init__(name=name, **kwargs)
if data_format is None:
@@ -65,45 +70,39 @@ class Pooling1D(Layer):
self.input_spec = InputSpec(ndim=3)
def call(self, inputs):
- # There is no TF op for 1D pooling, hence we make the inputs 4D.
- if self.data_format == 'channels_last':
- # input is NWC, make it NHWC
- inputs = array_ops.expand_dims(inputs, 1)
- # pool on the W dim
- pool_shape = (1, 1) + self.pool_size + (1,)
- strides = (1, 1) + self.strides + (1,)
- data_format = 'NHWC'
- else:
- # input is NCW, make it NCHW
- inputs = array_ops.expand_dims(inputs, 2)
- # pool on the W dim
- pool_shape = (1, 1, 1) + self.pool_size
- strides = (1, 1, 1) + self.strides
- data_format = 'NCHW'
-
+ pad_axis = 2 if self.data_format == 'channels_last' else 3
+ inputs = array_ops.expand_dims(inputs, pad_axis)
outputs = self.pool_function(
inputs,
- ksize=pool_shape,
- strides=strides,
- padding=self.padding.upper(),
- data_format=data_format)
-
- if self.data_format == 'channels_last':
- return array_ops.squeeze(outputs, 1)
- else:
- return array_ops.squeeze(outputs, 2)
+ self.pool_size + (1,),
+ strides=self.strides + (1,),
+ padding=self.padding,
+ data_format=self.data_format)
+ return array_ops.squeeze(outputs, pad_axis)
def compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
- length = conv_utils.conv_output_length(input_shape[1], self.pool_size[0],
- self.padding, self.strides[0])
- return tensor_shape.TensorShape([input_shape[0], length, input_shape[2]])
+ if self.data_format == 'channels_first':
+ steps = input_shape[2]
+ features = input_shape[1]
+ else:
+ steps = input_shape[1]
+ features = input_shape[2]
+ length = conv_utils.conv_output_length(steps,
+ self.pool_size[0],
+ self.padding,
+ self.strides[0])
+ if self.data_format == 'channels_first':
+ return tensor_shape.TensorShape([input_shape[0], features, length])
+ else:
+ return tensor_shape.TensorShape([input_shape[0], length, features])
def get_config(self):
config = {
'strides': self.strides,
'pool_size': self.pool_size,
- 'padding': self.padding
+ 'padding': self.padding,
+ 'data_format': self.data_format,
}
base_config = super(Pooling1D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@@ -119,19 +118,36 @@ class MaxPooling1D(Pooling1D):
E.g. 2 will halve the input.
If None, it will default to `pool_size`.
padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, steps, features)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, features, steps)`.
Input shape:
- 3D tensor with shape: `(batch_size, steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, steps)`
Output shape:
- 3D tensor with shape: `(batch_size, downsampled_steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, downsampled_steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, downsampled_steps)`
"""
def __init__(self, pool_size=2, strides=None,
- padding='valid', data_format=None, **kwargs):
+ padding='valid', data_format='channels_last', **kwargs):
super(MaxPooling1D, self).__init__(
- nn.max_pool,
+ functools.partial(backend.pool2d, pool_mode='max'),
pool_size=pool_size,
strides=strides,
padding=padding,
@@ -149,18 +165,35 @@ class AveragePooling1D(Pooling1D):
E.g. 2 will halve the input.
If None, it will default to `pool_size`.
padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, steps, features)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, features, steps)`.
Input shape:
- 3D tensor with shape: `(batch_size, steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, steps)`
Output shape:
- 3D tensor with shape: `(batch_size, downsampled_steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, downsampled_steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, downsampled_steps)`
"""
def __init__(self, pool_size=2, strides=None,
- padding='valid', data_format=None, **kwargs):
+ padding='valid', data_format='channels_last', **kwargs):
super(AveragePooling1D, self).__init__(
- nn.avg_pool,
+ functools.partial(backend.pool2d, pool_mode='avg'),
pool_size=pool_size,
strides=strides,
padding=padding,
@@ -561,41 +594,96 @@ class GlobalPooling1D(Layer):
"""Abstract class for different global pooling 1D layers.
"""
- def __init__(self, **kwargs):
+ def __init__(self, data_format='channels_last', **kwargs):
super(GlobalPooling1D, self).__init__(**kwargs)
self.input_spec = InputSpec(ndim=3)
+ self.data_format = conv_utils.normalize_data_format(data_format)
def compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
- return tensor_shape.TensorShape([input_shape[0], input_shape[2]])
+ if self.data_format == 'channels_first':
+ return tensor_shape.TensorShape([input_shape[0], input_shape[1]])
+ else:
+ return tensor_shape.TensorShape([input_shape[0], input_shape[2]])
def call(self, inputs):
raise NotImplementedError
+ def get_config(self):
+ config = {'data_format': self.data_format}
+ base_config = super(GlobalPooling1D, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
@tf_export('keras.layers.GlobalAveragePooling1D',
'keras.layers.GlobalAvgPool1D')
class GlobalAveragePooling1D(GlobalPooling1D):
"""Global average pooling operation for temporal data.
+ Arguments:
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, steps, features)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, features, steps)`.
+
Input shape:
- 3D tensor with shape: `(batch_size, steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, steps)`
Output shape:
2D tensor with shape:
`(batch_size, features)`
"""
- def call(self, inputs):
- return backend.mean(inputs, axis=1)
+ def __init__(self, data_format='channels_last', **kwargs):
+ super(GlobalAveragePooling1D, self).__init__(data_format=data_format,
+ **kwargs)
+ self.supports_masking = True
+
+ def call(self, inputs, mask=None):
+ steps_axis = 1 if self.data_format == 'channels_last' else 2
+ if mask is not None:
+ mask = math_ops.cast(mask, backend.floatx())
+ input_shape = inputs.shape.as_list()
+ broadcast_shape = [-1, input_shape[steps_axis], 1]
+ mask = array_ops.reshape(mask, broadcast_shape)
+ inputs *= mask
+ return backend.sum(inputs, axis=steps_axis) / math_ops.reduce_sum(
+ mask, axis=steps_axis)
+ else:
+ return backend.mean(inputs, axis=steps_axis)
+
+ def compute_mask(self, inputs, mask=None):
+ return None
@tf_export('keras.layers.GlobalMaxPool1D', 'keras.layers.GlobalMaxPooling1D')
class GlobalMaxPooling1D(GlobalPooling1D):
"""Global max pooling operation for temporal data.
+ Arguments:
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, steps, features)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, features, steps)`.
+
Input shape:
- 3D tensor with shape: `(batch_size, steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, steps)`
Output shape:
2D tensor with shape:
@@ -603,7 +691,8 @@ class GlobalMaxPooling1D(GlobalPooling1D):
"""
def call(self, inputs):
- return backend.max(inputs, axis=1)
+ steps_axis = 1 if self.data_format == 'channels_last' else 2
+ return backend.max(inputs, axis=steps_axis)
class GlobalPooling2D(Layer):
diff --git a/tensorflow/python/keras/layers/pooling_test.py b/tensorflow/python/keras/layers/pooling_test.py
index 2cd9939e66..936e73ecf9 100644
--- a/tensorflow/python/keras/layers/pooling_test.py
+++ b/tensorflow/python/keras/layers/pooling_test.py
@@ -18,11 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.python import keras
from tensorflow.python.eager import context
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import testing_utils
from tensorflow.python.platform import test
+from tensorflow.python.training import rmsprop
class GlobalPoolingTest(test.TestCase):
@@ -31,8 +34,26 @@ class GlobalPoolingTest(test.TestCase):
def test_globalpooling_1d(self):
testing_utils.layer_test(keras.layers.pooling.GlobalMaxPooling1D,
input_shape=(3, 4, 5))
+ testing_utils.layer_test(keras.layers.pooling.GlobalMaxPooling1D,
+ kwargs={'data_format': 'channels_first'},
+ input_shape=(3, 4, 5))
testing_utils.layer_test(
keras.layers.pooling.GlobalAveragePooling1D, input_shape=(3, 4, 5))
+ testing_utils.layer_test(keras.layers.pooling.GlobalAveragePooling1D,
+ kwargs={'data_format': 'channels_first'},
+ input_shape=(3, 4, 5))
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_globalpooling_1d_masking_support(self):
+ model = keras.Sequential()
+ model.add(keras.layers.Masking(mask_value=0., input_shape=(3, 4)))
+ model.add(keras.layers.GlobalAveragePooling1D())
+ model.compile(loss='mae', optimizer=rmsprop.RMSPropOptimizer(0.001))
+
+ model_input = np.random.random((2, 3, 4))
+ model_input[0, 1:, :] = 0
+ output = model.predict(model_input)
+ self.assertAllClose(output[0], model_input[0, 0, :])
@tf_test_util.run_in_graph_and_eager_modes
def test_globalpooling_2d(self):
@@ -172,6 +193,10 @@ class Pooling1DTest(test.TestCase):
kwargs={'strides': stride,
'padding': padding},
input_shape=(3, 5, 4))
+ testing_utils.layer_test(
+ keras.layers.MaxPooling1D,
+ kwargs={'data_format': 'channels_first'},
+ input_shape=(3, 2, 6))
@tf_test_util.run_in_graph_and_eager_modes
def test_averagepooling_1d(self):
@@ -183,6 +208,11 @@ class Pooling1DTest(test.TestCase):
'padding': padding},
input_shape=(3, 5, 4))
+ testing_utils.layer_test(
+ keras.layers.AveragePooling1D,
+ kwargs={'data_format': 'channels_first'},
+ input_shape=(3, 2, 6))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py
index a1933c11b0..d19d0b5f8c 100644
--- a/tensorflow/python/keras/layers/wrappers.py
+++ b/tensorflow/python/keras/layers/wrappers.py
@@ -587,6 +587,9 @@ class Bidirectional(Wrapper):
output = y * y_rev
elif self.merge_mode is None:
output = [y, y_rev]
+ else:
+ raise ValueError(
+ 'Unrecognized value for `merge_mode`: %s' % (self.merge_mode))
# Properly set learning phase
if (getattr(y, '_uses_learning_phase', False) or
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index f4e8419eb0..d217244e2f 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -651,7 +651,9 @@ def categorical_accuracy(y_true, y_pred):
@tf_export('keras.metrics.sparse_categorical_accuracy')
def sparse_categorical_accuracy(y_true, y_pred):
- y_true = math_ops.reduce_max(y_true, axis=-1)
+ # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
+ if (len(K.int_shape(y_true)) == len(K.int_shape(y_pred))):
+ y_true = array_ops.squeeze(y_true, [-1])
y_pred = math_ops.argmax(y_pred, axis=-1)
# If the expected labels are float, we need to cast the int returned by
@@ -670,11 +672,11 @@ def top_k_categorical_accuracy(y_true, y_pred, k=5):
@tf_export('keras.metrics.sparse_top_k_categorical_accuracy')
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
- return K.mean(
- nn.in_top_k(y_pred,
- math_ops.cast(math_ops.reduce_max(y_true, axis=-1), 'int32'),
- k),
- axis=-1)
+ # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
+ if (len(K.int_shape(y_true)) == len(K.int_shape(y_pred))):
+ y_true = array_ops.squeeze(y_true, [-1])
+
+ return K.mean(nn.in_top_k(y_pred, math_ops.cast(y_true, 'int32'), k), axis=-1)
# Aliases
diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py
index 4195ea18ad..5f5565d4d5 100644
--- a/tensorflow/python/keras/metrics_test.py
+++ b/tensorflow/python/keras/metrics_test.py
@@ -54,6 +54,18 @@ class KerasMetricsTest(test.TestCase):
y_pred = K.variable(np.random.random((6, 7)))
self.assertEqual(K.eval(metric(y_true, y_pred)).shape, (6,))
+ # Test correctness if the shape of y_true is (num_samples,)
+ y_true = K.variable([1., 0., 0., 0.])
+ y_pred = K.variable([[0.8, 0.2], [0.6, 0.4], [0.7, 0.3], [0.9, 0.1]])
+ print(K.eval(metric(y_true, y_pred)))
+ self.assertAllEqual(K.eval(metric(y_true, y_pred)), [0., 1., 1., 1.])
+
+ # Test correctness if the shape of y_true is (num_samples, 1)
+ y_true = K.variable([[1.], [0.], [0.], [0.]])
+ y_pred = K.variable([[0.8, 0.2], [0.6, 0.4], [0.7, 0.3], [0.9, 0.1]])
+ print(K.eval(metric(y_true, y_pred)))
+ self.assertAllEqual(K.eval(metric(y_true, y_pred)), [0., 1., 1., 1.])
+
def test_sparse_categorical_accuracy_float(self):
with self.cached_session():
metric = metrics.sparse_categorical_accuracy
@@ -79,6 +91,7 @@ class KerasMetricsTest(test.TestCase):
def test_sparse_top_k_categorical_accuracy(self):
with self.cached_session():
+ # Test correctness if the shape of y_true is (num_samples, 1)
y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
y_true = K.variable(np.array([[1], [0]]))
result = K.eval(
@@ -91,6 +104,19 @@ class KerasMetricsTest(test.TestCase):
metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1))
self.assertEqual(result, 0.)
+ # Test correctness if the shape of y_true is (num_samples,)
+ y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
+ y_true = K.variable(np.array([1, 0]))
+ result = K.eval(
+ metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3))
+ self.assertEqual(result, 1)
+ result = K.eval(
+ metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2))
+ self.assertEqual(result, 0.5)
+ result = K.eval(
+ metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1))
+ self.assertEqual(result, 0.)
+
def test_top_k_categorical_accuracy(self):
with self.cached_session():
y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
diff --git a/tensorflow/python/keras/optimizer_v2/adadelta.py b/tensorflow/python/keras/optimizer_v2/adadelta.py
new file mode 100644
index 0000000000..d3b3c9c12e
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/adadelta.py
@@ -0,0 +1,116 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Adadelta for TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.training import training_ops
+
+
+class Adadelta(optimizer_v2.OptimizerV2):
+ """Adadelta optimizer.
+
+ It is recommended to leave the parameters of this optimizer at their default
+ values.
+
+ See [M. D. Zeiler](http://arxiv.org/abs/1212.5701)
+ ([pdf](http://arxiv.org/pdf/1212.5701v1.pdf))
+
+ Some of the args below are hyperparameters, where a hyperparameter is
+ defined as a scalar Tensor, a regular Python value, or a callable (which
+ will be evaluated when `apply_gradients` is called) returning a scalar
+ Tensor or a Python value.
+
+ Arguments:
+ learning_rate: float hyperparameter >= 0. Learning rate. It is recommended
+ to leave it at the default value.
+ rho: float hyperparameter >= 0. The decay rate.
+ epsilon: float hyperparameter >= 0. Fuzz factor. A constant epsilon used
+ to better condition the grad update.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to 'Adadelta'.
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ rho=0.95,
+ epsilon=1e-8,
+ name="Adadelta"):
+ super(Adadelta, self).__init__(name)
+ self._set_hyper("learning_rate", learning_rate)
+ self._set_hyper("rho", rho)
+ self._set_hyper("epsilon", epsilon)
+
+ def _create_vars(self, var_list, state):
+ for v in var_list:
+ state.zeros_slot(v, "accum")
+ state.zeros_slot(v, "accum_update")
+
+ def _apply_dense(self, grad, var, state):
+ accum = state.get_slot(var, "accum")
+ accum_update = state.get_slot(var, "accum_update")
+ return training_ops.apply_adadelta(
+ var,
+ accum,
+ accum_update,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("epsilon", var.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking)
+
+ def _resource_apply_dense(self, grad, var, state):
+ accum = state.get_slot(var, "accum")
+ accum_update = state.get_slot(var, "accum_update")
+ return training_ops.resource_apply_adadelta(
+ var.handle,
+ accum.handle,
+ accum_update.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("epsilon", var.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking)
+
+ def _apply_sparse(self, grad, var, state):
+ accum = state.get_slot(var, "accum")
+ accum_update = state.get_slot(var, "accum_update")
+ return training_ops.sparse_apply_adadelta(
+ var,
+ accum,
+ accum_update,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("epsilon", var.dtype.base_dtype),
+ grad.values,
+ grad.indices,
+ use_locking=self._use_locking)
+
+ def _resource_apply_sparse(self, grad, var, indices, state):
+ accum = state.get_slot(var, "accum")
+ accum_update = state.get_slot(var, "accum_update")
+ return training_ops.resource_sparse_apply_adadelta(
+ var.handle,
+ accum.handle,
+ accum_update.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("epsilon", var.dtype.base_dtype),
+ grad,
+ indices,
+ use_locking=self._use_locking)
diff --git a/tensorflow/python/keras/optimizer_v2/adadelta_test.py b/tensorflow/python/keras/optimizer_v2/adadelta_test.py
new file mode 100644
index 0000000000..6e48f92e4f
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/adadelta_test.py
@@ -0,0 +1,166 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Adadelta Optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.keras.optimizer_v2 import adadelta
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class AdadeltaOptimizerTest(test.TestCase):
+
+ def doTestBasic(self, use_resource=False):
+ num_updates = 4 # number of ADADELTA steps to perform
+ for dtype in [dtypes.half, dtypes.float32]:
+ for grad in [0.2, 0.1, 0.01]:
+ for lr in [1.0, 0.5, 0.1]:
+ with self.cached_session():
+ var0_init = [1.0, 2.0]
+ var1_init = [3.0, 4.0]
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(
+ var0_init, dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable(
+ var1_init, dtype=dtype)
+ else:
+ var0 = variables.Variable(var0_init, dtype=dtype)
+ var1 = variables.Variable(var1_init, dtype=dtype)
+
+ grads = constant_op.constant([grad, grad], dtype=dtype)
+
+ accum = 0.0
+ accum_update = 0.0
+
+ # ADADELTA gradient optimizer
+ rho = 0.95
+ epsilon = 1e-8
+ adadelta_opt = adadelta.Adadelta(lr, rho, epsilon)
+ adadelta_update = adadelta_opt.apply_gradients(
+ zip([grads, grads], [var0, var1]))
+
+ opt_vars = adadelta_opt.variables()
+ self.assertStartsWith(opt_vars[0].name, var0._shared_name)
+ self.assertStartsWith(opt_vars[1].name, var0._shared_name)
+ self.assertStartsWith(opt_vars[2].name, var1._shared_name)
+ self.assertStartsWith(opt_vars[3].name, var1._shared_name)
+ self.assertEqual(4, len(opt_vars))
+
+ variables.global_variables_initializer().run()
+
+ # Assign slots
+ slot = [None] * 2
+ slot_update = [None] * 2
+ self.assertEqual(["accum", "accum_update"],
+ adadelta_opt.get_slot_names())
+ slot[0] = adadelta_opt.get_slot(var0, "accum")
+ self.assertEquals(slot[0].get_shape(), var0.get_shape())
+ self.assertFalse(slot[0] in variables.trainable_variables())
+
+ slot_update[0] = adadelta_opt.get_slot(var0, "accum_update")
+ self.assertEquals(slot_update[0].get_shape(), var0.get_shape())
+ self.assertFalse(slot_update[0] in variables.trainable_variables())
+
+ slot[1] = adadelta_opt.get_slot(var1, "accum")
+ self.assertEquals(slot[1].get_shape(), var1.get_shape())
+ self.assertFalse(slot[1] in variables.trainable_variables())
+
+ slot_update[1] = adadelta_opt.get_slot(var1, "accum_update")
+ self.assertEquals(slot_update[1].get_shape(), var1.get_shape())
+ self.assertFalse(slot_update[1] in variables.trainable_variables())
+
+ # Fetch params to validate initial values
+ self.assertAllClose(var0_init, var0.eval())
+ self.assertAllClose(var1_init, var1.eval())
+
+ update = [None] * num_updates
+ tot_update = 0
+ for step in range(num_updates):
+ # Run adadelta update for comparison
+ adadelta_update.run()
+
+ # Perform initial update without previous accum values
+ accum = accum * rho + (grad**2) * (1 - rho)
+ update[step] = (np.sqrt(accum_update + epsilon) *
+ (1. / np.sqrt(accum + epsilon)) * grad)
+ accum_update = (accum_update * rho + (update[step]**2) *
+ (1.0 - rho))
+ tot_update += update[step] * lr
+
+ # Check that the accumulators have been updated
+ for slot_idx in range(2):
+ self.assertAllCloseAccordingToType(
+ np.array([accum, accum], dtype=dtype.as_numpy_dtype()),
+ slot[slot_idx].eval(),
+ rtol=1e-5)
+
+ self.assertAllCloseAccordingToType(
+ np.array(
+ [accum_update, accum_update],
+ dtype=dtype.as_numpy_dtype()),
+ slot_update[slot_idx].eval(),
+ rtol=1e-5)
+
+ # Check that the parameters have been updated
+ self.assertAllCloseAccordingToType(
+ np.array(
+ [var0_init[0] - tot_update, var0_init[1] - tot_update],
+ dtype=dtype.as_numpy_dtype()),
+ var0.eval(),
+ rtol=1e-5)
+
+ self.assertAllCloseAccordingToType(
+ np.array(
+ [var1_init[0] - tot_update, var1_init[1] - tot_update],
+ dtype=dtype.as_numpy_dtype()),
+ var1.eval(),
+ rtol=1e-5)
+
+ def testBasic(self):
+ self.doTestBasic(use_resource=False)
+
+ def testResourceBasic(self):
+ self.doTestBasic(use_resource=True)
+
+ def testMinimizeSparseResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
+ loss = pred * pred
+ sgd_op = adadelta.Adadelta(1.0, 1.0, 1.0).minimize(loss)
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ [[-111, -138]], var0.eval())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/keras/optimizer_v2/adagrad.py b/tensorflow/python/keras/optimizer_v2/adagrad.py
new file mode 100644
index 0000000000..2d8cec2300
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/adagrad.py
@@ -0,0 +1,119 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Adagrad optimizer for TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.training import training_ops
+
+
+class Adagrad(optimizer_v2.OptimizerV2):
+ """Adagrad optimizer.
+
+ It is recommended to leave the parameters of this optimizer at their default
+ values.
+
+ See this [paper](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
+ or this
+ [intro](https://ppasupat.github.io/a9online/uploads/proximal_notes.pdf).
+
+ The learning_rate arg below is a hyperparameter, where a hyperparameter is
+ defined as a scalar Tensor, a regular Python value, or a callable (which
+ will be evaluated when `apply_gradients` is called) returning a scalar
+ Tensor or a Python value.
+
+ Arguments:
+ learning_rate: float hyperparameter >= 0. Learning rate.
+ initial_accumulator_value: A floating point value. Starting value for the
+ accumulators, must be positive.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to 'Adagrad'.
+
+ Raises:
+ ValueError: If the `initial_accumulator_value` is invalid.
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ initial_accumulator_value=0.1,
+ name="Adagrad"):
+ if initial_accumulator_value <= 0.0:
+ raise ValueError("initial_accumulator_value must be positive: %s" %
+ initial_accumulator_value)
+ super(Adagrad, self).__init__(name)
+ self._set_hyper("learning_rate", learning_rate)
+
+ self._initial_accumulator_value = initial_accumulator_value
+
+ def _create_vars(self, var_list, state):
+ for v in var_list:
+ dtype = v.dtype.base_dtype
+ if v.get_shape().is_fully_defined():
+ init = init_ops.constant_initializer(self._initial_accumulator_value,
+ dtype=dtype)
+ else:
+ def init(v=v, dtype=dtype):
+ # Use a Tensor instead of initializer if variable does not have
+ # static shape.
+ init_constant = gen_array_ops.fill(array_ops.shape(v),
+ self._initial_accumulator_value)
+ return math_ops.cast(init_constant, dtype)
+ state.create_slot_with_initializer(v, init, v.get_shape(), dtype,
+ "accumulator")
+
+ def _apply_dense(self, grad, var, state):
+ acc = state.get_slot(var, "accumulator")
+ return training_ops.apply_adagrad(
+ var,
+ acc,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking)
+
+ def _resource_apply_dense(self, grad, var, state):
+ acc = state.get_slot(var, "accumulator")
+ return training_ops.resource_apply_adagrad(
+ var.handle,
+ acc.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking)
+
+ def _apply_sparse(self, grad, var, state):
+ acc = state.get_slot(var, "accumulator")
+ return training_ops.sparse_apply_adagrad(
+ var,
+ acc,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad.values,
+ grad.indices,
+ use_locking=self._use_locking)
+
+ def _resource_apply_sparse(self, grad, var, indices, state):
+ acc = state.get_slot(var, "accumulator")
+ return training_ops.resource_sparse_apply_adagrad(
+ var.handle,
+ acc.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ indices,
+ use_locking=self._use_locking)
diff --git a/tensorflow/python/keras/optimizer_v2/adagrad_test.py b/tensorflow/python/keras/optimizer_v2/adagrad_test.py
new file mode 100644
index 0000000000..fc4ef5c399
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/adagrad_test.py
@@ -0,0 +1,276 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for aggregate operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.keras.optimizer_v2 import adagrad
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class AdagradOptimizerTest(test.TestCase):
+
+ def doTestBasic(self, use_resource=False):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ else:
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ ada_opt = adagrad.Adagrad(3.0, initial_accumulator_value=0.1)
+ ada_update = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 3 steps of adagrad
+ for _ in range(3):
+ ada_update.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([2.715679168701172, 3.715679168701172]), var1.eval())
+
+ def testBasic(self):
+ self.doTestBasic()
+
+ def testBasicResource(self):
+ self.doTestBasic(use_resource=True)
+
+ def testMinimizeSparseResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable(
+ [[1.0, 2.0], [3.0, 4.0]], dtype=dtype)
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
+ loss = pred * pred
+ sgd_op = adagrad.Adagrad(1.0).minimize(loss)
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType(
+ [[1.0, 2.0], [3.0, 4.0]], var0.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ [[0, 1], [3, 4]], var0.eval(), atol=0.01)
+
+ def testTensorLearningRate(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ ada_opt = adagrad.Adagrad(
+ constant_op.constant(3.0), initial_accumulator_value=0.1)
+ ada_update = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 3 steps of adagrad
+ for _ in range(3):
+ ada_update.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([2.715679168701172, 3.715679168701172]), var1.eval())
+
+ def testSparseBasic(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
+ var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant(
+ [0.1], shape=[1, 1], dtype=dtype),
+ constant_op.constant([0]),
+ constant_op.constant([2, 1]))
+ grads1 = ops.IndexedSlices(
+ constant_op.constant(
+ [0.01], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]),
+ constant_op.constant([2, 1]))
+ ada_opt = adagrad.Adagrad(3.0, initial_accumulator_value=0.1)
+ ada_update = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([[1.0], [2.0]], var0.eval())
+ self.assertAllClose([[3.0], [4.0]], var1.eval())
+ # Run 3 step of sgd
+ for _ in range(3):
+ ada_update.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ np.array([[-1.6026098728179932], [2.0]]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([[3.0], [3.715679168701172]]), var1.eval())
+
+ def testSparseRepeatedIndices(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ repeated_index_update_var = variables.Variable(
+ [[1.0], [2.0]], dtype=dtype)
+ aggregated_update_var = variables.Variable(
+ [[1.0], [2.0]], dtype=dtype)
+ grad_repeated_index = ops.IndexedSlices(
+ constant_op.constant(
+ [0.1, 0.1], shape=[2, 1], dtype=dtype),
+ constant_op.constant([1, 1]),
+ constant_op.constant([2, 1]))
+ grad_aggregated = ops.IndexedSlices(
+ constant_op.constant(
+ [0.2], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]),
+ constant_op.constant([2, 1]))
+ repeated_update = adagrad.Adagrad(3.0).apply_gradients(
+ [(grad_repeated_index, repeated_index_update_var)])
+ aggregated_update = adagrad.Adagrad(3.0).apply_gradients(
+ [(grad_aggregated, aggregated_update_var)])
+ variables.global_variables_initializer().run()
+ self.assertAllClose(aggregated_update_var.eval(),
+ repeated_index_update_var.eval())
+ for _ in range(3):
+ repeated_update.run()
+ aggregated_update.run()
+ self.assertAllClose(aggregated_update_var.eval(),
+ repeated_index_update_var.eval())
+
+ def testSparseRepeatedIndicesResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var_repeated = resource_variable_ops.ResourceVariable(
+ [1.0, 2.0], dtype=dtype)
+ loss_repeated = math_ops.reduce_sum(
+ embedding_ops.embedding_lookup(var_repeated, [0, 0]))
+ var_aggregated = resource_variable_ops.ResourceVariable(
+ [1.0, 2.0], dtype=dtype)
+ loss_aggregated = 2 * math_ops.reduce_sum(
+ embedding_ops.embedding_lookup(var_aggregated, [0]))
+ update_op_repeated = adagrad.Adagrad(2.0).minimize(loss_repeated)
+ update_op_aggregated = adagrad.Adagrad(2.0).minimize(loss_aggregated)
+ variables.global_variables_initializer().run()
+ self.assertAllCloseAccordingToType(
+ var_repeated.eval(), var_aggregated.eval())
+ for _ in range(3):
+ update_op_repeated.run()
+ update_op_aggregated.run()
+ self.assertAllCloseAccordingToType(
+ var_repeated.eval(), var_aggregated.eval())
+
+ def testSparseStability(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ shape = [1, 6]
+ var0 = variables.Variable(
+ [[
+ 0.00872496, -0.106952, 0.110467, 0.226505, -0.0147257,
+ -0.0105945
+ ]],
+ dtype=dtype)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant(
+ [[
+ -5.91278e-05, 5.31673e-05, -2.5779e-06, 4.29153e-05,
+ -8.4877e-05, -9.48906e-05
+ ]],
+ shape=shape,
+ dtype=dtype),
+ constant_op.constant([0]),
+ constant_op.constant(shape))
+ ada_opt = adagrad.Adagrad(1.0, initial_accumulator_value=0.1)
+ ada_update = ada_opt.apply_gradients(zip([grads0], [var0]))
+ self.assertEqual(["accumulator"], ada_opt.get_slot_names())
+ slot0 = ada_opt.get_slot(var0, "accumulator")
+ init = variables.global_variables_initializer()
+ for _ in range(100):
+ init.run()
+ ada_update.run()
+ self.assertAllCloseAccordingToType(
+ np.array([[0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]), slot0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([[
+ 0.00891194, -0.10712013, 0.11047515, 0.22636929, -0.0144573,
+ -0.01029443
+ ]]), var0.eval())
+
+ def testSharing(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ ada_opt = adagrad.Adagrad(3.0)
+ # Apply the optimizer twice. Both applications will use
+ # the same accums.
+ ada_update1 = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ ada_update2 = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ self.assertEqual(["accumulator"], ada_opt.get_slot_names())
+ slot0 = ada_opt.get_slot(var0, "accumulator")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ slot1 = ada_opt.get_slot(var1, "accumulator")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+ variables.global_variables_initializer().run()
+
+ # Fetch params to validate initial values.
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Mix the first and the second adagrad for 3 steps.
+ ada_update1.run()
+ ada_update2.run()
+ ada_update1.run()
+ # Validate updated params (the same as with only 1 Adagrad).
+ self.assertAllCloseAccordingToType(
+ np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([2.715679168701172, 3.715679168701172]), var1.eval())
+
+ def testDynamicShapeVariable_Ok(self):
+ with self.cached_session():
+ v = variable_scope.get_variable("v", initializer=constant_op.constant(1.),
+ validate_shape=False)
+ self.assertFalse(v.shape.is_fully_defined())
+ # Creating optimizer should cause no exception.
+ adagrad.Adagrad(3.0, initial_accumulator_value=0.1)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/keras/optimizer_v2/adam.py b/tensorflow/python/keras/optimizer_v2/adam.py
new file mode 100644
index 0000000000..8367228d7a
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/adam.py
@@ -0,0 +1,203 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Adam optimizer for TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.training import training_ops
+
+
+class Adam(optimizer_v2.OptimizerV2):
+ r"""Adam Optimizer.
+
+ Default parameters follow those provided in the original paper.
+
+ See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
+ ([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
+
+ Some of the args below are hyperparameters where a hyperparameter is
+ defined as a scalar Tensor, a regular Python value, or a callable (which
+ will be evaluated when `apply_gradients` is called) returning a scalar
+ Tensor or a Python value.
+
+ Initialization:
+
+ $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
+ $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
+ $$t := 0 \text{(Initialize timestep)}$$
+ The update rule for `variable` with gradient `g` uses an optimization
+ described at the end of section2 of the paper:
+
+ $$t := t + 1$$
+ $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
+
+ $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
+ $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
+ $$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
+
+ The default value of 1e-8 for epsilon might not be a good default in
+ general. For example, when training an Inception network on ImageNet a
+ current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
+ formulation just before Section 2.1 of the Kingma and Ba paper rather than
+ the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
+ hat" in the paper.
+
+ The sparse implementation of this algorithm (used when the gradient is an
+ IndexedSlices object, typically because of `tf.gather` or an embedding
+ lookup in the forward pass) does apply momentum to variable slices even if
+ they were not used in the forward pass (meaning they have a gradient equal
+ to zero). Momentum decay (beta1) is also applied to the entire momentum
+ accumulator. This means that the sparse behavior is equivalent to the dense
+ behavior (in contrast to some momentum implementations which ignore momentum
+ unless a variable slice was actually used).
+
+ Arguments:
+ learning_rate: float hyperparameter >= 0. Learning rate.
+ beta_1: float hyperparameter, 0 < beta_1 < 1. Generally close to 1. The
+ exponential decay rate for the 1st moment estimates.
+ beta_2: float hyperparameter, 0 < beta_2 < 1. Generally close to 1. The
+ exponential decay rate for the 2nd moment estimates.
+ epsilon: float hyperparameter >= 0. Fuzz factor. This epsilon is "epsilon
+ hat" in the Kingma and Ba paper (in the formula just before Section
+ 2.1), not the epsilon in Algorithm 1 of the paper.
+ name: Optional name for the operations created when applying gradients.
+ Defaults to "Adam".
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ beta_1=0.9,
+ beta_2=0.999,
+ epsilon=1e-8,
+ name="Adam"):
+ super(Adam, self).__init__(name)
+
+ self._set_hyper("learning_rate", learning_rate)
+ self._set_hyper("beta_1", beta_1)
+ self._set_hyper("beta_2", beta_2)
+ self._set_hyper("epsilon", epsilon)
+
+ def _get_beta_accumulators(self, state=None):
+ if state is None:
+ state = self._get_per_graph_state()
+ return (state.get_non_slot("beta_1_power"),
+ state.get_non_slot("beta_2_power"))
+
+ def _create_vars(self, var_list, state):
+ # Non-slot variables end up on the same device(s).
+ state.create_non_slot(
+ initial_value=lambda: state.get_hyper("beta_1"), name="beta_1_power")
+ state.create_non_slot(
+ initial_value=lambda: state.get_hyper("beta_2"), name="beta_2_power")
+
+ # Create slots for the first and second moments.
+ for v in var_list:
+ state.zeros_slot(v, "m")
+ state.zeros_slot(v, "v")
+
+ def _apply_dense(self, grad, var, state):
+ m = state.get_slot(var, "m")
+ v = state.get_slot(var, "v")
+ beta_1_power, beta_2_power = self._get_beta_accumulators(state)
+ return training_ops.apply_adam(
+ var,
+ m,
+ v,
+ math_ops.cast(beta_1_power, var.dtype.base_dtype),
+ math_ops.cast(beta_2_power, var.dtype.base_dtype),
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("beta_1", var.dtype.base_dtype),
+ state.get_hyper("beta_2", var.dtype.base_dtype),
+ state.get_hyper("epsilon", var.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking).op
+
+ def _resource_apply_dense(self, grad, var, state):
+ m = state.get_slot(var, "m")
+ v = state.get_slot(var, "v")
+ beta_1_power, beta_2_power = self._get_beta_accumulators(state)
+ return training_ops.resource_apply_adam(
+ var.handle,
+ m.handle,
+ v.handle,
+ math_ops.cast(beta_1_power, grad.dtype.base_dtype),
+ math_ops.cast(beta_2_power, grad.dtype.base_dtype),
+ state.get_hyper("learning_rate", grad.dtype.base_dtype),
+ state.get_hyper("beta_1", grad.dtype.base_dtype),
+ state.get_hyper("beta_2", grad.dtype.base_dtype),
+ state.get_hyper("epsilon", grad.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking)
+
+ def _apply_sparse_shared(self, grad, var, indices, scatter_add, state):
+ beta_1_power, beta_2_power = self._get_beta_accumulators(state)
+ beta_1_power = math_ops.cast(beta_1_power, var.dtype.base_dtype)
+ beta_2_power = math_ops.cast(beta_2_power, var.dtype.base_dtype)
+ lr_t = state.get_hyper("learning_rate", var.dtype.base_dtype)
+ beta_1_t = state.get_hyper("beta_1", var.dtype.base_dtype)
+ beta_2_t = state.get_hyper("beta_2", var.dtype.base_dtype)
+ epsilon_t = state.get_hyper("epsilon", var.dtype.base_dtype)
+ lr = (lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power))
+ # m_t = beta_1 * m + (1 - beta_1) * g_t
+ m = state.get_slot(var, "m")
+ m_scaled_g_values = grad * (1 - beta_1_t)
+ m_t = state_ops.assign(m, m * beta_1_t, use_locking=self._use_locking)
+ with ops.control_dependencies([m_t]):
+ m_t = scatter_add(m, indices, m_scaled_g_values)
+ # v_t = beta_2 * v + (1 - beta_2) * (g_t * g_t)
+ v = state.get_slot(var, "v")
+ v_scaled_g_values = (grad * grad) * (1 - beta_2_t)
+ v_t = state_ops.assign(v, v * beta_2_t, use_locking=self._use_locking)
+ with ops.control_dependencies([v_t]):
+ v_t = scatter_add(v, indices, v_scaled_g_values)
+ v_sqrt = math_ops.sqrt(v_t)
+ var_update = state_ops.assign_sub(var,
+ lr * m_t / (v_sqrt + epsilon_t),
+ use_locking=self._use_locking)
+ return control_flow_ops.group(*[var_update, m_t, v_t])
+
+ def _apply_sparse(self, grad, var, state):
+ return self._apply_sparse_shared(
+ grad.values, var, grad.indices,
+ lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
+ x, i, v, use_locking=self._use_locking),
+ state)
+
+ def _resource_scatter_add(self, x, i, v):
+ with ops.control_dependencies(
+ [resource_variable_ops.resource_scatter_add(
+ x.handle, i, v)]):
+ return x.value()
+
+ def _resource_apply_sparse(self, grad, var, indices, state):
+ return self._apply_sparse_shared(
+ grad, var, indices, self._resource_scatter_add, state)
+
+ def _finish(self, state):
+ # Update the power accumulators.
+ beta_1_power, beta_2_power = self._get_beta_accumulators(state)
+ update_beta_1 = beta_1_power.assign(
+ beta_1_power * state.get_hyper("beta_1"), use_locking=self._use_locking)
+ update_beta_2 = beta_2_power.assign(
+ beta_2_power * state.get_hyper("beta_2"), use_locking=self._use_locking)
+ return control_flow_ops.group(update_beta_1, update_beta_2)
diff --git a/tensorflow/python/keras/optimizer_v2/adam_test.py b/tensorflow/python/keras/optimizer_v2/adam_test.py
new file mode 100644
index 0000000000..77796317a1
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/adam_test.py
@@ -0,0 +1,333 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Adam optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.optimizer_v2 import adam
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+def adam_update_numpy(param,
+ g_t,
+ t,
+ m,
+ v,
+ alpha=0.001,
+ beta1=0.9,
+ beta2=0.999,
+ epsilon=1e-8):
+ alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t)
+
+ m_t = beta1 * m + (1 - beta1) * g_t
+ v_t = beta2 * v + (1 - beta2) * g_t * g_t
+
+ param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon)
+ return param_t, m_t, v_t
+
+
+class AdamOptimizerTest(test.TestCase):
+
+ def doTestSparse(self, use_resource=False):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ else:
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0_np_indices = np.array([0, 1], dtype=np.int32)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant(grads0_np),
+ constant_op.constant(grads0_np_indices), constant_op.constant([2]))
+ grads1_np_indices = np.array([0, 1], dtype=np.int32)
+ grads1 = ops.IndexedSlices(
+ constant_op.constant(grads1_np),
+ constant_op.constant(grads1_np_indices), constant_op.constant([2]))
+ opt = adam.Adam()
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of Adam
+ for t in range(1, 4):
+ self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+ self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
+ update.run()
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testSparse(self):
+ self.doTestSparse(use_resource=False)
+
+ def testResourceSparse(self):
+ self.doTestSparse(use_resource=True)
+
+ def testSparseDevicePlacement(self):
+ for index_dtype in [dtypes.int32, dtypes.int64]:
+ with self.test_session(force_gpu=test.is_gpu_available()):
+ # If a GPU is available, tests that all optimizer ops can be placed on
+ # it (i.e. they have GPU kernels).
+ var = variables.Variable([[1.0], [2.0]])
+ indices = constant_op.constant([0, 1], dtype=index_dtype)
+ gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices))
+ optimizer = adam.Adam(3.0)
+ minimize_op = optimizer.minimize(gathered_sum)
+ variables.global_variables_initializer().run()
+ minimize_op.run()
+
+ def testSparseRepeatedIndices(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ repeated_index_update_var = variables.Variable(
+ [[1.0], [2.0]], dtype=dtype)
+ aggregated_update_var = variables.Variable(
+ [[1.0], [2.0]], dtype=dtype)
+ grad_repeated_index = ops.IndexedSlices(
+ constant_op.constant(
+ [0.1, 0.1], shape=[2, 1], dtype=dtype),
+ constant_op.constant([1, 1]),
+ constant_op.constant([2, 1]))
+ grad_aggregated = ops.IndexedSlices(
+ constant_op.constant(
+ [0.2], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]),
+ constant_op.constant([2, 1]))
+ repeated_update = adam.Adam().apply_gradients(
+ [(grad_repeated_index, repeated_index_update_var)])
+ aggregated_update = adam.Adam().apply_gradients(
+ [(grad_aggregated, aggregated_update_var)])
+ variables.global_variables_initializer().run()
+ self.assertAllClose(aggregated_update_var.eval(),
+ repeated_index_update_var.eval())
+ for _ in range(3):
+ repeated_update.run()
+ aggregated_update.run()
+ self.assertAllClose(aggregated_update_var.eval(),
+ repeated_index_update_var.eval())
+
+ def doTestBasic(self, use_resource=False):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ with self.session(graph=ops.Graph()):
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(
+ var0_np, name="var0_%d" % i)
+ var1 = resource_variable_ops.ResourceVariable(
+ var1_np, name="var1_%d" % i)
+ else:
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+
+ opt = adam.Adam()
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ opt_variables = opt.variables()
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+ self.assertTrue(beta1_power is not None)
+ self.assertTrue(beta2_power is not None)
+ self.assertIn(beta1_power, opt_variables)
+ self.assertIn(beta2_power, opt_variables)
+
+ with ops.Graph().as_default():
+ # Shouldn't return non-slot variables from other graphs.
+ self.assertEqual(0, len(opt.variables()))
+
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of Adam
+ for t in range(1, 4):
+ if not context.executing_eagerly():
+ self.evaluate(update)
+ elif t > 1:
+ opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+
+ self.assertAllCloseAccordingToType(0.9**(t + 1),
+ self.evaluate(beta1_power))
+ self.assertAllCloseAccordingToType(0.999**(t + 1),
+ self.evaluate(beta2_power))
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+ self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
+ if use_resource:
+ self.assertEqual("var0_%d/Adam:0" % (i,),
+ opt.get_slot(var=var0, name="m").name)
+
+ def testBasic(self):
+ with self.cached_session():
+ self.doTestBasic(use_resource=False)
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testResourceBasic(self):
+ self.doTestBasic(use_resource=True)
+
+ def testTensorLearningRate(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ opt = adam.Adam(constant_op.constant(0.001))
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of Adam
+ for t in range(1, 4):
+ self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+ self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
+ update.run()
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testSharing(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ opt = adam.Adam()
+ update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 3 steps of intertwined Adam1 and Adam2.
+ for t in range(1, 4):
+ self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+ self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
+ if t % 2 == 0:
+ update1.run()
+ else:
+ update2.run()
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testTwoSessions(self):
+ optimizer = adam.Adam()
+ g = ops.Graph()
+ with g.as_default():
+ with session.Session():
+ var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
+ grads0 = constant_op.constant(np.array([0.1, 0.1]))
+ optimizer.apply_gradients([(grads0, var0)])
+
+ gg = ops.Graph()
+ with gg.as_default():
+ with session.Session():
+ var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
+ grads0 = constant_op.constant(np.array([0.1, 0.1]))
+
+ # If the optimizer saves any state not keyed by graph the following line
+ # fails.
+ optimizer.apply_gradients([(grads0, var0)])
+
+ def testSlotsUniqueEager(self):
+ with context.eager_mode():
+ v1 = resource_variable_ops.ResourceVariable(1.)
+ v2 = resource_variable_ops.ResourceVariable(1.)
+ opt = adam.Adam(1.)
+ opt.minimize(lambda: v1 + v2)
+ # There should be two non-slot variables, and two unique slot variables
+ # for v1 and v2 respectively.
+ self.assertEqual(6, len(set(opt.variables())))
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/keras/optimizer_v2/checkpointable_utils_test.py b/tensorflow/python/keras/optimizer_v2/checkpointable_utils_test.py
new file mode 100644
index 0000000000..338c04148b
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/checkpointable_utils_test.py
@@ -0,0 +1,761 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# TODO(josh11b): Forked from contrib/eager/python to test OptimizerV2 the same way
+# OptimizerV1 is tested. This file should be removed once the fork is resolved.
+
+import functools
+import os
+
+import six
+
+from tensorflow.python.client import session as session_lib
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
+from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import training
+from tensorflow.python.keras.layers import core
+from tensorflow.python.keras.optimizer_v2 import adam
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import template
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training import saver as core_saver
+from tensorflow.python.training import training_util
+from tensorflow.python.training.checkpointable import tracking
+from tensorflow.python.training.checkpointable import util
+
+
+class NonLayerCheckpointable(tracking.Checkpointable):
+
+ def __init__(self):
+ super(NonLayerCheckpointable, self).__init__()
+ self.a_variable = util.add_variable(
+ self, name="a_variable", shape=[])
+
+
+# pylint: disable=not-callable
+class MyModel(training.Model):
+ """A concrete Model for testing."""
+
+ def __init__(self):
+ super(MyModel, self).__init__()
+ self._named_dense = core.Dense(1, use_bias=True)
+ self._second = core.Dense(1, use_bias=False)
+ # We can still track Checkpointables which aren't Layers.
+ self._non_layer = NonLayerCheckpointable()
+
+ def call(self, values):
+ ret = self._second(self._named_dense(values))
+ return ret
+
+
+class _MirroringSaveable(
+ core_saver.BaseSaverBuilder.ResourceVariableSaveable):
+
+ def __init__(self, primary_variable, mirrored_variable, name):
+ self._primary_variable = primary_variable
+ self._mirrored_variable = mirrored_variable
+ super(_MirroringSaveable, self).__init__(
+ self._primary_variable, "", name)
+
+ def restore(self, restored_tensors, restored_shapes):
+ """Restore the same value into both variables."""
+ tensor, = restored_tensors
+ return control_flow_ops.group(
+ self._primary_variable.assign(tensor),
+ self._mirrored_variable.assign(tensor))
+
+
+class CheckpointingTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
+ def testNamingWithOptimizer(self):
+ input_value = constant_op.constant([[3.]])
+ model = MyModel()
+ # A nuisance Model using the same optimizer. Its slot variables should not
+ # go in the checkpoint, since it is never depended on.
+ other_model = MyModel()
+ optimizer = adam.Adam(0.001)
+ optimizer_step = training_util.get_or_create_global_step()
+ root_checkpointable = util.Checkpoint(
+ optimizer=optimizer, model=model, optimizer_step=optimizer_step)
+ if context.executing_eagerly():
+ optimizer.minimize(
+ lambda: model(input_value),
+ global_step=optimizer_step)
+ optimizer.minimize(
+ lambda: other_model(input_value),
+ global_step=optimizer_step)
+ else:
+ train_op = optimizer.minimize(
+ model(input_value), global_step=optimizer_step)
+ optimizer.minimize(
+ other_model(input_value),
+ global_step=optimizer_step)
+ self.evaluate(util.gather_initializers(
+ root_checkpointable))
+ self.evaluate(train_op)
+ named_variables, serialized_graph, _ = (
+ util._serialize_object_graph(
+ root_checkpointable, saveables_cache=None))
+ expected_checkpoint_names = (
+ # Created in the root node, so no prefix.
+ "optimizer_step",
+ "model/_second/kernel",
+ "model/_named_dense/kernel",
+ "model/_named_dense/bias",
+ # non-Layer dependency of the model
+ "model/_non_layer/a_variable",
+ # The optimizer creates two non-slot variables
+ "optimizer/beta_1_power",
+ "optimizer/beta_2_power",
+ # Slot variables
+ "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/m",
+ "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/v",
+ "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m",
+ "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/v",
+ "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/m",
+ "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/v",
+ )
+ suffix = "/.ATTRIBUTES/VARIABLE_VALUE"
+ expected_checkpoint_names = [
+ name + suffix for name in expected_checkpoint_names]
+ # The Dense layers also save get_config() JSON
+ expected_checkpoint_names.extend(
+ ["model/_second/.ATTRIBUTES/OBJECT_CONFIG_JSON",
+ "model/_named_dense/.ATTRIBUTES/OBJECT_CONFIG_JSON"])
+ named_variables = {v.name: v for v in named_variables}
+ six.assertCountEqual(self, expected_checkpoint_names,
+ named_variables.keys())
+ # Check that we've mapped to the right variable objects (not exhaustive)
+ self.assertEqual(
+ "global_step",
+ named_variables["optimizer_step" + suffix].full_name)
+ self.assertEqual(
+ "my_model/dense_1/kernel",
+ named_variables["model/_second/kernel" + suffix].full_name)
+ self.assertEqual(
+ "my_model/dense/kernel",
+ named_variables["model/_named_dense/kernel" + suffix].full_name)
+ self.assertEqual(
+ "beta_1_power",
+ named_variables["optimizer/beta_1_power" + suffix].full_name)
+ self.assertEqual(
+ "beta_2_power",
+ named_variables["optimizer/beta_2_power" + suffix].full_name)
+ # Spot check the generated protocol buffers.
+ self.assertEqual("optimizer",
+ serialized_graph.nodes[0].children[1].local_name)
+ optimizer_node = serialized_graph.nodes[serialized_graph.nodes[0].children[
+ 1].node_id]
+ self.assertEqual("beta_1_power", optimizer_node.children[0].local_name)
+ self.assertEqual(
+ "beta_1_power", serialized_graph.nodes[
+ optimizer_node.children[0].node_id].attributes[0].full_name)
+ self.assertEqual(
+ "my_model/dense/kernel",
+ serialized_graph.nodes[optimizer_node.slot_variables[0]
+ .original_variable_node_id]
+ .attributes[0].full_name)
+ # We strip off the :0 suffix, as variable.name-based saving does.
+ self.assertEqual(
+ "my_model/dense/kernel/Adam",
+ serialized_graph.nodes[optimizer_node.slot_variables[0]
+ .slot_variable_node_id]
+ .attributes[0].full_name)
+ self.assertEqual(
+ "my_model/dense/kernel/Adam:0",
+ optimizer.get_slot(
+ var=model._named_dense.kernel,
+ name="m").name)
+ self.assertEqual(
+ "model/_named_dense/kernel" + suffix,
+ serialized_graph.nodes[
+ optimizer_node.slot_variables[0]
+ .original_variable_node_id].attributes[0].checkpoint_key)
+ self.assertEqual("m", optimizer_node.slot_variables[0].slot_name)
+ self.assertEqual(
+ "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m" + suffix,
+ serialized_graph.nodes[
+ optimizer_node.slot_variables[0]
+ .slot_variable_node_id].attributes[0].checkpoint_key)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testSaveRestore(self):
+ model = MyModel()
+ optimizer = adam.Adam(0.001)
+ root_checkpointable = util.Checkpoint(
+ optimizer=optimizer, model=model)
+ input_value = constant_op.constant([[3.]])
+ if context.executing_eagerly():
+ optimizer.minimize(
+ lambda: model(input_value))
+ else:
+ train_op = optimizer.minimize(model(input_value))
+ # TODO(allenl): Make initialization more pleasant when graph building.
+ root_checkpointable.save_counter # pylint: disable=pointless-statement
+ self.evaluate(util.gather_initializers(
+ root_checkpointable))
+ self.evaluate(train_op)
+ prefix = os.path.join(self.get_temp_dir(), "ckpt")
+ self.evaluate(state_ops.assign(model._named_dense.variables[1], [42.]))
+ m_bias_slot = optimizer.get_slot(model._named_dense.variables[1], "m")
+ self.evaluate(state_ops.assign(m_bias_slot, [1.5]))
+ save_path = root_checkpointable.save(file_prefix=prefix)
+ self.evaluate(state_ops.assign(model._named_dense.variables[1], [43.]))
+ self.evaluate(state_ops.assign(root_checkpointable.save_counter, 3))
+ optimizer_variables = self.evaluate(optimizer.variables())
+ self.evaluate(state_ops.assign(m_bias_slot, [-2.]))
+ # Immediate restoration
+ status = root_checkpointable.restore(save_path=save_path).assert_consumed()
+ status.run_restore_ops()
+ self.assertAllEqual([42.], self.evaluate(model._named_dense.variables[1]))
+ self.assertAllEqual(1, self.evaluate(root_checkpointable.save_counter))
+ self.assertAllEqual([1.5], self.evaluate(m_bias_slot))
+ if not context.executing_eagerly():
+ return # Restore-on-create is only supported when executing eagerly
+ on_create_model = MyModel()
+ on_create_optimizer = adam.Adam(
+ 0.001,
+ # Preserve beta_1_power and beta_2_power when appying gradients
+ # so we can test that they've been restored correctly.
+ beta_1=1.0,
+ beta_2=1.0)
+ on_create_root = util.Checkpoint(
+ optimizer=on_create_optimizer, model=on_create_model)
+ # Deferred restoration
+ status = on_create_root.restore(save_path=save_path)
+ on_create_model(constant_op.constant([[3.]])) # create variables
+ self.assertAllEqual(1, self.evaluate(on_create_root.save_counter))
+ self.assertAllEqual([42.],
+ self.evaluate(
+ on_create_model._named_dense.variables[1]))
+ on_create_m_bias_slot = on_create_optimizer.get_slot(
+ on_create_model._named_dense.variables[1], "m")
+ # Optimizer slot variables are created when the original variable is
+ # restored.
+ self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot))
+ self.assertAllEqual(optimizer_variables[2:],
+ self.evaluate(on_create_optimizer.variables()))
+ dummy_var = resource_variable_ops.ResourceVariable([1.])
+ on_create_optimizer.minimize(loss=dummy_var.read_value)
+ status.assert_consumed()
+ beta_1_power, beta_2_power = on_create_optimizer._get_beta_accumulators()
+ self.assertAllEqual(optimizer_variables[0], self.evaluate(beta_1_power))
+ self.assertAllEqual(optimizer_variables[1], self.evaluate(beta_2_power))
+
+ # TODO(allenl): Debug garbage created by this test in python3.
+ def testDeferredRestorationUsageEager(self):
+ """An idiomatic eager execution example."""
+ num_training_steps = 10
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ for training_continuation in range(3):
+ model = MyModel()
+ optimizer = adam.Adam(0.001)
+ root = util.Checkpoint(
+ optimizer=optimizer, model=model,
+ optimizer_step=training_util.get_or_create_global_step())
+ root.restore(checkpoint_management.latest_checkpoint(
+ checkpoint_directory))
+ for _ in range(num_training_steps):
+ # TODO(allenl): Use a Dataset and serialize/checkpoint it.
+ input_value = constant_op.constant([[3.]])
+ optimizer.minimize(
+ lambda: model(input_value), # pylint: disable=cell-var-from-loop
+ global_step=root.optimizer_step)
+ root.save(file_prefix=checkpoint_prefix)
+ self.assertEqual((training_continuation + 1) * num_training_steps,
+ root.optimizer_step.numpy())
+
+ def testUsageGraph(self):
+ """Expected usage when graph building."""
+ with context.graph_mode():
+ num_training_steps = 10
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ for training_continuation in range(3):
+ with ops.Graph().as_default():
+ model = MyModel()
+ optimizer = adam.Adam(0.001)
+ root = util.Checkpoint(
+ optimizer=optimizer, model=model,
+ global_step=training_util.get_or_create_global_step())
+ input_value = constant_op.constant([[3.]])
+ train_op = optimizer.minimize(
+ model(input_value),
+ global_step=root.global_step)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
+ with self.session(graph=ops.get_default_graph()) as session:
+ status = root.restore(save_path=checkpoint_path)
+ status.initialize_or_restore(session=session)
+ if checkpoint_path is None:
+ self.assertEqual(0, training_continuation)
+ with self.assertRaises(AssertionError):
+ status.assert_consumed()
+ else:
+ status.assert_consumed()
+ for _ in range(num_training_steps):
+ session.run(train_op)
+ root.save(file_prefix=checkpoint_prefix, session=session)
+ self.assertEqual((training_continuation + 1) * num_training_steps,
+ session.run(root.global_step))
+ self.assertEqual(training_continuation + 1,
+ session.run(root.save_counter))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testAgnosticUsage(self):
+ """Graph/eager agnostic usage."""
+ # Does create garbage when executing eagerly due to ops.Graph() creation.
+ num_training_steps = 10
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ for training_continuation in range(3):
+ with ops.Graph().as_default(), self.test_session(
+ graph=ops.get_default_graph()), test_util.device(use_gpu=True):
+ model = MyModel()
+ optimizer = adam.Adam(0.001)
+ root = util.Checkpoint(
+ optimizer=optimizer, model=model,
+ global_step=training_util.get_or_create_global_step())
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
+ status = root.restore(save_path=checkpoint_path)
+ input_value = constant_op.constant([[3.]])
+ train_fn = functools.partial(
+ optimizer.minimize,
+ functools.partial(model, input_value),
+ global_step=root.global_step)
+ if not context.executing_eagerly():
+ train_fn = functools.partial(self.evaluate, train_fn())
+ status.initialize_or_restore()
+ for _ in range(num_training_steps):
+ train_fn()
+ root.save(file_prefix=checkpoint_prefix)
+ self.assertEqual((training_continuation + 1) * num_training_steps,
+ self.evaluate(root.global_step))
+ self.assertEqual(training_continuation + 1,
+ self.evaluate(root.save_counter))
+
+ # pylint: disable=cell-var-from-loop
+ @test_util.run_in_graph_and_eager_modes
+ def testWithDefun(self):
+ num_training_steps = 2
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ for training_continuation in range(3):
+ with ops.Graph().as_default(), self.test_session(
+ graph=ops.get_default_graph()), test_util.device(use_gpu=True):
+ model = MyModel()
+ # Don't actually train so we can test variable values
+ optimizer = adam.Adam(0.)
+ root = util.Checkpoint(
+ optimizer=optimizer, model=model,
+ global_step=training_util.get_or_create_global_step())
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
+ status = root.restore(save_path=checkpoint_path)
+ def train_fn():
+ @function.defun
+ def _call_model(x):
+ return model(x)
+ with backprop.GradientTape() as tape:
+ loss = _call_model(constant_op.constant([[3.]]))
+ gradients = tape.gradient(loss, model.variables)
+ return optimizer.apply_gradients(zip(gradients, model.variables),
+ global_step=root.global_step)
+ if not context.executing_eagerly():
+ train_fn = functools.partial(
+ self.evaluate, train_fn())
+ status.initialize_or_restore()
+ for _ in range(num_training_steps):
+ train_fn()
+ if training_continuation > 0:
+ status.assert_consumed()
+ self.assertAllClose([[42.]], self.evaluate(model.variables[0]))
+ else:
+ self.evaluate(model.variables[0].assign([[42.]]))
+ root.save(file_prefix=checkpoint_prefix)
+ self.assertEqual((training_continuation + 1) * num_training_steps,
+ self.evaluate(root.global_step))
+ self.assertEqual(training_continuation + 1,
+ self.evaluate(root.save_counter))
+ # pylint: enable=cell-var-from-loop
+
+ def testAnonymousVarsInInit(self):
+
+ class Model(training.Model):
+
+ def __init__(self):
+ super(Model, self).__init__()
+ self.w = resource_variable_ops.ResourceVariable(0.0)
+ self.b = resource_variable_ops.ResourceVariable(0.0)
+ self.vars = [self.w, self.b]
+
+ def call(self, x):
+ return x * self.w + self.b
+
+ with context.eager_mode():
+ model = Model()
+ optimizer = adam.Adam(learning_rate=0.05)
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ checkpoint = util.Checkpoint(
+ model=model, optimizer=optimizer)
+ for _ in range(2):
+ checkpoint.save(checkpoint_prefix)
+ with backprop.GradientTape() as tape:
+ loss = (constant_op.constant(1.)
+ - model(constant_op.constant(1.))) ** 2
+ grad = tape.gradient(loss, model.vars)
+ optimizer.apply_gradients(
+ [(g, v) for g, v in zip(grad, model.vars)])
+
+ @test_util.run_in_graph_and_eager_modes
+ def testDeferredSlotRestoration(self):
+ checkpoint_directory = self.get_temp_dir()
+
+ root = tracking.Checkpointable()
+ root.var = util.add_variable(
+ root, name="var", initializer=0.)
+ optimizer = adam.Adam(0.1)
+ if context.executing_eagerly():
+ optimizer.minimize(root.var.read_value)
+ else:
+ train_op = optimizer.minimize(root.var)
+ # Note that `optimizer` has not been added as a dependency of
+ # `root`. Create a one-off grouping so that slot variables for `root.var`
+ # get initialized too.
+ self.evaluate(util.gather_initializers(
+ util.Checkpoint(root=root, optimizer=optimizer)))
+ self.evaluate(train_op)
+ self.evaluate(state_ops.assign(root.var, 12.))
+ no_slots_path = util.CheckpointableSaver(root).save(
+ os.path.join(checkpoint_directory, "no_slots"))
+ root.optimizer = optimizer
+ self.evaluate(state_ops.assign(root.var, 13.))
+ self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var),
+ 14.))
+ slots_path = util.CheckpointableSaver(root).save(
+ os.path.join(checkpoint_directory, "with_slots"))
+ new_root = tracking.Checkpointable()
+ # Load the slot-containing checkpoint (deferred), then immediately overwrite
+ # the non-slot variable (also deferred).
+ slot_status = util.CheckpointableSaver(
+ new_root).restore(slots_path)
+ no_slot_status = util.CheckpointableSaver(
+ new_root).restore(no_slots_path)
+ with self.assertRaises(AssertionError):
+ no_slot_status.assert_consumed()
+ new_root.var = util.add_variable(
+ new_root, name="var", shape=[])
+ no_slot_status.assert_consumed()
+ no_slot_status.run_restore_ops()
+ self.assertEqual(12., self.evaluate(new_root.var))
+ new_root.optimizer = adam.Adam(0.1)
+ with self.assertRaisesRegexp(AssertionError, "beta_1_power"):
+ slot_status.assert_consumed()
+ self.assertEqual(12., self.evaluate(new_root.var))
+ if context.executing_eagerly():
+ # Slot variables are only created with restoring initializers when
+ # executing eagerly.
+ self.assertEqual(14., self.evaluate(
+ new_root.optimizer.get_slot(name="m", var=new_root.var)))
+ else:
+ self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var),
+ None)
+ if context.executing_eagerly():
+ new_root.optimizer.minimize(new_root.var.read_value)
+ else:
+ train_op = new_root.optimizer.minimize(new_root.var)
+ # The slot variable now exists; restore() didn't create it, but we should
+ # now have a restore op for it.
+ slot_status.run_restore_ops()
+ self.assertEqual(14., self.evaluate(
+ new_root.optimizer.get_slot(name="m", var=new_root.var)))
+ self.evaluate(train_op)
+ slot_status.assert_consumed()
+
+ def testManySavesGraph(self):
+ """Saves after the first should not modify the graph."""
+ with context.graph_mode():
+ graph = ops.Graph()
+ with graph.as_default(), self.session(graph):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ obj = tracking.Checkpointable()
+ obj.var = variable_scope.get_variable(name="v", initializer=0.)
+ obj.opt = adam.Adam(0.1)
+ obj.opt.minimize(obj.var.read_value())
+ self.evaluate(util.gather_initializers(obj))
+ saver = util.CheckpointableSaver(obj)
+ saver.save(checkpoint_prefix)
+ before_ops = graph.get_operations()
+ saver.save(checkpoint_prefix)
+ self.assertEqual(before_ops, graph.get_operations())
+
+ def testManyRestoresGraph(self):
+ """Restores after the first should not modify the graph."""
+ with context.graph_mode():
+ graph = ops.Graph()
+ with graph.as_default(), self.session(graph):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ obj = tracking.Checkpointable()
+ obj.var = variable_scope.get_variable(name="v", initializer=0.)
+ obj.opt = adam.Adam(0.1)
+ obj.opt.minimize(obj.var.read_value())
+ self.evaluate(util.gather_initializers(obj))
+ saver = util.CheckpointableSaver(obj)
+ save_path = saver.save(checkpoint_prefix)
+ saver.restore(save_path)
+ before_ops = graph.get_operations()
+ saver.restore(save_path)
+ self.assertEqual(before_ops, graph.get_operations())
+
+ def testMultipleGraphsNonSlotVariables(self):
+ with context.graph_mode():
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ optimizer = adam.Adam(0.001)
+ # Construct a model in one graph
+ first_graph = ops.Graph()
+ first_session = session_lib.Session(graph=first_graph)
+ with first_graph.as_default(), first_session.as_default():
+ first_variable = resource_variable_ops.ResourceVariable([1.])
+ first_root_checkpointable = util.Checkpoint(
+ optimizer=optimizer, variable=first_variable)
+ train_op = optimizer.minimize(first_variable.read_value)
+ self.evaluate(util.gather_initializers(
+ first_root_checkpointable))
+ self.evaluate(train_op)
+ self.evaluate(first_variable.assign([1.]))
+ self.evaluate(optimizer.get_slot(
+ var=first_variable, name="m").assign([2.]))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(3.))
+
+ # Save and load in a second graph
+ second_graph = ops.Graph()
+ with second_graph.as_default(), session_lib.Session(graph=second_graph):
+ second_variable = resource_variable_ops.ResourceVariable([1.])
+ second_root_checkpointable = util.Checkpoint(
+ optimizer=optimizer, variable=second_variable)
+ train_op = optimizer.minimize(second_variable.read_value)
+ second_root_checkpointable.restore(None).initialize_or_restore()
+ self.evaluate(train_op)
+ self.evaluate(second_variable.assign([4.]))
+ self.evaluate(optimizer.get_slot(
+ var=second_variable, name="m").assign([5.]))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(6.))
+ save_path = second_root_checkpointable.save(checkpoint_prefix)
+ self.evaluate(second_variable.assign([7.]))
+ self.evaluate(optimizer.get_slot(
+ var=second_variable, name="m").assign([8.]))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.assertAllEqual(6., self.evaluate(beta_1_power))
+ status = second_root_checkpointable.restore(save_path)
+ status.assert_consumed().run_restore_ops()
+ self.assertAllEqual([4.], self.evaluate(second_variable))
+ self.assertAllEqual([5.], self.evaluate(optimizer.get_slot(
+ var=second_variable, name="m")))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.assertAllEqual(6., self.evaluate(beta_1_power))
+
+ # Check that the first graph is unmolested
+ with first_graph.as_default(), first_session.as_default():
+ self.assertAllEqual([1.], self.evaluate(first_variable))
+ self.assertAllEqual([2.], self.evaluate(optimizer.get_slot(
+ var=first_variable, name="m")))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.assertAllEqual(3., self.evaluate(beta_1_power))
+
+
+class TemplateTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_checkpointable_save_restore(self):
+
+ def _templated():
+ v = variable_scope.get_variable(
+ "v", shape=[1], initializer=init_ops.zeros_initializer(),
+ use_resource=True)
+ v2 = variable_scope.get_variable(
+ "v2", shape=[1], initializer=init_ops.zeros_initializer(),
+ use_resource=True)
+ return v, v + 1., v2
+
+ save_template = template.make_template("s1", _templated)
+ v1_save, _, v2_save = save_template()
+ optimizer = adam.Adam(0.0)
+ save_root = util.Checkpoint(
+ my_template=save_template, optimizer=optimizer)
+ optimizer.minimize(v1_save.read_value)
+ self.evaluate([v.initializer for v in optimizer.variables()])
+ self.evaluate(v1_save.assign([12.]))
+ self.evaluate(v2_save.assign([14.]))
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ save_path = save_root.save(checkpoint_prefix)
+
+ load_template = template.make_template("s2", _templated)
+ load_optimizer = adam.Adam(0.0)
+ load_root = util.Checkpoint(
+ my_template=load_template, optimizer=load_optimizer)
+ status = load_root.restore(save_path)
+ var, var_plus_one, var2 = load_template()
+ load_optimizer.minimize(var.read_value)
+ self.assertEqual(2, len(load_template._checkpoint_dependencies))
+ self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
+ self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
+ status.assert_consumed().run_restore_ops()
+ self.assertAllEqual([12.], self.evaluate(var))
+ self.assertAllEqual([13.], self.evaluate(var_plus_one))
+ self.assertAllEqual([14.], self.evaluate(var2))
+
+
+class CheckpointCompatibilityTests(test.TestCase):
+
+ def _initialized_model(self):
+ input_value = constant_op.constant([[3.]])
+ model = MyModel()
+ optimizer = adam.Adam(0.001)
+ optimizer_step = training_util.get_or_create_global_step()
+ root_checkpointable = util.Checkpoint(
+ optimizer=optimizer, model=model, optimizer_step=optimizer_step)
+ train_op = optimizer.minimize(
+ functools.partial(model, input_value),
+ global_step=optimizer_step)
+ self.evaluate(util.gather_initializers(
+ root_checkpointable))
+ self.evaluate(train_op)
+ # A regular variable, a slot variable, and a non-slot Optimizer variable
+ # with known values to check when loading.
+ self.evaluate(model._named_dense.bias.assign([1.]))
+ self.evaluate(optimizer.get_slot(
+ var=model._named_dense.bias, name="m").assign([2.]))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(3.))
+ return root_checkpointable
+
+ def _set_sentinels(self, root_checkpointable):
+ self.evaluate(root_checkpointable.model._named_dense.bias.assign([101.]))
+ self.evaluate(
+ root_checkpointable.optimizer.get_slot(
+ var=root_checkpointable.model._named_dense.bias, name="m")
+ .assign([102.]))
+ beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(103.))
+
+ def _check_sentinels(self, root_checkpointable):
+ self.assertAllEqual(
+ [1.], self.evaluate(root_checkpointable.model._named_dense.bias))
+ self.assertAllEqual([2.], self.evaluate(
+ root_checkpointable.optimizer.get_slot(
+ var=root_checkpointable.model._named_dense.bias, name="m")))
+ beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
+ self.assertAllEqual(3., self.evaluate(beta_1_power))
+
+ def _write_name_based_checkpoint(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ with context.graph_mode():
+ save_graph = ops.Graph()
+ with save_graph.as_default(), self.test_session(
+ graph=save_graph) as session:
+ root = self._initialized_model()
+ name_saver = core_saver.Saver()
+ return name_saver.save(
+ sess=session, save_path=checkpoint_prefix,
+ global_step=root.optimizer_step)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testLoadFromNameBasedSaver(self):
+ """Save a name-based checkpoint, load it using the object-based API."""
+ with test_util.device(use_gpu=True):
+ save_path = self._write_name_based_checkpoint()
+ root = self._initialized_model()
+ self._set_sentinels(root)
+ with self.assertRaises(AssertionError):
+ self._check_sentinels(root)
+ object_saver = util.CheckpointableSaver(root)
+ self._set_sentinels(root)
+ status = object_saver.restore(save_path)
+ if context.executing_eagerly():
+ self._check_sentinels(root)
+ if context.executing_eagerly():
+ with self.assertRaisesRegexp(AssertionError, "OBJECT_CONFIG_JSON"):
+ status.assert_consumed()
+ else:
+ # When graph building, we haven't read any keys, so we don't know
+ # whether the restore will be complete.
+ with self.assertRaisesRegexp(AssertionError, "not restored"):
+ status.assert_consumed()
+ status.run_restore_ops()
+ self._check_sentinels(root)
+ self._set_sentinels(root)
+ status = object_saver.restore(save_path)
+ status.initialize_or_restore()
+ self._check_sentinels(root)
+
+ # TODO(allenl): Test for the core name-based saver loading object-based
+ # checkpoints once object-based checkpointing is in core.
+
+ def testSaveGraphLoadEager(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ with context.graph_mode():
+ save_graph = ops.Graph()
+ with save_graph.as_default(), self.test_session(
+ graph=save_graph) as session:
+ root = self._initialized_model()
+ save_path = root.save(
+ session=session, file_prefix=checkpoint_prefix)
+ with context.eager_mode():
+ root = self._initialized_model()
+ self._set_sentinels(root)
+ root.restore(save_path).assert_consumed()
+ self._check_sentinels(root)
+
+ def testSaveEagerLoadGraph(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ with context.eager_mode():
+ root = self._initialized_model()
+ save_path = root.save(file_prefix=checkpoint_prefix)
+ with context.graph_mode():
+ save_graph = ops.Graph()
+ with save_graph.as_default(), self.test_session(
+ graph=save_graph):
+ root = self._initialized_model()
+ self._set_sentinels(root)
+ root.restore(save_path).assert_consumed().run_restore_ops()
+ self._check_sentinels(root)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
new file mode 100644
index 0000000000..bd5557f4fd
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
@@ -0,0 +1,1349 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Version 2 of class Optimizer."""
+# pylint: disable=g-bad-name
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gradients
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
+from tensorflow.python.training import optimizer as optimizer_v1
+from tensorflow.python.training import slot_creator
+from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.util import nest
+
+
+class _OptimizableVariable(object):
+ """Interface for abstracting over variables in the optimizers."""
+
+ @abc.abstractmethod
+ def target(self):
+ """Returns the optimization target for this variable."""
+ raise NotImplementedError("Calling an abstract method.")
+
+ @abc.abstractmethod
+ def update_op(self, optimizer, g, *args):
+ """Returns the update ops for updating the variable."""
+ raise NotImplementedError("Calling an abstract method.")
+
+
+class _RefVariableProcessor(_OptimizableVariable):
+ """Processor for Variable."""
+
+ def __init__(self, v):
+ self._v = v
+
+ def target(self):
+ return self._v._ref() # pylint: disable=protected-access
+
+ def update_op(self, optimizer, g, *args):
+ if isinstance(g, ops.Tensor):
+ update_op = optimizer._apply_dense(g, self._v, *args) # pylint: disable=protected-access
+ if self._v.constraint is not None:
+ with ops.control_dependencies([update_op]):
+ return self._v.assign(self._v.constraint(self._v))
+ else:
+ return update_op
+ else:
+ assert isinstance(g, ops.IndexedSlices), ("Gradient ", g, " is neither a "
+ "tensor nor IndexedSlices.")
+ if self._v.constraint is not None:
+ raise RuntimeError(
+ "Cannot use a constraint function on a sparse variable.")
+ # pylint: disable=protected-access
+ return optimizer._apply_sparse_duplicate_indices(g, self._v, *args)
+
+
+class _DenseReadResourceVariableProcessor(_OptimizableVariable):
+ """Processor for dense ResourceVariables."""
+
+ def __init__(self, v):
+ self._v = v
+
+ def target(self):
+ return self._v
+
+ def update_op(self, optimizer, g, *args):
+ # pylint: disable=protected-access
+ update_op = optimizer._resource_apply_dense(g, self._v.op.inputs[0], *args)
+ if self._v.constraint is not None:
+ with ops.control_dependencies([update_op]):
+ return self._v.assign(self._v.constraint(self._v))
+ else:
+ return update_op
+
+
+class _DenseResourceVariableProcessor(_OptimizableVariable):
+ """Processor for dense ResourceVariables."""
+
+ def __init__(self, v):
+ self._v = v
+
+ def target(self):
+ return self._v
+
+ def update_op(self, optimizer, g, *args):
+ # pylint: disable=protected-access
+ if isinstance(g, ops.IndexedSlices):
+ if self._v.constraint is not None:
+ raise RuntimeError(
+ "Cannot use a constraint function on a sparse variable.")
+ return optimizer._resource_apply_sparse_duplicate_indices(
+ g.values, self._v, g.indices, *args)
+ update_op = optimizer._resource_apply_dense(g, self._v, *args)
+ if self._v.constraint is not None:
+ with ops.control_dependencies([update_op]):
+ return self._v.assign(self._v.constraint(self._v))
+ else:
+ return update_op
+
+
+class _TensorProcessor(_OptimizableVariable):
+ """Processor for ordinary Tensors.
+
+ Even though a Tensor can't really be updated, sometimes it is useful to
+ compute the gradients with respect to a Tensor using the optimizer. Updating
+ the Tensor is, of course, unsupported.
+ """
+
+ def __init__(self, v):
+ self._v = v
+
+ def target(self):
+ return self._v
+
+ def update_op(self, optimizer, g, *args):
+ raise NotImplementedError("Trying to update a Tensor ", self._v)
+
+
+def _get_processor(v):
+ """The processor of v."""
+ if context.executing_eagerly():
+ if isinstance(v, ops.Tensor):
+ return _TensorProcessor(v)
+ else:
+ return _DenseResourceVariableProcessor(v)
+ if v.op.type == "VarHandleOp":
+ return _DenseResourceVariableProcessor(v)
+ if isinstance(v, variables.Variable):
+ return _RefVariableProcessor(v)
+ if isinstance(v, ops.Tensor):
+ return _TensorProcessor(v)
+ raise NotImplementedError("Trying to optimize unsupported type ", v)
+
+
+def _var_key_v2(var):
+ """Key for representing a primary variable, for looking up slots."""
+ # pylint: disable=protected-access
+ if hasattr(var, "_distributed_container"):
+ distributed_container = var._distributed_container()
+ assert distributed_container is not None
+ if context.executing_eagerly():
+ return distributed_container._unique_id
+ return distributed_container._shared_name
+ if context.executing_eagerly():
+ return var._unique_id
+ return var.op.name
+
+
+def _resolve(value, name):
+ if callable(value):
+ value = value()
+ return ops.convert_to_tensor(value, name=name)
+
+
+def _is_dynamic(value):
+ """Returns true if __init__ arg `value` should be re-evaluated each step."""
+ if callable(value): return True
+ # Don't need to do anything special in graph mode, since dynamic values
+ # will propagate correctly automatically.
+ # TODO(josh11b): Add per-device caching across steps using variables for
+ # truly static values once we add distributed support.
+ if context.executing_eagerly() and isinstance(
+ value, resource_variable_ops.ResourceVariable):
+ return True
+ return False
+
+
+class _OptimizerV2State(object):
+ """Holds per-graph and per-step optimizer state.
+
+ Use _init_with_static_hyper() to create the state for a graph, and then
+ _copy_with_dynamic_hyper() to convert that to state for a particular step.
+ The difference between the two is that the former only has hyper
+ parameter values that are static and the latter also has values that
+ can change every step (according to _is_dynamic()).
+ """
+
+ def __init__(self, op_name):
+ self._op_name = op_name
+
+ def _init_with_static_hyper(self, hyper):
+ """Initialize a fresh state object from hyper dict."""
+ # self._hyper contains a dict from name to a dict with the Tensor values.
+ # This dict starts with a single item with key "None" with the hyper
+ # parameter value converted to a Tensor. Other items have dtype keys
+ # with that Tensor cast to that dtype.
+ with ops.init_scope():
+ self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)}
+ for name, (dynamic, value) in sorted(hyper.items())
+ if not dynamic}
+ self._slots = {}
+ self._non_slot_dict = {}
+ # Extra state to help Optimizers implement Checkpointable. Holds information
+ # about variables which will be restored as soon as they're created.
+ self._deferred_dependencies = {} # Non-slot variables
+ self._deferred_slot_restorations = {} # Slot variables
+
+ def _copy_with_dynamic_hyper(self, hyper, distribution, non_slot_devices):
+ """Create a new state object for a particular step."""
+ ret = _OptimizerV2State(self._op_name)
+ # pylint: disable=protected-access
+ ret._slots = self._slots
+ ret._non_slot_dict = self._non_slot_dict
+ ret._deferred_dependencies = self._deferred_dependencies
+ ret._deferred_slot_restorations = self._deferred_slot_restorations
+ ret._hyper = {name: {None: _resolve(value, name)}
+ for name, (dynamic, value) in sorted(hyper.items())
+ if dynamic}
+ ret._hyper.update(self._hyper)
+ ret._non_slot_devices = non_slot_devices
+ ret._distribution = distribution
+ return ret
+
+ def _variables(self):
+ """Returns a list of all variables held by self."""
+ optimizer_variables = list(self._non_slot_dict.values())
+ for variable_dict in self._slots.values():
+ for slot_for_variable in variable_dict.values():
+ optimizer_variables.append(slot_for_variable)
+ # Sort variables by name so that the return is deterministic.
+ return sorted(optimizer_variables, key=lambda v: v.name)
+
+ def _slot_dict(self, slot_name):
+ """Returns a dict for caching slots created under the given name.
+
+ Args:
+ slot_name: Name for the slot.
+
+ Returns:
+ A dict that maps primary `Variable` objects to the slot created
+ for that variable, under the given slot name.
+ """
+ named_slots = self._slots.get(slot_name, None)
+ if named_slots is None:
+ named_slots = {}
+ self._slots[slot_name] = named_slots
+ return named_slots
+
+ def create_slot(self, var, val, slot_name, optional_op_name=None):
+ """Find or create a slot for a variable.
+
+ Args:
+ var: A `Variable` object.
+ val: A `Tensor`. The initial value of the slot.
+ slot_name: Name for the slot.
+ optional_op_name: Name to use when scoping the Variable that
+ needs to be created for the slot.
+
+ Returns:
+ A `Variable` object.
+ """
+ named_slots = self._slot_dict(slot_name)
+ var_key = _var_key_v2(var)
+ if var_key not in named_slots:
+ new_slot_variable = slot_creator.create_slot(
+ var, val, optional_op_name or self._op_name)
+ self._restore_slot_variable(
+ slot_name=slot_name, variable=var,
+ slot_variable=new_slot_variable)
+ named_slots[var_key] = new_slot_variable
+ return named_slots[var_key]
+
+ def create_slot_with_initializer(self, var, initializer, shape, dtype,
+ slot_name, optional_op_name=None):
+ """Find or create a slot for a variable, using an Initializer.
+
+ Args:
+ var: A `Variable` object.
+ initializer: An `Initializer`. The initial value of the slot.
+ shape: Shape of the initial value of the slot.
+ dtype: Type of the value of the slot.
+ slot_name: Name for the slot.
+ optional_op_name: Name to use when scoping the Variable that
+ needs to be created for the slot.
+
+ Returns:
+ A `Variable` object.
+ """
+ named_slots = self._slot_dict(slot_name)
+ var_key = _var_key_v2(var)
+ if var_key not in named_slots:
+ new_slot_variable = slot_creator.create_slot_with_initializer(
+ var, initializer, shape, dtype, optional_op_name or self._op_name)
+ self._restore_slot_variable(
+ slot_name=slot_name, variable=var,
+ slot_variable=new_slot_variable)
+ named_slots[var_key] = new_slot_variable
+ return named_slots[var_key]
+
+ def zeros_slot(self, var, slot_name, optional_op_name=None):
+ """Find or create a slot initialized with 0.0.
+
+ Args:
+ var: A `Variable` object.
+ slot_name: Name for the slot.
+ optional_op_name: Name to use when scoping the Variable that
+ needs to be created for the slot.
+
+ Returns:
+ A `Variable` object.
+ """
+ named_slots = self._slot_dict(slot_name)
+ var_key = _var_key_v2(var)
+ if var_key not in named_slots:
+ new_slot_variable = slot_creator.create_zeros_slot(
+ var, optional_op_name or self._op_name)
+ self._restore_slot_variable(
+ slot_name=slot_name, variable=var,
+ slot_variable=new_slot_variable)
+ named_slots[var_key] = new_slot_variable
+ return named_slots[var_key]
+
+ def _create_or_restore_slot_variable(
+ self, slot_variable_position, slot_name, variable,
+ optional_op_name=None):
+ """Restore a slot variable's value, possibly creating it.
+
+ Called when a variable which has an associated slot variable is created or
+ restored. When executing eagerly, we create the slot variable with a
+ restoring initializer.
+
+ No new variables are created when graph building. Instead,
+ _restore_slot_variable catches these after normal creation and adds restore
+ ops to the graph. This method is nonetheless important when graph building
+ for the case when a slot variable has already been created but `variable`
+ has just been added to a dependency graph (causing us to realize that the
+ slot variable needs to be restored).
+
+ Args:
+ slot_variable_position: A `checkpointable._CheckpointPosition` object
+ indicating the slot variable `Checkpointable` object to be restored.
+ slot_name: The name of this `Optimizer`'s slot to restore into.
+ variable: The variable object this slot is being created for.
+ optional_op_name: Name to use when scoping the Variable that
+ needs to be created for the slot.
+ """
+ slot_variable = self.get_slot(var=variable, name=slot_name)
+ if (slot_variable is None and context.executing_eagerly() and
+ slot_variable_position.is_simple_variable()
+ # Defer slot variable creation if there is an active variable creator
+ # scope. Generally we'd like to eagerly create/restore slot variables
+ # when possible, but this may mean that scopes intended to catch
+ # `variable` also catch its eagerly created slot variable
+ # unintentionally (specifically make_template would add a dependency on
+ # a slot variable if not for this case). Deferring is mostly harmless
+ # (aside from double initialization), and makes variable creator scopes
+ # behave the same way they do when graph building.
+ and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access
+ initializer = checkpointable.CheckpointInitialValue(
+ checkpoint_position=slot_variable_position)
+ slot_variable = self.create_slot(
+ var=variable,
+ val=initializer,
+ slot_name=slot_name,
+ optional_op_name=optional_op_name)
+ # Optimizers do not have unconditional dependencies on their slot
+ # variables (nor do any other objects). They are only saved if the
+ # variables they were created for are also saved.
+ if slot_variable is not None:
+ # If we've either made this slot variable, or if we've pulled out an
+ # existing slot variable, we should restore it.
+ slot_variable_position.restore(slot_variable)
+ else:
+ # We didn't make the slot variable. Defer restoring until it gets created
+ # normally. We keep a list rather than the one with the highest restore
+ # UID in case slot variables have their own dependencies, in which case
+ # those could differ between restores.
+ variable_key = _var_key_v2(variable)
+ self._deferred_slot_restorations.setdefault(
+ slot_name, {}).setdefault(variable_key, []).append(
+ slot_variable_position)
+
+ def get_slot(self, var, name):
+ """Return a slot named `name` created for `var` by the Optimizer.
+
+ Some `Optimizer` subclasses use additional variables. For example
+ `Momentum` and `Adagrad` use variables to accumulate updates. This method
+ gives access to these `Variable` objects if for some reason you need them.
+
+ Use `get_slot_names()` to get the list of slot names created by the
+ `Optimizer`.
+
+ Args:
+ var: A variable passed to `minimize()` or `apply_gradients()`.
+ name: A string.
+
+ Returns:
+ The `Variable` for the slot if it was created, `None` otherwise.
+ """
+ named_slots = self._slots.get(name, None)
+ if not named_slots:
+ return None
+ return named_slots.get(_var_key_v2(var), None)
+
+ def get_slot_names(self):
+ """Return a list of the names of slots created by the `Optimizer`.
+
+ See `get_slot()`.
+
+ Returns:
+ A list of strings.
+ """
+ return sorted(self._slots.keys())
+
+ def create_non_slot(self, initial_value, name, colocate_with=None):
+ """Add an extra variable, not associated with a slot."""
+ v = self._non_slot_dict.get(name, None)
+ if v is None:
+ if colocate_with is None: colocate_with = self._non_slot_devices
+ with self._distribution.colocate_vars_with(colocate_with):
+ # TODO(josh11b): Use get_variable() except for the legacy Adam use case.
+ v = variable_scope.variable(initial_value, name=name, trainable=False)
+ self._non_slot_dict[name] = v
+ deferred_dependencies_list = self._deferred_dependencies.pop(name, ())
+ for checkpoint_position in sorted(
+ deferred_dependencies_list,
+ key=lambda restore: restore.checkpoint.restore_uid,
+ reverse=True):
+ checkpoint_position.restore(v)
+ return v
+
+ def _restore_slot_variable(self, slot_name, variable, slot_variable):
+ """Restore a newly created slot variable's value."""
+ variable_key = _var_key_v2(variable)
+ deferred_restorations = self._deferred_slot_restorations.get(
+ slot_name, {}).pop(variable_key, [])
+ # Iterate over restores, highest restore UID first to minimize the number
+ # of assignments.
+ deferred_restorations.sort(key=lambda position: position.restore_uid,
+ reverse=True)
+ for checkpoint_position in deferred_restorations:
+ checkpoint_position.restore(slot_variable)
+
+ def get_non_slot(self, name):
+ """Returns the non-slot variable identified by `name`."""
+ return self._non_slot_dict.get(name, None)
+
+ def get_hyper(self, name, dtype=None):
+ """Returns the `name` hyper parameter, optionally cast to `dtype`."""
+ dtype_dict = self._hyper[name]
+ # Do we have the value cast to dtype already cached? This should always
+ # succeed when dtype is None.
+ if dtype in dtype_dict:
+ return dtype_dict[dtype]
+ # Not cached, cast to dtype and save the result in the cache.
+ result = math_ops.cast(dtype_dict[None], dtype)
+ dtype_dict[dtype] = result
+ return result
+
+
+class OptimizerV2(optimizer_v1.Optimizer):
+ """Updated base class for optimizers.
+
+ This class defines the API to add Ops to train a model. You never use this
+ class directly, but instead instantiate one of its subclasses such as
+ `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`.
+
+ ### Usage
+
+ ```python
+ # Create an optimizer with the desired parameters.
+ opt = GradientDescentOptimizer(learning_rate=0.1)
+ # Add Ops to the graph to minimize a cost by updating a list of variables.
+ # "cost" is a Tensor, and the list of variables contains tf.Variable
+ # objects.
+ opt_op = opt.minimize(cost, var_list=<list of variables>)
+ ```
+
+ In the training program you will just have to run the returned Op.
+
+ ```python
+ # Execute opt_op to do one step of training:
+ opt_op.run()
+ ```
+
+ ### Processing gradients before applying them.
+
+ Calling `minimize()` takes care of both computing the gradients and
+ applying them to the variables. If you want to process the gradients
+ before applying them you can instead use the optimizer in three steps:
+
+ 1. Compute the gradients with `compute_gradients()`.
+ 2. Process the gradients as you wish.
+ 3. Apply the processed gradients with `apply_gradients()`.
+
+ Example:
+
+ ```python
+ # Create an optimizer.
+ opt = GradientDescentOptimizer(learning_rate=0.1)
+
+ # Compute the gradients for a list of variables.
+ grads_and_vars = opt.compute_gradients(loss, <list of variables>)
+
+ # grads_and_vars is a list of tuples (gradient, variable). Do whatever you
+ # need to the 'gradient' part, for example cap them, etc.
+ capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars]
+
+ # Ask the optimizer to apply the capped gradients.
+ opt.apply_gradients(capped_grads_and_vars)
+ ```
+
+ ### Gating Gradients
+
+ Both `minimize()` and `compute_gradients()` accept a `gate_gradients`
+ argument that controls the degree of parallelism during the application of
+ the gradients.
+
+ The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`.
+
+ <b>`GATE_NONE`</b>: Compute and apply gradients in parallel. This provides
+ the maximum parallelism in execution, at the cost of some non-reproducibility
+ in the results. For example the two gradients of `matmul` depend on the input
+ values: With `GATE_NONE` one of the gradients could be applied to one of the
+ inputs _before_ the other gradient is computed resulting in non-reproducible
+ results.
+
+ <b>`GATE_OP`</b>: For each Op, make sure all gradients are computed before
+ they are used. This prevents race conditions for Ops that generate gradients
+ for multiple inputs where the gradients depend on the inputs.
+
+ <b>`GATE_GRAPH`</b>: Make sure all gradients for all variables are computed
+ before any one of them is used. This provides the least parallelism but can
+ be useful if you want to process all gradients before applying any of them.
+
+ ### Slots
+
+ Some optimizer subclasses, such as `MomentumOptimizer` and `AdagradOptimizer`
+ allocate and manage additional variables associated with the variables to
+ train. These are called <i>Slots</i>. Slots have names and you can ask the
+ optimizer for the names of the slots that it uses. Once you have a slot name
+ you can ask the optimizer for the variable it created to hold the slot value.
+
+ This can be useful if you want to log debug a training algorithm, report stats
+ about the slots, etc.
+
+ ### Non-slot variables
+
+ Some optimizer subclasses, such as `AdamOptimizer` have variables that
+ are not associated with the variables to train, just the step itself.
+
+ ### Hyper parameters
+
+ These are arguments passed to the optimizer subclass constructor
+ (the `__init__` method), and then passed to `self._set_hyper()`.
+ They can be either regular Python values (like 1.0), tensors, or
+ callables. If they are callable, the callable will be called during
+ `apply_gradients()` to get the value for the hyper parameter.
+
+ ### State
+
+ Internal methods are passed a `state` argument with the correct
+ values to use for the slot and non-slot variables, and the hyper
+ parameters.
+ """
+
+ # Values for gate_gradients.
+ GATE_NONE = 0
+ GATE_OP = 1
+ GATE_GRAPH = 2
+
+ def __init__(self, name):
+ """Create a new Optimizer.
+
+ This must be called by the constructors of subclasses.
+ Note that Optimizer instances should not bind to a single graph,
+ and so shouldn't keep Tensors as member variables. Generally
+ you should be able to use the _set_hyper()/state.get_hyper()
+ facility instead.
+
+ Args:
+ name: A non-empty string. The name to use for accumulators created
+ for the optimizer.
+
+ Raises:
+ ValueError: If name is malformed.
+ RuntimeError: If _create_slots has been overridden instead of
+ _create_vars.
+ """
+ # Note: We intentionally don't call parent __init__.
+
+ # Optimizer._create_slots was replaced by _create_vars in OptimizerV2.
+ if (self.__class__._create_slots.__code__ is not # pylint: disable=protected-access
+ OptimizerV2._create_slots.__code__):
+ raise RuntimeError("Override _create_vars instead of _create_slots when "
+ "descending from OptimizerV2 (class %s)" %
+ self.__class__.__name__)
+ if not name:
+ raise ValueError("Must specify the optimizer name")
+
+ self._use_locking = False
+ self._name = name
+ # Map from graph_key to state for that graph. We use the graph_key
+ # since it works in both eager and graph mode, and gives the outer
+ # graph inside functions.
+ tower_context = distribution_strategy_context.get_tower_context()
+ if tower_context is None:
+ # In a cross-tower context for a DistributionStrategy, which means
+ # only one Optimizer will be created, not one per tower.
+ self._per_graph_state = {}
+ else:
+ # We use get_tower_context().merge_call() to get a single dict
+ # shared across all model replicas when running with a
+ # DistributionStrategy.
+ self._per_graph_state = tower_context.merge_call(lambda _: {})
+
+ # Hyper parameters, and whether they should be re-evaluated every step.
+ self._hyper = {}
+
+ def _set_hyper(self, name, value):
+ self._hyper[name] = (_is_dynamic(value), value)
+
+ def minimize(self, loss, global_step=None, var_list=None,
+ gate_gradients=GATE_OP, aggregation_method=None,
+ colocate_gradients_with_ops=False, name=None,
+ grad_loss=None, stop_gradients=None,
+ scale_loss_by_num_towers=None):
+ """Add operations to minimize `loss` by updating `var_list`.
+
+ This method simply combines calls `compute_gradients()` and
+ `apply_gradients()`. If you want to process the gradient before applying
+ them call `compute_gradients()` and `apply_gradients()` explicitly instead
+ of using this function.
+
+ Args:
+ loss: A `Tensor` containing the value to minimize.
+ global_step: Optional `Variable` to increment by one after the
+ variables have been updated.
+ var_list: Optional list or tuple of `Variable` objects to update to
+ minimize `loss`. Defaults to the list of variables collected in
+ the graph under the key `GraphKeys.TRAINABLE_VARIABLES`.
+ gate_gradients: How to gate the computation of gradients. Can be
+ `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
+ aggregation_method: Specifies the method used to combine gradient terms.
+ Valid values are defined in the class `AggregationMethod`.
+ colocate_gradients_with_ops: If True, try colocating gradients with
+ the corresponding op.
+ name: Optional name for the returned operation.
+ grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
+ stop_gradients: Optional. A Tensor or list of tensors not to differentiate
+ through.
+ scale_loss_by_num_towers: Optional boolean. If true, scale the loss
+ down by the number of towers. By default, auto-detects whether this
+ is needed.
+
+ Returns:
+ An Operation that updates the variables in `var_list`. If `global_step`
+ was not `None`, that operation also increments `global_step`.
+
+ Raises:
+ ValueError: If some of the variables are not `Variable` objects.
+
+ @compatibility(eager)
+ When eager execution is enabled, `loss` should be a Python function that
+ takes elements of `var_list` as arguments and computes the value to be
+ minimized. If `var_list` is None, `loss` should take no arguments.
+ Minimization (and gradient computation) is done with respect to the
+ elements of `var_list` if not None, else with respect to any trainable
+ variables created during the execution of the `loss` function.
+ `gate_gradients`, `aggregation_method`, `colocate_gradients_with_ops` and
+ `grad_loss` are ignored when eager execution is enabled.
+ @end_compatibility
+ """
+ grads_and_vars = self.compute_gradients(
+ loss, var_list=var_list, gate_gradients=gate_gradients,
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ grad_loss=grad_loss, stop_gradients=stop_gradients,
+ scale_loss_by_num_towers=scale_loss_by_num_towers)
+
+ vars_with_grad = [v for g, v in grads_and_vars if g is not None]
+ if not vars_with_grad:
+ raise ValueError(
+ "No gradients provided for any variable, check your graph for ops"
+ " that do not support gradients, between variables %s and loss %s." %
+ ([str(v) for _, v in grads_and_vars], loss))
+
+ return self.apply_gradients(grads_and_vars, global_step=global_step,
+ name=name)
+
+ def compute_gradients(self, loss, var_list=None,
+ gate_gradients=GATE_OP,
+ aggregation_method=None,
+ colocate_gradients_with_ops=False,
+ grad_loss=None, stop_gradients=None,
+ scale_loss_by_num_towers=None):
+ """Compute gradients of `loss` for the variables in `var_list`.
+
+ This is the first part of `minimize()`. It returns a list
+ of (gradient, variable) pairs where "gradient" is the gradient
+ for "variable". Note that "gradient" can be a `Tensor`, an
+ `IndexedSlices`, or `None` if there is no gradient for the
+ given variable.
+
+ Args:
+ loss: A Tensor containing the value to minimize or a callable taking
+ no arguments which returns the value to minimize. When eager execution
+ is enabled it must be a callable.
+ var_list: Optional list or tuple of `tf.Variable` to update to minimize
+ `loss`. Defaults to the list of variables collected in the graph
+ under the key `GraphKeys.TRAINABLE_VARIABLES`.
+ gate_gradients: How to gate the computation of gradients. Can be
+ `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
+ aggregation_method: Specifies the method used to combine gradient terms.
+ Valid values are defined in the class `AggregationMethod`.
+ colocate_gradients_with_ops: If True, try colocating gradients with
+ the corresponding op.
+ grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
+ stop_gradients: Optional. A Tensor or list of tensors not to differentiate
+ through.
+ scale_loss_by_num_towers: Optional boolean. If true, scale the loss
+ down by the number of towers. By default, auto-detects whether this
+ is needed.
+
+ Returns:
+ A list of (gradient, variable) pairs. Variable is always present, but
+ gradient can be `None`.
+
+ Raises:
+ TypeError: If `var_list` contains anything else than `Variable` objects.
+ ValueError: If some arguments are invalid.
+ RuntimeError: If called with eager execution enabled and `loss` is
+ not callable.
+
+ @compatibility(eager)
+ When eager execution is enabled, `gate_gradients`, `aggregation_method`,
+ and `colocate_gradients_with_ops` are ignored.
+ @end_compatibility
+ """
+ # TODO(josh11b): Test that we handle weight decay in a reasonable way.
+ if callable(loss):
+ with backprop.GradientTape() as tape:
+ if var_list is not None:
+ tape.watch(var_list)
+ loss_value = loss()
+
+ # Scale loss for number of towers (callable-loss case). In this case,
+ # we have to be careful to call distribute_lib.get_loss_reduction()
+ # *after* loss() is evaluated, so we know what loss reduction it uses.
+ if scale_loss_by_num_towers is None:
+ scale_loss_by_num_towers = (
+ distribute_lib.get_loss_reduction() ==
+ variable_scope.VariableAggregation.MEAN)
+ if scale_loss_by_num_towers:
+ num_towers = distribution_strategy_context.get_distribution_strategy(
+ ).num_towers
+ if num_towers > 1:
+ loss_value *= 1. / num_towers
+
+ if var_list is None:
+ var_list = tape.watched_variables()
+ grads = tape.gradient(loss_value, var_list, grad_loss)
+ return list(zip(grads, var_list))
+ if context.executing_eagerly():
+ raise RuntimeError(
+ "`loss` passed to Optimizer.compute_gradients should "
+ "be a function when eager execution is enabled.")
+
+ # Scale loss for number of towers (non-callable-loss case).
+ if scale_loss_by_num_towers is None:
+ scale_loss_by_num_towers = (
+ distribute_lib.get_loss_reduction() ==
+ variable_scope.VariableAggregation.MEAN)
+ if scale_loss_by_num_towers:
+ num_towers = distribution_strategy_context.get_distribution_strategy(
+ ).num_towers
+ if num_towers > 1:
+ loss *= 1. / num_towers
+
+ if gate_gradients not in [optimizer_v1.Optimizer.GATE_NONE,
+ optimizer_v1.Optimizer.GATE_OP,
+ optimizer_v1.Optimizer.GATE_GRAPH]:
+ raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
+ "Optimizer.GATE_OP, Optimizer.GATE_GRAPH. Not %s" %
+ gate_gradients)
+ self._assert_valid_dtypes([loss])
+ if grad_loss is not None:
+ self._assert_valid_dtypes([grad_loss])
+ if var_list is None:
+ var_list = (
+ variables.trainable_variables() +
+ ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
+ else:
+ var_list = nest.flatten(var_list)
+ # pylint: disable=protected-access
+ var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS)
+ # pylint: enable=protected-access
+ processors = [_get_processor(v) for v in var_list]
+ if not var_list:
+ raise ValueError("No variables to optimize.")
+ var_refs = [p.target() for p in processors]
+ grads = gradients.gradients(
+ loss, var_refs, grad_ys=grad_loss,
+ gate_gradients=(gate_gradients == optimizer_v1.Optimizer.GATE_OP),
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ stop_gradients=stop_gradients)
+ if gate_gradients == optimizer_v1.Optimizer.GATE_GRAPH:
+ grads = control_flow_ops.tuple(grads)
+ grads_and_vars = list(zip(grads, var_list))
+ self._assert_valid_dtypes(
+ [v for g, v in grads_and_vars
+ if g is not None and v.dtype != dtypes.resource])
+ return grads_and_vars
+
+ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
+ """Apply gradients to variables.
+
+ This is the second part of `minimize()`. It returns an `Operation` that
+ applies gradients.
+
+ Args:
+ grads_and_vars: List of (gradient, variable) pairs as returned by
+ `compute_gradients()`.
+ global_step: Optional `Variable` to increment by one after the
+ variables have been updated.
+ name: Optional name for the returned operation. Default to the
+ name passed to the `Optimizer` constructor.
+
+ Returns:
+ An `Operation` that applies the specified gradients. If `global_step`
+ was not None, that operation also increments `global_step`.
+
+ Raises:
+ TypeError: If `grads_and_vars` is malformed.
+ ValueError: If none of the variables have gradients.
+ """
+ # This is a default implementation of apply_gradients() that can be shared
+ # by most optimizers. It relies on the subclass implementing the following
+ # methods: _create_vars(), _prepare(), _apply_dense(), and _apply_sparse().
+
+ # Filter out variables with gradients of `None`.
+ grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works.
+ if not grads_and_vars:
+ raise ValueError("No variables provided.")
+ filtered = tuple((g, v) for (g, v) in grads_and_vars if g is not None)
+ if not filtered:
+ raise ValueError("No gradients provided for any variable: %s." %
+ ([str(v) for _, v in grads_and_vars],))
+ return distribution_strategy_context.get_tower_context().merge_call(
+ self._distributed_apply, filtered, global_step=global_step, name=name)
+
+ def _get_or_create_state(self, var_list=None):
+ """Either looks up or creates `_OptimizerV2State`.
+
+ If any variables are available, they should be passed via the `var_list`
+ argument, and these will be used to determine the graph to create/retrieve
+ state for. Otherwise the returned state is for the current default graph.
+
+ Args:
+ var_list: A list of variables to extract a graph from.
+
+ Returns:
+ An `_OptimizerV2State` object.
+ """
+ # Determine the graph_key from the current graph.
+ eager_execution = context.executing_eagerly()
+ if eager_execution or var_list is None:
+ graph = ops.get_default_graph()
+ else:
+ graph = ops._get_graph_from_inputs(var_list) # pylint: disable=protected-access
+ assert graph is not None
+ graph_key = graph._graph_key # pylint: disable=protected-access
+
+ # Get the per graph state by looking up the graph_key.
+ if graph_key in self._per_graph_state:
+ per_graph_state = self._per_graph_state[graph_key]
+ else:
+ per_graph_state = _OptimizerV2State(self._name)
+ per_graph_state._init_with_static_hyper(self._hyper) # pylint: disable=protected-access
+ self._per_graph_state[graph_key] = per_graph_state
+ return per_graph_state
+
+ def _distributed_apply(self, distribution, grads_and_vars, global_step, name):
+ """`apply_gradients` for use with a `DistributionStrategy`."""
+ reduced_grads = distribution.batch_reduce(
+ variable_scope.VariableAggregation.SUM, grads_and_vars)
+ var_list = [v for _, v in grads_and_vars]
+ grads_and_vars = zip(reduced_grads, var_list)
+
+ unwrapped_var_list = [x for v in var_list for x in distribution.unwrap(v)]
+ eager_execution = context.executing_eagerly()
+ if eager_execution:
+ # Give a clear error in this case instead of "name not supported
+ # for Eager Tensors" when we compute non_slot_devices.
+ for v in unwrapped_var_list:
+ if isinstance(v, ops.Tensor):
+ raise NotImplementedError("Trying to update a Tensor ", v)
+
+ with ops.name_scope(name, self._name) as name:
+ per_graph_state = self._get_or_create_state(var_list=unwrapped_var_list)
+ # Include the current value of any dynamic hyper parameters in `state`.
+ non_slot_devices = distribution.non_slot_devices(var_list)
+ state = per_graph_state._copy_with_dynamic_hyper( # pylint: disable=protected-access
+ self._hyper, distribution, non_slot_devices)
+
+ # Create any slot and non-slot variables we need in `state`.
+ with ops.init_scope():
+ self._create_vars(var_list, state)
+
+ with ops.name_scope(name): # Re-enter name_scope created above
+ # Give the child class a chance to do something before we start
+ # applying gradients.
+ self._prepare(state)
+
+ def update(v, g):
+ """Update variable `v` using gradient `g`."""
+ assert v is not None
+
+ # Convert the grad to Tensor or IndexedSlices if necessary, and
+ # look up a processor for each variable's type.
+ try:
+ g = ops.convert_to_tensor_or_indexed_slices(g)
+ except TypeError:
+ raise TypeError(
+ "Gradient must be convertible to a Tensor"
+ " or IndexedSlices, or None: %s" % g)
+ if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
+ raise TypeError(
+ "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
+ processor = _get_processor(v)
+
+ # We colocate all ops created in _apply_dense or _apply_sparse
+ # on the same device as the variable.
+ # TODO(apassos): figure out how to get the variable name here.
+ scope_name = "" if eager_execution else v.op.name
+ # device_policy is set because non-mirrored tensors will be read in
+ # `update_op`.
+ # TODO(josh11b): Make different state objects for each device to
+ # avoid needing to set the device_policy.
+ with ops.name_scope("update_" + scope_name), \
+ context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
+ return processor.update_op(self, g, state)
+
+ # Use the processors to update the variables.
+ update_ops = []
+ for grad, var in grads_and_vars:
+ update_ops.extend(distribution.update(var, update, grad, grouped=False))
+
+ # Give the child class a chance to do something after applying
+ # gradients
+ def finish():
+ # TODO(josh11b): Make different state objects for each device to
+ # avoid needing to set the device_policy.
+ with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
+ return self._finish(state)
+
+ update_ops = control_flow_ops.group(update_ops)
+ with ops.control_dependencies([update_ops]):
+ finish_updates = distribution.update_non_slot(
+ non_slot_devices, finish, grouped=False)
+ # We said grouped=False, which means finish_updates is always a list.
+ # It will be [None] when finish() returns None.
+ if finish_updates == [None]:
+ finish_updates = [update_ops]
+
+ # Update `global_step` (if any).
+ if global_step is None:
+ apply_updates = distribution.group(finish_updates, name=name)
+ else:
+ with ops.control_dependencies(finish_updates):
+
+ def update_global_step(global_step, name):
+ return global_step.assign_add(1, read_value=False, name=name)
+
+ apply_updates = distribution.update(global_step, update_global_step,
+ name)
+
+ # Add the training op to the TRAIN_OP graph collection in graph mode.
+ if not eager_execution:
+ if isinstance(apply_updates, ops.Tensor):
+ apply_updates = apply_updates.op
+ train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
+ if apply_updates not in train_op:
+ train_op.append(apply_updates)
+
+ return apply_updates
+
+ def get_slot(self, var, name):
+ """Return a slot named `name` created for `var` by the Optimizer.
+
+ Some `Optimizer` subclasses use additional variables. For example
+ `Momentum` and `Adagrad` use variables to accumulate updates. This method
+ gives access to these `Variable` objects if for some reason you need them.
+
+ Use `get_slot_names()` to get the list of slot names created by the
+ `Optimizer`.
+
+ Args:
+ var: A variable passed to `minimize()` or `apply_gradients()`.
+ name: A string.
+
+ Returns:
+ The `Variable` for the slot if it was created, `None` otherwise.
+ """
+ state = self._get_state_for_var(var)
+ return state.get_slot(var, name) if state is not None else None
+
+ def get_slot_names(self):
+ """Return a list of the names of slots created by the `Optimizer`.
+
+ See `get_slot()`.
+
+ Returns:
+ A list of strings.
+ """
+ state = self._get_per_graph_state()
+ return state.get_slot_names() if state is not None else []
+
+ def variables(self):
+ """A list of variables which encode the current state of `Optimizer`.
+
+ Includes slot variables and additional global variables created by the
+ optimizer in the current default graph.
+
+ Returns:
+ A list of variables.
+ """
+ state = self._get_per_graph_state()
+ return state._variables() if state is not None else [] # pylint: disable=protected-access
+
+ # --------------
+ # Methods to be implemented by subclasses if they want to use the
+ # inherited implementation of apply_gradients() or compute_gradients().
+ # --------------
+ def _create_vars(self, var_list, state):
+ """Create all slots needed by the variables and any non-slot variables.
+
+ Args:
+ var_list: A list of `Variable` objects.
+ state: An object with these methods:
+ `create_slot(var, val, slot_name, optional_op_name)`,
+ `create_slot_with_initializer(`
+ `var, initializer, shape, dtype, slot_name, optional_op_name)`,
+ `zeros_slot(var, slot_name, optional_op_name)`,
+ `create_non_slot_variable(initial_value, name, colocate_with)`,
+ `get_hyper(name)`
+ """
+ # No slots needed by default
+ pass
+
+ def _prepare(self, state):
+ """Code to execute before applying gradients.
+
+ Note that most uses of _prepare() in Optimizer have been subsumed
+ by explicit support for hyper parameters in OptimizerV2
+
+ Args:
+ state: An object with a `get_hyper(name)` method.
+
+ Returns:
+ Return value will be ignored.
+ """
+ pass
+
+ def _apply_dense(self, grad, var, state):
+ """Add ops to apply dense gradients to `var`.
+
+ Args:
+ grad: A `Tensor`.
+ var: A `Variable` object.
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ An `Operation`.
+ """
+ raise NotImplementedError()
+
+ def _resource_apply_dense(self, grad, handle, state):
+ """Add ops to apply dense gradients to the variable `handle`.
+
+ Args:
+ grad: a `Tensor` representing the gradient.
+ handle: a `Tensor` of dtype `resource` which points to the variable
+ to be updated.
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ An `Operation` which updates the value of the variable.
+ """
+ raise NotImplementedError()
+
+ def _resource_apply_sparse_duplicate_indices(
+ self, grad, handle, indices, state):
+ """Add ops to apply sparse gradients to `handle`, with repeated indices.
+
+ Optimizers which override this method must deal with repeated indices. See
+ the docstring of `_apply_sparse_duplicate_indices` for details. By default
+ the correct behavior, to sum non-unique indices and their associated
+ gradients, is enforced by first pre-processing `grad` and `indices` and
+ passing them on to `_resource_apply_sparse`. Optimizers which deal correctly
+ with duplicate indices may instead override this method to avoid the
+ overhead of summing.
+
+ Args:
+ grad: a `Tensor` representing the gradient for the affected indices.
+ handle: a `Tensor` of dtype `resource` which points to the variable
+ to be updated.
+ indices: a `Tensor` of integral type representing the indices for
+ which the gradient is nonzero. Indices may be repeated.
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ An `Operation` which updates the value of the variable.
+ """
+ # pylint: disable=protected-access
+ summed_grad, unique_indices = optimizer_v1._deduplicate_indexed_slices(
+ values=grad, indices=indices)
+ # pylint: enable=protected-access
+ return self._resource_apply_sparse(
+ summed_grad, handle, unique_indices, state)
+
+ def _resource_apply_sparse(self, grad, handle, indices, state):
+ """Add ops to apply sparse gradients to the variable `handle`.
+
+ Similar to `_apply_sparse`, the `indices` argument to this method has been
+ de-duplicated. Optimizers which deal correctly with non-unique indices may
+ instead override `_resource_apply_sparse_duplicate_indices` to avoid this
+ overhead.
+
+ Args:
+ grad: a `Tensor` representing the gradient for the affected indices.
+ handle: a `Tensor` of dtype `resource` which points to the variable
+ to be updated.
+ indices: a `Tensor` of integral type representing the indices for
+ which the gradient is nonzero. Indices are unique.
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ An `Operation` which updates the value of the variable.
+ """
+ raise NotImplementedError()
+
+ def _apply_sparse_duplicate_indices(self, grad, var, state):
+ """Add ops to apply sparse gradients to `var`, with repeated sparse indices.
+
+ Optimizers which override this method must deal with IndexedSlices objects
+ such as the following:
+
+ IndexedSlicesValue(values=[1, 1], indices=[0, 0], dense_shape=[1])
+
+ The correct interpretation is:
+
+ IndexedSlicesValue(values=[2], indices=[0], dense_shape=[1])
+
+ Many optimizers deal incorrectly with repeated indices when updating based
+ on sparse gradients (e.g. summing squares rather than squaring the sum, or
+ applying momentum terms multiple times). Adding first is always the correct
+ behavior, so this is enforced here by reconstructing the IndexedSlices to
+ have only unique indices, then calling _apply_sparse.
+
+ Optimizers which deal correctly with repeated indices may instead override
+ this method to avoid the overhead of summing indices.
+
+ Args:
+ grad: `IndexedSlices`.
+ var: A `Variable` object.
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ An `Operation`.
+ """
+ # pylint: disable=protected-access
+ summed_values, unique_indices = optimizer_v1._deduplicate_indexed_slices(
+ values=grad.values, indices=grad.indices)
+ # pylint: enable=protected-access
+ gradient_no_duplicate_indices = ops.IndexedSlices(
+ indices=unique_indices,
+ values=summed_values,
+ dense_shape=grad.dense_shape)
+ return self._apply_sparse(gradient_no_duplicate_indices, var, state)
+
+ def _apply_sparse(self, grad, var, state):
+ """Add ops to apply sparse gradients to `var`.
+
+ The IndexedSlices object passed to `grad` in this function is by default
+ pre-processed in `_apply_sparse_duplicate_indices` to remove duplicate
+ indices (see its docstring for details). Optimizers which can tolerate or
+ have correct special cases for duplicate sparse indices may override
+ `_apply_sparse_duplicate_indices` instead of this function, avoiding that
+ overhead.
+
+ Args:
+ grad: `IndexedSlices`, with no repeated indices.
+ var: A `Variable` object.
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ An `Operation`.
+ """
+ raise NotImplementedError()
+
+ def _finish(self, state):
+ """Do what is needed to finish the update.
+
+ This is called inside a scope colocated with any non-slot variables.
+
+ Args:
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ The operation to apply updates, or None if no updates.
+ """
+ return None
+
+ # --------------
+ # Utility methods for subclasses.
+ # --------------
+ def _get_per_graph_state(self):
+ # pylint: disable=protected-access
+ return self._per_graph_state.get(ops.get_default_graph()._graph_key, None)
+
+ def _get_state_for_var(self, var):
+ # pylint: disable=protected-access
+ return self._per_graph_state.get(var._graph_key, None)
+
+ # --------------
+ # Overridden methods from Checkpointable.
+ # --------------
+
+ def _track_checkpointable(self, *args, **kwargs):
+ """Optimizers may not track dependencies. Raises an error."""
+ raise NotImplementedError(
+ "Optimizers may not have dependencies. File a feature request if this "
+ "limitation bothers you.")
+
+ @property
+ def _checkpoint_dependencies(self):
+ """From Checkpointable. Gather graph-specific non-slot variables to save."""
+ current_graph_non_slot_variables = []
+ state = self._get_per_graph_state()
+ if state is not None:
+ for name, variable_object in sorted(
+ state._non_slot_dict.items(), # pylint: disable=protected-access
+ # Avoid comparing variables
+ key=lambda item: item[0]):
+ current_graph_non_slot_variables.append(
+ checkpointable.CheckpointableReference(
+ name=name, ref=variable_object))
+ # Note: ignores super(); Optimizers may not have any dependencies outside of
+ # state objects.
+ return current_graph_non_slot_variables
+
+ def _lookup_dependency(self, name):
+ """From Checkpointable. Find a non-slot variable in the current graph."""
+ state = self._get_per_graph_state()
+ if state is None:
+ return None
+ else:
+ return state.get_non_slot(name)
+
+ @property
+ def _deferred_dependencies(self):
+ """Lets Checkpointable know where non-slot variables are created.
+
+ If necessary, creates a new state object for the current default graph.
+ Checkpointable will then add entries to that state's deferred dependency
+ dictionary. The state object will check that dictionary when creating
+ non-slot variables, restoring their value if an entry is found.
+
+ Returns:
+ A dictionary which holds deferred dependencies for the current default
+ graph.
+ """
+ state = self._get_or_create_state()
+ return state._deferred_dependencies # pylint: disable=protected-access
+
+ def _create_or_restore_slot_variable(
+ self, slot_variable_position, slot_name, variable):
+ """Checkpointable: Restore a slot variable's value, possibly creating it.
+
+ Called when a variable which has an associated slot variable is created or
+ restored.
+
+ Args:
+ slot_variable_position: A `checkpointable._CheckpointPosition` object
+ indicating the slot variable `Checkpointable` object to be restored.
+ slot_name: The name of this `Optimizer`'s slot to restore into.
+ variable: The variable object this slot is being created for.
+ """
+ state = self._get_or_create_state(var_list=[variable])
+ state._create_or_restore_slot_variable( # pylint: disable=protected-access
+ slot_variable_position=slot_variable_position,
+ slot_name=slot_name,
+ variable=variable,
+ optional_op_name=self._name)
+
+ # --------------
+ # Unsupported parent methods
+ # --------------
+ def _slot_dict(self, slot_name):
+ raise NotImplementedError(
+ "_slot_dict() method unsupported in OptimizerV2")
+
+ def _get_or_make_slot(self, var, val, slot_name, op_name):
+ raise NotImplementedError(
+ "_get_or_make_slot() method unsupported in OptimizerV2")
+
+ def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype,
+ slot_name, op_name):
+ raise NotImplementedError(
+ "_get_or_make_slot_with_initializer() method unsupported in "
+ "OptimizerV2")
+
+ def _create_non_slot_variable(self, initial_value, name, colocate_with):
+ raise NotImplementedError(
+ "_create_non_slot_variable() method unsupported in OptimizerV2")
+
+ def _get_non_slot_variable(self, name, graph=None):
+ raise NotImplementedError(
+ "_get_non_slot_variable() method unsupported in OptimizerV2")
+
+ def _non_slot_variables(self):
+ raise NotImplementedError(
+ "_non_slot_variables() method unsupported in OptimizerV2")
diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
new file mode 100644
index 0000000000..a6c939393e
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
@@ -0,0 +1,277 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional test for OptimizerV2."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.keras.optimizer_v2 import sgd
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class OptimizerTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testBasic(self):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ def loss():
+ return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop
+ # Note that for eager execution, minimize expects a function instead of a
+ # Tensor.
+ global_step = resource_variable_ops.ResourceVariable(
+ array_ops.zeros([], dtypes.int64), name='global_step_%d' % i)
+ sgd_op = sgd.SGD(3.0)
+
+ self.evaluate(variables.global_variables_initializer())
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+ # Run 1 step of sgd through optimizer
+ opt_op = sgd_op.minimize(loss, global_step, [var0, var1])
+ self.evaluate(opt_op)
+ # Validate updated params
+ self.assertAllClose([-14., -13.], self.evaluate(var0))
+ self.assertAllClose([-6., -5.], self.evaluate(var1))
+
+ def testAggregationMethod(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ cost = 5 * var0 + 3 * var1
+ global_step = variables.Variable(
+ array_ops.zeros([], dtypes.int64), name='global_step')
+ sgd_op = sgd.SGD(3.0)
+ opt_op = sgd_op.minimize(
+ cost,
+ global_step, [var0, var1],
+ aggregation_method=gradients_impl.AggregationMethod.
+ EXPERIMENTAL_ACCUMULATE_N)
+
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd through optimizer
+ opt_op.run()
+ # Validate updated params
+ self.assertAllClose([-14., -13.], var0.eval())
+ self.assertAllClose([-6., -5.], var1.eval())
+
+ def testPrecomputedGradient(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ cost = 5 * var0 + 3 * var1
+ grad_loss = constant_op.constant([42, -42], dtype=dtype)
+ global_step = variables.Variable(
+ array_ops.zeros([], dtypes.int64), name='global_step')
+ sgd_op = sgd.SGD(3.0)
+ opt_op = sgd_op.minimize(
+ cost, global_step, [var0, var1], grad_loss=grad_loss)
+
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd through optimizer
+ opt_op.run()
+ # Validate updated params
+ self.assertAllClose([1.0 - 3 * 5 * 42.0, 2.0 - 3 * 5 * (-42.0)],
+ var0.eval())
+ self.assertAllClose([3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)],
+ var1.eval())
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNoVariables(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ # pylint: disable=cell-var-from-loop
+ def loss():
+ var0 = resource_variable_ops.ResourceVariable(
+ [1.0, 2.0], dtype=dtype, trainable=False, name='a')
+ var1 = resource_variable_ops.ResourceVariable(
+ [3.0, 4.0], dtype=dtype, trainable=False, name='b')
+ return 5 * var0 + var1
+ # pylint: enable=cell-var-from-loop
+ sgd_op = sgd.SGD(3.0)
+ with self.assertRaisesRegexp(ValueError, 'No.*variables'):
+ sgd_op.minimize(loss)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNoGradients(self):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ # pylint: disable=cell-var-from-loop
+ def loss():
+ return 5 * var0
+ # pylint: enable=cell-var-from-loop
+ sgd_op = sgd.SGD(3.0)
+ with self.assertRaisesRegexp(ValueError, 'No gradients'):
+ # var1 has no gradient
+ sgd_op.minimize(loss, var_list=[var1])
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNoGradientsForAnyVariables_Minimize(self):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ def loss():
+ return constant_op.constant(5.0)
+
+ sgd_op = sgd.SGD(3.0)
+ with self.assertRaisesRegexp(ValueError,
+ 'No gradients provided for any variable'):
+ sgd_op.minimize(loss, var_list=[var0, var1])
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNoGradientsForAnyVariables_ApplyGradients(self):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ sgd_op = sgd.SGD(3.0)
+ with self.assertRaisesRegexp(ValueError,
+ 'No gradients provided for any variable'):
+ sgd_op.apply_gradients([(None, var0), (None, var1)])
+
+ @test_util.run_in_graph_and_eager_modes
+ def testGradientsAsVariables(self):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ def loss():
+ return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop
+
+ sgd_op = sgd.SGD(3.0)
+ grads_and_vars = sgd_op.compute_gradients(loss, [var0, var1])
+ # Convert gradients to tf.Variables
+ converted_grads = [
+ resource_variable_ops.ResourceVariable(array_ops.zeros([2], dtype),
+ name='c_%d_%d' % (i, j))
+ for j, gv in enumerate(grads_and_vars)
+ ]
+ convert_ops = [
+ state_ops.assign(converted_grads[j], gv[0])
+ for j, gv in enumerate(grads_and_vars)
+ ]
+
+ self.evaluate(variables.global_variables_initializer())
+ # Run convert_ops to achieve the gradietns converting
+ self.evaluate(convert_ops)
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+
+ # Run 1 step of sgd through optimizer
+ converted_grads_and_vars = list(zip(converted_grads, [var0, var1]))
+ opt_op = sgd_op.apply_gradients(converted_grads_and_vars)
+ self.evaluate(opt_op)
+
+ # Validate updated params
+ self.assertAllClose([-14., -13.], self.evaluate(var0))
+ self.assertAllClose([-6., -5.], self.evaluate(var1))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testComputeGradientsWithTensors(self):
+ x = ops.convert_to_tensor(1.0)
+ def f():
+ return x * x
+
+ sgd_op = sgd.SGD(3.0)
+ grads_and_vars = sgd_op.compute_gradients(f, [x])
+ self.assertEqual(1, len(grads_and_vars))
+ grad, x_as_var = grads_and_vars[0]
+ self.assertIs(x, x_as_var)
+ self.assertEqual(2.0, self.evaluate(grad))
+
+ with self.assertRaises(NotImplementedError):
+ sgd_op.apply_gradients(grads_and_vars)
+
+ def testTrainOp(self):
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0])
+ var1 = variables.Variable([3.0, 4.0])
+ cost = 5 * var0 + 3 * var1
+ global_step = variables.Variable(
+ array_ops.zeros([], dtypes.int64), name='global_step')
+ sgd_op = sgd.SGD(3.0)
+ opt_op = sgd_op.minimize(cost, global_step, [var0, var1])
+ self.assertTrue(opt_op in ops.get_collection(ops.GraphKeys.TRAIN_OP))
+
+ def testConstraint(self):
+ constraint_01 = lambda x: clip_ops.clip_by_value(x, -0.1, 0.)
+ constraint_0 = lambda x: clip_ops.clip_by_value(x, 0., 1.)
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0],
+ constraint=constraint_01)
+ var1 = variables.Variable([3.0, 4.0],
+ constraint=constraint_0)
+ cost = 5 * var0 + 3 * var1
+ global_step = variables.Variable(
+ array_ops.zeros([], dtypes.int64), name='global_step')
+ sgd_op = sgd.SGD(3.0)
+ opt_op = sgd_op.minimize(cost, global_step, [var0, var1])
+
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd through optimizer
+ opt_op.run()
+ # Validate updated params
+ self.assertAllClose([-0.1, -0.1], var0.eval())
+ self.assertAllClose([0., 0.], var1.eval())
+
+ def testStopGradients(self):
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], name='var0')
+ var1 = variables.Variable([3.0, 4.0], name='var1')
+ var0_id = array_ops.identity(var0)
+ cost = 5 * var0_id + 3 * var1
+ sgd_op = sgd.SGD(3.0)
+ grads_and_vars = sgd_op.compute_gradients(cost, [var0, var1],
+ stop_gradients=[var0_id])
+ grad_dict = {var.op.name: grad for grad, var in grads_and_vars}
+ self.assertIsNone(grad_dict['var0'])
+ self.assertIsNotNone(grad_dict['var1'])
+
+ def testDoNotOverrideCreateSlots(self):
+ class ShouldNotOverrideCreateSlots(optimizer_v2.OptimizerV2):
+
+ def _create_slots(self, var_list):
+ """In OptimizerV2 _create_slots was renamed _create_vars."""
+ return var_list
+
+ with self.assertRaises(RuntimeError):
+ ShouldNotOverrideCreateSlots('name')
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/keras/optimizer_v2/rmsprop.py b/tensorflow/python/keras/optimizer_v2/rmsprop.py
new file mode 100644
index 0000000000..2748d8eff7
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/rmsprop.py
@@ -0,0 +1,239 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""RMSprop optimizer for Tensorflow.
+
+rmsprop algorithm [tieleman2012rmsprop]
+
+A detailed description of rmsprop.
+
+- maintain a moving (discounted) average of the square of gradients
+- divide gradient by the root of this average
+
+mean_square = rho * mean_square{t-1} + (1-rho) * gradient ** 2
+mom = momentum * mom{t-1} + learning_rate * g_t / sqrt(mean_square)
+delta = - mom
+
+This implementation of RMSProp uses plain momentum, not Nesterov momentum.
+
+The centered version additionally maintains a moving (discounted) average of the
+gradients, and uses that average to estimate the variance:
+
+mean_grad = rho * mean_square{t-1} + (1-rho) * gradient
+mean_square = rho * mean_square{t-1} + (1-rho) * gradient ** 2
+mom = momentum * mom{t-1} + learning_rate * g_t /
+ sqrt(mean_square - mean_grad**2)
+delta = - mom
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.ops import array_ops
+
+from tensorflow.python.training import training_ops
+
+
+class RMSProp(optimizer_v2.OptimizerV2):
+ """RMSProp optimizer.
+
+ It is recommended to leave the parameters of this optimizer at their default
+ values (except the learning rate, which can be freely tuned).
+
+ This optimizer is usually a good choice for recurrent neural networks.
+
+ Some of the args below are hyperparameters, where a hyperparameter is
+ defined as a scalar Tensor, a regular Python value, or a callable (which
+ will be evaluated when `apply_gradients` is called) returning a scalar
+ Tensor or a Python value.
+
+ Note that in the dense implementation of this algorithm, variables and their
+ corresponding accumulators (momentum, gradient moving average, square
+ gradient moving average) will be updated even if the gradient is zero
+ (i.e. accumulators will decay, momentum will be applied). The sparse
+ implementation (used when the gradient is an `IndexedSlices` object,
+ typically because of `tf.gather` or an embedding lookup in the forward pass)
+ will not update variable slices or their accumulators unless those slices
+ were used in the forward pass (nor is there an "eventual" correction to
+ account for these omitted updates). This leads to more efficient updates for
+ large embedding lookup tables (where most of the slices are not accessed in
+ a particular graph execution), but differs from the published algorithm.
+
+ Arguments:
+ learning_rate: A float hyperparameter >= 0. The learning rate.
+ rho: A float hyperparameter >= 0. Discounting factor for the
+ history/coming gradient.
+ momentum: A float hyperparameter >= 0.
+ epsilon: A float hyperparameter >= 0 . Small value to initialize the
+ average square gradient variable and avoid zero denominator.
+ centered: If True, gradients are normalized by the estimated variance of
+ the gradient; if False, by the uncentered second moment. Setting this to
+ True may help with training, but is slightly more expensive in terms of
+ computation and memory. Defaults to False.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to "RMSProp".
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ rho=0.9,
+ momentum=None,
+ epsilon=1e-10,
+ centered=False,
+ name="RMSProp"):
+ super(RMSProp, self).__init__(name)
+ # Momentum default is `None` for consistency with SGD
+ # but underlying implementation uses `momentum` hyperparameter here
+ # regardless unlike SGD. Since extneral Keras RMSProp does not have
+ # a `momentum` weight, for compatibility with external Keras h5 files,
+ # when `momentum` was set as `None` we should ignore the `momentum`
+ # variable in `get_weights` and not require it in `set_weights`.
+ if momentum is None:
+ momentum = 0.0
+ self._set_hyper("learning_rate", learning_rate)
+ self._set_hyper("rho", rho)
+ self._set_hyper("momentum", momentum)
+ self._set_hyper("epsilon", epsilon)
+
+ self._centered = centered
+
+ def _create_vars(self, var_list, state):
+ for v in var_list:
+ init_rms = state.get_hyper(
+ "epsilon", v.dtype.base_dtype) * array_ops.ones_like(v)
+ state.create_slot_with_initializer(v, init_rms, v.get_shape(),
+ v.dtype.base_dtype, "rms")
+ if self._centered:
+ state.zeros_slot(v, "mg")
+ state.zeros_slot(v, "momentum")
+
+ def _apply_dense(self, grad, var, state):
+ rms = state.get_slot(var, "rms")
+ mom = state.get_slot(var, "momentum")
+ if self._centered:
+ mg = state.get_slot(var, "mg")
+ return training_ops.apply_centered_rms_prop(
+ var,
+ mg,
+ rms,
+ mom,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ # epsilon is now the rms initial value and is not added to the
+ # denominator anymore, hence calling the kernel op with epsilon=0.
+ 0,
+ grad,
+ use_locking=self._use_locking).op
+ else:
+ return training_ops.apply_rms_prop(
+ var,
+ rms,
+ mom,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad,
+ use_locking=self._use_locking).op
+
+ def _resource_apply_dense(self, grad, var, state):
+ rms = state.get_slot(var, "rms")
+ mom = state.get_slot(var, "momentum")
+ if self._centered:
+ mg = state.get_slot(var, "mg")
+ return training_ops.resource_apply_centered_rms_prop(
+ var.handle,
+ mg.handle,
+ rms.handle,
+ mom.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad,
+ use_locking=self._use_locking)
+ else:
+ return training_ops.resource_apply_rms_prop(
+ var.handle,
+ rms.handle,
+ mom.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad,
+ use_locking=self._use_locking)
+
+ def _apply_sparse(self, grad, var, state):
+ rms = state.get_slot(var, "rms")
+ mom = state.get_slot(var, "momentum")
+ if self._centered:
+ mg = state.get_slot(var, "mg")
+ return training_ops.sparse_apply_centered_rms_prop(
+ var,
+ mg,
+ rms,
+ mom,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad.values,
+ grad.indices,
+ use_locking=self._use_locking)
+ else:
+ return training_ops.sparse_apply_rms_prop(
+ var,
+ rms,
+ mom,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad.values,
+ grad.indices,
+ use_locking=self._use_locking)
+
+ def _resource_apply_sparse(self, grad, var, indices, state):
+ rms = state.get_slot(var, "rms")
+ mom = state.get_slot(var, "momentum")
+ if self._centered:
+ mg = self.get_slot(var, "mg")
+ return training_ops.resource_sparse_apply_centered_rms_prop(
+ var.handle,
+ mg.handle,
+ rms.handle,
+ mom.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad,
+ indices,
+ use_locking=self._use_locking)
+ else:
+ return training_ops.resource_sparse_apply_rms_prop(
+ var.handle,
+ rms.handle,
+ mom.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad,
+ indices,
+ use_locking=self._use_locking)
diff --git a/tensorflow/python/keras/optimizer_v2/rmsprop_test.py b/tensorflow/python/keras/optimizer_v2/rmsprop_test.py
new file mode 100644
index 0000000000..2c5eccdc5b
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/rmsprop_test.py
@@ -0,0 +1,444 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for rmsprop optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+import math
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.optimizer_v2 import rmsprop
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+_DATA_TYPES = [dtypes.half, dtypes.float32]
+
+_TEST_PARAM_VALUES = [
+ # learning_rate, rho, momentum, epsilon, centered, use_resource
+ [0.5, 0.9, 0.0, 1.0, True, False],
+ [0.5, 0.9, 0.0, 1.0, False, False],
+ [0.5, 0.9, 0.0, 1.0, True, True],
+ [0.5, 0.9, 0.0, 1.0, False, True],
+ [0.1, 0.9, 0.0, 1.0, True, False],
+ [0.5, 0.95, 0.0, 1.0, False, False],
+ [0.5, 0.8, 0.0, 1e-3, True, False],
+ [0.5, 0.8, 0.9, 1e-3, True, False],
+]
+
+
+class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase):
+
+ def _rmsprop_update_numpy(self, var, g, mg, rms, mom, lr, rho, momentum,
+ centered):
+ rms_t = rms * rho + (1 - rho) * g * g
+ if centered:
+ mg_t = mg * rho + (1 - rho) * g
+ denom_t = rms_t - mg_t * mg_t
+ else:
+ mg_t = mg
+ denom_t = rms_t
+ mom_t = momentum * mom + lr * g / np.sqrt(denom_t, dtype=denom_t.dtype)
+ var_t = var - mom_t
+ return var_t, mg_t, rms_t, mom_t
+
+ def _sparse_rmsprop_update_numpy(self, var, gindexs, gvalues, mg, rms, mom,
+ lr, rho, momentum, centered):
+ mg_t = copy.deepcopy(mg)
+ rms_t = copy.deepcopy(rms)
+ mom_t = copy.deepcopy(mom)
+ var_t = copy.deepcopy(var)
+ for i in range(len(gindexs)):
+ gindex = gindexs[i]
+ gvalue = gvalues[i]
+ rms_t[gindex] = rms[gindex] * rho + (1 - rho) * gvalue * gvalue
+ denom_t = rms_t[gindex]
+ if centered:
+ mg_t[gindex] = mg_t[gindex] * rho + (1 - rho) * gvalue
+ denom_t -= mg_t[gindex] * mg_t[gindex]
+ mom_t[gindex] = momentum * mom[gindex] + lr * gvalue / np.sqrt(denom_t)
+ var_t[gindex] = var[gindex] - mom_t[gindex]
+ return var_t, mg_t, rms_t, mom_t
+
+ @parameterized.named_parameters(
+ *test_util.generate_combinations_with_testcase_name(
+ dtype=_DATA_TYPES, param_value=_TEST_PARAM_VALUES))
+ def testDense(self, dtype, param_value):
+ (learning_rate, rho, momentum, epsilon, centered,
+ use_resource) = tuple(param_value)
+ with self.test_session(use_gpu=True):
+ # Initialize variables for numpy implementation.
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.2], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.2], dtype=dtype.as_numpy_dtype)
+
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ else:
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ opt = rmsprop.RMSProp(
+ learning_rate=learning_rate,
+ rho=rho,
+ momentum=momentum,
+ epsilon=epsilon,
+ centered=centered)
+
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ mg0 = opt.get_slot(var0, "mg")
+ self.assertEqual(mg0 is not None, centered)
+ mg1 = opt.get_slot(var1, "mg")
+ self.assertEqual(mg1 is not None, centered)
+ rms0 = opt.get_slot(var0, "rms")
+ self.assertIsNotNone(rms0)
+ rms1 = opt.get_slot(var1, "rms")
+ self.assertIsNotNone(rms1)
+ mom0 = opt.get_slot(var0, "momentum")
+ self.assertIsNotNone(mom0)
+ mom1 = opt.get_slot(var1, "momentum")
+ self.assertIsNotNone(mom1)
+
+ mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ rms0_np = np.array([epsilon, epsilon], dtype=dtype.as_numpy_dtype)
+ rms1_np = np.array([epsilon, epsilon], dtype=dtype.as_numpy_dtype)
+ mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 4 steps of RMSProp
+ for _ in range(4):
+ update.run()
+
+ var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(
+ var0_np, grads0_np, mg0_np, rms0_np, mom0_np, learning_rate, rho,
+ momentum, centered)
+ var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy(
+ var1_np, grads1_np, mg1_np, rms1_np, mom1_np, learning_rate, rho,
+ momentum, centered)
+
+ # Validate updated params
+ if centered:
+ self.assertAllCloseAccordingToType(mg0_np, mg0.eval())
+ self.assertAllCloseAccordingToType(mg1_np, mg1.eval())
+ self.assertAllCloseAccordingToType(rms0_np, rms0.eval())
+ self.assertAllCloseAccordingToType(rms1_np, rms1.eval())
+ self.assertAllCloseAccordingToType(mom0_np, mom0.eval())
+ self.assertAllCloseAccordingToType(mom1_np, mom1.eval())
+ self.assertAllCloseAccordingToType(
+ var0_np, var0.eval(), half_rtol=0.01, half_atol=0.01)
+ self.assertAllCloseAccordingToType(
+ var1_np, var1.eval(), half_rtol=0.01, half_atol=0.01)
+
+ @parameterized.parameters([dtypes.float32, dtypes.float64])
+ def testMinimizeSparseResourceVariable(self, dtype):
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
+ loss = pred * pred
+ sgd_op = rmsprop.RMSProp(
+ learning_rate=1.0, rho=0.0, momentum=0.0, epsilon=0.0,
+ centered=False).minimize(loss)
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ [[0., 1.]], var0.eval(), atol=0.01)
+
+ @parameterized.parameters([dtypes.float32, dtypes.float64])
+ def testMinimizeSparseResourceVariableCentered(self, dtype):
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
+ loss = pred * pred
+ sgd_op = rmsprop.RMSProp(
+ learning_rate=1.0, rho=0.1, momentum=0.0, epsilon=1.0,
+ centered=True).minimize(loss)
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ [[-7/3.0, -4/3.0]], var0.eval(), atol=0.01)
+
+ @parameterized.named_parameters(
+ *test_util.generate_combinations_with_testcase_name(
+ dtype=_DATA_TYPES, param_value=_TEST_PARAM_VALUES))
+ def testSparse(self, dtype, param_value):
+ (learning_rate, rho, momentum, epsilon, centered, _) = tuple(param_value)
+ with self.test_session(use_gpu=True):
+ # Initialize variables for numpy implementation.
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0_np_indices = np.array([0], dtype=np.int32)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant(grads0_np),
+ constant_op.constant(grads0_np_indices), constant_op.constant([1]))
+ grads1_np_indices = np.array([1], dtype=np.int32)
+ grads1 = ops.IndexedSlices(
+ constant_op.constant(grads1_np),
+ constant_op.constant(grads1_np_indices), constant_op.constant([1]))
+ opt = rmsprop.RMSProp(
+ learning_rate=learning_rate,
+ rho=rho,
+ momentum=momentum,
+ epsilon=epsilon,
+ centered=centered)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ mg0 = opt.get_slot(var0, "mg")
+ self.assertEqual(mg0 is not None, centered)
+ mg1 = opt.get_slot(var1, "mg")
+ self.assertEqual(mg1 is not None, centered)
+ rms0 = opt.get_slot(var0, "rms")
+ self.assertIsNotNone(rms0)
+ rms1 = opt.get_slot(var1, "rms")
+ self.assertIsNotNone(rms1)
+ mom0 = opt.get_slot(var0, "momentum")
+ self.assertIsNotNone(mom0)
+ mom1 = opt.get_slot(var1, "momentum")
+ self.assertIsNotNone(mom1)
+
+ mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ rms0_np = np.array([epsilon, epsilon], dtype=dtype.as_numpy_dtype)
+ rms1_np = np.array([epsilon, epsilon], dtype=dtype.as_numpy_dtype)
+ mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 4 steps of RMSProp
+ for _ in range(4):
+ update.run()
+
+ var0_np, mg0_np, rms0_np, mom0_np = self._sparse_rmsprop_update_numpy(
+ var0_np, grads0_np_indices, grads0_np, mg0_np, rms0_np, mom0_np,
+ learning_rate, rho, momentum, centered)
+ var1_np, mg1_np, rms1_np, mom1_np = self._sparse_rmsprop_update_numpy(
+ var1_np, grads1_np_indices, grads1_np, mg1_np, rms1_np, mom1_np,
+ learning_rate, rho, momentum, centered)
+
+ # Validate updated params
+ if centered:
+ self.assertAllCloseAccordingToType(mg0_np, mg0.eval())
+ self.assertAllCloseAccordingToType(mg1_np, mg1.eval())
+ self.assertAllCloseAccordingToType(rms0_np, rms0.eval())
+ self.assertAllCloseAccordingToType(rms1_np, rms1.eval())
+ self.assertAllCloseAccordingToType(mom0_np, mom0.eval())
+ self.assertAllCloseAccordingToType(mom1_np, mom1.eval())
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ @parameterized.parameters(_DATA_TYPES)
+ def testWithoutMomentum(self, dtype):
+ with self.test_session(use_gpu=True):
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ opt = rmsprop.RMSProp(
+ learning_rate=2.0, rho=0.9, momentum=0.0, epsilon=1.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ rms0 = opt.get_slot(var0, "rms")
+ self.assertIsNotNone(rms0)
+ rms1 = opt.get_slot(var1, "rms")
+ self.assertIsNotNone(rms1)
+ mom0 = opt.get_slot(var0, "momentum")
+ self.assertIsNotNone(mom0)
+ mom1 = opt.get_slot(var1, "momentum")
+ self.assertIsNotNone(mom1)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Step 1: the rms accumulators where 1. So we should see a normal
+ # update: v -= grad * learning_rate
+ update.run()
+ # Check the root mean square accumulators.
+ self.assertAllCloseAccordingToType(
+ np.array([0.901, 0.901]), rms0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([0.90001, 0.90001]), rms1.eval())
+ # Check the parameters.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0 / math.sqrt(0.901)),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901))
+ ]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 3.0 - (0.01 * 2.0 / math.sqrt(0.90001)),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001))
+ ]), var1.eval())
+ # Step 2: the root mean square accumulators contain the previous update.
+ update.run()
+ # Check the rms accumulators.
+ self.assertAllCloseAccordingToType(
+ np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]), rms0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), rms1.eval())
+ # Check the parameters.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0 / math.sqrt(0.901)) -
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001)),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901)) -
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001))
+ ]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 3.0 - (0.01 * 2.0 / math.sqrt(0.90001)) -
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5)),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001)) -
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5))
+ ]), var1.eval())
+
+ @parameterized.parameters(_DATA_TYPES)
+ def testWithMomentum(self, dtype):
+ with self.test_session(use_gpu=True):
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+
+ opt = rmsprop.RMSProp(
+ learning_rate=2.0, rho=0.9, momentum=0.5, epsilon=1.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ rms0 = opt.get_slot(var0, "rms")
+ self.assertIsNotNone(rms0)
+ rms1 = opt.get_slot(var1, "rms")
+ self.assertIsNotNone(rms1)
+ mom0 = opt.get_slot(var0, "momentum")
+ self.assertIsNotNone(mom0)
+ mom1 = opt.get_slot(var1, "momentum")
+ self.assertIsNotNone(mom1)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Step 1: rms = 1, mom = 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ update.run()
+ # Check the root mean square accumulators.
+ self.assertAllCloseAccordingToType(
+ np.array([0.901, 0.901]), rms0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([0.90001, 0.90001]), rms1.eval())
+ # Check the momentum accumulators
+ self.assertAllCloseAccordingToType(
+ np.array([(0.1 * 2.0 / math.sqrt(0.901)),
+ (0.1 * 2.0 / math.sqrt(0.901))]), mom0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([(0.01 * 2.0 / math.sqrt(0.90001)),
+ (0.01 * 2.0 / math.sqrt(0.90001))]), mom1.eval())
+
+ # Check that the parameters.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0 / math.sqrt(0.901)),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901))
+ ]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 3.0 - (0.01 * 2.0 / math.sqrt(0.90001)),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001))
+ ]), var1.eval())
+
+ # Step 2: the root mean square accumulators contain the previous update.
+ update.run()
+ # Check the rms accumulators.
+ self.assertAllCloseAccordingToType(
+ np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]), rms0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), rms1.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 0.5 * (0.1 * 2.0 / math.sqrt(0.901)) +
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001)),
+ 0.5 * (0.1 * 2.0 / math.sqrt(0.901)) +
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001))
+ ]), mom0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 0.5 * (0.01 * 2.0 / math.sqrt(0.90001)) +
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5)),
+ 0.5 * (0.01 * 2.0 / math.sqrt(0.90001)) +
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5))
+ ]), mom1.eval())
+
+ # Check the parameters.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0 / math.sqrt(0.901)) -
+ (0.5 * (0.1 * 2.0 / math.sqrt(0.901)) +
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001))),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901)) -
+ (0.5 * (0.1 * 2.0 / math.sqrt(0.901)) +
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001)))
+ ]), var0.eval())
+
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 3.0 - (0.01 * 2.0 / math.sqrt(0.90001)) -
+ (0.5 * (0.01 * 2.0 / math.sqrt(0.90001)) +
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5))),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001)) -
+ (0.5 * (0.01 * 2.0 / math.sqrt(0.90001)) +
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5)))
+ ]), var1.eval())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/keras/optimizer_v2/sgd.py b/tensorflow/python/keras/optimizer_v2/sgd.py
new file mode 100644
index 0000000000..f5583691f7
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/sgd.py
@@ -0,0 +1,170 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Momentum for TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.training import training_ops
+
+
+class SGD(optimizer_v2.OptimizerV2):
+ """Stochastic gradient descent optimizer.
+
+ Includes support for momentum and Nesterov momentum.
+
+ Computes (if `nesterov = False`):
+
+ ```
+ accumulation = momentum * accumulation + gradient
+ variable -= learning_rate * accumulation
+ ```
+
+ Some of the args below are hyperparameters, where a hyperparameter is
+ defined as a scalar Tensor, a regular Python value, or a callable (which
+ will be evaluated when `apply_gradients` is called) returning a scalar
+ Tensor or a Python value.
+
+ Note that in the dense version of this algorithm, `accumulation` is updated
+ and applied regardless of a gradient's value, whereas the sparse version (when
+ the gradient is an `IndexedSlices`, typically because of `tf.gather` or an
+ embedding) only updates variable slices and corresponding `accumulation` terms
+ when that part of the variable was used in the forward pass.
+
+ @compatibility(eager)
+ When eager execution is enabled, learning_rate and momentum can each be a
+ callable that takes no arguments and returns the actual value to use. This
+ can be useful for changing these values across different invocations of
+ optimizer functions.
+ @end_compatibility
+
+ Arguments:
+ learning_rate: float hyperparameter >= 0. Learning rate.
+ momentum: float hyperparameter >= 0 or None. Parameter that accelerates
+ SGD in the relevant direction and dampens oscillations.
+ nesterov: boolean. Whether to apply Nesterov momentum. See [Sutskever et
+ al., 2013](http://jmlr.org/proceedings/papers/v28/sutskever13.pdf). This
+ implementation always computes gradients at the value of the
+ variable(s) passed to the optimizer. Using Nesterov Momentum makes the
+ variable(s) track the values called `theta_t + mu*v_t` in the paper.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to 'SGD'.
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ momentum=None,
+ nesterov=False,
+ name="SGD"):
+ super(SGD, self).__init__(name)
+ self._set_hyper("learning_rate", learning_rate)
+ # Only create momentum variables and use momentum ops if needed.
+ if momentum is not None:
+ self._set_hyper("momentum", momentum)
+ self._use_nesterov = nesterov
+ self._use_momentum = True
+ else:
+ self._use_momentum = False
+
+ def _create_vars(self, var_list, state):
+ if self._use_momentum:
+ for v in var_list:
+ state.zeros_slot(v, "momentum")
+
+ def _apply_dense(self, grad, var, state):
+ if self._use_momentum:
+ mom = state.get_slot(var, "momentum")
+ return training_ops.apply_momentum(
+ var,
+ mom,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ use_locking=self._use_locking,
+ use_nesterov=self._use_nesterov).op
+ else:
+ return training_ops.apply_gradient_descent(
+ var,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking).op
+
+ def _resource_apply_dense(self, grad, var, state):
+ if self._use_momentum:
+ mom = state.get_slot(var, "momentum")
+ return training_ops.resource_apply_momentum(
+ var.handle,
+ mom.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ use_locking=self._use_locking,
+ use_nesterov=self._use_nesterov)
+ else:
+ lr = state.get_hyper("learning_rate", grad.dtype.base_dtype)
+ return training_ops.resource_apply_gradient_descent(
+ var.handle, lr, grad, use_locking=self._use_locking)
+
+ def _apply_sparse(self, grad, var, state):
+ if self._use_momentum:
+ mom = state.get_slot(var, "momentum")
+ return training_ops.sparse_apply_momentum(
+ var,
+ mom,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad.values,
+ grad.indices,
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ use_locking=self._use_locking,
+ use_nesterov=self._use_nesterov).op
+ else:
+ return super(SGD, self)._apply_sparse(grad, var, state)
+
+ def _resource_apply_sparse(self, grad, var, indices, state):
+ if self._use_momentum:
+ mom = state.get_slot(var, "momentum")
+ return training_ops.resource_sparse_apply_momentum(
+ var.handle,
+ mom.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ indices,
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ use_locking=self._use_locking,
+ use_nesterov=self._use_nesterov)
+ else:
+ return super(SGD, self)._resource_apply_sparse(grad, var, indices, state)
+
+ def _resource_apply_sparse_duplicate_indices(self, grad, var, indices, state):
+ if self._use_momentum:
+ return super(SGD, self)._resource_apply_sparse_duplicate_indices(
+ grad, var, indices, state)
+ else:
+ lr = state.get_hyper("learning_rate", grad.dtype.base_dtype)
+ return resource_variable_ops.resource_scatter_add(var.handle, indices,
+ -grad * lr)
+
+ def _apply_sparse_duplicate_indices(self, grad, var, state):
+ if self._use_momentum:
+ return super(SGD, self)._apply_sparse_duplicate_indices(grad, var, state)
+ else:
+ delta = ops.IndexedSlices(
+ grad.values * state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad.indices, grad.dense_shape)
+ return var.scatter_sub(delta, use_locking=self._use_locking)
diff --git a/tensorflow/python/keras/optimizer_v2/sgd_test.py b/tensorflow/python/keras/optimizer_v2/sgd_test.py
new file mode 100644
index 0000000000..eb39aac283
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/sgd_test.py
@@ -0,0 +1,759 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Momentum."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.optimizer_v2 import sgd
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import resources
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class GradientDescentOptimizerTest(test.TestCase):
+
+ def testBasic(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ optimizer = sgd.SGD(3.0)
+ sgd_op = optimizer.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1],
+ var0.eval())
+ self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
+ var1.eval())
+ self.assertEqual(0, len(optimizer.variables()))
+
+ def testBasicResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ sgd_op = sgd.SGD(3.0).apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ # TODO(apassos) calling initialize_resources on all resources here
+ # doesn't work because the sessions and graph are reused across unit
+ # tests and this would mean trying to reinitialize variables. Figure out
+ # a long-term solution for this.
+ resources.initialize_resources([var0, var1]).run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1],
+ var0.eval())
+ self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
+ var1.eval())
+
+ def testMinimizeResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0], dtype=dtype)
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(var0, x) + var1
+ loss = pred * pred
+ sgd_op = sgd.SGD(1.0).minimize(loss)
+ # TODO(apassos) calling initialize_resources on all resources here
+ # doesn't work because the sessions and graph are reused across unit
+ # tests and this would mean trying to reinitialize variables. Figure out
+ # a long-term solution for this.
+ resources.initialize_resources([var0, var1]).run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
+ self.assertAllCloseAccordingToType([3.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ np_pred = 1.0 * 4.0 + 2.0 * 5.0 + 3.0
+ np_grad = 2 * np_pred
+ self.assertAllCloseAccordingToType(
+ [[1.0 - np_grad * 4.0, 2.0 - np_grad * 5.0]], var0.eval())
+ self.assertAllCloseAccordingToType([3.0 - np_grad], var1.eval())
+
+ def testMinimizeSparseResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0], dtype=dtype)
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
+ pred += var1
+ loss = pred * pred
+ sgd_op = sgd.SGD(1.0).minimize(loss)
+ # TODO(apassos) calling initialize_resources on all resources here
+ # doesn't work because the sessions and graph are reused across unit
+ # tests and this would mean trying to reinitialize variables. Figure out
+ # a long-term solution for this.
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
+ self.assertAllCloseAccordingToType([3.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ np_pred = 1.0 * 4.0 + 2.0 * 5.0 + 3.0
+ np_grad = 2 * np_pred
+ self.assertAllCloseAccordingToType(
+ [[1.0 - np_grad * 4.0, 2.0 - np_grad * 5.0]], var0.eval())
+ self.assertAllCloseAccordingToType([3.0 - np_grad], var1.eval())
+
+ def testTensorLearningRate(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ lrate = constant_op.constant(3.0)
+ sgd_op = sgd.SGD(lrate).apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1],
+ var0.eval())
+ self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
+ var1.eval())
+
+ def testGradWrtRef(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ opt = sgd.SGD(3.0)
+ values = [1.0, 3.0]
+ vars_ = [variables.Variable([v], dtype=dtype) for v in values]
+ grads_and_vars = opt.compute_gradients(vars_[0] + vars_[1], vars_)
+ variables.global_variables_initializer().run()
+ for grad, _ in grads_and_vars:
+ self.assertAllCloseAccordingToType([1.0], grad.eval())
+
+ def testWithGlobalStep(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ global_step = variables.Variable(0, trainable=False)
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ sgd_op = sgd.SGD(3.0).apply_gradients(
+ zip([grads0, grads1], [var0, var1]), global_step=global_step)
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params and global_step
+ self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1],
+ var0.eval())
+ self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
+ var1.eval())
+ self.assertAllCloseAccordingToType(1, global_step.eval())
+
+ def testSparseBasic(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
+ var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant([0.1], shape=[1, 1], dtype=dtype),
+ constant_op.constant([0]), constant_op.constant([2, 1]))
+ grads1 = ops.IndexedSlices(
+ constant_op.constant([0.01], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]), constant_op.constant([2, 1]))
+ sgd_op = sgd.SGD(3.0).apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([[1.0], [2.0]], var0.eval())
+ self.assertAllCloseAccordingToType([[3.0], [4.0]], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType([[1.0 - 3.0 * 0.1], [2.0]],
+ var0.eval())
+ self.assertAllCloseAccordingToType([[3.0], [4.0 - 3.0 * 0.01]],
+ var1.eval())
+
+
+if __name__ == "__main__":
+ test.main()
+
+
+class MomentumOptimizerTest(test.TestCase):
+
+ def _update_nesterov_momentum_numpy(self, var, accum, g, lr, momentum):
+ var = var + accum * lr * momentum
+ accum = accum * momentum + g
+ var = var - lr * accum
+ var = var - accum * lr * momentum
+ return var, accum
+
+ def doTestBasic(self, use_resource=False, use_callable_params=False):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(
+ [1.0, 2.0], dtype=dtype, name="var0_%d" % i)
+ var1 = resource_variable_ops.ResourceVariable(
+ [3.0, 4.0], dtype=dtype, name="var1_%d" % i)
+ else:
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ learning_rate = lambda: 2.0
+ momentum = lambda: 0.9
+ if not use_callable_params:
+ learning_rate = learning_rate()
+ momentum = momentum()
+ mom_opt = sgd.SGD(learning_rate=learning_rate, momentum=momentum)
+ mom_update = mom_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+
+ # Check we have slots
+ self.assertEqual(["momentum"], mom_opt.get_slot_names())
+ slot0 = mom_opt.get_slot(var0, "momentum")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ slot1 = mom_opt.get_slot(var1, "momentum")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+ if not context.executing_eagerly():
+ self.assertFalse(slot0 in variables.trainable_variables())
+ self.assertFalse(slot1 in variables.trainable_variables())
+
+ # Step 1: the momentum accumulators where 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ if not context.executing_eagerly():
+ self.evaluate(mom_update)
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(np.array([0.1, 0.1]),
+ self.evaluate(slot0))
+ self.assertAllCloseAccordingToType(np.array([0.01, 0.01]),
+ self.evaluate(slot1))
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]),
+ self.evaluate(var0))
+ self.assertAllCloseAccordingToType(
+ np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]),
+ self.evaluate(var1))
+ # Step 2: the momentum accumulators contain the previous update.
+ if context.executing_eagerly():
+ mom_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ else:
+ self.evaluate(mom_update)
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]),
+ self.evaluate(slot0))
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
+ self.evaluate(slot1))
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
+ 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
+ ]), self.evaluate(var0))
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - (
+ (0.9 * 0.01 + 0.01) * 2.0)
+ ]), self.evaluate(var1))
+
+ def testBasic(self):
+ with self.cached_session():
+ self.doTestBasic(use_resource=False)
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testResourceBasic(self):
+ self.doTestBasic(use_resource=True)
+
+ def testBasicCallableParams(self):
+ with context.eager_mode():
+ self.doTestBasic(use_resource=True, use_callable_params=True)
+
+ def testVariablesAcrossGraphs(self):
+ optimizer = sgd.SGD(0.01, 0.5)
+ with ops.Graph().as_default():
+ var0 = resource_variable_ops.ResourceVariable(
+ [1.0, 2.0], dtype=dtypes.float32, name="var0")
+ var1 = resource_variable_ops.ResourceVariable(
+ [3.0, 4.0], dtype=dtypes.float32, name="var1")
+ loss = math_ops.reduce_sum(var0 + var1)
+ optimizer.minimize(loss)
+ optimizer_variables = optimizer.variables()
+ self.assertStartsWith(optimizer_variables[0].name, "var0")
+ self.assertStartsWith(optimizer_variables[1].name, "var1")
+ self.assertEquals(2, len(optimizer_variables))
+
+ with ops.Graph().as_default():
+ var2 = resource_variable_ops.ResourceVariable(
+ [1.0, 2.0], dtype=dtypes.float32, name="var2")
+ var3 = resource_variable_ops.ResourceVariable(
+ [3.0, 4.0], dtype=dtypes.float32, name="var3")
+ loss = math_ops.reduce_sum(var2 + var3)
+ optimizer.minimize(loss)
+ optimizer_variables = optimizer.variables()
+ self.assertStartsWith(optimizer_variables[0].name, "var2")
+ self.assertStartsWith(optimizer_variables[1].name, "var3")
+ self.assertEquals(2, len(optimizer_variables))
+
+ def testNesterovMomentum(self):
+ for dtype in [dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ cost = 5 * var0 * var0 + 3 * var1
+ global_step = variables.Variable(
+ array_ops.zeros([], dtypes.int64), name="global_step")
+ mom_op = sgd.SGD(learning_rate=2.0, momentum=0.9, nesterov=True)
+ opt_op = mom_op.minimize(cost, global_step, [var0, var1])
+ variables.global_variables_initializer().run()
+ for t in range(1, 5):
+ opt_op.run()
+ var0_np, accum0_np = self._update_nesterov_momentum_numpy(
+ var0_np, accum0_np, var0_np * 10, 2.0, 0.9)
+ var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np,
+ accum1_np,
+ 3, 2.0, 0.9)
+ self.assertAllClose(var0_np, var0.eval())
+ self.assertAllClose(var1_np, var1.eval())
+
+ def testSparseNesterovMomentum(self):
+ for dtype in [dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ grads = []
+ for t in range(1, 5):
+ grads.append(var0_np * 10)
+ var0_np, accum0_np = self._update_nesterov_momentum_numpy(
+ var0_np, accum0_np, var0_np * 10, 2.0, 0.9)
+ var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np,
+ accum1_np,
+ 3, 2.0, 0.9)
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ loss = 5 * var0 * var0 + 3 * var1
+ mom_op = sgd.SGD(learning_rate=2.0, momentum=0.9, nesterov=True)
+ x_feed = array_ops.placeholder(dtype)
+ y_feed = ops.IndexedSlices(
+ x_feed, constant_op.constant([0, 1]), constant_op.constant([2]))
+ grads_and_vars = [(y_feed, var0), (constant_op.constant(
+ [3.0, 3.0], dtype=dtype), var1)]
+ opt_update = mom_op.apply_gradients(grads_and_vars)
+ variables.global_variables_initializer().run()
+ for t in range(1, 5):
+ opt_update.run(feed_dict={x_feed: grads[t - 1]})
+ var0_np, accum0_np = self._update_nesterov_momentum_numpy(
+ var0_np, accum0_np, var0_np * 10, 2.0, 0.9)
+ var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np,
+ accum1_np,
+ 3, 2.0, 0.9)
+ self.assertAllClose(var0_np, var0.eval())
+ self.assertAllClose(var1_np, var1.eval())
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testMinimizeSparseResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ # This test invokes the ResourceSparseApplyMomentum operation, which
+ # did not have a registered GPU kernel as of April 2018. With graph
+ # execution, the placement algorithm notices this and automatically
+ # places the variable in CPU (host) memory. With eager execution,
+ # the variable would be placed in GPU memory if available, which
+ # would then conflict with the future invocation of the
+ # ResourceSparseApplyMomentum operation.
+ # To work around this discrepancy, for now we force the variable
+ # to be placed on CPU.
+ with ops.device("/cpu:0"):
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+
+ # pylint: disable=cell-var-from-loop
+ def loss():
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
+ return pred * pred
+ # pylint: enable=cell-var-from-loop
+
+ opt = sgd.SGD(learning_rate=1.0, momentum=0.0)
+ sgd_op = opt.minimize(loss)
+ self.evaluate(variables.global_variables_initializer())
+ # Run 1 step of sgd
+ self.evaluate(sgd_op)
+ # Validate updated params
+ self.assertAllCloseAccordingToType([[-111, -138]], self.evaluate(var0))
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testMinimizeWith2DIndiciesForEmbeddingLookup(self):
+ # This test invokes the ResourceSparseApplyMomentum operation, which
+ # did not have a registered GPU kernel as of April 2018. With graph
+ # execution, the placement algorithm notices this and automatically
+ # places the variable in CPU (host) memory. With eager execution,
+ # the variable would be placed in GPU memory if available, which
+ # would then conflict with the future invocation of the
+ # ResourceSparseApplyMomentum operation.
+ # To work around this discrepancy, for now we force the variable
+ # to be placed on CPU.
+ with ops.device("/cpu:0"):
+ var0 = resource_variable_ops.ResourceVariable(array_ops.ones([2, 2]))
+
+ def loss():
+ return math_ops.reduce_sum(embedding_ops.embedding_lookup(var0, [[1]]))
+
+ opt = sgd.SGD(learning_rate=1.0, momentum=0.0)
+ sgd_op = opt.minimize(loss)
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(sgd_op)
+ self.assertAllCloseAccordingToType([[1, 1], [0, 0]], self.evaluate(var0))
+
+ def testTensorLearningRateAndMomentum(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ mom_opt = sgd.SGD(
+ learning_rate=constant_op.constant(2.0),
+ momentum=constant_op.constant(0.9))
+ mom_update = mom_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Check we have slots
+ self.assertEqual(["momentum"], mom_opt.get_slot_names())
+ slot0 = mom_opt.get_slot(var0, "momentum")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ self.assertFalse(slot0 in variables.trainable_variables())
+ slot1 = mom_opt.get_slot(var1, "momentum")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+ self.assertFalse(slot1 in variables.trainable_variables())
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Step 1: the momentum accumulators where 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ mom_update.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval())
+ self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval())
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval())
+ # Step 2: the momentum accumulators contain the previous update.
+ mom_update.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval())
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
+ 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
+ ]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - (
+ (0.9 * 0.01 + 0.01) * 2.0)
+ ]), var1.eval())
+
+ def _dbParamsMom01(self):
+ """Return dist-belief momentum values.
+
+ Return values been generated from the dist-belief momentum unittest,
+ running with a learning rate of 0.1 and a momentum of 0.1.
+
+ These values record how a parameter vector of size 10, initialized with 0.0,
+ gets updated with 10 consecutive momentum steps. It uses random gradients.
+
+ Returns:
+ db_grad: The gradients to apply
+ db_out: The parameters after the momentum update.
+ """
+ db_grad = [[]] * 10
+ db_out = [[]] * 10
+ # pylint: disable=line-too-long
+ db_grad[0] = [
+ 0.00096264342, 0.17914793, 0.93945462, 0.41396621, 0.53037018,
+ 0.93197989, 0.78648776, 0.50036013, 0.55345792, 0.96722615
+ ]
+ db_out[0] = [
+ -9.6264346e-05, -0.017914793, -0.093945466, -0.041396622, -0.053037018,
+ -0.093197994, -0.078648776, -0.050036013, -0.055345792, -0.096722618
+ ]
+ db_grad[1] = [
+ 0.17075552, 0.88821375, 0.20873757, 0.25236958, 0.57578111, 0.15312378,
+ 0.5513742, 0.94687688, 0.16012503, 0.22159521
+ ]
+ db_out[1] = [
+ -0.017181443, -0.10852765, -0.12421377, -0.070773244, -0.11591884,
+ -0.11783017, -0.14165108, -0.14972731, -0.076892875, -0.1285544
+ ]
+ db_grad[2] = [
+ 0.35077485, 0.47304362, 0.44412705, 0.44368884, 0.078527533, 0.81223965,
+ 0.31168157, 0.43203235, 0.16792089, 0.24644311
+ ]
+ db_out[2] = [
+ -0.053967446, -0.1648933, -0.1716533, -0.1180798, -0.13005978,
+ -0.20151734, -0.17911947, -0.20289968, -0.095839672, -0.15638189
+ ]
+ db_grad[3] = [
+ 0.9694621, 0.75035888, 0.28171822, 0.83813518, 0.53807181, 0.3728098,
+ 0.81454384, 0.03848977, 0.89759839, 0.93665648
+ ]
+ db_out[3] = [
+ -0.15459226, -0.24556576, -0.20456907, -0.20662397, -0.18528105,
+ -0.24716705, -0.2643207, -0.21206589, -0.18749419, -0.2528303
+ ]
+ db_grad[4] = [
+ 0.38578293, 0.8536852, 0.88722926, 0.66276771, 0.13678469, 0.94036359,
+ 0.69107032, 0.81897682, 0.5433259, 0.67860287
+ ]
+ db_out[4] = [
+ -0.20323303, -0.33900154, -0.29658359, -0.28175515, -0.20448165,
+ -0.34576839, -0.34194785, -0.29488021, -0.25099224, -0.33033544
+ ]
+ db_grad[5] = [
+ 0.27885768, 0.76100707, 0.24625534, 0.81354135, 0.18959245, 0.48038563,
+ 0.84163809, 0.41172323, 0.83259648, 0.44941229
+ ]
+ db_out[5] = [
+ -0.23598288, -0.42444581, -0.33041057, -0.3706224, -0.22536094,
+ -0.40366709, -0.43387437, -0.34433398, -0.34060168, -0.38302717
+ ]
+ db_grad[6] = [
+ 0.27233034, 0.056316052, 0.5039115, 0.24105175, 0.35697976, 0.75913221,
+ 0.73577434, 0.16014607, 0.57500273, 0.071136251
+ ]
+ db_out[6] = [
+ -0.26649091, -0.43862185, -0.38418442, -0.40361428, -0.26314685,
+ -0.48537019, -0.51664448, -0.36529395, -0.40706289, -0.39540997
+ ]
+ db_grad[7] = [
+ 0.58697265, 0.2494842, 0.08106143, 0.39954534, 0.15892942, 0.12683646,
+ 0.74053431, 0.16033, 0.66625422, 0.73515922
+ ]
+ db_out[7] = [
+ -0.32823896, -0.46498787, -0.39766794, -0.446868, -0.28281838,
+ -0.50622416, -0.59897494, -0.38342294, -0.48033443, -0.47016418
+ ]
+ db_grad[8] = [
+ 0.8215279, 0.41994119, 0.95172721, 0.68000203, 0.79439718, 0.43384039,
+ 0.55561525, 0.22567581, 0.93331909, 0.29438227
+ ]
+ db_out[8] = [
+ -0.41656655, -0.50961858, -0.49418902, -0.51919359, -0.36422527,
+ -0.55169362, -0.6627695, -0.40780342, -0.58099347, -0.50707781
+ ]
+ db_grad[9] = [
+ 0.68297005, 0.67758518, 0.1748755, 0.13266537, 0.70697063, 0.055731893,
+ 0.68593478, 0.50580865, 0.12602448, 0.093537711
+ ]
+ db_out[9] = [
+ -0.49369633, -0.58184016, -0.52132869, -0.5396927, -0.44306302,
+ -0.56181377, -0.73774242, -0.46082234, -0.60366184, -0.52012295
+ ]
+ # pylint: enable=line-too-long
+ return db_grad, db_out
+
+ def testLikeDistBeliefMom01(self):
+ with self.cached_session():
+ db_grad, db_out = self._dbParamsMom01()
+ num_samples = len(db_grad)
+ var0 = variables.Variable([0.0] * num_samples)
+ grads0 = constant_op.constant([0.0] * num_samples)
+ mom_opt = sgd.SGD(learning_rate=0.1, momentum=0.1)
+ mom_update = mom_opt.apply_gradients(zip([grads0], [var0]))
+ variables.global_variables_initializer().run()
+ for i in xrange(num_samples):
+ mom_update.run(feed_dict={grads0: db_grad[i]})
+ self.assertAllClose(np.array(db_out[i]), var0.eval())
+
+ def testSparse(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable(array_ops.zeros([4, 2], dtype=dtype))
+ var1 = variables.Variable(constant_op.constant(1.0, dtype, [4, 2]))
+ grads0 = ops.IndexedSlices(
+ constant_op.constant(
+ [[.1, .1]], dtype=dtype),
+ constant_op.constant([1]),
+ constant_op.constant([4, 2]))
+ grads1 = ops.IndexedSlices(
+ constant_op.constant(
+ [[.01, .01], [.01, .01]], dtype=dtype),
+ constant_op.constant([2, 3]),
+ constant_op.constant([4, 2]))
+ mom_opt = sgd.SGD(learning_rate=2.0, momentum=0.9)
+ mom_update = mom_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ # Check we have slots
+ self.assertEqual(["momentum"], mom_opt.get_slot_names())
+ slot0 = mom_opt.get_slot(var0, "momentum")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ slot1 = mom_opt.get_slot(var1, "momentum")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+
+ # Fetch params to validate initial values
+ self.assertAllClose([0, 0], var0.eval()[0])
+ self.assertAllClose([0, 0], var0.eval()[1])
+ self.assertAllClose([1, 1], var1.eval()[2])
+
+ # Step 1: the momentum accumulators are 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ mom_update.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(np.array([0, 0]), slot0.eval()[0])
+ self.assertAllCloseAccordingToType(np.array([.1, .1]), slot0.eval()[1])
+ self.assertAllCloseAccordingToType(
+ np.array([.01, .01]), slot1.eval()[2])
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(np.array([0, 0]), var0.eval()[0])
+ self.assertAllCloseAccordingToType(
+ np.array([-(0.1 * 2.0), -(0.1 * 2.0)]), var0.eval()[1])
+ self.assertAllCloseAccordingToType(
+ np.array([1.0 - (0.01 * 2.0), 1.0 - (0.01 * 2.0)]), var1.eval()[2])
+ # Step 2: the momentum accumulators contain the previous update.
+ mom_update.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllClose(np.array([0, 0]), slot0.eval()[0])
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval()[1])
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
+ slot1.eval()[2])
+ # Check that the parameters have been updated.
+ self.assertAllClose(np.array([0, 0]), var0.eval()[0])
+ self.assertAllCloseAccordingToType(
+ np.array([
+ -(0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0), -(0.1 * 2.0) - (
+ (0.9 * 0.1 + 0.1) * 2.0)
+ ]), var0.eval()[1])
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 0.98 - ((0.9 * 0.01 + 0.01) * 2.0), 0.98 - (
+ (0.9 * 0.01 + 0.01) * 2.0)
+ ]), var1.eval()[2])
+
+ def testSharing(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ mom_opt = sgd.SGD(learning_rate=2.0, momentum=0.9)
+ mom_update1 = mom_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ mom_update2 = mom_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ self.assertEqual(["momentum"], mom_opt.get_slot_names())
+ slot0 = mom_opt.get_slot(var0, "momentum")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ slot1 = mom_opt.get_slot(var1, "momentum")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Step 1: the momentum accumulators where 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ mom_update1.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval())
+ self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval())
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval())
+ # Step 2: the second momentum accumulators contain the previous update.
+ mom_update2.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval())
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
+ 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
+ ]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - (
+ (0.9 * 0.01 + 0.01) * 2.0)
+ ]), var1.eval())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py
index 501b50ba5f..2fae094a1e 100644
--- a/tensorflow/python/keras/testing_utils.py
+++ b/tensorflow/python/keras/testing_utils.py
@@ -166,8 +166,9 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None,
if expected_dim is not None:
if expected_dim != actual_dim:
raise AssertionError(
- 'When testing layer %s, for input %s, found output_shape='
- '%s but expected to find %s.\nFull kwargs: %s' %
+ 'When testing layer %s **after deserialization**, '
+ 'for input %s, found output_shape='
+ '%s but expected to find inferred shape %s.\nFull kwargs: %s' %
(layer_cls.__name__,
x,
actual_output_shape,
diff --git a/tensorflow/python/keras/utils/conv_utils.py b/tensorflow/python/keras/utils/conv_utils.py
index 8ebca1418d..f486e631e5 100644
--- a/tensorflow/python/keras/utils/conv_utils.py
+++ b/tensorflow/python/keras/utils/conv_utils.py
@@ -137,26 +137,49 @@ def conv_input_length(output_length, filter_size, padding, stride):
return (output_length - 1) * stride - 2 * pad + filter_size
-def deconv_output_length(input_length, filter_size, padding, stride):
+def deconv_output_length(input_length, filter_size, padding,
+ output_padding=None, stride=0, dilation=1):
"""Determines output length of a transposed convolution given input length.
Arguments:
- input_length: integer.
- filter_size: integer.
- padding: one of "same", "valid", "full".
- stride: integer.
+ input_length: Integer.
+ filter_size: Integer.
+ padding: one of `"same"`, `"valid"`, `"full"`.
+ output_padding: Integer, amount of padding along the output dimension.
+ Can be set to `None` in which case the output length is inferred.
+ stride: Integer.
+ dilation: Integer.
Returns:
The output length (integer).
"""
+ assert padding in {'same', 'valid', 'full'}
if input_length is None:
return None
- input_length *= stride
- if padding == 'valid':
- input_length += max(filter_size - stride, 0)
- elif padding == 'full':
- input_length -= (stride + filter_size - 2)
- return input_length
+
+ # Get the dilated kernel size
+ filter_size = filter_size + (filter_size - 1) * (dilation - 1)
+
+ # Infer length if output padding is None, else compute the exact length
+ if output_padding is None:
+ if padding == 'valid':
+ length = input_length * stride + max(filter_size - stride, 0)
+ elif padding == 'full':
+ length = input_length * stride - (stride + filter_size - 2)
+ elif padding == 'same':
+ length = input_length * stride
+
+ else:
+ if padding == 'same':
+ pad = filter_size // 2
+ elif padding == 'valid':
+ pad = 0
+ elif padding == 'full':
+ pad = filter_size - 1
+
+ length = ((input_length - 1) * stride + filter_size - 2 * pad +
+ output_padding)
+ return length
def normalize_data_format(value):
diff --git a/tensorflow/python/keras/utils/multi_gpu_utils.py b/tensorflow/python/keras/utils/multi_gpu_utils.py
index e1c49bc852..04b2ea8fe3 100644
--- a/tensorflow/python/keras/utils/multi_gpu_utils.py
+++ b/tensorflow/python/keras/utils/multi_gpu_utils.py
@@ -244,9 +244,24 @@ def multi_gpu_model(model, gpus, cpu_merge=True, cpu_relocation=False):
for o in range(len(outputs)):
all_outputs[o].append(outputs[o])
+ # Deduplicate output names to handle Siamese networks.
+ occurrences = {}
+ for n in model.output_names:
+ if n not in occurrences:
+ occurrences[n] = 1
+ else:
+ occurrences[n] += 1
+ conflict_counter = {n: 0 for n, count in occurrences.items() if count > 1}
+ output_names = []
+ for n in model.output_names:
+ if n in conflict_counter:
+ conflict_counter[n] += 1
+ n += '_%d' % conflict_counter[n]
+ output_names.append(n)
+
# Merge outputs under expected scope.
with ops.device('/cpu:0' if cpu_merge else '/gpu:%d' % target_gpu_ids[0]):
merged = []
- for name, outputs in zip(model.output_names, all_outputs):
+ for name, outputs in zip(output_names, all_outputs):
merged.append(concatenate(outputs, axis=0, name=name))
return Model(model.inputs, merged)
diff --git a/tensorflow/python/keras/utils/multi_gpu_utils_test.py b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
index 3d0351a11f..1780ab6587 100644
--- a/tensorflow/python/keras/utils/multi_gpu_utils_test.py
+++ b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
@@ -198,5 +198,31 @@ class TestMultiGPUModel(test.TestCase):
parallel_model.compile(loss='mean_squared_error', optimizer='adam')
parallel_model.train_on_batch(x, y)
+ def test_multi_gpu_with_siamese_network(self):
+ gpus = 2
+
+ if not check_if_compatible_devices(gpus=gpus):
+ return
+
+ with self.cached_session():
+ input_shape = (3,)
+ nested_model = keras.models.Sequential([
+ keras.layers.Dense(32, input_shape=input_shape),
+ keras.layers.Dense(1)
+ ], name='nested')
+
+ input1 = keras.Input(input_shape)
+ input2 = keras.Input(input_shape)
+ score1 = nested_model(input1)
+ score2 = nested_model(input2)
+ score_sum = keras.layers.Add(name='add')([score1, score2])
+
+ siamese = keras.models.Model(inputs=[input1, input2],
+ outputs=[score_sum, score1, score2],
+ name='siamese')
+ parallel_siamese = keras.utils.multi_gpu_model(siamese, gpus)
+ self.assertEqual(parallel_siamese.output_names,
+ ['add', 'nested_1', 'nested_2'])
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/utils/np_utils.py b/tensorflow/python/keras/utils/np_utils.py
index c24e87308b..3763999bff 100644
--- a/tensorflow/python/keras/utils/np_utils.py
+++ b/tensorflow/python/keras/utils/np_utils.py
@@ -22,7 +22,7 @@ from tensorflow.python.util.tf_export import tf_export
@tf_export('keras.utils.to_categorical')
-def to_categorical(y, num_classes=None):
+def to_categorical(y, num_classes=None, dtype='float32'):
"""Converts a class vector (integers) to binary class matrix.
E.g. for use with categorical_crossentropy.
@@ -31,6 +31,7 @@ def to_categorical(y, num_classes=None):
y: class vector to be converted into a matrix
(integers from 0 to num_classes).
num_classes: total number of classes.
+ dtype: The data type expected by the input. Default: `'float32'`.
Returns:
A binary matrix representation of the input. The classes axis is placed
@@ -44,7 +45,7 @@ def to_categorical(y, num_classes=None):
if not num_classes:
num_classes = np.max(y) + 1
n = y.shape[0]
- categorical = np.zeros((n, num_classes), dtype=np.float32)
+ categorical = np.zeros((n, num_classes), dtype=dtype)
categorical[np.arange(n), y] = 1
output_shape = input_shape + (num_classes,)
categorical = np.reshape(categorical, output_shape)
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 9303c70c60..4e8639dfc8 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -76,6 +76,7 @@ tf_py_test(
name = "batch_gather_op_test",
srcs = ["batch_gather_op_test.py"],
additional_deps = [
+ "@absl_py//absl/testing:parameterized",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
@@ -3254,7 +3255,7 @@ tf_py_test(
tags = ["no_pip"],
)
-tf_py_test(
+cuda_py_test(
name = "cond_v2_test",
size = "medium",
srcs = ["cond_v2_test.py"],
@@ -3271,7 +3272,6 @@ tf_py_test(
"//tensorflow/python:training",
],
grpc_enabled = True,
- tags = ["no_gpu"], # TODO(b/111656070)
)
cuda_py_test(
diff --git a/tensorflow/python/kernel_tests/batch_gather_op_test.py b/tensorflow/python/kernel_tests/batch_gather_op_test.py
index 7dd347989a..84e93b8136 100644
--- a/tensorflow/python/kernel_tests/batch_gather_op_test.py
+++ b/tensorflow/python/kernel_tests/batch_gather_op_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.python.framework import constant_op
@@ -29,7 +30,7 @@ _TEST_TYPES = (dtypes.int64, dtypes.float32,
dtypes.complex64, dtypes.complex128)
-class GatherTest(test.TestCase):
+class GatherTest(test.TestCase, parameterized.TestCase):
def _buildParams(self, data, dtype):
data = data.astype(dtype.as_numpy_dtype)
@@ -39,14 +40,15 @@ class GatherTest(test.TestCase):
return data + 10j * data
return data
- def testSimpleGather(self):
+ @parameterized.parameters(dtypes.int32, dtypes.int64)
+ def testSimpleGather(self, indices_dtype):
data = np.array([0, 1, 2, 3, 7, 5, 8, 9, 10, 11, 15, 13])
indices = [3, 4]
with self.test_session(use_gpu=True):
for dtype in _TEST_TYPES:
params_np = self._buildParams(data, dtype)
params = constant_op.constant(params_np)
- indices_tf = constant_op.constant(indices)
+ indices_tf = constant_op.constant(indices, dtype=indices_dtype)
gather_t = array_ops.batch_gather(params, indices_tf)
expected_result = np.array([3, 7])
np_val = self._buildParams(expected_result, dtype)
@@ -54,14 +56,15 @@ class GatherTest(test.TestCase):
self.assertAllEqual(np_val, gather_val)
self.assertEqual(np_val.shape, gather_t.get_shape())
- def test2DArray(self):
+ @parameterized.parameters(dtypes.int32, dtypes.int64)
+ def test2DArray(self, indices_dtype):
data = np.array([[0, 1, 2, 3, 7, 5], [8, 9, 10, 11, 15, 13]])
indices = [[3], [4]]
with self.test_session(use_gpu=True):
for dtype in _TEST_TYPES:
params_np = self._buildParams(data, dtype)
params = constant_op.constant(params_np)
- indices_tf = constant_op.constant(indices)
+ indices_tf = constant_op.constant(indices, dtype=indices_dtype)
gather_t = array_ops.batch_gather(params, indices_tf)
expected_result = np.array([[3], [15]])
np_val = self._buildParams(expected_result, dtype)
diff --git a/tensorflow/python/kernel_tests/benchmark_test.py b/tensorflow/python/kernel_tests/benchmark_test.py
index 78b6e38d94..5777a5d097 100644
--- a/tensorflow/python/kernel_tests/benchmark_test.py
+++ b/tensorflow/python/kernel_tests/benchmark_test.py
@@ -64,7 +64,7 @@ class TestReportingBenchmark(test.Benchmark):
"other_key": "string"})
def benchmark_times_an_op(self):
- with session.Session() as sess:
+ with session.Session(config=benchmark.benchmark_config()) as sess:
a = constant_op.constant(0.0)
a_plus_a = a + a
return self.run_op_benchmark(
diff --git a/tensorflow/python/kernel_tests/bincount_op_test.py b/tensorflow/python/kernel_tests/bincount_op_test.py
index 8a58b3f97e..8177cdd454 100644
--- a/tensorflow/python/kernel_tests/bincount_op_test.py
+++ b/tensorflow/python/kernel_tests/bincount_op_test.py
@@ -22,6 +22,8 @@ import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
@@ -97,6 +99,22 @@ class BincountTest(test_util.TensorFlowTestCase):
with self.assertRaises(errors.InvalidArgumentError):
math_ops.bincount([1, 2, 3, -1, 6, 8]).eval()
+ def test_shape_function(self):
+ # size must be scalar.
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1 for 'Bincount'"):
+ gen_math_ops.bincount([1, 2, 3, -1, 6, 8], [1], [])
+ # size must be positive.
+ with self.assertRaisesRegexp(ValueError, "must be non-negative"):
+ gen_math_ops.bincount([1, 2, 3, -1, 6, 8], -5, [])
+ # if size is a constant then the shape is known.
+ v1 = gen_math_ops.bincount([1, 2, 3, -1, 6, 8], 5, [])
+ self.assertAllEqual(v1.get_shape().as_list(), [5])
+ # if size is a placeholder then the shape is unknown.
+ s = array_ops.placeholder(dtype=dtypes.int32)
+ v2 = gen_math_ops.bincount([1, 2, 3, -1, 6, 8], s, [])
+ self.assertAllEqual(v2.get_shape().as_list(), [None])
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/kernel_tests/cholesky_op_test.py b/tensorflow/python/kernel_tests/cholesky_op_test.py
index 782e6b5068..2ebf74a4d7 100644
--- a/tensorflow/python/kernel_tests/cholesky_op_test.py
+++ b/tensorflow/python/kernel_tests/cholesky_op_test.py
@@ -36,6 +36,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.linalg import linalg
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
@@ -327,7 +328,7 @@ class CholeskyBenchmark(test.Benchmark):
def benchmarkCholeskyOp(self):
for shape in self.shapes:
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/cpu:0"):
matrix = variables.Variable(self._GenerateMatrix(shape))
l = linalg_ops.cholesky(matrix)
@@ -341,7 +342,7 @@ class CholeskyBenchmark(test.Benchmark):
if test.is_gpu_available(True):
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/device:GPU:0"):
matrix = variables.Variable(self._GenerateMatrix(shape))
l = linalg_ops.cholesky(matrix)
@@ -359,7 +360,7 @@ class CholeskyBenchmark(test.Benchmark):
for shape in self.shapes:
matrix = self._GenerateMatrix(shape)
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device(device):
l = variables.Variable(np.linalg.cholesky(matrix))
grad_matrix = variables.Variable(
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index 377c041675..a424a0f219 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -153,6 +153,7 @@ class CondV2Test(test.TestCase):
self.assertIn("foo_cond_1_false", ops.get_default_graph()._functions)
def testDefunInCond(self):
+ self.skipTest("b/117293122")
x = constant_op.constant(1.0, name="x")
y = constant_op.constant(2.0, name="y")
@@ -172,7 +173,7 @@ class CondV2Test(test.TestCase):
self._testCond(true_fn, false_fn, [y])
def testNestedDefunInCond(self):
- self.skipTest("b/110550782")
+ self.skipTest("b/117284369")
x = constant_op.constant(1.0, name="x")
y = constant_op.constant(2.0, name="y")
@@ -198,7 +199,7 @@ class CondV2Test(test.TestCase):
self._testCond(true_fn, false_fn, [y])
def testDoubleNestedDefunInCond(self):
- self.skipTest("b/110550782")
+ self.skipTest("b/117284369")
x = constant_op.constant(1.0, name="x")
y = constant_op.constant(2.0, name="y")
@@ -468,7 +469,6 @@ class CondV2Test(test.TestCase):
}), [5., 0.])
def testBuildCondAndGradientInsideDefun(self):
- self.skipTest("b/110550782")
def build_graph():
pred_outer = array_ops.placeholder(dtypes.bool, name="pred_outer")
@@ -502,29 +502,29 @@ class CondV2Test(test.TestCase):
return grads, pred_outer, pred_inner
- with ops.Graph().as_default():
+ with ops.Graph().as_default(), self.session(
+ graph=ops.get_default_graph()) as sess:
grads, pred_outer, pred_inner = build_graph()
- with self.session(graph=ops.get_default_graph()) as sess:
- self.assertSequenceEqual(
- sess.run(grads, {
- pred_outer: True,
- pred_inner: True
- }), [0., 0.])
- self.assertSequenceEqual(
- sess.run(grads, {
- pred_outer: True,
- pred_inner: False
- }), [0., 0.])
- self.assertSequenceEqual(
- sess.run(grads, {
- pred_outer: False,
- pred_inner: True
- }), [4., 2.])
- self.assertSequenceEqual(
- sess.run(grads, {
- pred_outer: False,
- pred_inner: False
- }), [5., 0.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: True,
+ pred_inner: True
+ }), [0., 0.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: True,
+ pred_inner: False
+ }), [0., 0.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: False,
+ pred_inner: True
+ }), [4., 2.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: False,
+ pred_inner: False
+ }), [5., 0.])
def testSecondDerivative(self):
with self.cached_session() as sess:
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index a1be77601c..baea5c0f6d 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -23,7 +23,6 @@ from __future__ import print_function
import collections
import math
import time
-import unittest
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
@@ -661,10 +660,8 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(pred, fn1, fn2)
sess.run(r)
- @test_util.disable_control_flow_v2("b/113346829 (gpu failure)")
def testCondGrad_1(self):
- graph = ops.Graph()
- with graph.as_default():
+ with self.cached_session():
x = constant_op.constant(10.0, name="x")
pred = math_ops.less(1, 2)
fn1 = lambda: array_ops.identity(x)
@@ -672,8 +669,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(pred, fn1, fn2)
grad = gradients_impl.gradients(r, [x])[0]
- with self.cached_session():
- self.assertAllEqual(1.0, grad.eval())
+ self.assertAllEqual(1.0, grad.eval())
def testCondGrad_2(self):
with self.cached_session():
@@ -3422,10 +3418,8 @@ class EagerTest(test.TestCase):
self.assertAllEqual(r.numpy(), 10)
self.assertFalse(isinstance(r, list))
- def testCondInDefun(self):
- if "GPU" in [d.device_type for d in device_lib.list_local_devices()]:
- return unittest.skip("b/113346829 (gpu failure)")
-
+ # TODO(b/117279927): Re-enable once msan failure is fixed.
+ def DISABLED_testCondInDefun(self):
with context.eager_mode():
@eager_function.defun
diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
index 6aee2eb0a3..737a73f97a 100644
--- a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
+++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
@@ -131,7 +131,7 @@ class DepthwiseConv2DTest(test.TestCase):
with self.session(graph=graph, use_gpu=use_gpu) as sess:
tolerance = {
dtypes.float16: 4e-2,
- dtypes.float32: 1e-6,
+ dtypes.float32: 1e-5,
dtypes.float64: 1e-12,
}[data_type]
diff --git a/tensorflow/python/kernel_tests/determinant_op_test.py b/tensorflow/python/kernel_tests/determinant_op_test.py
index a52b2c0dc3..fb114f9f24 100644
--- a/tensorflow/python/kernel_tests/determinant_op_test.py
+++ b/tensorflow/python/kernel_tests/determinant_op_test.py
@@ -28,6 +28,7 @@ from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
@@ -185,8 +186,8 @@ class MatrixDeterminantBenchmark(test.Benchmark):
def benchmarkMatrixDeterminantOp(self):
for shape in self.shapes:
- with ops.Graph().as_default(), session.Session() as sess, ops.device(
- "/cpu:0"):
+ with ops.Graph().as_default(), session.Session(
+ config=benchmark.benchmark_config()) as sess, ops.device("/cpu:0"):
matrix = self._GenerateMatrix(shape)
d = linalg_ops.matrix_determinant(matrix)
variables.global_variables_initializer().run()
@@ -198,8 +199,8 @@ class MatrixDeterminantBenchmark(test.Benchmark):
name="matrix_determinant_cpu_{shape}".format(shape=shape))
if test.is_gpu_available(True):
- with ops.Graph().as_default(), session.Session() as sess, ops.device(
- "/gpu:0"):
+ with ops.Graph().as_default(), session.Session(
+ config=benchmark.benchmark_config()) as sess, ops.device("/gpu:0"):
matrix = self._GenerateMatrix(shape)
d = linalg_ops.matrix_determinant(matrix)
variables.global_variables_initializer().run()
diff --git a/tensorflow/python/kernel_tests/logging_ops_test.py b/tensorflow/python/kernel_tests/logging_ops_test.py
index 4beddd00bb..2f19ecc0e6 100644
--- a/tensorflow/python/kernel_tests/logging_ops_test.py
+++ b/tensorflow/python/kernel_tests/logging_ops_test.py
@@ -306,6 +306,19 @@ class PrintV2Test(test.TestCase):
logging_ops.print_v2(tensor)
self.assertTrue((expected + "\n") in printed.contents())
+ def testPrintsOrderedInDefun(self):
+ with context.eager_mode():
+
+ @function.defun
+ def prints():
+ logging_ops.print_v2("A")
+ logging_ops.print_v2("B")
+ logging_ops.print_v2("C")
+
+ with self.captureWritesToStream(sys.stderr) as printed:
+ prints()
+ self.assertTrue(("A\nB\nC\n") in printed.contents())
+
@test_util.run_in_graph_and_eager_modes()
def testPrintInDefunWithoutExplicitEvalOfPrint(self):
@function.defun
diff --git a/tensorflow/python/kernel_tests/matrix_band_part_op_test.py b/tensorflow/python/kernel_tests/matrix_band_part_op_test.py
index 68d626de2c..a0ef3a607e 100644
--- a/tensorflow/python/kernel_tests/matrix_band_part_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_band_part_op_test.py
@@ -27,6 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import variables
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test as test_lib
@@ -109,7 +110,7 @@ class MatrixBandPartBenchmark(test_lib.Benchmark):
for shape_ in self.shapes:
for limits in (-1, -1), (-1, 0), (0, -1), (2, 2):
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/cpu:0"):
matrix = variables.Variable(array_ops.ones(shape_))
band = array_ops.matrix_band_part(matrix, limits[0], limits[1])
@@ -123,7 +124,7 @@ class MatrixBandPartBenchmark(test_lib.Benchmark):
if test_lib.is_gpu_available(True):
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/gpu:0"):
matrix = variables.Variable(array_ops.ones(shape_))
band = array_ops.matrix_band_part(matrix, limits[0], limits[1])
diff --git a/tensorflow/python/kernel_tests/matrix_exponential_op_test.py b/tensorflow/python/kernel_tests/matrix_exponential_op_test.py
index 0386e91276..9630c052b8 100644
--- a/tensorflow/python/kernel_tests/matrix_exponential_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_exponential_op_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.linalg import linalg_impl
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
@@ -181,7 +182,7 @@ class MatrixExponentialBenchmark(test.Benchmark):
def benchmarkMatrixExponentialOp(self):
for shape in self.shapes:
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/cpu:0"):
matrix = self._GenerateMatrix(shape)
expm = linalg_impl.matrix_exponential(matrix)
@@ -195,7 +196,7 @@ class MatrixExponentialBenchmark(test.Benchmark):
if test.is_gpu_available(True):
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/gpu:0"):
matrix = self._GenerateMatrix(shape)
expm = linalg_impl.matrix_exponential(matrix)
diff --git a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
index 720ba806e9..8bda04b53d 100644
--- a/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_inverse_op_test.py
@@ -28,6 +28,7 @@ from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
@@ -179,7 +180,7 @@ class MatrixInverseBenchmark(test.Benchmark):
for adjoint in False, True:
for shape in self.shapes:
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/cpu:0"):
matrix = self._GenerateMatrix(shape)
inv = linalg_ops.matrix_inverse(matrix, adjoint=adjoint)
@@ -193,7 +194,7 @@ class MatrixInverseBenchmark(test.Benchmark):
if test.is_gpu_available(True):
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/gpu:0"):
matrix = self._GenerateMatrix(shape)
inv = linalg_ops.matrix_inverse(matrix, adjoint=adjoint)
diff --git a/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py b/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py
index 723a15fbd1..3205e211d9 100644
--- a/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_logarithm_op_test.py
@@ -31,6 +31,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.linalg import linalg_impl
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
@@ -159,7 +160,7 @@ class MatrixLogarithmBenchmark(test.Benchmark):
def benchmarkMatrixLogarithmOp(self):
for shape in self.shapes:
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/cpu:0"):
matrix = self._GenerateMatrix(shape)
logm = gen_linalg_ops.matrix_logarithm(matrix)
diff --git a/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py b/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py
index de495968a7..225a10e117 100644
--- a/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_solve_ls_op_test.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test as test_lib
@@ -313,7 +314,7 @@ class MatrixSolveLsBenchmark(test_lib.Benchmark):
for num_rhs in 1, 2, matrix_shape[-1]:
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/cpu:0"):
matrix, rhs = _GenerateTestData(matrix_shape, num_rhs)
x = linalg_ops.matrix_solve_ls(matrix, rhs, regularizer)
@@ -328,7 +329,7 @@ class MatrixSolveLsBenchmark(test_lib.Benchmark):
if run_gpu_test and (len(matrix_shape) < 3 or matrix_shape[0] < 513):
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/gpu:0"):
matrix, rhs = _GenerateTestData(matrix_shape, num_rhs)
x = linalg_ops.matrix_solve_ls(matrix, rhs, regularizer)
diff --git a/tensorflow/python/kernel_tests/matrix_solve_op_test.py b/tensorflow/python/kernel_tests/matrix_solve_op_test.py
index b8f2736b7b..264df2565c 100644
--- a/tensorflow/python/kernel_tests/matrix_solve_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_solve_op_test.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
@@ -167,7 +168,7 @@ class MatrixSolveBenchmark(test.Benchmark):
for num_rhs in 1, 2, matrix_shape[-1]:
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/cpu:0"):
matrix, rhs = self._GenerateTestData(matrix_shape, num_rhs)
x = linalg_ops.matrix_solve(matrix, rhs, adjoint=adjoint)
@@ -185,7 +186,7 @@ class MatrixSolveBenchmark(test.Benchmark):
if run_gpu_test:
with ops.Graph().as_default(), \
- session.Session() as sess, \
+ session.Session(config=benchmark.benchmark_config()) as sess, \
ops.device("/gpu:0"):
matrix, rhs = self._GenerateTestData(matrix_shape, num_rhs)
x = linalg_ops.matrix_solve(matrix, rhs, adjoint=adjoint)
diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py
index a45a325b47..672d6556f5 100644
--- a/tensorflow/python/kernel_tests/relu_op_test.py
+++ b/tensorflow/python/kernel_tests/relu_op_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.python.compat import compat
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -282,6 +283,125 @@ class Relu6Test(test.TestCase):
self.assertLess(err, 1e-10)
+class LeakyReluTest(test.TestCase):
+
+ def _npLeakyRelu(self, np_features, alpha=0.1):
+ return np.maximum(np_features, alpha * np_features)
+
+ def testNpLeakyRelu(self):
+ self.assertAllClose(
+ np.array([[-0.09, 0.7, -0.05, 0.3, -0.01],
+ [0.1, -0.03, 0.5, -0.07, 0.9]]),
+ self._npLeakyRelu(
+ np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
+ 0.9]]),
+ alpha=0.1))
+
+ def _testLeakyRelu(self, np_features, alpha, use_gpu=False):
+ np_leaky_relu = self._npLeakyRelu(np_features, alpha)
+ with self.test_session(use_gpu=use_gpu):
+ leaky_relu = nn_ops.leaky_relu(np_features, alpha)
+ tf_leaky_relu = leaky_relu.eval()
+ self.assertAllClose(np_leaky_relu, tf_leaky_relu)
+ self.assertShapeEqual(np_leaky_relu, leaky_relu)
+
+ def testNumbers(self):
+ for t in [np.int32, np.int64, np.float16, np.float32, np.float64]:
+ self._testLeakyRelu(
+ np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
+ alpha=0.2,
+ use_gpu=False)
+ if t in [np.float16, np.float32, np.float64]:
+ self._testLeakyRelu(
+ np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
+ alpha=0.1,
+ use_gpu=True)
+
+ # The gradient test for Leaky ReLU is a bit tricky as the derivative is not
+ # well defined at around zero and we want to avoid that in terms of input
+ # values.
+ def testGradientFloat32(self):
+ with self.test_session():
+ x = constant_op.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
+ shape=[2, 5],
+ name="x")
+ y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu")
+ x_init = np.asarray(
+ [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
+ dtype=np.float32,
+ order="F")
+ err = gradient_checker.compute_gradient_error(
+ x, [2, 5], y, [2, 5], x_init_value=x_init)
+ print("leaky_relu (float32) gradient err = ", err)
+ self.assertLess(err, 1e-4)
+
+ def testGradientFloat64(self):
+ with self.test_session():
+ x = constant_op.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
+ shape=[2, 5],
+ dtype=dtypes.float64,
+ name="x")
+ y = nn_ops.leaky_relu(x, alpha=0.2, name="leaky_relu")
+ x_init = np.asarray(
+ [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
+ dtype=np.float64,
+ order="F")
+ err = gradient_checker.compute_gradient_error(
+ x, [2, 5], y, [2, 5], x_init_value=x_init)
+ print("leaky_relu (float64) gradient err = ", err)
+ self.assertLess(err, 1e-10)
+
+ def testGradGradFloat32(self):
+ with compat.forward_compatibility_horizon(2018, 11, 2):
+ with self.test_session():
+ x = constant_op.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
+ shape=[2, 5],
+ name="x")
+ y = nn_ops.leaky_relu(x, alpha=0.1, name="leaky_relu")
+ z = gradients_impl.gradients(y, x)
+ x_init = np.asarray(
+ [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
+ dtype=np.float32,
+ order="F")
+ err = gradient_checker.compute_gradient_error(
+ x, [2, 5], z[0], [2, 5], x_init_value=x_init)
+ print("leaky_relu (float32) gradient of gradient err = ", err)
+ self.assertLess(err, 1e-4)
+
+ def testGradGradFloat64(self):
+ with compat.forward_compatibility_horizon(2018, 11, 2):
+ with self.test_session():
+ x = constant_op.constant(
+ [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9],
+ shape=[2, 5],
+ dtype=dtypes.float64,
+ name="x")
+ y = nn_ops.leaky_relu(x, alpha=0.02, name="leaky_relu")
+ z = gradients_impl.gradients(y, x)
+ x_init = np.asarray(
+ [[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
+ dtype=np.float64,
+ order="F")
+ err = gradient_checker.compute_gradient_error(
+ x, [2, 5], z[0], [2, 5], x_init_value=x_init)
+ print("leaky_relu (float64) gradient of gradient err = ", err)
+ self.assertLess(err, 1e-10)
+
+ def testGradientScalar(self):
+ with self.test_session() as sess:
+ x = variables.Variable(-100.)
+ y = nn_ops.leaky_relu(x, 0.05)
+ loss = y**2
+ optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.2)
+ train_op = optimizer.minimize(loss)
+ sess.run(variables.global_variables_initializer())
+ sess.run(train_op)
+ self.assertAllClose(x.eval(), -99.9)
+
+
class EluTest(test.TestCase):
def _npElu(self, np_features):
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index 1365d4b240..a9fd93e9f8 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -142,7 +142,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
v = resource_variable_ops.ResourceVariable(1.0)
ops.reset_default_graph()
v.assign(2.0) # Note: this fails if we run convert_to_tensor on not the
- # variable graph.
+ # variable graph.
def testFetchHandle(self):
with self.cached_session():
@@ -908,6 +908,13 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(Exception, r"shape.*2.*3"):
state_ops.scatter_update(v, [0, 1], [0, 1, 2])
+ @test_util.run_in_graph_and_eager_modes
+ def testAssignIncompatibleShape(self):
+ v = resource_variable_ops.ResourceVariable([0, 1, 2, 3])
+ self.evaluate(v.initializer)
+ with self.assertRaisesRegexp(Exception, r"hapes must be equal"):
+ self.assertAllEqual(self.evaluate(v.assign_add(1)), [1, 2, 3, 4])
+
class _MixedPrecisionVariableTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py b/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py
index 31e84341ae..fdfe1001b8 100644
--- a/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py
+++ b/tensorflow/python/kernel_tests/sparse_tensors_map_ops_test.py
@@ -27,6 +27,7 @@ from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variables
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
# pylint: disable=protected-access
@@ -192,7 +193,7 @@ class BenchmarkSparseTensorsMapVsSerialization(test.Benchmark):
sorted(zip(indices_batch, indices_value)), dtype=np.int64)
values = ["feature_value_for_embedding_lookup"] * num_elements
shape = np.asarray([batch_size, num_elements], dtype=np.int64)
- with session.Session() as sess:
+ with session.Session(config=benchmark.benchmark_config()) as sess:
with ops.device("/cpu:0"):
indices = variables.Variable(indices)
values = variables.Variable(values)
diff --git a/tensorflow/python/kernel_tests/substr_op_test.py b/tensorflow/python/kernel_tests/substr_op_test.py
index cd3fe14883..37aa624b07 100644
--- a/tensorflow/python/kernel_tests/substr_op_test.py
+++ b/tensorflow/python/kernel_tests/substr_op_test.py
@@ -28,270 +28,448 @@ from tensorflow.python.platform import test
class SubstrOpTest(test.TestCase, parameterized.TestCase):
- def _testScalarString(self, dtype):
- test_string = b"Hello"
- position = np.array(1, dtype)
+ @parameterized.parameters(
+ (np.int32, 1, "BYTE"),
+ (np.int64, 1, "BYTE"),
+ (np.int32, -4, "BYTE"),
+ (np.int64, -4, "BYTE"),
+ (np.int32, 1, "UTF8_CHAR"),
+ (np.int64, 1, "UTF8_CHAR"),
+ (np.int32, -4, "UTF8_CHAR"),
+ (np.int64, -4, "UTF8_CHAR"),
+ )
+ def testScalarString(self, dtype, pos, unit):
+ test_string = {
+ "BYTE": b"Hello",
+ "UTF8_CHAR": u"He\xc3\xc3\U0001f604".encode("utf-8"),
+ }[unit]
+ expected_value = {
+ "BYTE": b"ell",
+ "UTF8_CHAR": u"e\xc3\xc3".encode("utf-8"),
+ }[unit]
+ position = np.array(pos, dtype)
length = np.array(3, dtype)
- expected_value = b"ell"
-
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- # Negative position.
- test_string = b"Hello"
- position = np.array(-4, dtype)
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testScalarString_EdgeCases(self, dtype, unit):
+ # Empty string
+ test_string = {
+ "BYTE": b"",
+ "UTF8_CHAR": u"".encode("utf-8"),
+ }[unit]
+ expected_value = b""
+ position = np.array(0, dtype)
length = np.array(3, dtype)
- expected_value = b"ell"
-
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- # Position is equal to the length of string.
- test_string = b""
+ # Full string
+ test_string = {
+ "BYTE": b"Hello",
+ "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"),
+ }[unit]
position = np.array(0, dtype)
- length = np.array(2, dtype)
- expected_value = b""
-
- substr_op = string_ops.substr(test_string, position, length)
+ length = np.array(5, dtype)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
- self.assertAllEqual(substr, expected_value)
-
- # Negative position magnitude is equal to the length of string.
- test_string = b"yo"
- position = np.array(-2, dtype)
- length = np.array(1, dtype)
- expected_value = b"y"
-
- substr_op = string_ops.substr(test_string, position, length)
+ self.assertAllEqual(substr, test_string)
+
+ # Full string (Negative)
+ test_string = {
+ "BYTE": b"Hello",
+ "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"),
+ }[unit]
+ position = np.array(-5, dtype)
+ length = np.array(5, dtype)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
- self.assertAllEqual(substr, expected_value)
-
- def _testVectorStrings(self, dtype):
- test_string = [b"Hello", b"World"]
- position = np.array(1, dtype)
- length = np.array(3, dtype)
- expected_value = [b"ell", b"orl"]
-
- substr_op = string_ops.substr(test_string, position, length)
+ self.assertAllEqual(substr, test_string)
+
+ # Length is larger in magnitude than a negative position
+ test_string = {
+ "BYTE": b"Hello",
+ "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"),
+ }[unit]
+ expected_string = {
+ "BYTE": b"ello",
+ "UTF8_CHAR": u"\xc3ll\U0001f604".encode("utf-8"),
+ }[unit]
+ position = np.array(-4, dtype)
+ length = np.array(5, dtype)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
- self.assertAllEqual(substr, expected_value)
-
- # Negative position.
- test_string = [b"Hello", b"World"]
- position = np.array(-4, dtype)
+ self.assertAllEqual(substr, expected_string)
+
+ @parameterized.parameters(
+ (np.int32, 1, "BYTE"),
+ (np.int64, 1, "BYTE"),
+ (np.int32, -4, "BYTE"),
+ (np.int64, -4, "BYTE"),
+ (np.int32, 1, "UTF8_CHAR"),
+ (np.int64, 1, "UTF8_CHAR"),
+ (np.int32, -4, "UTF8_CHAR"),
+ (np.int64, -4, "UTF8_CHAR"),
+ )
+ def testVectorStrings(self, dtype, pos, unit):
+ test_string = {
+ "BYTE": [b"Hello", b"World"],
+ "UTF8_CHAR": [x.encode("utf-8") for x in [u"H\xc3llo",
+ u"W\U0001f604rld"]],
+ }[unit]
+ expected_value = {
+ "BYTE": [b"ell", b"orl"],
+ "UTF8_CHAR": [x.encode("utf-8") for x in [u"\xc3ll", u"\U0001f604rl"]],
+ }[unit]
+ position = np.array(pos, dtype)
length = np.array(3, dtype)
- expected_value = [b"ell", b"orl"]
-
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- def _testMatrixStrings(self, dtype):
- test_string = [[b"ten", b"eleven", b"twelve"],
- [b"thirteen", b"fourteen", b"fifteen"],
- [b"sixteen", b"seventeen", b"eighteen"]]
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testMatrixStrings(self, dtype, unit):
+ test_string = {
+ "BYTE": [[b"ten", b"eleven", b"twelve"],
+ [b"thirteen", b"fourteen", b"fifteen"],
+ [b"sixteen", b"seventeen", b"eighteen"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n",
+ u"\xc6\u053c\u025bv\u025bn",
+ u"tw\u0c1dlv\u025b"]],
+ [x.encode("utf-8") for x in [u"He\xc3\xc3o",
+ u"W\U0001f604rld",
+ u"d\xfcd\xea"]]],
+ }[unit]
position = np.array(1, dtype)
length = np.array(4, dtype)
- expected_value = [[b"en", b"leve", b"welv"], [b"hirt", b"ourt", b"ifte"],
- [b"ixte", b"even", b"ight"]]
-
- substr_op = string_ops.substr(test_string, position, length)
+ expected_value = {
+ "BYTE": [[b"en", b"leve", b"welv"], [b"hirt", b"ourt", b"ifte"],
+ [b"ixte", b"even", b"ight"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d227n",
+ u"\u053c\u025bv\u025b",
+ u"w\u0c1dlv"]],
+ [x.encode("utf-8") for x in [u"e\xc3\xc3o",
+ u"\U0001f604rld",
+ u"\xfcd\xea"]]],
+ }[unit]
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- # Negative position
- test_string = [[b"ten", b"eleven", b"twelve"],
- [b"thirteen", b"fourteen", b"fifteen"],
- [b"sixteen", b"seventeen", b"eighteen"]]
- position = np.array(-2, dtype)
+ position = np.array(-3, dtype)
length = np.array(2, dtype)
- expected_value = [[b"en", b"en", b"ve"], [b"en", b"en", b"en"],
- [b"en", b"en", b"en"]]
-
- substr_op = string_ops.substr(test_string, position, length)
+ expected_value = {
+ "BYTE": [[b"te", b"ve", b"lv"], [b"ee", b"ee", b"ee"],
+ [b"ee", b"ee", b"ee"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227",
+ u"v\u025b", u"lv"]],
+ [x.encode("utf-8") for x in [u"\xc3\xc3", u"rl",
+ u"\xfcd"]]],
+ }[unit]
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- def _testElementWisePosLen(self, dtype):
- test_string = [[b"ten", b"eleven", b"twelve"],
- [b"thirteen", b"fourteen", b"fifteen"],
- [b"sixteen", b"seventeen", b"eighteen"]]
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testElementWisePosLen(self, dtype, unit):
+ test_string = {
+ "BYTE": [[b"ten", b"eleven", b"twelve"],
+ [b"thirteen", b"fourteen", b"fifteen"],
+ [b"sixteen", b"seventeen", b"eighteen"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n",
+ u"\xc6\u053c\u025bv\u025bn",
+ u"tw\u0c1dlv\u025b"]],
+ [x.encode("utf-8") for x in [u"He\xc3\xc3o",
+ u"W\U0001f604rld",
+ u"d\xfcd\xea"]],
+ [x.encode("utf-8") for x in [u"sixt\xea\xean",
+ u"se\U00010299enteen",
+ u"ei\U0001e920h\x86een"]]],
+ }[unit]
position = np.array([[1, -4, 3], [1, 2, -4], [-5, 2, 3]], dtype)
length = np.array([[2, 2, 4], [4, 3, 2], [5, 5, 5]], dtype)
- expected_value = [[b"en", b"ev", b"lve"], [b"hirt", b"urt", b"te"],
- [b"xteen", b"vente", b"hteen"]]
-
- substr_op = string_ops.substr(test_string, position, length)
+ expected_value = {
+ "BYTE": [[b"en", b"ev", b"lve"], [b"hirt", b"urt", b"te"],
+ [b"xteen", b"vente", b"hteen"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d227n",
+ u"\u025bv",
+ u"lv\u025b"]],
+ [x.encode("utf-8") for x in [u"e\xc3\xc3o",
+ u"rld",
+ u"d\xfc"]],
+ [x.encode("utf-8") for x in [u"xt\xea\xean",
+ u"\U00010299ente",
+ u"h\x86een"]]],
+ }[unit]
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- def _testBroadcast(self, dtype):
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testBroadcast(self, dtype, unit):
# Broadcast pos/len onto input string
- test_string = [[b"ten", b"eleven", b"twelve"],
- [b"thirteen", b"fourteen", b"fifteen"],
- [b"sixteen", b"seventeen", b"eighteen"],
- [b"nineteen", b"twenty", b"twentyone"]]
+ test_string = {
+ "BYTE": [[b"ten", b"eleven", b"twelve"],
+ [b"thirteen", b"fourteen", b"fifteen"],
+ [b"sixteen", b"seventeen", b"eighteen"],
+ [b"nineteen", b"twenty", b"twentyone"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n",
+ u"\xc6\u053c\u025bv\u025bn",
+ u"tw\u0c1dlv\u025b"]],
+ [x.encode("utf-8") for x in [u"th\xcdrt\xea\xean",
+ u"f\U0001f604urt\xea\xean",
+ u"f\xcd\ua09ctee\ua0e4"]],
+ [x.encode("utf-8") for x in [u"s\xcdxt\xea\xean",
+ u"se\U00010299enteen",
+ u"ei\U0001e920h\x86een"]],
+ [x.encode("utf-8") for x in [u"nineteen",
+ u"twenty",
+ u"twentyone"]]],
+ }[unit]
position = np.array([1, -4, 3], dtype)
length = np.array([1, 2, 3], dtype)
- expected_value = [[b"e", b"ev", b"lve"], [b"h", b"te", b"tee"],
- [b"i", b"te", b"hte"], [b"i", b"en", b"nty"]]
- substr_op = string_ops.substr(test_string, position, length)
+ expected_value = {
+ "BYTE": [[b"e", b"ev", b"lve"], [b"h", b"te", b"tee"],
+ [b"i", b"te", b"hte"], [b"i", b"en", b"nty"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d227",
+ u"\u025bv", u"lv\u025b"]],
+ [x.encode("utf-8") for x in [u"h", u"t\xea", u"tee"]],
+ [x.encode("utf-8") for x in [u"\xcd", u"te", u"h\x86e"]],
+ [x.encode("utf-8") for x in [u"i", u"en", u"nty"]]],
+ }[unit]
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
# Broadcast input string onto pos/len
- test_string = [b"thirteen", b"fourteen", b"fifteen"]
+ test_string = {
+ "BYTE": [b"thirteen", b"fourteen", b"fifteen"],
+ "UTF8_CHAR": [x.encode("utf-8") for x in [u"th\xcdrt\xea\xean",
+ u"f\U0001f604urt\xea\xean",
+ u"f\xcd\ua09ctee\ua0e4"]],
+ }[unit]
position = np.array([[1, -2, 3], [-3, 2, 1], [5, 5, -5]], dtype)
length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
- expected_value = [[b"hir", b"en", b"t"], [b"e", b"ur", b"ift"],
- [b"ee", b"ee", b"ft"]]
- substr_op = string_ops.substr(test_string, position, length)
+ expected_value = {
+ "BYTE": [[b"hir", b"en", b"t"], [b"e", b"ur", b"ift"],
+ [b"ee", b"ee", b"ft"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"h\xcdr", u"\xean", u"t"]],
+ [x.encode("utf-8") for x in [u"\xea", u"ur",
+ u"\xcd\ua09ct"]],
+ [x.encode("utf-8") for x in [u"\xea\xea", u"\xea\xea",
+ u"\ua09ct"]]],
+ }[unit]
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
# Test 1D broadcast
- test_string = b"thirteen"
- position = np.array([1, -5, 7], dtype)
+ test_string = {
+ "BYTE": b"thirteen",
+ "UTF8_CHAR": u"th\xcdrt\xea\xean".encode("utf-8"),
+ }[unit]
+ position = np.array([1, -4, 7], dtype)
length = np.array([3, 2, 1], dtype)
- expected_value = [b"hir", b"rt", b"n"]
- substr_op = string_ops.substr(test_string, position, length)
+ expected_value = {
+ "BYTE": [b"hir", b"te", b"n"],
+ "UTF8_CHAR": [x.encode("utf-8") for x in [u"h\xcdr", u"t\xea", u"n"]],
+ }[unit]
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- def _testBadBroadcast(self, dtype):
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testBadBroadcast(self, dtype, unit):
test_string = [[b"ten", b"eleven", b"twelve"],
[b"thirteen", b"fourteen", b"fifteen"],
[b"sixteen", b"seventeen", b"eighteen"]]
position = np.array([1, 2, -3, 4], dtype)
length = np.array([1, 2, 3, 4], dtype)
with self.assertRaises(ValueError):
- substr_op = string_ops.substr(test_string, position, length)
-
- def _testOutOfRangeError(self, dtype):
+ string_ops.substr(test_string, position, length, unit=unit)
+
+ @parameterized.parameters(
+ (np.int32, 6, "BYTE"),
+ (np.int64, 6, "BYTE"),
+ (np.int32, -6, "BYTE"),
+ (np.int64, -6, "BYTE"),
+ (np.int32, 6, "UTF8_CHAR"),
+ (np.int64, 6, "UTF8_CHAR"),
+ (np.int32, -6, "UTF8_CHAR"),
+ (np.int64, -6, "UTF8_CHAR"),
+ )
+ def testOutOfRangeError_Scalar(self, dtype, pos, unit):
# Scalar/Scalar
- test_string = b"Hello"
- position = np.array(7, dtype)
- length = np.array(3, dtype)
- substr_op = string_ops.substr(test_string, position, length)
- with self.cached_session():
- with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
-
- # Scalar/Scalar (with negative)
- test_string = b"Hello"
- position = np.array(-7, dtype)
+ test_string = {
+ "BYTE": b"Hello",
+ "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"),
+ }[unit]
+ position = np.array(pos, dtype)
length = np.array(3, dtype)
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
-
+ substr_op.eval()
+
+ @parameterized.parameters(
+ (np.int32, 4, "BYTE"),
+ (np.int64, 4, "BYTE"),
+ (np.int32, -4, "BYTE"),
+ (np.int64, -4, "BYTE"),
+ (np.int32, 4, "UTF8_CHAR"),
+ (np.int64, 4, "UTF8_CHAR"),
+ (np.int32, -4, "UTF8_CHAR"),
+ (np.int64, -4, "UTF8_CHAR"),
+ )
+ def testOutOfRangeError_VectorScalar(self, dtype, pos, unit):
# Vector/Scalar
- test_string = [b"good", b"good", b"bad", b"good"]
- position = np.array(4, dtype)
- length = np.array(1, dtype)
- substr_op = string_ops.substr(test_string, position, length)
- with self.cached_session():
- with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
-
- # Vector/Scalar (with negative)
- test_string = [b"good", b"good", b"bad", b"good"]
- position = np.array(-4, dtype)
+ test_string = {
+ "BYTE": [b"good", b"good", b"bad", b"good"],
+ "UTF8_CHAR": [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"b\xc3d",
+ u"g\xc3\xc3d"]],
+ }[unit]
+ position = np.array(pos, dtype)
length = np.array(1, dtype)
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
-
+ substr_op.eval()
+
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testOutOfRangeError_MatrixMatrix(self, dtype, unit):
# Matrix/Matrix
- test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"],
- [b"good", b"good", b"good"]]
+ test_string = {
+ "BYTE": [[b"good", b"good", b"good"], [b"good", b"good", b"bad"],
+ [b"good", b"good", b"good"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
+ u"g\xc3\xc3d"]],
+ [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
+ u"b\xc3d"]],
+ [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
+ u"g\xc3\xc3d"]]],
+ }[unit]
position = np.array([[1, 2, 3], [1, 2, 4], [1, 2, 3]], dtype)
length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
+ substr_op.eval()
# Matrix/Matrix (with negative)
- test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"],
- [b"good", b"good", b"good"]]
position = np.array([[1, 2, -3], [1, 2, -4], [1, 2, -3]], dtype)
length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
-
+ substr_op.eval()
+
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testOutOfRangeError_Broadcast(self, dtype, unit):
# Broadcast
- test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"]]
+ test_string = {
+ "BYTE": [[b"good", b"good", b"good"], [b"good", b"good", b"bad"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
+ u"g\xc3\xc3d"]],
+ [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
+ u"b\xc3d"]]],
+ }[unit]
position = np.array([1, 2, 4], dtype)
length = np.array([1, 2, 3], dtype)
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
+ substr_op.eval()
# Broadcast (with negative)
- test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"]]
position = np.array([-1, -2, -4], dtype)
length = np.array([1, 2, 3], dtype)
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
-
- def _testMismatchPosLenShapes(self, dtype):
- test_string = [[b"ten", b"eleven", b"twelve"],
- [b"thirteen", b"fourteen", b"fifteen"],
- [b"sixteen", b"seventeen", b"eighteen"]]
+ substr_op.eval()
+
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testMismatchPosLenShapes(self, dtype, unit):
+ test_string = {
+ "BYTE": [[b"ten", b"eleven", b"twelve"],
+ [b"thirteen", b"fourteen", b"fifteen"],
+ [b"sixteen", b"seventeen", b"eighteen"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n",
+ u"\xc6\u053c\u025bv\u025bn",
+ u"tw\u0c1dlv\u025b"]],
+ [x.encode("utf-8") for x in [u"th\xcdrt\xea\xean",
+ u"f\U0001f604urt\xea\xean",
+ u"f\xcd\ua09ctee\ua0e4"]],
+ [x.encode("utf-8") for x in [u"s\xcdxt\xea\xean",
+ u"se\U00010299enteen",
+ u"ei\U0001e920h\x86een"]]],
+ }[unit]
position = np.array([[1, 2, 3]], dtype)
length = np.array([2, 3, 4], dtype)
# Should fail: position/length have different rank
with self.assertRaises(ValueError):
- substr_op = string_ops.substr(test_string, position, length)
+ string_ops.substr(test_string, position, length)
position = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]], dtype)
length = np.array([[2, 3, 4]], dtype)
# Should fail: position/length have different dimensionality
with self.assertRaises(ValueError):
- substr_op = string_ops.substr(test_string, position, length)
-
- # Negative position.
- test_string = [[b"ten", b"eleven", b"twelve"],
- [b"thirteen", b"fourteen", b"fifteen"],
- [b"sixteen", b"seventeen", b"eighteen"]]
- position = np.array([[-1, -2, -3]], dtype)
- length = np.array([1, 2, 3], dtype)
- # Should fail: position/length have different rank
- with self.assertRaises(ValueError):
- substr_op = string_ops.substr(test_string, position, length)
-
- @parameterized.parameters(np.int32, np.int64)
- def testAll(self, dtype):
- self._testScalarString(dtype)
- self._testVectorStrings(dtype)
- self._testMatrixStrings(dtype)
- self._testElementWisePosLen(dtype)
- self._testBroadcast(dtype)
- self._testBadBroadcast(dtype)
- self._testOutOfRangeError(dtype)
- self._testMismatchPosLenShapes(dtype)
+ string_ops.substr(test_string, position, length)
def testWrongDtype(self):
with self.cached_session():
@@ -300,6 +478,11 @@ class SubstrOpTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(TypeError):
string_ops.substr(b"test", 3, 1.0)
+ def testInvalidUnit(self):
+ with self.cached_session():
+ with self.assertRaises(ValueError):
+ string_ops.substr(b"test", 3, 1, unit="UTF8")
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/where_op_test.py b/tensorflow/python/kernel_tests/where_op_test.py
index 29fb002ef4..04ac589432 100644
--- a/tensorflow/python/kernel_tests/where_op_test.py
+++ b/tensorflow/python/kernel_tests/where_op_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
@@ -160,7 +161,7 @@ class WhereBenchmark(test.Benchmark):
x = random_ops.random_uniform((m, n), dtype=dtypes.float32) <= p
v = resource_variable_ops.ResourceVariable(x)
op = array_ops.where(v)
- with session.Session() as sess:
+ with session.Session(config=benchmark.benchmark_config()) as sess:
v.initializer.run()
r = self.run_op_benchmark(sess, op, min_iters=100, name=name)
gb_processed_input = m * n / 1.0e9
@@ -186,7 +187,7 @@ class WhereBenchmark(test.Benchmark):
y = resource_variable_ops.ResourceVariable(y_gen)
c = resource_variable_ops.ResourceVariable(c_gen)
op = array_ops.where(c, x, y)
- with session.Session() as sess:
+ with session.Session(config=benchmark.benchmark_config()) as sess:
x.initializer.run()
y.initializer.run()
c.initializer.run()
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 9f5149d5ac..e3e4d5f910 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1407,8 +1407,13 @@ def transpose(a, perm=None, name="transpose", conjugate=False):
gen_array_ops.conjugate_transpose
if (conjugate and a.dtype.is_complex) else gen_array_ops.transpose)
if perm is None:
- rank = gen_array_ops.rank(a)
- perm = (rank - 1) - gen_math_ops._range(0, rank, 1)
+ a = ops.convert_to_tensor(a, name="a")
+ if not a.get_shape().ndims:
+ rank = gen_array_ops.rank(a)
+ perm = (rank - 1) - gen_math_ops._range(0, rank, 1)
+ else:
+ rank = a.get_shape().ndims
+ perm = (rank - 1) - np.arange(rank)
ret = transpose_fn(a, perm, name=name)
# NOTE(mrry): Setting the shape explicitly because
# reverse is not handled by the shape function.
@@ -2716,16 +2721,22 @@ def batch_gather(params, indices, name=None):
params = ops.convert_to_tensor(params, name="params")
indices_shape = shape(indices)
params_shape = shape(params)
+
ndims = indices.shape.ndims
if ndims is None:
raise ValueError("batch_gather does not allow indices with unknown "
"shape.")
batch_indices = indices
- accum_dim_value = 1
+ indices_dtype = indices.dtype.base_dtype
+ accum_dim_value = ones((), dtype=indices_dtype)
+ # Use correct type for offset index computation
+ casted_params_shape = gen_math_ops.cast(params_shape, indices_dtype)
for dim in range(ndims-1, 0, -1):
- dim_value = params_shape[dim-1]
- accum_dim_value *= params_shape[dim]
- dim_indices = gen_math_ops._range(0, dim_value, 1)
+ dim_value = casted_params_shape[dim-1]
+ accum_dim_value *= casted_params_shape[dim]
+ start = zeros((), dtype=indices_dtype)
+ step = ones((), dtype=indices_dtype)
+ dim_indices = gen_math_ops._range(start, dim_value, step)
dim_indices *= accum_dim_value
dim_shape = stack([1] * (dim - 1) + [dim_value] + [1] * (ndims - dim),
axis=0)
diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py
index 195ad11c71..c9aa4d4889 100644
--- a/tensorflow/python/ops/cond_v2_impl.py
+++ b/tensorflow/python/ops/cond_v2_impl.py
@@ -282,9 +282,10 @@ def _resolve_grad_inputs(cond_graph, grad_graph):
as is.
2. Tensors in the forward pass graph. These tensors may not be "live"
when the gradient is being computed. We replace such references by their
- corresponding tensor in the least common ancestor graph of `grad_graph` and
- `cond_graph`. Since we export intermediate tensors for all branch
- functions, this is always possible.
+ corresponding tensor in `cond_graph.outer_graph`. In the case of nested
+ control flow or functions, the gradient logic handling
+ `grad_graph.outer_graph` will make sure the tensor from
+ `cond_graph.outer_graph` is also correctly captured.
Args:
cond_graph: function.FuncGraph. The forward-pass function.
@@ -296,24 +297,23 @@ def _resolve_grad_inputs(cond_graph, grad_graph):
new_inputs = []
for t in grad_graph.external_captures:
+ # `t` must either be in `grad_graph.outer_graph` or in the forward
+ # `cond_graph`.
if t.graph != grad_graph.outer_graph:
- # `t` is a tensor in `cond_graph` or one of its ancestors. We bubble this
- # tensor to the least common ancestor of the `cond_graph` and
- # `grad_graph` so that it is "in-scope" for `grad_graph`.
- # TODO(srbs): `_is_ancestor` calls may be expensive. Compute the least
- # common ancestor once and re-use.
- assert _is_ancestor(cond_graph, t.graph)
- while not _is_ancestor(grad_graph, t.graph):
- assert isinstance(t.graph, _function.FuncGraph)
- if t in t.graph.internal_captures:
- # TODO(srbs): Consider building a map of internal_captures ->
- # external_captures instead of searching for `t` twice.
- t = t.graph.external_captures[t.graph.internal_captures.index(t)]
- else:
- # Note: All intermediate tensors are output by the If op.
- # TODO(srbs): .index() calls may be expensive. Optimize.
- t = t.graph._if.outputs[t.graph.outputs.index(t)]
- assert _is_ancestor(grad_graph, t.graph)
+ assert t.graph == cond_graph
+ # `internal_captures` are not treated as intermediates and hence not added
+ # to If op outputs. So we get the outer tensor corresponding to those
+ # from the list of `external_captures`.
+ try:
+ t = t.graph._if.outputs[t.graph.outputs.index(t)]
+ except ValueError:
+ index = t.graph.internal_captures.index(t)
+ t = t.graph.external_captures[index]
+
+ # Note: We rely on the capturing logic of the gradient If op graph to
+ # correctly capture the tensors in `cond_graph.outer_graph`. Both cond_v2
+ # and while_v2 handle this while building their gradient functions.
+ assert t.graph == cond_graph.outer_graph
new_inputs.append(t)
return new_inputs
@@ -492,11 +492,3 @@ def _get_output_shapes(true_graph_outputs, false_graph_outputs):
for t_out, f_out in zip(true_graph_outputs, false_graph_outputs)
]
return output_shapes
-
-
-def _is_ancestor(graph, maybe_ancestor):
- if maybe_ancestor == graph:
- return True
- if isinstance(graph, _function.FuncGraph):
- return _is_ancestor(graph.outer_graph, maybe_ancestor)
- return False
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index f779c3d273..5bc217d355 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -1333,6 +1333,9 @@ class ControlFlowState(object):
"""
if util.IsLoopSwitch(op):
return None
+ if op.graph._building_function: # pylint: disable=protected-access
+ # The optimization here is tricky to apply to functions
+ return array_ops.zeros_like(op.outputs[index])
dead_branch = util.IsSwitch(op)
forward_ctxt = _GetWhileContext(op)
grad_state = self._map.get(forward_ctxt)
diff --git a/tensorflow/python/ops/control_flow_ops_benchmark.py b/tensorflow/python/ops/control_flow_ops_benchmark.py
new file mode 100644
index 0000000000..9ba5ff2c0f
--- /dev/null
+++ b/tensorflow/python/ops/control_flow_ops_benchmark.py
@@ -0,0 +1,122 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Benchmark for control flow ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+from tensorflow.python.client import session
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test
+
+
+class CondWithManyIntermediatesBenchmark(test.Benchmark):
+ """Checks the runtime performance of outputting all intermediates."""
+
+ NUM_INTERMEDIATES = 1000
+ NUM_ITERS = 500
+ NUM_WARM_UP_ITERS = 50
+
+ def _create_cond(self, x):
+
+ def branch_fn():
+ # Use a random value so the adds can't be constant folded.
+ return x + sum(random_ops.random_normal([])
+ for _ in range(self.NUM_INTERMEDIATES))
+
+ # Use a dynamic predicate to make sure the cond isn't constant folded.
+ return control_flow_ops.cond(math_ops.not_equal(x, -1),
+ branch_fn, lambda: 0.0)
+
+ def _benchmark_defun(self):
+ """Benchmarks cond in a defun."""
+
+ @function.defun
+ def cond_fn(x):
+ return self._create_cond(x)
+
+ # Warm up
+ for _ in range(self.NUM_WARM_UP_ITERS):
+ cond_fn(0.0)
+
+ start_time = time.time()
+
+ for _ in range(self.NUM_ITERS):
+ cond_fn(0.0)
+
+ self.report_benchmark(
+ wall_time=time.time() - start_time,
+ iters=self.NUM_ITERS)
+
+ def _benchmark_graph(self):
+ """Benchmarks cond in legacy graph mode."""
+ with context.graph_mode():
+ with ops.Graph().as_default():
+ x = array_ops.placeholder(dtypes.float32)
+ cond_val = self._create_cond(x)
+
+ with session.Session() as sess:
+ cond_fn = sess.make_callable(cond_val, [x])
+
+ # Warm up
+ for _ in range(self.NUM_WARM_UP_ITERS):
+ cond_fn(0.0)
+
+ start_time = time.time()
+
+ for _ in range(self.NUM_ITERS):
+ cond_fn(0.0)
+
+ self.report_benchmark(
+ wall_time=time.time() - start_time,
+ iters=self.NUM_ITERS)
+
+ def benchmark_cond_v1_defun(self):
+ old_val = control_flow_ops.ENABLE_COND_V2
+ control_flow_ops.ENABLE_COND_V2 = False
+ self._benchmark_defun()
+ control_flow_ops.ENABLE_COND_V2 = old_val
+
+ def benchmark_cond_v2_defun(self):
+ old_val = control_flow_ops.ENABLE_COND_V2
+ control_flow_ops.ENABLE_COND_V2 = True
+ self._benchmark_defun()
+ control_flow_ops.ENABLE_COND_V2 = old_val
+
+ def benchmark_cond_v1_graph(self):
+ old_val = control_flow_ops.ENABLE_COND_V2
+ control_flow_ops.ENABLE_COND_V2 = False
+ self._benchmark_graph()
+ control_flow_ops.ENABLE_COND_V2 = old_val
+
+ def benchmark_cond_v2_graph(self):
+ old_val = control_flow_ops.ENABLE_COND_V2
+ control_flow_ops.ENABLE_COND_V2 = True
+ self._benchmark_graph()
+ control_flow_ops.ENABLE_COND_V2 = old_val
+
+if __name__ == "__main__":
+ ops.enable_eager_execution()
+ test.main()
diff --git a/tensorflow/python/ops/custom_gradient.py b/tensorflow/python/ops/custom_gradient.py
index d7834ba350..bfe23834b7 100644
--- a/tensorflow/python/ops/custom_gradient.py
+++ b/tensorflow/python/ops/custom_gradient.py
@@ -18,9 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import tape as tape_lib
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
@@ -33,6 +35,45 @@ from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
+def copy_handle_data(source_t, target_t):
+ """Copies HandleData for variant and resource type tensors if available.
+
+ The CppShapeInferenceResult::HandleData proto contains information about the
+ shapes and types of the element tensors of resource/variant type tensors.
+ We need to copy this across function boundaries, i.e., when capturing a
+ placeholder or when returning a function tensor as output. If we don't do this
+ the element tensors will have unknown shapes, e.g., if a TensorList variant
+ tensor is captured as a placeholder, elements popped from that list would have
+ unknown shape.
+
+ Args:
+ source_t: The tensor to copy HandleData from.
+ target_t: The tensor to copy HandleData to.
+ """
+ if (target_t.dtype == dtypes.resource or
+ target_t.dtype == dtypes.variant):
+ if isinstance(source_t, ops.EagerTensor):
+ handle_data = source_t._handle_data # pylint: disable=protected-access
+ else:
+ handle_data = resource_variable_ops.get_resource_handle_data(source_t)
+ if handle_data is not None and handle_data.is_set:
+ # pylint: disable=protected-access
+ pywrap_tensorflow.SetHandleShapeAndType(target_t.graph._c_graph,
+ target_t._as_tf_output(),
+ handle_data.SerializeToString())
+ # pylint: enable=protected-access
+ # Ensure that shapes and dtypes are propagated.
+ shapes, types = zip(*[(pair.shape, pair.dtype)
+ for pair in handle_data.shape_and_type])
+ ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
+ shapes = [[d.size for d in s.dim]
+ if not s.unknown_rank else None for s in shapes]
+ pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
+ target_t._op._graph._c_graph, # pylint: disable=protected-access
+ target_t._as_tf_output(), # pylint: disable=protected-access
+ shapes, ranks, types)
+
+
@tf_export("custom_gradient")
def custom_gradient(f):
"""Decorator to define a function with a custom gradient.
@@ -180,8 +221,11 @@ def _graph_mode_decorator(f, *args, **kwargs):
input_grads = nest.flatten(input_grads)
return ([None] * len(flat_result)) + input_grads + variable_grads
+ original_tensors = all_tensors
with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
all_tensors = array_ops.identity_n(all_tensors)
+ for ot, t in zip(original_tensors, all_tensors):
+ copy_handle_data(ot, t)
return nest.pack_sequence_as(
structure=result, flat_sequence=all_tensors[:len(flat_result)])
diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py
index 1dc666e78b..794465b10e 100644
--- a/tensorflow/python/ops/gradients.py
+++ b/tensorflow/python/ops/gradients.py
@@ -25,4 +25,5 @@ from tensorflow.python.ops.custom_gradient import custom_gradient
from tensorflow.python.ops.gradients_impl import AggregationMethod
from tensorflow.python.ops.gradients_impl import gradients
from tensorflow.python.ops.gradients_impl import hessians
+from tensorflow.python.ops.gradients_impl import UnconnectedGradients
# pylint: enable=unused-import
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 056015d6b6..6909fcaed5 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import collections
import contextlib
+import enum # pylint: disable=g-bad-import-order
import sys
import warnings
@@ -537,6 +538,26 @@ def _Consumers(t, func_graphs):
return consumers
+@tf_export("UnconnectedGradients")
+class UnconnectedGradients(enum.Enum):
+ """Controls how gradient computation behaves when y does not depend on x.
+
+ The gradient of y with respect to x can be zero in two different ways: there
+ could be no differentiable path in the graph connecting x to y (and so we can
+ statically prove that the gradient is zero) or it could be that runtime values
+ of tensors in a particular execution lead to a gradient of zero (say, if a
+ relu unit happens to not be activated). To allow you to distinguish between
+ these two cases you can choose what value gets returned for the gradient when
+ there is no path in the graph from x to y:
+
+ * `NONE`: Indicates that [None] will be returned if there is no path from x
+ to y
+ * `ZERO`: Indicates that a zero tensor will be returned in the shape of x.
+ """
+ NONE = "none"
+ ZERO = "zero"
+
+
@tf_export("gradients")
def gradients(ys,
xs,
@@ -545,7 +566,8 @@ def gradients(ys,
colocate_gradients_with_ops=False,
gate_gradients=False,
aggregation_method=None,
- stop_gradients=None):
+ stop_gradients=None,
+ unconnected_gradients=UnconnectedGradients.NONE):
"""Constructs symbolic derivatives of sum of `ys` w.r.t. x in `xs`.
`ys` and `xs` are each a `Tensor` or a list of tensors. `grad_ys`
@@ -596,6 +618,23 @@ def gradients(ys,
All integer tensors are considered constant with respect to all `xs`, as if
they were included in `stop_gradients`.
+ `unconnected_gradients` determines the value returned for each x in xs if it
+ is unconnected in the graph to ys. By default this is None to safeguard
+ against errors. MAthematically these gradients are zero which can be requested
+ using the `'zero'` option. `tf.UnconnectedGradients` provides the
+ following options and behaviors:
+
+ ```python
+ a = tf.ones([1, 2])
+ b = tf.ones([3, 1])
+ g1 = tf.gradients([b], [a], unnconnected_gradients='none')
+ sess.run(g1) # [None]
+
+ g2 = tf.gradients([b], [a], unconnected_gradients='zero')
+ sess.run(g2) # [array([[0., 0.]], dtype=float32)]
+ ```
+
+
Args:
ys: A `Tensor` or list of tensors to be differentiated.
xs: A `Tensor` or list of tensors to be used for differentiation.
@@ -611,6 +650,10 @@ def gradients(ys,
Accepted values are constants defined in the class `AggregationMethod`.
stop_gradients: Optional. A `Tensor` or list of tensors not to differentiate
through.
+ unconnected_gradients: Optional. Specifies the gradient value returned when
+ the given input tensors are unconnected. Accepted values are constants
+ defined in the class `tf.UnconnectedGradients` and the default value is
+ `none`.
Returns:
A list of `sum(dy/dx)` for each x in `xs`.
@@ -627,7 +670,8 @@ def gradients(ys,
# mutating new ops.
with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access
return _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
- gate_gradients, aggregation_method, stop_gradients)
+ gate_gradients, aggregation_method, stop_gradients,
+ unconnected_gradients)
def _GradientsHelper(ys,
@@ -638,6 +682,7 @@ def _GradientsHelper(ys,
gate_gradients=False,
aggregation_method=None,
stop_gradients=None,
+ unconnected_gradients=UnconnectedGradients.NONE,
src_graph=None):
"""Implementation of gradients()."""
if context.executing_eagerly():
@@ -645,6 +690,11 @@ def _GradientsHelper(ys,
"is enabled. Use tf.GradientTape instead.")
if src_graph is None:
src_graph = ops.get_default_graph()
+ try:
+ unconnected_gradients = UnconnectedGradients(unconnected_gradients)
+ except ValueError:
+ raise ValueError(
+ "Unknown value for unconnected_gradients: %r" % unconnected_gradients)
# If src_graph is a _FuncGraph (i.e. a function body), gather it and all
# ancestor graphs. This is necessary for correctly handling captured values.
@@ -750,23 +800,21 @@ def _GradientsHelper(ys,
# pylint: enable=protected-access
has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
if has_out_grads and (op not in stop_ops):
- if is_func_call:
- if is_partitioned_call:
- func_call = src_graph._get_function( # pylint: disable=protected-access
- compat.as_bytes(op.get_attr("f").name))
+ try:
+ grad_fn = ops.get_gradient_function(op)
+ except LookupError:
+ if is_func_call:
+ if is_partitioned_call:
+ func_call = src_graph._get_function( # pylint: disable=protected-access
+ compat.as_bytes(op.get_attr("f").name))
+ else:
+ func_call = src_graph._get_function(op.type) # pylint: disable=protected-access
+ # Note that __defun is not set if the graph is
+ # imported. If it's set, we prefer to access the original
+ # defun.
+ func_call = getattr(op, "__defun", func_call)
+ grad_fn = func_call.python_grad_func
else:
- func_call = src_graph._get_function(op.type) # pylint: disable=protected-access
- # Note that __defun is not set if the graph is
- # imported. If it's set, we prefer to access the original
- # defun.
- func_call = getattr(op, "__defun", func_call)
- grad_fn = func_call.python_grad_func
- else:
- # A grad_fn must be defined, either as a function or as None
- # for ops that do not have gradients.
- try:
- grad_fn = ops.get_gradient_function(op)
- except LookupError:
raise LookupError(
"No gradient defined for operation '%s' (op type: %s)" %
(op.name, op.type))
@@ -856,7 +904,7 @@ def _GradientsHelper(ys,
if loop_state:
loop_state.PostProcessing()
- return [_GetGrad(grads, x) for x in xs]
+ return [_GetGrad(grads, x, unconnected_gradients) for x in xs]
def _HasAnyNotNoneGrads(grads, op):
@@ -924,12 +972,19 @@ def _SetGrad(grads, t, grad):
op_grads[t.value_index] = grad
-def _GetGrad(grads, t):
+def _GetGrad(grads, t, unconnected_gradients):
"""Gets gradient for tensor "t"."""
op = t.op
op_grads = grads.get(op)
if not op_grads:
- return None
+ if unconnected_gradients == UnconnectedGradients.ZERO:
+ return array_ops.zeros_like(t)
+ elif unconnected_gradients == UnconnectedGradients.NONE:
+ return None
+ else:
+ raise ValueError(
+ "Unknown value for unconnected_gradients: %r" % unconnected_gradients)
+
t_grad = op_grads[t.value_index]
assert not isinstance(
t_grad, list), ("gradients list should have been aggregated by now.")
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index 3c9b7a01c7..c93e2493ee 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -350,6 +350,40 @@ class GradientsTest(test_util.TensorFlowTestCase):
for a, b in zip(npgrad1, npgrad2):
np.testing.assert_allclose(a, b)
+ def testUnconnectedGradientsNoneUnconnectedGradients(self):
+ with ops.Graph().as_default():
+ x = constant(1.0, shape=[2, 2])
+ y = constant(3.0, shape=[3, 1])
+ grad = gradients.gradients(
+ [y], [x], unconnected_gradients="none")
+ self.assertIsNone(grad[0])
+
+ def testUnconnectedGradientsZerosUnconnectedGradients(self):
+ with ops.Graph().as_default():
+ x = constant(1.0, shape=[2, 2])
+ y = constant(3.0, shape=[3, 1])
+ grads = gradients.gradients(
+ [y], [x], unconnected_gradients="zero")
+ with self.cached_session() as sess:
+ self.assertAllEqual([[0.0, 0.0], [0.0, 0.0]], sess.run(grads)[0])
+
+ def testUnconnectedGradientsZeroConnectedGradients(self):
+ with ops.Graph().as_default():
+ x = constant(1.0)
+ y = x * 3.0
+ grad = gradients.gradients(
+ [y], [x], unconnected_gradients="zero")
+ with self.cached_session() as sess:
+ self.assertEquals(3.0, sess.run(grad)[0])
+
+ def testUnknownUnconnectedGradientsValueGiven(self):
+ with ops.Graph().as_default():
+ x = constant(1.0)
+ y = constant(1.0)
+ with self.assertRaisesRegexp(
+ ValueError, "Unknown value for unconnected_gradients: 'nonsense'"):
+ gradients.gradients([y], [x], unconnected_gradients="nonsense")
+
class FunctionGradientsTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 35fdee4fad..ff86df6346 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -602,20 +602,19 @@ class AdjustHueBenchmark(test.Benchmark):
if cpu_count is not None:
config.inter_op_parallelism_threads = 1
config.intra_op_parallelism_threads = cpu_count
- with session.Session("", graph=ops.Graph(), config=config) as sess:
- with ops.device(device):
- inputs = variables.Variable(
- random_ops.random_uniform(image_shape, dtype=dtypes.float32) * 255,
- trainable=False,
- dtype=dtypes.float32)
- delta = constant_op.constant(0.1, dtype=dtypes.float32)
- outputs = image_ops.adjust_hue(inputs, delta)
- run_op = control_flow_ops.group(outputs)
- sess.run(variables.global_variables_initializer())
- for i in xrange(warmup_rounds + benchmark_rounds):
- if i == warmup_rounds:
- start = time.time()
- sess.run(run_op)
+ with self.benchmark_session(config=config, device=device) as sess:
+ inputs = variables.Variable(
+ random_ops.random_uniform(image_shape, dtype=dtypes.float32) * 255,
+ trainable=False,
+ dtype=dtypes.float32)
+ delta = constant_op.constant(0.1, dtype=dtypes.float32)
+ outputs = image_ops.adjust_hue(inputs, delta)
+ run_op = control_flow_ops.group(outputs)
+ sess.run(variables.global_variables_initializer())
+ for i in xrange(warmup_rounds + benchmark_rounds):
+ if i == warmup_rounds:
+ start = time.time()
+ sess.run(run_op)
end = time.time()
step_time = (end - start) / benchmark_rounds
tag = device + "_%s" % (cpu_count if cpu_count is not None else "_all")
@@ -646,21 +645,20 @@ class AdjustSaturationBenchmark(test.Benchmark):
if cpu_count is not None:
config.inter_op_parallelism_threads = 1
config.intra_op_parallelism_threads = cpu_count
- with session.Session("", graph=ops.Graph(), config=config) as sess:
- with ops.device(device):
- inputs = variables.Variable(
- random_ops.random_uniform(image_shape, dtype=dtypes.float32) * 255,
- trainable=False,
- dtype=dtypes.float32)
- delta = constant_op.constant(0.1, dtype=dtypes.float32)
- outputs = image_ops.adjust_saturation(inputs, delta)
- run_op = control_flow_ops.group(outputs)
- sess.run(variables.global_variables_initializer())
- for _ in xrange(warmup_rounds):
- sess.run(run_op)
- start = time.time()
- for _ in xrange(benchmark_rounds):
- sess.run(run_op)
+ with self.benchmark_session(config=config, device=device) as sess:
+ inputs = variables.Variable(
+ random_ops.random_uniform(image_shape, dtype=dtypes.float32) * 255,
+ trainable=False,
+ dtype=dtypes.float32)
+ delta = constant_op.constant(0.1, dtype=dtypes.float32)
+ outputs = image_ops.adjust_saturation(inputs, delta)
+ run_op = control_flow_ops.group(outputs)
+ sess.run(variables.global_variables_initializer())
+ for _ in xrange(warmup_rounds):
+ sess.run(run_op)
+ start = time.time()
+ for _ in xrange(benchmark_rounds):
+ sess.run(run_op)
end = time.time()
step_time = (end - start) / benchmark_rounds
tag = device + "_%s" % (cpu_count if cpu_count is not None else "_all")
@@ -699,7 +697,7 @@ class ResizeBilinearBenchmark(test.Benchmark):
deps = [resize_op]
benchmark_op = control_flow_ops.group(*deps)
- with session.Session() as sess:
+ with self.benchmark_session() as sess:
sess.run(variables.global_variables_initializer())
results = self.run_op_benchmark(
sess,
@@ -747,7 +745,7 @@ class ResizeBicubicBenchmark(test.Benchmark):
deps = [resize_op]
benchmark_op = control_flow_ops.group(*deps)
- with session.Session() as sess:
+ with self.benchmark_session() as sess:
sess.run(variables.global_variables_initializer())
results = self.run_op_benchmark(
sess,
@@ -804,7 +802,7 @@ class ResizeAreaBenchmark(test.Benchmark):
deps = [resize_op]
benchmark_op = control_flow_ops.group(*deps)
- with session.Session() as sess:
+ with self.benchmark_session() as sess:
sess.run(variables.global_variables_initializer())
results = self.run_op_benchmark(
sess,
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index e1a01ab4c3..902653befc 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -389,6 +389,21 @@ def _Relu6GradGrad(op, grad):
array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
+@ops.RegisterGradient("LeakyRelu")
+def _LeakyReluGrad(op, grad):
+ x = op.inputs[0]
+ alpha = op.get_attr("alpha")
+ return gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha)
+
+
+@ops.RegisterGradient("LeakyReluGrad")
+def _LeakyReluGradGrad(op, grad):
+ x = op.inputs[1]
+ alpha = op.get_attr("alpha")
+ return (gen_nn_ops.leaky_relu_grad(grad, x, alpha=alpha),
+ array_ops.zeros(shape=array_ops.shape(x), dtype=x.dtype))
+
+
@ops.RegisterGradient("Elu")
def _EluGrad(op, grad):
return gen_nn_ops.elu_grad(grad, op.outputs[0])
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 1fbe31a098..04962da7f7 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -22,6 +22,7 @@ import numbers
import numpy as np
+from tensorflow.python.compat import compat
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_util
@@ -1602,6 +1603,8 @@ def leaky_relu(features, alpha=0.2, name=None):
features = ops.convert_to_tensor(features, name="features")
if features.dtype.is_integer:
features = math_ops.to_float(features)
+ if compat.forward_compatible(2018, 11, 1):
+ return gen_nn_ops.leaky_relu(features, alpha=alpha, name=name)
alpha = ops.convert_to_tensor(alpha, dtype=features.dtype, name="alpha")
return math_ops.maximum(alpha * features, features, name=name)
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index ff50fe0d09..a2da6412ed 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -217,21 +217,21 @@ def _features_to_raw_params(features, types):
feature = features[key]
if isinstance(feature, VarLenFeature):
if VarLenFeature not in types:
- raise ValueError("Unsupported VarLenFeature %s." % feature)
+ raise ValueError("Unsupported VarLenFeature %s." % (feature,))
if not feature.dtype:
raise ValueError("Missing type for feature %s." % key)
sparse_keys.append(key)
sparse_types.append(feature.dtype)
elif isinstance(feature, SparseFeature):
if SparseFeature not in types:
- raise ValueError("Unsupported SparseFeature %s." % feature)
+ raise ValueError("Unsupported SparseFeature %s." % (feature,))
if not feature.index_key:
raise ValueError(
- "Missing index_key for SparseFeature %s." % feature)
+ "Missing index_key for SparseFeature %s." % (feature,))
if not feature.value_key:
raise ValueError(
- "Missing value_key for SparseFeature %s." % feature)
+ "Missing value_key for SparseFeature %s." % (feature,))
if not feature.dtype:
raise ValueError("Missing type for feature %s." % key)
index_keys = feature.index_key
@@ -260,7 +260,7 @@ def _features_to_raw_params(features, types):
sparse_types.append(feature.dtype)
elif isinstance(feature, FixedLenFeature):
if FixedLenFeature not in types:
- raise ValueError("Unsupported FixedLenFeature %s." % feature)
+ raise ValueError("Unsupported FixedLenFeature %s." % (feature,))
if not feature.dtype:
raise ValueError("Missing type for feature %s." % key)
if feature.shape is None:
@@ -281,7 +281,8 @@ def _features_to_raw_params(features, types):
dense_defaults[key] = feature.default_value
elif isinstance(feature, FixedLenSequenceFeature):
if FixedLenSequenceFeature not in types:
- raise ValueError("Unsupported FixedLenSequenceFeature %s." % feature)
+ raise ValueError("Unsupported FixedLenSequenceFeature %s." % (
+ feature,))
if not feature.dtype:
raise ValueError("Missing type for feature %s." % key)
if feature.shape is None:
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index 0812f901a2..f26388efea 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -347,6 +347,22 @@ def string_length(input, name=None, unit="BYTE"):
string_length.__doc__ = gen_string_ops.string_length.__doc__
+@tf_export("substr")
+@deprecation.deprecated(None, "Use `tf.strings.substr` instead of `tf.substr`.")
+def substr_deprecated(input, pos, len, name=None, unit="BYTE"):
+ return substr(input, pos, len, name=name, unit=unit)
+
+substr_deprecated.__doc__ = gen_string_ops.substr.__doc__
+
+
+@tf_export("strings.substr")
+def substr(input, pos, len, name=None, unit="BYTE"):
+ return gen_string_ops.substr(input, pos, len, unit=unit, name=name)
+
+
+substr.__doc__ = gen_string_ops.substr.__doc__
+
+
ops.NotDifferentiable("RegexReplace")
ops.NotDifferentiable("StringToHashBucket")
ops.NotDifferentiable("StringToHashBucketFast")
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index 8e88a84d60..0419656143 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -37,6 +37,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2_impl as cond_v2
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
+from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import list_ops
@@ -580,7 +581,7 @@ def _check_shapes_compat(output_tensors, shape_invariants, input_tensors):
def _copy_handle_data(src_tensors, tgt_tensors):
for src_t, tgt_t in zip(src_tensors, tgt_tensors):
- function._copy_handle_data(src_t, tgt_t)
+ custom_gradient.copy_handle_data(src_t, tgt_t)
# TODO(srbs): Move to common utils for cond_v2 and while_v2.
diff --git a/tensorflow/python/platform/benchmark.py b/tensorflow/python/platform/benchmark.py
index fa17b17d10..4f7abb311a 100644
--- a/tensorflow/python/platform/benchmark.py
+++ b/tensorflow/python/platform/benchmark.py
@@ -27,6 +27,7 @@ import time
import six
from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.core.util import test_log_pb2
from tensorflow.python.client import timeline
from tensorflow.python.platform import app
@@ -182,6 +183,19 @@ class Benchmark(six.with_metaclass(_BenchmarkRegistrar, object)):
throughput=throughput, extras=extras)
+@tf_export("test.benchmark_config")
+def benchmark_config():
+ """Returns a tf.ConfigProto for disabling the dependency optimizer.
+
+ Returns:
+ A TensorFlow ConfigProto object.
+ """
+ config = config_pb2.ConfigProto()
+ config.graph_options.rewrite_options.dependency_optimization = (
+ rewriter_config_pb2.RewriterConfig.OFF)
+ return config
+
+
@tf_export("test.Benchmark")
class TensorFlowBenchmark(Benchmark):
"""Abstract class that provides helpers for TensorFlow benchmarks."""
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 61e0abbfcb..adbce95c6f 100755
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -209,6 +209,7 @@ limitations under the License.
SWIG_fail;
} else {
int num_outputs = $1->size();
+ Py_CLEAR($result);
$result = PyList_New(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
PyObject *output;
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index 82f0e3be52..a479f38165 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -195,8 +195,12 @@ class Scaffold(object):
default_ready_op)
if self._ready_for_local_init_op is None:
def default_ready_for_local_init_op():
- return variables.report_uninitialized_variables(
- variables.global_variables())
+ return array_ops.concat([
+ variables.report_uninitialized_variables(
+ variables.global_variables()),
+ resources.report_uninitialized_resources(
+ resources.shared_resources())
+ ], 0)
self._ready_for_local_init_op = Scaffold.get_or_default(
'ready_for_local_init_op', ops.GraphKeys.READY_FOR_LOCAL_INIT_OP,
default_ready_for_local_init_op)
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
index 041266da3e..89bfcaf4ad 100644
--- a/tensorflow/python/training/moving_averages.py
+++ b/tensorflow/python/training/moving_averages.py
@@ -25,6 +25,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import slot_creator
from tensorflow.python.util.tf_export import tf_export
@@ -36,9 +37,8 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
The moving average of 'variable' updated with 'value' is:
variable * decay + value * (1 - decay)
- The returned Operation sets 'variable' to the newly computed moving average.
-
- The new value of 'variable' can be set with the 'AssignSub' op as:
+ The returned Operation sets 'variable' to the newly computed moving average,
+ by performing this subtraction:
variable -= (1 - decay) * (variable - value)
Since variables that are initialized to a `0` value will be `0` biased,
@@ -50,7 +50,7 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
The names of the debias shadow variables, by default, include both the scope
they were created in and the scope of the variables they debias. They are also
- given a uniqifying-suffix.
+ given a uniquifying-suffix.
E.g.:
@@ -58,8 +58,8 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
with tf.variable_scope('scope1'):
with tf.variable_scope('scope2'):
var = tf.get_variable('foo')
- tf.assign_moving_average(var, 0.0, 1.0)
- tf.assign_moving_average(var, 0.0, 0.9)
+ update_1 = tf.assign_moving_average(var, 0.0, 1.0)
+ update_2 = tf.assign_moving_average(var, 0.0, 0.9)
# var.name: 'scope1/scope2/foo'
# shadow var names: 'scope1/scope2/scope1/scope2/foo/biased'
@@ -76,20 +76,33 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
name: Optional name of the returned operation.
Returns:
- A reference to the input 'variable' tensor with the newly computed
- moving average.
+ A tensor which if evaluated will compute and return the new moving average.
"""
+ def update_fn(v, value, decay=decay):
+ decay = ops.convert_to_tensor(1.0 - decay, name="decay")
+ if decay.dtype != v.dtype.base_dtype:
+ decay = math_ops.cast(decay, v.dtype.base_dtype)
+ if zero_debias:
+ update_delta = _zero_debias(v, value, decay)
+ else:
+ update_delta = (v - value) * decay
+ return state_ops.assign_sub(v, update_delta, name=scope)
+
with ops.name_scope(name, "AssignMovingAvg",
[variable, value, decay]) as scope:
- with ops.colocate_with(variable):
- decay = ops.convert_to_tensor(1.0 - decay, name="decay")
- if decay.dtype != variable.dtype.base_dtype:
- decay = math_ops.cast(decay, variable.dtype.base_dtype)
- if zero_debias:
- update_delta = _zero_debias(variable, value, decay)
- else:
- update_delta = (variable - value) * decay
- return state_ops.assign_sub(variable, update_delta, name=scope)
+ tower_context = distribution_strategy_context.get_tower_context()
+ if tower_context:
+ # In a tower context, we update variable using the mean of value across
+ # towers.
+ def merge_fn(strategy, v, value):
+ value = strategy.reduce(
+ variable_scope.VariableAggregation.MEAN, value, v)
+ return strategy.update(v, update_fn, value)
+
+ return tower_context.merge_call(merge_fn, variable, value)
+ else:
+ strategy = distribution_strategy_context.get_cross_tower_context()
+ return strategy.update(variable, update_fn, value)
def weighted_moving_average(value,
@@ -379,8 +392,6 @@ class ExponentialMovingAverage(object):
Raises:
TypeError: If the arguments are not an allowed type.
- ValueError: If the moving average of one of the variables is already
- being computed.
"""
# TODO(touts): op_scope
if var_list is None:
diff --git a/tensorflow/python/util/protobuf/compare.py b/tensorflow/python/util/protobuf/compare.py
index a0e6bf65cf..3a3af4bffa 100644
--- a/tensorflow/python/util/protobuf/compare.py
+++ b/tensorflow/python/util/protobuf/compare.py
@@ -63,6 +63,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import difflib
import six
@@ -101,10 +102,19 @@ def assertProtoEqual(self, a, b, check_initialized=True, # pylint: disable=inva
if normalize_numbers:
NormalizeNumberFields(pb)
- self.assertMultiLineEqual(
- text_format.MessageToString(a, descriptor_pool=pool),
- text_format.MessageToString(b, descriptor_pool=pool),
- msg=msg)
+ a_str = text_format.MessageToString(a, descriptor_pool=pool)
+ b_str = text_format.MessageToString(b, descriptor_pool=pool)
+
+ # Some Python versions would perform regular diff instead of multi-line
+ # diff if string is longer than 2**16. We substitute this behavior
+ # with a call to unified_diff instead to have easier-to-read diffs.
+ # For context, see: https://bugs.python.org/issue11763.
+ if len(a_str) < 2**16 and len(b_str) < 2**16:
+ self.assertMultiLineEqual(a_str, b_str, msg=msg)
+ else:
+ diff = '\n' + ''.join(difflib.unified_diff(a_str.splitlines(True),
+ b_str.splitlines(True)))
+ self.fail('%s : %s' % (msg, diff))
def NormalizeNumberFields(pb):
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index 7b3e618e84..11eb9ce947 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -825,18 +825,16 @@ PyObject* IsNamedtuple(PyObject* o, bool strict) {
}
PyObject* SameNamedtuples(PyObject* o1, PyObject* o2) {
- PyObject* f1 = PyObject_GetAttrString(o1, "_fields");
- PyObject* f2 = PyObject_GetAttrString(o2, "_fields");
+ Safe_PyObjectPtr f1 = make_safe(PyObject_GetAttrString(o1, "_fields"));
+ Safe_PyObjectPtr f2 = make_safe(PyObject_GetAttrString(o2, "_fields"));
if (f1 == nullptr || f2 == nullptr) {
- Py_XDECREF(f1);
- Py_XDECREF(f2);
PyErr_SetString(
PyExc_RuntimeError,
"Expected namedtuple-like objects (that have _fields attr)");
return nullptr;
}
- if (PyObject_RichCompareBool(f1, f2, Py_NE)) {
+ if (PyObject_RichCompareBool(f1.get(), f2.get(), Py_NE)) {
Py_RETURN_FALSE;
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-unconnected-gradients.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-unconnected-gradients.pbtxt
new file mode 100644
index 0000000000..c5eb959430
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-unconnected-gradients.pbtxt
@@ -0,0 +1,12 @@
+path: "tensorflow.UnconnectedGradients"
+tf_class {
+ is_instance: "<enum \'UnconnectedGradients\'>"
+ member {
+ name: "NONE"
+ mtype: "<enum \'UnconnectedGradients\'>"
+ }
+ member {
+ name: "ZERO"
+ mtype: "<enum \'UnconnectedGradients\'>"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.activations.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.activations.pbtxt
index 2e9de9ebb2..eb315e356d 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.activations.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.activations.pbtxt
@@ -9,6 +9,10 @@ tf_module {
argspec: "args=[\'x\', \'alpha\'], varargs=None, keywords=None, defaults=[\'1.0\'], "
}
member_method {
+ name: "exponential"
+ argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "get"
argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.backend.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.backend.pbtxt
index a71a59e269..9feb7c09b8 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.backend.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.backend.pbtxt
@@ -46,7 +46,7 @@ tf_module {
}
member_method {
name: "batch_normalization"
- argspec: "args=[\'x\', \'mean\', \'var\', \'beta\', \'gamma\', \'epsilon\'], varargs=None, keywords=None, defaults=[\'0.001\'], "
+ argspec: "args=[\'x\', \'mean\', \'var\', \'beta\', \'gamma\', \'axis\', \'epsilon\'], varargs=None, keywords=None, defaults=[\'-1\', \'0.001\'], "
}
member_method {
name: "batch_set_value"
@@ -98,7 +98,7 @@ tf_module {
}
member_method {
name: "conv2d_transpose"
- argspec: "args=[\'x\', \'kernel\', \'output_shape\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=None, defaults=[\'(1, 1)\', \'valid\', \'None\'], "
+ argspec: "args=[\'x\', \'kernel\', \'output_shape\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\'], varargs=None, keywords=None, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\'], "
}
member_method {
name: "conv3d"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt
index c3dd2ad046..014f5828fa 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'None\'], "
+ argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'channels_last\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt
index c440604aae..a6e4856de9 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'None\'], "
+ argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'channels_last\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
index 065bb4d35b..381839d6de 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
@@ -84,7 +84,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'None\', \'(1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
index c7ba6056f9..2933f9f4b3 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
@@ -84,7 +84,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
index 8f4f7918ab..9c9c7461c8 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
@@ -84,7 +84,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'None\', \'(1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
index 93c442bd55..44ca598724 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
@@ -84,7 +84,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
index 5ea61d118d..a8094c0bde 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+ argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'channels_last\'], "
}
member_method {
name: "add_loss"
@@ -111,7 +111,7 @@ tf_class {
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_mask"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
index 11dca17c6d..3ebe162f57 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+ argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'channels_last\'], "
}
member_method {
name: "add_loss"
@@ -111,7 +111,7 @@ tf_class {
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_mask"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
index 278429af6f..c0a53b847b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+ argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'channels_last\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
index 935a69ab2f..ff6c6f3ec4 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+ argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'channels_last\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt
index 238d96cca6..d26da270e7 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'None\'], "
+ argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'channels_last\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt
index 4a45bf7997..524c5fd69e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'None\'], "
+ argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'channels_last\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt
index 81b91d2780..138d97b11f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.utils.pbtxt
@@ -70,6 +70,6 @@ tf_module {
}
member_method {
name: "to_categorical"
- argspec: "args=[\'y\', \'num_classes\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'y\', \'num_classes\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'float32\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
index a268529c1f..247dfcc1ca 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
@@ -249,6 +249,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "UnconnectedGradients"
+ mtype: "<class \'enum.EnumMeta\'>"
+ }
+ member {
name: "VERSION"
mtype: "<type \'str\'>"
}
@@ -1234,7 +1238,7 @@ tf_module {
}
member_method {
name: "gradients"
- argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\'], "
+ argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\', \'unconnected_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\', \'UnconnectedGradients.NONE\'], "
}
member_method {
name: "greater"
@@ -2090,7 +2094,7 @@ tf_module {
}
member_method {
name: "substr"
- argspec: "args=[\'input\', \'pos\', \'len\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'input\', \'pos\', \'len\', \'name\', \'unit\'], varargs=None, keywords=None, defaults=[\'None\', \'BYTE\'], "
}
member_method {
name: "subtract"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
index ebdaf57231..5ba48e7f57 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
@@ -34,7 +34,7 @@ tf_module {
}
member_method {
name: "substr"
- argspec: "args=[\'input\', \'pos\', \'len\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'input\', \'pos\', \'len\', \'name\', \'unit\'], varargs=None, keywords=None, defaults=[\'None\', \'BYTE\'], "
}
member_method {
name: "to_hash_bucket"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.test.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.test.pbtxt
index abe9b068ae..984c584c9e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.test.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.test.pbtxt
@@ -21,6 +21,10 @@ tf_module {
argspec: "args=[\'actual\', \'expected\', \'checkpoint_v2\'], varargs=None, keywords=None, defaults=[\'False\'], "
}
member_method {
+ name: "benchmark_config"
+ argspec: "args=[], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "compute_gradient"
argspec: "args=[\'x\', \'x_shape\', \'y\', \'y_shape\', \'x_init_value\', \'delta\', \'init_targets\', \'extra_feed_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'0.001\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-unconnected-gradients.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-unconnected-gradients.pbtxt
new file mode 100644
index 0000000000..c5eb959430
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-unconnected-gradients.pbtxt
@@ -0,0 +1,12 @@
+path: "tensorflow.UnconnectedGradients"
+tf_class {
+ is_instance: "<enum \'UnconnectedGradients\'>"
+ member {
+ name: "NONE"
+ mtype: "<enum \'UnconnectedGradients\'>"
+ }
+ member {
+ name: "ZERO"
+ mtype: "<enum \'UnconnectedGradients\'>"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.activations.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.activations.pbtxt
index 2e9de9ebb2..eb315e356d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.activations.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.activations.pbtxt
@@ -9,6 +9,10 @@ tf_module {
argspec: "args=[\'x\', \'alpha\'], varargs=None, keywords=None, defaults=[\'1.0\'], "
}
member_method {
+ name: "exponential"
+ argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "get"
argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt
index a71a59e269..9feb7c09b8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt
@@ -46,7 +46,7 @@ tf_module {
}
member_method {
name: "batch_normalization"
- argspec: "args=[\'x\', \'mean\', \'var\', \'beta\', \'gamma\', \'epsilon\'], varargs=None, keywords=None, defaults=[\'0.001\'], "
+ argspec: "args=[\'x\', \'mean\', \'var\', \'beta\', \'gamma\', \'axis\', \'epsilon\'], varargs=None, keywords=None, defaults=[\'-1\', \'0.001\'], "
}
member_method {
name: "batch_set_value"
@@ -98,7 +98,7 @@ tf_module {
}
member_method {
name: "conv2d_transpose"
- argspec: "args=[\'x\', \'kernel\', \'output_shape\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=None, defaults=[\'(1, 1)\', \'valid\', \'None\'], "
+ argspec: "args=[\'x\', \'kernel\', \'output_shape\', \'strides\', \'padding\', \'data_format\', \'dilation_rate\'], varargs=None, keywords=None, defaults=[\'(1, 1)\', \'valid\', \'None\', \'(1, 1)\'], "
}
member_method {
name: "conv3d"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt
index c3dd2ad046..014f5828fa 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'None\'], "
+ argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'channels_last\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt
index c440604aae..a6e4856de9 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'None\'], "
+ argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'channels_last\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
index 065bb4d35b..381839d6de 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
@@ -84,7 +84,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'None\', \'(1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
index c7ba6056f9..2933f9f4b3 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
@@ -84,7 +84,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
index 8f4f7918ab..9c9c7461c8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
@@ -84,7 +84,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'dilation_rate\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1)\', \'valid\', \'None\', \'None\', \'(1, 1)\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
index 93c442bd55..44ca598724 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
@@ -84,7 +84,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'self\', \'filters\', \'kernel_size\', \'strides\', \'padding\', \'output_padding\', \'data_format\', \'activation\', \'use_bias\', \'kernel_initializer\', \'bias_initializer\', \'kernel_regularizer\', \'bias_regularizer\', \'activity_regularizer\', \'kernel_constraint\', \'bias_constraint\'], varargs=None, keywords=kwargs, defaults=[\'(1, 1, 1)\', \'valid\', \'None\', \'None\', \'None\', \'True\', \'glorot_uniform\', \'zeros\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
index 5ea61d118d..a8094c0bde 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+ argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'channels_last\'], "
}
member_method {
name: "add_loss"
@@ -111,7 +111,7 @@ tf_class {
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_mask"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
index 11dca17c6d..3ebe162f57 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+ argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'channels_last\'], "
}
member_method {
name: "add_loss"
@@ -111,7 +111,7 @@ tf_class {
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_mask"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
index 278429af6f..c0a53b847b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+ argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'channels_last\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
index 935a69ab2f..ff6c6f3ec4 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
+ argspec: "args=[\'self\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'channels_last\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt
index 238d96cca6..d26da270e7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'None\'], "
+ argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'channels_last\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt
index 4a45bf7997..524c5fd69e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt
@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'None\'], "
+ argspec: "args=[\'self\', \'pool_size\', \'strides\', \'padding\', \'data_format\'], varargs=None, keywords=kwargs, defaults=[\'2\', \'None\', \'valid\', \'channels_last\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt
index 81b91d2780..138d97b11f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.utils.pbtxt
@@ -70,6 +70,6 @@ tf_module {
}
member_method {
name: "to_categorical"
- argspec: "args=[\'y\', \'num_classes\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'y\', \'num_classes\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'float32\'], "
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
index 5b3ea75bce..978afcf985 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
@@ -221,6 +221,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "UnconnectedGradients"
+ mtype: "<class \'enum.EnumMeta\'>"
+ }
+ member {
name: "VERSION"
mtype: "<type \'str\'>"
}
@@ -1134,7 +1138,7 @@ tf_module {
}
member_method {
name: "gradients"
- argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\'], "
+ argspec: "args=[\'ys\', \'xs\', \'grad_ys\', \'name\', \'colocate_gradients_with_ops\', \'gate_gradients\', \'aggregation_method\', \'stop_gradients\', \'unconnected_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'gradients\', \'False\', \'False\', \'None\', \'None\', \'UnconnectedGradients.NONE\'], "
}
member_method {
name: "greater"
@@ -1930,7 +1934,7 @@ tf_module {
}
member_method {
name: "substr"
- argspec: "args=[\'input\', \'pos\', \'len\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'input\', \'pos\', \'len\', \'name\', \'unit\'], varargs=None, keywords=None, defaults=[\'None\', \'BYTE\'], "
}
member_method {
name: "subtract"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
index ebdaf57231..5ba48e7f57 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
@@ -34,7 +34,7 @@ tf_module {
}
member_method {
name: "substr"
- argspec: "args=[\'input\', \'pos\', \'len\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'input\', \'pos\', \'len\', \'name\', \'unit\'], varargs=None, keywords=None, defaults=[\'None\', \'BYTE\'], "
}
member_method {
name: "to_hash_bucket"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.test.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.test.pbtxt
index abe9b068ae..984c584c9e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.test.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.test.pbtxt
@@ -21,6 +21,10 @@ tf_module {
argspec: "args=[\'actual\', \'expected\', \'checkpoint_v2\'], varargs=None, keywords=None, defaults=[\'False\'], "
}
member_method {
+ name: "benchmark_config"
+ argspec: "args=[], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "compute_gradient"
argspec: "args=[\'x\', \'x_shape\', \'y\', \'y_shape\', \'x_init_value\', \'delta\', \'init_targets\', \'extra_feed_dict\'], varargs=None, keywords=None, defaults=[\'None\', \'0.001\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/ci_build/Dockerfile.android b/tensorflow/tools/ci_build/Dockerfile.android
index dcf077791a..7e72eb0cbf 100644
--- a/tensorflow/tools/ci_build/Dockerfile.android
+++ b/tensorflow/tools/ci_build/Dockerfile.android
@@ -45,9 +45,14 @@ ENV ANDROID_NDK_FILENAME android-ndk-r14b-linux-x86_64.zip
ENV ANDROID_NDK_URL https://dl.google.com/android/repository/${ANDROID_NDK_FILENAME}
ENV ANDROID_NDK_HOME ${ANDROID_DEV_HOME}/ndk
ENV PATH ${PATH}:${ANDROID_NDK_HOME}
+# Workaround for b/117156972: inject missing #include into NDK versions of
+# futex.h.
RUN cd ${ANDROID_DEV_HOME} && \
wget -q ${ANDROID_NDK_URL} && \
unzip ${ANDROID_NDK_FILENAME} -d ${ANDROID_DEV_HOME} && \
+ sed -i 15i"#include <linux/compiler.h>" ${ANDROID_DEV_HOME}/android-ndk-r14b/platforms/android-14/arch-arm/usr/include/linux/futex.h && \
+ sed -i 15i"#include <linux/compiler.h>" ${ANDROID_DEV_HOME}/android-ndk-r14b/platforms/android-14/arch-mips/usr/include/linux/futex.h && \
+ sed -i 15i"#include <linux/compiler.h>" ${ANDROID_DEV_HOME}/android-ndk-r14b/platforms/android-14/arch-x86/usr/include/linux/futex.h && \
rm ${ANDROID_NDK_FILENAME} && \
bash -c "ln -s ${ANDROID_DEV_HOME}/android-ndk-* ${ANDROID_NDK_HOME}"
diff --git a/tensorflow/tools/ci_build/ci_parameterized_build.sh b/tensorflow/tools/ci_build/ci_parameterized_build.sh
index 99bdedf7b4..489722c0e9 100755
--- a/tensorflow/tools/ci_build/ci_parameterized_build.sh
+++ b/tensorflow/tools/ci_build/ci_parameterized_build.sh
@@ -83,6 +83,9 @@
# Use the specified configurations when building.
# When set, overrides TF_BUILD_IS_OPT and TF_BUILD_MAVX
# options, as this will replace the two.
+# TF_BUILD_TEST_TIMEOUT:
+# Sets the value of bazel --test_timeout, defaults to -1
+# which uses the bazel defaults.
# TF_SKIP_CONTRIB_TESTS:
# If set to any non-empty or non-0 value, will skip running
# contrib tests.
@@ -125,6 +128,8 @@ NO_DOCKER_OPT_FLAG="--genrule_strategy=standalone"
DO_DOCKER=1
+# Bazel uses defaults for all test sizes when given `-1`.
+TF_BUILD_TEST_TIMEOUT=${TF_BUILD_TEST_TIMEOUT:--1}
# Helpful flags:
# --test_summary=detailed: Tell us more about which targets are being built
@@ -132,7 +137,16 @@ DO_DOCKER=1
# --build_tests_only: Don't build targets depended on by tests if the test is
# disabled. Also saves some compilation time. Otherwise,
# tries to build everything.
-BAZEL_TEST_FLAGS="--test_summary=detailed --build_tests_only --keep_going"
+# --test_timeout: Test timeouts in the order short,moderate,long,eternal.
+# --test_env: Environment variables to set when running bazel tests. These are
+# especially important when using --run_under with
+# parallel_gpu_execute.
+BAZEL_TEST_FLAGS=""\
+"--test_summary=detailed --build_tests_only --keep_going "\
+"--test_timeout=${TF_BUILD_TEST_TIMEOUT} "\
+"--test_env=TF_GPU_COUNT=${TF_GPU_COUNT} "\
+"--test_env=TF_TESTS_PER_GPU=${TF_TESTS_PER_GPU} "\
+"--test_env=TF_PER_DEVICE_MEMORY_LIMIT_MB=${TF_PER_DEVICE_MEMORY_LIMIT_MB}"
BAZEL_BUILD_FLAGS="--keep_going"
BAZEL_CMD="bazel test ${BAZEL_TEST_FLAGS}"
@@ -148,13 +162,6 @@ ANDROID_FULL_CMD="${CI_BUILD_DIR}/builds/android_full.sh"
TF_GPU_COUNT=${TF_GPU_COUNT:-4}
PARALLEL_GPU_TEST_CMD='//tensorflow/tools/ci_build/gpu_build:parallel_gpu_execute'
-# Environment variables to set when running bazel tests. These are especially
-# important when using --run_under with parallel_gpu_execute.
-BAZEL_TEST_ENV=""\
-"--test_env=TF_GPU_COUNT=${TF_GPU_COUNT} "\
-"--test_env=TF_TESTS_PER_GPU=${TF_TESTS_PER_GPU} "\
-"--test_env=TF_PER_DEVICE_MEMORY_LIMIT_MB=${TF_PER_DEVICE_MEMORY_LIMIT_MB} "
-
BENCHMARK_CMD="${CI_BUILD_DIR}/builds/benchmark.sh"
EXTRA_PARAMS=""
@@ -415,11 +422,11 @@ if [[ ${TF_BUILD_IS_PIP} == "no_pip" ]] ||
if [[ ${CTYPE} == cpu* ]] || \
[[ ${CTYPE} == "debian.jessie.cpu" ]]; then
# CPU only command, fully parallel.
- NO_PIP_MAIN_CMD="${MAIN_CMD} ${BAZEL_CMD} ${BAZEL_TEST_ENV} ${OPT_FLAG} "\
- "${EXTRA_ARGS} -- ${BAZEL_TARGET}"
+ NO_PIP_MAIN_CMD="${MAIN_CMD} ${BAZEL_CMD} ${OPT_FLAG} "\
+"${EXTRA_ARGS} -- ${BAZEL_TARGET}"
elif [[ ${CTYPE} == gpu* ]]; then
# GPU only command, run as many jobs as the GPU count only.
- NO_PIP_MAIN_CMD="${BAZEL_CMD} ${BAZEL_TEST_ENV} ${OPT_FLAG} "\
+ NO_PIP_MAIN_CMD="${BAZEL_CMD} ${OPT_FLAG} "\
"--local_test_jobs=${TF_GPU_COUNT} "\
"--run_under=${PARALLEL_GPU_TEST_CMD} "\
"${EXTRA_ARGS} -- ${BAZEL_TARGET}"
diff --git a/tensorflow/tools/ci_build/install/install_auditwheel.sh b/tensorflow/tools/ci_build/install/install_auditwheel.sh
index e6f6124d56..0e6d98c0a8 100755
--- a/tensorflow/tools/ci_build/install/install_auditwheel.sh
+++ b/tensorflow/tools/ci_build/install/install_auditwheel.sh
@@ -18,6 +18,10 @@ set -e
sudo pip3 install auditwheel==1.5.0
+# Pin wheel==0.31.1 to work around issue
+# https://github.com/pypa/auditwheel/issues/102
+sudo pip3 install wheel==0.31.1
+
set +e
patchelf_location=$(which patchelf)
if [[ -z "$patchelf_location" ]]; then
diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh
index 7f293e8604..329d05342a 100755
--- a/tensorflow/tools/ci_build/install/install_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh
@@ -124,6 +124,10 @@ pip3 install keras_preprocessing==1.0.5 --no-deps
pip2 install --upgrade h5py==2.8.0
pip3 install --upgrade h5py==2.8.0
+# Estimator
+pip2 install tensorflow_estimator --no-deps
+pip3 install tensorflow_estimator --no-deps
+
# Install last working version of setuptools.
pip2 install --upgrade setuptools==39.1.0
pip3 install --upgrade setuptools==39.1.0
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index 54a7b7ffbe..dd1dca9ee8 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -226,15 +226,14 @@ if os.name == 'nt':
else:
EXTENSION_NAME = 'python/_pywrap_tensorflow_internal.so'
-headers = (list(find_files('*.h', 'tensorflow/core')) +
- list(find_files('*.h', 'tensorflow/stream_executor')) +
- list(find_files('*.h', 'google/protobuf_archive/src')) +
- list(find_files('*', 'third_party/eigen3')) +
- list(find_files('*.h',
- 'tensorflow/include/external/com_google_absl')) +
- list(find_files('*.inc',
- 'tensorflow/include/external/com_google_absl')) +
- list(find_files('*', 'tensorflow/include/external/eigen_archive')))
+headers = (
+ list(find_files('*.h', 'tensorflow/core')) + list(
+ find_files('*.h', 'tensorflow/stream_executor')) +
+ list(find_files('*.h', 'google/protobuf_archive/src')) + list(
+ find_files('*', 'third_party/eigen3')) + list(
+ find_files('*.h', 'tensorflow/include/external/com_google_absl')) +
+ list(find_files('*.inc', 'tensorflow/include/external/com_google_absl')) +
+ list(find_files('*', 'tensorflow/include/external/eigen_archive')))
setup(
name=project_name,
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 72f3fd0cf8..adeac62e43 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -22,10 +22,14 @@ load(
)
load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")
load("//third_party/icu:workspace.bzl", icu = "repo")
+load("//third_party/jpeg:workspace.bzl", jpeg = "repo")
+load("//third_party/nasm:workspace.bzl", nasm = "repo")
def initialize_third_party():
flatbuffers()
icu()
+ jpeg()
+ nasm()
# Sanitize a dependency so that it works correctly from code that includes
# TensorFlow as a submodule.
@@ -110,11 +114,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "com_google_absl",
build_file = clean_dep("//third_party:com_google_absl.BUILD"),
- sha256 = "507903ef9353cb25cccd0a6840048fdd348fd20e98314d694f04a990c0f277e3",
- strip_prefix = "abseil-cpp-f21d187b80e3b7f08fb279775ea9c8b48c636030",
+ sha256 = "f186bf5d9fce3037c602a21f86facbdd317adecef36e1726ec7bc7b496943a82",
+ strip_prefix = "abseil-cpp-e821380d69a549dc64900693942789d21aa4df5e",
urls = [
- "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/f21d187b80e3b7f08fb279775ea9c8b48c636030.tar.gz",
- "https://github.com/abseil/abseil-cpp/archive/f21d187b80e3b7f08fb279775ea9c8b48c636030.tar.gz",
+ "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/e821380d69a549dc64900693942789d21aa4df5e.tar.gz",
+ "https://github.com/abseil/abseil-cpp/archive/e821380d69a549dc64900693942789d21aa4df5e.tar.gz",
],
)
@@ -234,31 +238,6 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
)
tf_http_archive(
- name = "nasm",
- build_file = clean_dep("//third_party:nasm.BUILD"),
- sha256 = "63ec86477ad3f0f6292325fd89e1d93aea2e2fd490070863f17d48f7cd387011",
- strip_prefix = "nasm-2.13.03",
- system_build_file = clean_dep("//third_party/systemlibs:nasm.BUILD"),
- urls = [
- "https://mirror.bazel.build/www.nasm.us/pub/nasm/releasebuilds/2.13.03/nasm-2.13.03.tar.bz2",
- "http://pkgs.fedoraproject.org/repo/pkgs/nasm/nasm-2.13.03.tar.bz2/sha512/d7a6b4cee8dfd603d8d4c976e5287b5cc542fa0b466ff989b743276a6e28114e64289bf02a7819eca63142a5278aa6eed57773007e5f589e15768e6456a8919d/nasm-2.13.03.tar.bz2",
- "http://www.nasm.us/pub/nasm/releasebuilds/2.13.03/nasm-2.13.03.tar.bz2",
- ],
- )
-
- tf_http_archive(
- name = "jpeg",
- build_file = clean_dep("//third_party/jpeg:jpeg.BUILD"),
- sha256 = "f892fff427ab3adffc289363eac26d197ce3ccacefe5f5822377348a8166069b",
- strip_prefix = "libjpeg-turbo-2.0.0",
- system_build_file = clean_dep("//third_party/systemlibs:jpeg.BUILD"),
- urls = [
- "https://mirror.bazel.build/github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.0.tar.gz",
- "https://github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.0.tar.gz",
- ],
- )
-
- tf_http_archive(
name = "png_archive",
build_file = clean_dep("//third_party:png.BUILD"),
patch_file = clean_dep("//third_party:png_fix_rpi.patch"),
@@ -585,12 +564,12 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "nccl_archive",
- build_file = clean_dep("//third_party:nccl/nccl_archive.BUILD"),
- sha256 = "2ca86fb6179ecbff789cc67c836139c1bbc0324ed8c04643405a30bf26325176",
- strip_prefix = "nccl-03d856977ecbaac87e598c0c4bafca96761b9ac7",
+ build_file = clean_dep("//third_party:nccl/archive.BUILD"),
+ sha256 = "19132b5127fa8e02d95a09795866923f04064c8f1e0770b2b42ab551408882a4",
+ strip_prefix = "nccl-f93fe9bfd94884cec2ba711897222e0df5569a53",
urls = [
- "https://mirror.bazel.build/github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz",
- "https://github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz",
+ "https://mirror.bazel.build/github.com/nvidia/nccl/archive/f93fe9bfd94884cec2ba711897222e0df5569a53.tar.gz",
+ "https://github.com/nvidia/nccl/archive/f93fe9bfd94884cec2ba711897222e0df5569a53.tar.gz",
],
)
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index 69f4599c16..831a3067b2 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -126,118 +126,141 @@ load(
)
def _get_python_bin(repository_ctx):
- """Gets the python bin path."""
- python_bin = repository_ctx.os.environ.get(_PYTHON_BIN_PATH)
- if python_bin != None:
- return python_bin
- python_bin_name = "python.exe" if _is_windows(repository_ctx) else "python"
- python_bin_path = repository_ctx.which(python_bin_name)
- if python_bin_path != None:
- return str(python_bin_path)
- auto_configure_fail("Cannot find python in PATH, please make sure " +
- "python is installed and add its directory in PATH, or --define " +
- "%s='/something/else'.\nPATH=%s" % (
- _PYTHON_BIN_PATH,
- repository_ctx.os.environ.get("PATH", ""),
- ))
+ """Gets the python bin path."""
+ python_bin = repository_ctx.os.environ.get(_PYTHON_BIN_PATH)
+ if python_bin != None:
+ return python_bin
+ python_bin_name = "python.exe" if _is_windows(repository_ctx) else "python"
+ python_bin_path = repository_ctx.which(python_bin_name)
+ if python_bin_path != None:
+ return str(python_bin_path)
+ auto_configure_fail(
+ "Cannot find python in PATH, please make sure " +
+ "python is installed and add its directory in PATH, or --define " +
+ "%s='/something/else'.\nPATH=%s" % (
+ _PYTHON_BIN_PATH,
+ repository_ctx.os.environ.get("PATH", ""),
+ ))
+
def _get_nvcc_tmp_dir_for_windows(repository_ctx):
- """Return the tmp directory for nvcc to generate intermediate source files."""
- escaped_tmp_dir = escape_string(
- get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace("\\", "\\\\"),
- )
- return escaped_tmp_dir + "\\\\nvcc_inter_files_tmp_dir"
+ """Return the tmp directory for nvcc to generate intermediate source files."""
+ escaped_tmp_dir = escape_string(
+ get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace(
+ "\\", "\\\\"),)
+ return escaped_tmp_dir + "\\\\nvcc_inter_files_tmp_dir"
-def _get_msvc_compiler(repository_ctx):
- vc_path = find_vc_path(repository_ctx)
- return find_msvc_tool(repository_ctx, vc_path, "cl.exe").replace("\\", "/")
-def _get_win_cuda_defines(repository_ctx):
- """Return CROSSTOOL defines for Windows"""
-
- # If we are not on Windows, return empty vaules for Windows specific fields.
- # This ensures the CROSSTOOL file parser is happy.
- if not _is_windows(repository_ctx):
- return {
- "%{msvc_env_tmp}": "",
- "%{msvc_env_path}": "",
- "%{msvc_env_include}": "",
- "%{msvc_env_lib}": "",
- "%{msvc_cl_path}": "",
- "%{msvc_ml_path}": "",
- "%{msvc_link_path}": "",
- "%{msvc_lib_path}": "",
- "%{cxx_builtin_include_directory}": "",
- }
-
- vc_path = find_vc_path(repository_ctx)
- if not vc_path:
- auto_configure_fail("Visual C++ build tools not found on your machine." +
- "Please check your installation following https://docs.bazel.build/versions/master/windows.html#using")
- return {}
-
- env = setup_vc_env_vars(repository_ctx, vc_path)
- escaped_paths = escape_string(env["PATH"])
- escaped_include_paths = escape_string(env["INCLUDE"])
- escaped_lib_paths = escape_string(env["LIB"])
- escaped_tmp_dir = escape_string(
- get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace("\\", "\\\\"),
- )
+def _get_msvc_compiler(repository_ctx):
+ vc_path = find_vc_path(repository_ctx)
+ return find_msvc_tool(repository_ctx, vc_path, "cl.exe").replace("\\", "/")
- msvc_cl_path = "windows/msvc_wrapper_for_nvcc.bat"
- msvc_ml_path = find_msvc_tool(repository_ctx, vc_path, "ml64.exe").replace("\\", "/")
- msvc_link_path = find_msvc_tool(repository_ctx, vc_path, "link.exe").replace("\\", "/")
- msvc_lib_path = find_msvc_tool(repository_ctx, vc_path, "lib.exe").replace("\\", "/")
- # nvcc will generate some temporary source files under %{nvcc_tmp_dir}
- # The generated files are guranteed to have unique name, so they can share the same tmp directory
- escaped_cxx_include_directories = ["cxx_builtin_include_directory: \"%s\"" % _get_nvcc_tmp_dir_for_windows(repository_ctx)]
- for path in escaped_include_paths.split(";"):
- if path:
- escaped_cxx_include_directories.append("cxx_builtin_include_directory: \"%s\"" % path)
+def _get_win_cuda_defines(repository_ctx):
+ """Return CROSSTOOL defines for Windows"""
+ # If we are not on Windows, return empty vaules for Windows specific fields.
+ # This ensures the CROSSTOOL file parser is happy.
+ if not _is_windows(repository_ctx):
return {
- "%{msvc_env_tmp}": escaped_tmp_dir,
- "%{msvc_env_path}": escaped_paths,
- "%{msvc_env_include}": escaped_include_paths,
- "%{msvc_env_lib}": escaped_lib_paths,
- "%{msvc_cl_path}": msvc_cl_path,
- "%{msvc_ml_path}": msvc_ml_path,
- "%{msvc_link_path}": msvc_link_path,
- "%{msvc_lib_path}": msvc_lib_path,
- "%{cxx_builtin_include_directory}": "\n".join(escaped_cxx_include_directories),
+ "%{msvc_env_tmp}": "",
+ "%{msvc_env_path}": "",
+ "%{msvc_env_include}": "",
+ "%{msvc_env_lib}": "",
+ "%{msvc_cl_path}": "",
+ "%{msvc_ml_path}": "",
+ "%{msvc_link_path}": "",
+ "%{msvc_lib_path}": "",
+ "%{cxx_builtin_include_directory}": "",
}
+ vc_path = find_vc_path(repository_ctx)
+ if not vc_path:
+ auto_configure_fail(
+ "Visual C++ build tools not found on your machine." +
+ "Please check your installation following https://docs.bazel.build/versions/master/windows.html#using"
+ )
+ return {}
+
+ env = setup_vc_env_vars(repository_ctx, vc_path)
+ escaped_paths = escape_string(env["PATH"])
+ escaped_include_paths = escape_string(env["INCLUDE"])
+ escaped_lib_paths = escape_string(env["LIB"])
+ escaped_tmp_dir = escape_string(
+ get_env_var(repository_ctx, "TMP", "C:\\Windows\\Temp").replace(
+ "\\", "\\\\"),)
+
+ msvc_cl_path = "windows/msvc_wrapper_for_nvcc.bat"
+ msvc_ml_path = find_msvc_tool(repository_ctx, vc_path, "ml64.exe").replace(
+ "\\", "/")
+ msvc_link_path = find_msvc_tool(repository_ctx, vc_path, "link.exe").replace(
+ "\\", "/")
+ msvc_lib_path = find_msvc_tool(repository_ctx, vc_path, "lib.exe").replace(
+ "\\", "/")
+
+ # nvcc will generate some temporary source files under %{nvcc_tmp_dir}
+ # The generated files are guranteed to have unique name, so they can share the same tmp directory
+ escaped_cxx_include_directories = [
+ "cxx_builtin_include_directory: \"%s\"" %
+ _get_nvcc_tmp_dir_for_windows(repository_ctx)
+ ]
+ for path in escaped_include_paths.split(";"):
+ if path:
+ escaped_cxx_include_directories.append(
+ "cxx_builtin_include_directory: \"%s\"" % path)
+
+ return {
+ "%{msvc_env_tmp}":
+ escaped_tmp_dir,
+ "%{msvc_env_path}":
+ escaped_paths,
+ "%{msvc_env_include}":
+ escaped_include_paths,
+ "%{msvc_env_lib}":
+ escaped_lib_paths,
+ "%{msvc_cl_path}":
+ msvc_cl_path,
+ "%{msvc_ml_path}":
+ msvc_ml_path,
+ "%{msvc_link_path}":
+ msvc_link_path,
+ "%{msvc_lib_path}":
+ msvc_lib_path,
+ "%{cxx_builtin_include_directory}":
+ "\n".join(escaped_cxx_include_directories),
+ }
+
# TODO(dzc): Once these functions have been factored out of Bazel's
# cc_configure.bzl, load them from @bazel_tools instead.
# BEGIN cc_configure common functions.
def find_cc(repository_ctx):
- """Find the C++ compiler."""
- if _is_windows(repository_ctx):
- return _get_msvc_compiler(repository_ctx)
-
- if _use_cuda_clang(repository_ctx):
- target_cc_name = "clang"
- cc_path_envvar = _CLANG_CUDA_COMPILER_PATH
- if _flag_enabled(repository_ctx, _TF_DOWNLOAD_CLANG):
- return "extra_tools/bin/clang"
- else:
- target_cc_name = "gcc"
- cc_path_envvar = _GCC_HOST_COMPILER_PATH
- cc_name = target_cc_name
-
- if cc_path_envvar in repository_ctx.os.environ:
- cc_name_from_env = repository_ctx.os.environ[cc_path_envvar].strip()
- if cc_name_from_env:
- cc_name = cc_name_from_env
- if cc_name.startswith("/"):
- # Absolute path, maybe we should make this supported by our which function.
- return cc_name
- cc = repository_ctx.which(cc_name)
- if cc == None:
- fail(("Cannot find {}, either correct your path or set the {}" +
- " environment variable").format(target_cc_name, cc_path_envvar))
- return cc
+ """Find the C++ compiler."""
+ if _is_windows(repository_ctx):
+ return _get_msvc_compiler(repository_ctx)
+
+ if _use_cuda_clang(repository_ctx):
+ target_cc_name = "clang"
+ cc_path_envvar = _CLANG_CUDA_COMPILER_PATH
+ if _flag_enabled(repository_ctx, _TF_DOWNLOAD_CLANG):
+ return "extra_tools/bin/clang"
+ else:
+ target_cc_name = "gcc"
+ cc_path_envvar = _GCC_HOST_COMPILER_PATH
+ cc_name = target_cc_name
+
+ if cc_path_envvar in repository_ctx.os.environ:
+ cc_name_from_env = repository_ctx.os.environ[cc_path_envvar].strip()
+ if cc_name_from_env:
+ cc_name = cc_name_from_env
+ if cc_name.startswith("/"):
+ # Absolute path, maybe we should make this supported by our which function.
+ return cc_name
+ cc = repository_ctx.which(cc_name)
+ if cc == None:
+ fail(("Cannot find {}, either correct your path or set the {}" +
+ " environment variable").format(target_cc_name, cc_path_envvar))
+ return cc
+
_INC_DIR_MARKER_BEGIN = "#include <...>"
@@ -246,80 +269,82 @@ _OSX_FRAMEWORK_SUFFIX = " (framework directory)"
_OSX_FRAMEWORK_SUFFIX_LEN = len(_OSX_FRAMEWORK_SUFFIX)
def _cxx_inc_convert(path):
- """Convert path returned by cc -E xc++ in a complete path."""
- path = path.strip()
- if path.endswith(_OSX_FRAMEWORK_SUFFIX):
- path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip()
- return path
+ """Convert path returned by cc -E xc++ in a complete path."""
+ path = path.strip()
+ if path.endswith(_OSX_FRAMEWORK_SUFFIX):
+ path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip()
+ return path
+
def _normalize_include_path(repository_ctx, path):
- """Normalizes include paths before writing them to the crosstool.
+ """Normalizes include paths before writing them to the crosstool.
If path points inside the 'crosstool' folder of the repository, a relative
path is returned.
If path points outside the 'crosstool' folder, an absolute path is returned.
"""
- path = str(repository_ctx.path(path))
- crosstool_folder = str(repository_ctx.path(".").get_child("crosstool"))
+ path = str(repository_ctx.path(path))
+ crosstool_folder = str(repository_ctx.path(".").get_child("crosstool"))
+
+ if path.startswith(crosstool_folder):
+ # We drop the path to "$REPO/crosstool" and a trailing path separator.
+ return path[len(crosstool_folder) + 1:]
+ return path
- if path.startswith(crosstool_folder):
- # We drop the path to "$REPO/crosstool" and a trailing path separator.
- return path[len(crosstool_folder) + 1:]
- return path
def _get_cxx_inc_directories_impl(repository_ctx, cc, lang_is_cpp):
- """Compute the list of default C or C++ include directories."""
- if lang_is_cpp:
- lang = "c++"
- else:
- lang = "c"
- result = repository_ctx.execute([cc, "-E", "-x" + lang, "-", "-v"])
- index1 = result.stderr.find(_INC_DIR_MARKER_BEGIN)
- if index1 == -1:
- return []
- index1 = result.stderr.find("\n", index1)
- if index1 == -1:
- return []
- index2 = result.stderr.rfind("\n ")
- if index2 == -1 or index2 < index1:
- return []
- index2 = result.stderr.find("\n", index2 + 1)
- if index2 == -1:
- inc_dirs = result.stderr[index1 + 1:]
- else:
- inc_dirs = result.stderr[index1 + 1:index2].strip()
+ """Compute the list of default C or C++ include directories."""
+ if lang_is_cpp:
+ lang = "c++"
+ else:
+ lang = "c"
+ result = repository_ctx.execute([cc, "-E", "-x" + lang, "-", "-v"])
+ index1 = result.stderr.find(_INC_DIR_MARKER_BEGIN)
+ if index1 == -1:
+ return []
+ index1 = result.stderr.find("\n", index1)
+ if index1 == -1:
+ return []
+ index2 = result.stderr.rfind("\n ")
+ if index2 == -1 or index2 < index1:
+ return []
+ index2 = result.stderr.find("\n", index2 + 1)
+ if index2 == -1:
+ inc_dirs = result.stderr[index1 + 1:]
+ else:
+ inc_dirs = result.stderr[index1 + 1:index2].strip()
+
+ return [
+ _normalize_include_path(repository_ctx, _cxx_inc_convert(p))
+ for p in inc_dirs.split("\n")
+ ]
- return [
- _normalize_include_path(repository_ctx, _cxx_inc_convert(p))
- for p in inc_dirs.split("\n")
- ]
def get_cxx_inc_directories(repository_ctx, cc):
- """Compute the list of default C and C++ include directories."""
-
- # For some reason `clang -xc` sometimes returns include paths that are
- # different from the ones from `clang -xc++`. (Symlink and a dir)
- # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists
- includes_cpp = _get_cxx_inc_directories_impl(repository_ctx, cc, True)
- includes_c = _get_cxx_inc_directories_impl(repository_ctx, cc, False)
-
- includes_cpp_set = depset(includes_cpp)
- return includes_cpp + [
- inc
- for inc in includes_c
- if inc not in includes_cpp_set
- ]
+ """Compute the list of default C and C++ include directories."""
+
+ # For some reason `clang -xc` sometimes returns include paths that are
+ # different from the ones from `clang -xc++`. (Symlink and a dir)
+ # So we run the compiler with both `-xc` and `-xc++` and merge resulting lists
+ includes_cpp = _get_cxx_inc_directories_impl(repository_ctx, cc, True)
+ includes_c = _get_cxx_inc_directories_impl(repository_ctx, cc, False)
+
+ includes_cpp_set = depset(includes_cpp)
+ return includes_cpp + [
+ inc for inc in includes_c if inc not in includes_cpp_set
+ ]
+
def auto_configure_fail(msg):
- """Output failure message when cuda configuration fails."""
- red = "\033[0;31m"
- no_color = "\033[0m"
- fail("\n%sCuda Configuration Error:%s %s\n" % (red, no_color, msg))
+ """Output failure message when cuda configuration fails."""
+ red = "\033[0;31m"
+ no_color = "\033[0m"
+ fail("\n%sCuda Configuration Error:%s %s\n" % (red, no_color, msg))
# END cc_configure common functions (see TODO above).
def _host_compiler_includes(repository_ctx, cc):
- """Generates the cxx_builtin_include_directory entries for gcc inc dirs.
+ """Generates the cxx_builtin_include_directory entries for gcc inc dirs.
Args:
repository_ctx: The repository context.
@@ -330,14 +355,15 @@ def _host_compiler_includes(repository_ctx, cc):
host compiler include directories, which can be added to the CROSSTOOL
file.
"""
- inc_dirs = get_cxx_inc_directories(repository_ctx, cc)
- inc_entries = []
- for inc_dir in inc_dirs:
- inc_entries.append(" cxx_builtin_include_directory: \"%s\"" % inc_dir)
- return "\n".join(inc_entries)
+ inc_dirs = get_cxx_inc_directories(repository_ctx, cc)
+ inc_entries = []
+ for inc_dir in inc_dirs:
+ inc_entries.append(" cxx_builtin_include_directory: \"%s\"" % inc_dir)
+ return "\n".join(inc_entries)
+
def _cuda_include_path(repository_ctx, cuda_config):
- """Generates the cxx_builtin_include_directory entries for cuda inc dirs.
+ """Generates the cxx_builtin_include_directory entries for cuda inc dirs.
Args:
repository_ctx: The repository context.
@@ -348,39 +374,41 @@ def _cuda_include_path(repository_ctx, cuda_config):
host compiler include directories, which can be added to the CROSSTOOL
file.
"""
- nvcc_path = repository_ctx.path("%s/bin/nvcc%s" %
- (
- cuda_config.cuda_toolkit_path,
- ".exe" if cuda_config.cpu_value == "Windows" else "",
- ))
- result = repository_ctx.execute([
- nvcc_path,
- "-v",
- "/dev/null",
- "-o",
- "/dev/null",
- ])
- target_dir = ""
- for one_line in result.stderr.splitlines():
- if one_line.startswith("#$ _TARGET_DIR_="):
- target_dir = (cuda_config.cuda_toolkit_path + "/" +
- one_line.replace("#$ _TARGET_DIR_=", "") + "/include")
- inc_entries = []
- if target_dir != "":
- inc_entries.append(" cxx_builtin_include_directory: \"%s\"" % target_dir)
- default_include = cuda_config.cuda_toolkit_path + "/include"
- inc_entries.append(" cxx_builtin_include_directory: \"%s\"" %
- default_include)
- return "\n".join(inc_entries)
+ nvcc_path = repository_ctx.path("%s/bin/nvcc%s" % (
+ cuda_config.cuda_toolkit_path,
+ ".exe" if cuda_config.cpu_value == "Windows" else "",
+ ))
+ result = repository_ctx.execute([
+ nvcc_path,
+ "-v",
+ "/dev/null",
+ "-o",
+ "/dev/null",
+ ])
+ target_dir = ""
+ for one_line in result.stderr.splitlines():
+ if one_line.startswith("#$ _TARGET_DIR_="):
+ target_dir = (
+ cuda_config.cuda_toolkit_path + "/" + one_line.replace(
+ "#$ _TARGET_DIR_=", "") + "/include")
+ inc_entries = []
+ if target_dir != "":
+ inc_entries.append(" cxx_builtin_include_directory: \"%s\"" % target_dir)
+ default_include = cuda_config.cuda_toolkit_path + "/include"
+ inc_entries.append(
+ " cxx_builtin_include_directory: \"%s\"" % default_include)
+ return "\n".join(inc_entries)
+
def _enable_cuda(repository_ctx):
- if "TF_NEED_CUDA" in repository_ctx.os.environ:
- enable_cuda = repository_ctx.os.environ["TF_NEED_CUDA"].strip()
- return enable_cuda == "1"
- return False
+ if "TF_NEED_CUDA" in repository_ctx.os.environ:
+ enable_cuda = repository_ctx.os.environ["TF_NEED_CUDA"].strip()
+ return enable_cuda == "1"
+ return False
+
-def _cuda_toolkit_path(repository_ctx):
- """Finds the cuda toolkit directory.
+def cuda_toolkit_path(repository_ctx):
+ """Finds the cuda toolkit directory.
Args:
repository_ctx: The repository context.
@@ -388,27 +416,31 @@ def _cuda_toolkit_path(repository_ctx):
Returns:
A speculative real path of the cuda toolkit install directory.
"""
- cuda_toolkit_path = _DEFAULT_CUDA_TOOLKIT_PATH
- if _CUDA_TOOLKIT_PATH in repository_ctx.os.environ:
- cuda_toolkit_path = repository_ctx.os.environ[_CUDA_TOOLKIT_PATH].strip()
- if not repository_ctx.path(cuda_toolkit_path).exists:
- auto_configure_fail("Cannot find cuda toolkit path.")
- return str(repository_ctx.path(cuda_toolkit_path).realpath)
+ cuda_toolkit_path = _DEFAULT_CUDA_TOOLKIT_PATH
+ if _CUDA_TOOLKIT_PATH in repository_ctx.os.environ:
+ cuda_toolkit_path = repository_ctx.os.environ[_CUDA_TOOLKIT_PATH].strip()
+ if not repository_ctx.path(cuda_toolkit_path).exists:
+ auto_configure_fail("Cannot find cuda toolkit path.")
+ return str(repository_ctx.path(cuda_toolkit_path).realpath)
+
def _cudnn_install_basedir(repository_ctx):
- """Finds the cudnn install directory."""
- cudnn_install_path = _DEFAULT_CUDNN_INSTALL_PATH
- if _CUDNN_INSTALL_PATH in repository_ctx.os.environ:
- cudnn_install_path = repository_ctx.os.environ[_CUDNN_INSTALL_PATH].strip()
- if not repository_ctx.path(cudnn_install_path).exists:
- auto_configure_fail("Cannot find cudnn install path.")
- return cudnn_install_path
+ """Finds the cudnn install directory."""
+ cudnn_install_path = _DEFAULT_CUDNN_INSTALL_PATH
+ if _CUDNN_INSTALL_PATH in repository_ctx.os.environ:
+ cudnn_install_path = repository_ctx.os.environ[_CUDNN_INSTALL_PATH].strip()
+ if not repository_ctx.path(cudnn_install_path).exists:
+ auto_configure_fail("Cannot find cudnn install path.")
+ return cudnn_install_path
+
def matches_version(environ_version, detected_version):
- """Checks whether the user-specified version matches the detected version.
+ """Checks whether the user-specified version matches the detected version.
- This function performs a weak matching so that if the user specifies only the
- major or major and minor versions, the versions are still considered matching
+ This function performs a weak matching so that if the user specifies only
+ the
+ major or major and minor versions, the versions are still considered
+ matching
if the version parts match. To illustrate:
environ_version detected_version result
@@ -424,25 +456,25 @@ def matches_version(environ_version, detected_version):
variables.
detected_version: The version autodetected from the CUDA installation on
the system.
-
Returns: True if user-specified version matches detected version and False
otherwise.
- """
- environ_version_parts = environ_version.split(".")
- detected_version_parts = detected_version.split(".")
- if len(detected_version_parts) < len(environ_version_parts):
- return False
- for i, part in enumerate(detected_version_parts):
- if i >= len(environ_version_parts):
- break
- if part != environ_version_parts[i]:
- return False
- return True
+ """
+ environ_version_parts = environ_version.split(".")
+ detected_version_parts = detected_version.split(".")
+ if len(detected_version_parts) < len(environ_version_parts):
+ return False
+ for i, part in enumerate(detected_version_parts):
+ if i >= len(environ_version_parts):
+ break
+ if part != environ_version_parts[i]:
+ return False
+ return True
+
_NVCC_VERSION_PREFIX = "Cuda compilation tools, release "
def _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value):
- """Detects the version of CUDA installed on the system.
+ """Detects the version of CUDA installed on the system.
Args:
repository_ctx: The repository context.
@@ -452,64 +484,61 @@ def _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value):
String containing the version of CUDA.
"""
- # Run nvcc --version and find the line containing the CUDA version.
- nvcc_path = repository_ctx.path("%s/bin/nvcc%s" %
- (
- cuda_toolkit_path,
- ".exe" if cpu_value == "Windows" else "",
- ))
- if not nvcc_path.exists:
- auto_configure_fail("Cannot find nvcc at %s" % str(nvcc_path))
- result = repository_ctx.execute([str(nvcc_path), "--version"])
- if result.stderr:
- auto_configure_fail("Error running nvcc --version: %s" % result.stderr)
- lines = result.stdout.splitlines()
- version_line = lines[len(lines) - 1]
- if version_line.find(_NVCC_VERSION_PREFIX) == -1:
- auto_configure_fail(
- "Could not parse CUDA version from nvcc --version. Got: %s" %
- result.stdout,
- )
-
- # Parse the CUDA version from the line containing the CUDA version.
- prefix_removed = version_line.replace(_NVCC_VERSION_PREFIX, "")
- parts = prefix_removed.split(",")
- if len(parts) != 2 or len(parts[0]) < 2:
- auto_configure_fail(
- "Could not parse CUDA version from nvcc --version. Got: %s" %
- result.stdout,
- )
- full_version = parts[1].strip()
- if full_version.startswith("V"):
- full_version = full_version[1:]
-
- # Check whether TF_CUDA_VERSION was set by the user and fail if it does not
- # match the detected version.
- environ_version = ""
- if _TF_CUDA_VERSION in repository_ctx.os.environ:
- environ_version = repository_ctx.os.environ[_TF_CUDA_VERSION].strip()
- if environ_version and not matches_version(environ_version, full_version):
- auto_configure_fail(
- ("CUDA version detected from nvcc (%s) does not match " +
- "TF_CUDA_VERSION (%s)") % (full_version, environ_version),
- )
-
- # We only use the version consisting of the major and minor version numbers.
- version_parts = full_version.split(".")
- if len(version_parts) < 2:
- auto_configure_fail("CUDA version detected from nvcc (%s) is incomplete.")
- if cpu_value == "Windows":
- version = "64_%s%s" % (version_parts[0], version_parts[1])
- else:
- version = "%s.%s" % (version_parts[0], version_parts[1])
- return version
+ # Run nvcc --version and find the line containing the CUDA version.
+ nvcc_path = repository_ctx.path("%s/bin/nvcc%s" % (
+ cuda_toolkit_path,
+ ".exe" if cpu_value == "Windows" else "",
+ ))
+ if not nvcc_path.exists:
+ auto_configure_fail("Cannot find nvcc at %s" % str(nvcc_path))
+ result = repository_ctx.execute([str(nvcc_path), "--version"])
+ if result.stderr:
+ auto_configure_fail("Error running nvcc --version: %s" % result.stderr)
+ lines = result.stdout.splitlines()
+ version_line = lines[len(lines) - 1]
+ if version_line.find(_NVCC_VERSION_PREFIX) == -1:
+ auto_configure_fail(
+ "Could not parse CUDA version from nvcc --version. Got: %s" %
+ result.stdout,)
+
+ # Parse the CUDA version from the line containing the CUDA version.
+ prefix_removed = version_line.replace(_NVCC_VERSION_PREFIX, "")
+ parts = prefix_removed.split(",")
+ if len(parts) != 2 or len(parts[0]) < 2:
+ auto_configure_fail(
+ "Could not parse CUDA version from nvcc --version. Got: %s" %
+ result.stdout,)
+ full_version = parts[1].strip()
+ if full_version.startswith("V"):
+ full_version = full_version[1:]
+
+ # Check whether TF_CUDA_VERSION was set by the user and fail if it does not
+ # match the detected version.
+ environ_version = ""
+ if _TF_CUDA_VERSION in repository_ctx.os.environ:
+ environ_version = repository_ctx.os.environ[_TF_CUDA_VERSION].strip()
+ if environ_version and not matches_version(environ_version, full_version):
+ auto_configure_fail(
+ ("CUDA version detected from nvcc (%s) does not match " +
+ "TF_CUDA_VERSION (%s)") % (full_version, environ_version),)
+
+ # We only use the version consisting of the major and minor version numbers.
+ version_parts = full_version.split(".")
+ if len(version_parts) < 2:
+ auto_configure_fail("CUDA version detected from nvcc (%s) is incomplete.")
+ if cpu_value == "Windows":
+ version = "64_%s%s" % (version_parts[0], version_parts[1])
+ else:
+ version = "%s.%s" % (version_parts[0], version_parts[1])
+ return version
+
_DEFINE_CUDNN_MAJOR = "#define CUDNN_MAJOR"
_DEFINE_CUDNN_MINOR = "#define CUDNN_MINOR"
_DEFINE_CUDNN_PATCHLEVEL = "#define CUDNN_PATCHLEVEL"
def find_cuda_define(repository_ctx, header_dir, header_file, define):
- """Returns the value of a #define in a header file.
+ """Returns the value of a #define in a header file.
Greps through a header file and returns the value of the specified #define.
If the #define is not found, then raise an error.
@@ -524,52 +553,52 @@ def find_cuda_define(repository_ctx, header_dir, header_file, define):
The value of the #define found in the header.
"""
- # Confirm location of the header and grep for the line defining the macro.
- h_path = repository_ctx.path("%s/%s" % (header_dir, header_file))
- if not h_path.exists:
- auto_configure_fail("Cannot find %s at %s" % (header_file, str(h_path)))
- result = repository_ctx.execute(
- # Grep one more lines as some #defines are splitted into two lines.
- ["grep", "--color=never", "-A1", "-E", define, str(h_path)],
- )
- if result.stderr:
- auto_configure_fail("Error reading %s: %s" % (str(h_path), result.stderr))
-
- # Parse the version from the line defining the macro.
- if result.stdout.find(define) == -1:
- auto_configure_fail("Cannot find line containing '%s' in %s" %
- (define, h_path))
-
- # Split results to lines
- lines = result.stdout.split("\n")
- num_lines = len(lines)
- for l in range(num_lines):
- line = lines[l]
- if define in line: # Find the line with define
- version = line
- if l != num_lines - 1 and line[-1] == "\\": # Add next line, if multiline
- version = version[:-1] + lines[l + 1]
- break
-
- # Remove any comments
- version = version.split("//")[0]
-
- # Remove define name
- version = version.replace(define, "").strip()
-
- # Remove the code after the version number.
- version_end = version.find(" ")
- if version_end != -1:
- if version_end == 0:
- auto_configure_fail(
- "Cannot extract the version from line containing '%s' in %s" %
- (define, str(h_path)),
- )
- version = version[:version_end].strip()
- return version
+ # Confirm location of the header and grep for the line defining the macro.
+ h_path = repository_ctx.path("%s/%s" % (header_dir, header_file))
+ if not h_path.exists:
+ auto_configure_fail("Cannot find %s at %s" % (header_file, str(h_path)))
+ result = repository_ctx.execute(
+ # Grep one more lines as some #defines are splitted into two lines.
+ ["grep", "--color=never", "-A1", "-E", define,
+ str(h_path)],)
+ if result.stderr:
+ auto_configure_fail("Error reading %s: %s" % (str(h_path), result.stderr))
+
+ # Parse the version from the line defining the macro.
+ if result.stdout.find(define) == -1:
+ auto_configure_fail(
+ "Cannot find line containing '%s' in %s" % (define, h_path))
+
+ # Split results to lines
+ lines = result.stdout.split("\n")
+ num_lines = len(lines)
+ for l in range(num_lines):
+ line = lines[l]
+ if define in line: # Find the line with define
+ version = line
+ if l != num_lines - 1 and line[-1] == "\\": # Add next line, if multiline
+ version = version[:-1] + lines[l + 1]
+ break
+
+ # Remove any comments
+ version = version.split("//")[0]
+
+ # Remove define name
+ version = version.replace(define, "").strip()
+
+ # Remove the code after the version number.
+ version_end = version.find(" ")
+ if version_end != -1:
+ if version_end == 0:
+ auto_configure_fail(
+ "Cannot extract the version from line containing '%s' in %s" %
+ (define, str(h_path)),)
+ version = version[:version_end].strip()
+ return version
+
def _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value):
- """Detects the version of cuDNN installed on the system.
+ """Detects the version of cuDNN installed on the system.
Args:
repository_ctx: The repository context.
@@ -579,68 +608,68 @@ def _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value):
Returns:
A string containing the version of cuDNN.
"""
- cudnn_header_dir = _find_cudnn_header_dir(
- repository_ctx,
- cudnn_install_basedir,
- )
- major_version = find_cuda_define(
- repository_ctx,
- cudnn_header_dir,
- "cudnn.h",
- _DEFINE_CUDNN_MAJOR,
- )
- minor_version = find_cuda_define(
- repository_ctx,
- cudnn_header_dir,
- "cudnn.h",
- _DEFINE_CUDNN_MINOR,
- )
- patch_version = find_cuda_define(
- repository_ctx,
- cudnn_header_dir,
- "cudnn.h",
- _DEFINE_CUDNN_PATCHLEVEL,
- )
- full_version = "%s.%s.%s" % (major_version, minor_version, patch_version)
-
- # Check whether TF_CUDNN_VERSION was set by the user and fail if it does not
- # match the detected version.
- environ_version = ""
- if _TF_CUDNN_VERSION in repository_ctx.os.environ:
- environ_version = repository_ctx.os.environ[_TF_CUDNN_VERSION].strip()
- if environ_version and not matches_version(environ_version, full_version):
- cudnn_h_path = repository_ctx.path("%s/include/cudnn.h" %
- cudnn_install_basedir)
- auto_configure_fail(
- ("cuDNN version detected from %s (%s) does not match " +
- "TF_CUDNN_VERSION (%s)") %
- (str(cudnn_h_path), full_version, environ_version),
- )
-
- # We only use the major version since we use the libcudnn libraries that are
- # only versioned with the major version (e.g. libcudnn.so.5).
- version = major_version
- if cpu_value == "Windows":
- version = "64_" + version
- return version
-
-def _compute_capabilities(repository_ctx):
- """Returns a list of strings representing cuda compute capabilities."""
- if _TF_CUDA_COMPUTE_CAPABILITIES not in repository_ctx.os.environ:
- return _DEFAULT_CUDA_COMPUTE_CAPABILITIES
- capabilities_str = repository_ctx.os.environ[_TF_CUDA_COMPUTE_CAPABILITIES]
- capabilities = capabilities_str.split(",")
- for capability in capabilities:
- # Workaround for Skylark's lack of support for regex. This check should
- # be equivalent to checking:
- # if re.match("[0-9]+.[0-9]+", capability) == None:
- parts = capability.split(".")
- if len(parts) != 2 or not parts[0].isdigit() or not parts[1].isdigit():
- auto_configure_fail("Invalid compute capability: %s" % capability)
- return capabilities
+ cudnn_header_dir = _find_cudnn_header_dir(
+ repository_ctx,
+ cudnn_install_basedir,
+ )
+ major_version = find_cuda_define(
+ repository_ctx,
+ cudnn_header_dir,
+ "cudnn.h",
+ _DEFINE_CUDNN_MAJOR,
+ )
+ minor_version = find_cuda_define(
+ repository_ctx,
+ cudnn_header_dir,
+ "cudnn.h",
+ _DEFINE_CUDNN_MINOR,
+ )
+ patch_version = find_cuda_define(
+ repository_ctx,
+ cudnn_header_dir,
+ "cudnn.h",
+ _DEFINE_CUDNN_PATCHLEVEL,
+ )
+ full_version = "%s.%s.%s" % (major_version, minor_version, patch_version)
+
+ # Check whether TF_CUDNN_VERSION was set by the user and fail if it does not
+ # match the detected version.
+ environ_version = ""
+ if _TF_CUDNN_VERSION in repository_ctx.os.environ:
+ environ_version = repository_ctx.os.environ[_TF_CUDNN_VERSION].strip()
+ if environ_version and not matches_version(environ_version, full_version):
+ cudnn_h_path = repository_ctx.path(
+ "%s/include/cudnn.h" % cudnn_install_basedir)
+ auto_configure_fail(("cuDNN version detected from %s (%s) does not match " +
+ "TF_CUDNN_VERSION (%s)") %
+ (str(cudnn_h_path), full_version, environ_version),)
+
+ # We only use the major version since we use the libcudnn libraries that are
+ # only versioned with the major version (e.g. libcudnn.so.5).
+ version = major_version
+ if cpu_value == "Windows":
+ version = "64_" + version
+ return version
+
+
+def compute_capabilities(repository_ctx):
+ """Returns a list of strings representing cuda compute capabilities."""
+ if _TF_CUDA_COMPUTE_CAPABILITIES not in repository_ctx.os.environ:
+ return _DEFAULT_CUDA_COMPUTE_CAPABILITIES
+ capabilities_str = repository_ctx.os.environ[_TF_CUDA_COMPUTE_CAPABILITIES]
+ capabilities = capabilities_str.split(",")
+ for capability in capabilities:
+ # Workaround for Skylark's lack of support for regex. This check should
+ # be equivalent to checking:
+ # if re.match("[0-9]+.[0-9]+", capability) == None:
+ parts = capability.split(".")
+ if len(parts) != 2 or not parts[0].isdigit() or not parts[1].isdigit():
+ auto_configure_fail("Invalid compute capability: %s" % capability)
+ return capabilities
+
def get_cpu_value(repository_ctx):
- """Returns the name of the host operating system.
+ """Returns the name of the host operating system.
Args:
repository_ctx: The repository context.
@@ -648,20 +677,22 @@ def get_cpu_value(repository_ctx):
Returns:
A string containing the name of the host operating system.
"""
- os_name = repository_ctx.os.name.lower()
- if os_name.startswith("mac os"):
- return "Darwin"
- if os_name.find("windows") != -1:
- return "Windows"
- result = repository_ctx.execute(["uname", "-s"])
- return result.stdout.strip()
+ os_name = repository_ctx.os.name.lower()
+ if os_name.startswith("mac os"):
+ return "Darwin"
+ if os_name.find("windows") != -1:
+ return "Windows"
+ result = repository_ctx.execute(["uname", "-s"])
+ return result.stdout.strip()
+
def _is_windows(repository_ctx):
- """Returns true if the host operating system is windows."""
- return get_cpu_value(repository_ctx) == "Windows"
+ """Returns true if the host operating system is windows."""
+ return get_cpu_value(repository_ctx) == "Windows"
+
def _lib_name(lib, cpu_value, version = "", static = False):
- """Constructs the platform-specific name of a library.
+ """Constructs the platform-specific name of a library.
Args:
lib: The name of the library, such as "cudart"
@@ -672,23 +703,24 @@ def _lib_name(lib, cpu_value, version = "", static = False):
Returns:
The platform-specific name of the library.
"""
- if cpu_value in ("Linux", "FreeBSD"):
- if static:
- return "lib%s.a" % lib
- else:
- if version:
- version = ".%s" % version
- return "lib%s.so%s" % (lib, version)
- elif cpu_value == "Windows":
- return "%s.lib" % lib
- elif cpu_value == "Darwin":
- if static:
- return "lib%s.a" % lib
- elif version:
- version = ".%s" % version
- return "lib%s%s.dylib" % (lib, version)
+ if cpu_value in ("Linux", "FreeBSD"):
+ if static:
+ return "lib%s.a" % lib
else:
- auto_configure_fail("Invalid cpu_value: %s" % cpu_value)
+ if version:
+ version = ".%s" % version
+ return "lib%s.so%s" % (lib, version)
+ elif cpu_value == "Windows":
+ return "%s.lib" % lib
+ elif cpu_value == "Darwin":
+ if static:
+ return "lib%s.a" % lib
+ elif version:
+ version = ".%s" % version
+ return "lib%s%s.dylib" % (lib, version)
+ else:
+ auto_configure_fail("Invalid cpu_value: %s" % cpu_value)
+
def _find_cuda_lib(
lib,
@@ -697,7 +729,7 @@ def _find_cuda_lib(
basedir,
version = "",
static = False):
- """Finds the given CUDA or cuDNN library on the system.
+ """Finds the given CUDA or cuDNN library on the system.
Args:
lib: The name of the library, such as "cudart"
@@ -712,15 +744,16 @@ def _find_cuda_lib(
file_name: The basename of the library found on the system.
path: The full path to the library.
"""
- file_name = _lib_name(lib, cpu_value, version, static)
- for relative_path in CUDA_LIB_PATHS:
- path = repository_ctx.path("%s/%s%s" % (basedir, relative_path, file_name))
- if path.exists:
- return struct(file_name = file_name, path = str(path.realpath))
- auto_configure_fail("Cannot find cuda library %s" % file_name)
+ file_name = _lib_name(lib, cpu_value, version, static)
+ for relative_path in CUDA_LIB_PATHS:
+ path = repository_ctx.path("%s/%s%s" % (basedir, relative_path, file_name))
+ if path.exists:
+ return struct(file_name=file_name, path=str(path.realpath))
+ auto_configure_fail("Cannot find cuda library %s" % file_name)
+
def _find_cupti_header_dir(repository_ctx, cuda_config):
- """Returns the path to the directory containing cupti.h
+ """Returns the path to the directory containing cupti.h
On most systems, the cupti library is not installed in the same directory as
the other CUDA libraries but rather in a special extras/CUPTI directory.
@@ -732,14 +765,17 @@ def _find_cupti_header_dir(repository_ctx, cuda_config):
Returns:
The path of the directory containing the cupti header.
"""
- cuda_toolkit_path = cuda_config.cuda_toolkit_path
- for relative_path in CUPTI_HEADER_PATHS:
- if repository_ctx.path("%s/%scupti.h" % (cuda_toolkit_path, relative_path)).exists:
- return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1]
- auto_configure_fail("Cannot find cupti.h under %s" % ", ".join([cuda_toolkit_path + "/" + s for s in CUPTI_HEADER_PATHS]))
+ cuda_toolkit_path = cuda_config.cuda_toolkit_path
+ for relative_path in CUPTI_HEADER_PATHS:
+ if repository_ctx.path(
+ "%s/%scupti.h" % (cuda_toolkit_path, relative_path)).exists:
+ return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1]
+ auto_configure_fail("Cannot find cupti.h under %s" % ", ".join(
+ [cuda_toolkit_path + "/" + s for s in CUPTI_HEADER_PATHS]))
+
def _find_cupti_lib(repository_ctx, cuda_config):
- """Finds the cupti library on the system.
+ """Finds the cupti library on the system.
On most systems, the cupti library is not installed in the same directory as
the other CUDA libraries but rather in a special extras/CUPTI directory.
@@ -753,23 +789,23 @@ def _find_cupti_lib(repository_ctx, cuda_config):
file_name: The basename of the library found on the system.
path: The full path to the library.
"""
- file_name = _lib_name(
- "cupti",
- cuda_config.cpu_value,
- cuda_config.cuda_version,
- )
- cuda_toolkit_path = cuda_config.cuda_toolkit_path
- for relative_path in CUPTI_LIB_PATHS:
- path = repository_ctx.path(
- "%s/%s%s" % (cuda_toolkit_path, relative_path, file_name),
- )
- if path.exists:
- return struct(file_name = file_name, path = str(path.realpath))
+ file_name = _lib_name(
+ "cupti",
+ cuda_config.cpu_value,
+ cuda_config.cuda_version,
+ )
+ cuda_toolkit_path = cuda_config.cuda_toolkit_path
+ for relative_path in CUPTI_LIB_PATHS:
+ path = repository_ctx.path(
+ "%s/%s%s" % (cuda_toolkit_path, relative_path, file_name),)
+ if path.exists:
+ return struct(file_name=file_name, path=str(path.realpath))
+
+ auto_configure_fail("Cannot find cupti library %s" % file_name)
- auto_configure_fail("Cannot find cupti library %s" % file_name)
def _find_libs(repository_ctx, cuda_config):
- """Returns the CUDA and cuDNN libraries on the system.
+ """Returns the CUDA and cuDNN libraries on the system.
Args:
repository_ctx: The repository context.
@@ -778,64 +814,75 @@ def _find_libs(repository_ctx, cuda_config):
Returns:
Map of library names to structs of filename and path.
"""
- cpu_value = cuda_config.cpu_value
- return {
- "cuda": _find_cuda_lib("cuda", repository_ctx, cpu_value, cuda_config.cuda_toolkit_path),
- "cudart": _find_cuda_lib(
- "cudart",
- repository_ctx,
- cpu_value,
- cuda_config.cuda_toolkit_path,
- cuda_config.cuda_version,
- ),
- "cudart_static": _find_cuda_lib(
- "cudart_static",
- repository_ctx,
- cpu_value,
- cuda_config.cuda_toolkit_path,
- cuda_config.cuda_version,
- static = True,
- ),
- "cublas": _find_cuda_lib(
- "cublas",
- repository_ctx,
- cpu_value,
- cuda_config.cuda_toolkit_path,
- cuda_config.cuda_version,
- ),
- "cusolver": _find_cuda_lib(
- "cusolver",
- repository_ctx,
- cpu_value,
- cuda_config.cuda_toolkit_path,
- cuda_config.cuda_version,
- ),
- "curand": _find_cuda_lib(
- "curand",
- repository_ctx,
- cpu_value,
- cuda_config.cuda_toolkit_path,
- cuda_config.cuda_version,
- ),
- "cufft": _find_cuda_lib(
- "cufft",
- repository_ctx,
- cpu_value,
- cuda_config.cuda_toolkit_path,
- cuda_config.cuda_version,
- ),
- "cudnn": _find_cuda_lib(
- "cudnn",
- repository_ctx,
- cpu_value,
- cuda_config.cudnn_install_basedir,
- cuda_config.cudnn_version,
- ),
- "cupti": _find_cupti_lib(repository_ctx, cuda_config),
- }
+ cpu_value = cuda_config.cpu_value
+ return {
+ "cuda":
+ _find_cuda_lib("cuda", repository_ctx, cpu_value,
+ cuda_config.cuda_toolkit_path),
+ "cudart":
+ _find_cuda_lib(
+ "cudart",
+ repository_ctx,
+ cpu_value,
+ cuda_config.cuda_toolkit_path,
+ cuda_config.cuda_version,
+ ),
+ "cudart_static":
+ _find_cuda_lib(
+ "cudart_static",
+ repository_ctx,
+ cpu_value,
+ cuda_config.cuda_toolkit_path,
+ cuda_config.cuda_version,
+ static=True,
+ ),
+ "cublas":
+ _find_cuda_lib(
+ "cublas",
+ repository_ctx,
+ cpu_value,
+ cuda_config.cuda_toolkit_path,
+ cuda_config.cuda_version,
+ ),
+ "cusolver":
+ _find_cuda_lib(
+ "cusolver",
+ repository_ctx,
+ cpu_value,
+ cuda_config.cuda_toolkit_path,
+ cuda_config.cuda_version,
+ ),
+ "curand":
+ _find_cuda_lib(
+ "curand",
+ repository_ctx,
+ cpu_value,
+ cuda_config.cuda_toolkit_path,
+ cuda_config.cuda_version,
+ ),
+ "cufft":
+ _find_cuda_lib(
+ "cufft",
+ repository_ctx,
+ cpu_value,
+ cuda_config.cuda_toolkit_path,
+ cuda_config.cuda_version,
+ ),
+ "cudnn":
+ _find_cuda_lib(
+ "cudnn",
+ repository_ctx,
+ cpu_value,
+ cuda_config.cudnn_install_basedir,
+ cuda_config.cudnn_version,
+ ),
+ "cupti":
+ _find_cupti_lib(repository_ctx, cuda_config),
+ }
+
def _find_cuda_include_path(repository_ctx, cuda_config):
- """Returns the path to the directory containing cuda.h
+ """Returns the path to the directory containing cuda.h
Args:
repository_ctx: The repository context.
@@ -844,14 +891,16 @@ def _find_cuda_include_path(repository_ctx, cuda_config):
Returns:
The path of the directory containing the CUDA headers.
"""
- cuda_toolkit_path = cuda_config.cuda_toolkit_path
- for relative_path in CUDA_INCLUDE_PATHS:
- if repository_ctx.path("%s/%scuda.h" % (cuda_toolkit_path, relative_path)).exists:
- return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1]
- auto_configure_fail("Cannot find cuda.h under %s" % cuda_toolkit_path)
+ cuda_toolkit_path = cuda_config.cuda_toolkit_path
+ for relative_path in CUDA_INCLUDE_PATHS:
+ if repository_ctx.path(
+ "%s/%scuda.h" % (cuda_toolkit_path, relative_path)).exists:
+ return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1]
+ auto_configure_fail("Cannot find cuda.h under %s" % cuda_toolkit_path)
+
def _find_cudnn_header_dir(repository_ctx, cudnn_install_basedir):
- """Returns the path to the directory containing cudnn.h
+ """Returns the path to the directory containing cudnn.h
Args:
repository_ctx: The repository context.
@@ -861,15 +910,17 @@ def _find_cudnn_header_dir(repository_ctx, cudnn_install_basedir):
Returns:
The path of the directory containing the cudnn header.
"""
- for relative_path in CUDA_INCLUDE_PATHS:
- if repository_ctx.path("%s/%scudnn.h" % (cudnn_install_basedir, relative_path)).exists:
- return ("%s/%s" % (cudnn_install_basedir, relative_path))[:-1]
- if repository_ctx.path("/usr/include/cudnn.h").exists:
- return "/usr/include"
- auto_configure_fail("Cannot find cudnn.h under %s" % cudnn_install_basedir)
+ for relative_path in CUDA_INCLUDE_PATHS:
+ if repository_ctx.path(
+ "%s/%scudnn.h" % (cudnn_install_basedir, relative_path)).exists:
+ return ("%s/%s" % (cudnn_install_basedir, relative_path))[:-1]
+ if repository_ctx.path("/usr/include/cudnn.h").exists:
+ return "/usr/include"
+ auto_configure_fail("Cannot find cudnn.h under %s" % cudnn_install_basedir)
+
def _find_nvvm_libdevice_dir(repository_ctx, cuda_config):
- """Returns the path to the directory containing libdevice in bitcode format.
+ """Returns the path to the directory containing libdevice in bitcode format.
Args:
repository_ctx: The repository context.
@@ -878,19 +929,23 @@ def _find_nvvm_libdevice_dir(repository_ctx, cuda_config):
Returns:
The path of the directory containing the CUDA headers.
"""
- cuda_toolkit_path = cuda_config.cuda_toolkit_path
- for libdevice_file in NVVM_LIBDEVICE_FILES:
- for relative_path in NVVM_LIBDEVICE_PATHS:
- if repository_ctx.path("%s/%s%s" % (cuda_toolkit_path, relative_path, libdevice_file)).exists:
- return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1]
- auto_configure_fail("Cannot find libdevice*.bc files under %s" % cuda_toolkit_path)
+ cuda_toolkit_path = cuda_config.cuda_toolkit_path
+ for libdevice_file in NVVM_LIBDEVICE_FILES:
+ for relative_path in NVVM_LIBDEVICE_PATHS:
+ if repository_ctx.path("%s/%s%s" % (cuda_toolkit_path, relative_path,
+ libdevice_file)).exists:
+ return ("%s/%s" % (cuda_toolkit_path, relative_path))[:-1]
+ auto_configure_fail(
+ "Cannot find libdevice*.bc files under %s" % cuda_toolkit_path)
+
def _cudart_static_linkopt(cpu_value):
- """Returns additional platform-specific linkopts for cudart."""
- return "" if cpu_value == "Darwin" else "\"-lrt\","
+ """Returns additional platform-specific linkopts for cudart."""
+ return "" if cpu_value == "Darwin" else "\"-lrt\","
+
def _get_cuda_config(repository_ctx):
- """Detects and returns information about the CUDA installation on the system.
+ """Detects and returns information about the CUDA installation on the system.
Args:
repository_ctx: The repository context.
@@ -904,35 +959,39 @@ def _get_cuda_config(repository_ctx):
compute_capabilities: A list of the system's CUDA compute capabilities.
cpu_value: The name of the host operating system.
"""
- cpu_value = get_cpu_value(repository_ctx)
- cuda_toolkit_path = _cuda_toolkit_path(repository_ctx)
- cuda_version = _cuda_version(repository_ctx, cuda_toolkit_path, cpu_value)
- cudnn_install_basedir = _cudnn_install_basedir(repository_ctx)
- cudnn_version = _cudnn_version(repository_ctx, cudnn_install_basedir, cpu_value)
- return struct(
- cuda_toolkit_path = cuda_toolkit_path,
- cudnn_install_basedir = cudnn_install_basedir,
- cuda_version = cuda_version,
- cudnn_version = cudnn_version,
- compute_capabilities = _compute_capabilities(repository_ctx),
- cpu_value = cpu_value,
- )
+ cpu_value = get_cpu_value(repository_ctx)
+ toolkit_path = cuda_toolkit_path(repository_ctx)
+ cuda_version = _cuda_version(repository_ctx, toolkit_path, cpu_value)
+ cudnn_install_basedir = _cudnn_install_basedir(repository_ctx)
+ cudnn_version = _cudnn_version(repository_ctx, cudnn_install_basedir,
+ cpu_value)
+ return struct(
+ cuda_toolkit_path=toolkit_path,
+ cudnn_install_basedir=cudnn_install_basedir,
+ cuda_version=cuda_version,
+ cudnn_version=cudnn_version,
+ compute_capabilities=compute_capabilities(repository_ctx),
+ cpu_value=cpu_value,
+ )
+
def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
- if not out:
- out = tpl.replace(":", "/")
- repository_ctx.template(
- out,
- Label("//third_party/gpus/%s.tpl" % tpl),
- substitutions,
- )
+ if not out:
+ out = tpl.replace(":", "/")
+ repository_ctx.template(
+ out,
+ Label("//third_party/gpus/%s.tpl" % tpl),
+ substitutions,
+ )
+
def _file(repository_ctx, label):
- repository_ctx.template(
- label.replace(":", "/"),
- Label("//third_party/gpus/%s.tpl" % label),
- {},
- )
+ repository_ctx.template(
+ label.replace(":", "/"),
+ Label("//third_party/gpus/%s.tpl" % label),
+ {},
+ )
+
_DUMMY_CROSSTOOL_BZL_FILE = """
def error_gpu_disabled():
@@ -960,81 +1019,99 @@ error_gpu_disabled()
"""
def _create_dummy_repository(repository_ctx):
- cpu_value = get_cpu_value(repository_ctx)
+ cpu_value = get_cpu_value(repository_ctx)
+
+ # Set up BUILD file for cuda/.
+ _tpl(
+ repository_ctx,
+ "cuda:build_defs.bzl",
+ {
+ "%{cuda_is_configured}": "False",
+ "%{cuda_extra_copts}": "[]",
+ },
+ )
+ _tpl(
+ repository_ctx,
+ "cuda:BUILD",
+ {
+ "%{cuda_driver_lib}":
+ _lib_name("cuda", cpu_value),
+ "%{cudart_static_lib}":
+ _lib_name(
+ "cudart_static",
+ cpu_value,
+ static=True,
+ ),
+ "%{cudart_static_linkopt}":
+ _cudart_static_linkopt(cpu_value),
+ "%{cudart_lib}":
+ _lib_name("cudart", cpu_value),
+ "%{cublas_lib}":
+ _lib_name("cublas", cpu_value),
+ "%{cusolver_lib}":
+ _lib_name("cusolver", cpu_value),
+ "%{cudnn_lib}":
+ _lib_name("cudnn", cpu_value),
+ "%{cufft_lib}":
+ _lib_name("cufft", cpu_value),
+ "%{curand_lib}":
+ _lib_name("curand", cpu_value),
+ "%{cupti_lib}":
+ _lib_name("cupti", cpu_value),
+ "%{cuda_include_genrules}":
+ "",
+ "%{cuda_headers}":
+ "",
+ },
+ )
- # Set up BUILD file for cuda/.
- _tpl(
- repository_ctx,
- "cuda:build_defs.bzl",
- {
- "%{cuda_is_configured}": "False",
- "%{cuda_extra_copts}": "[]",
- },
- )
- _tpl(
- repository_ctx,
- "cuda:BUILD",
- {
- "%{cuda_driver_lib}": _lib_name("cuda", cpu_value),
- "%{cudart_static_lib}": _lib_name(
- "cudart_static",
- cpu_value,
- static = True,
- ),
- "%{cudart_static_linkopt}": _cudart_static_linkopt(cpu_value),
- "%{cudart_lib}": _lib_name("cudart", cpu_value),
- "%{cublas_lib}": _lib_name("cublas", cpu_value),
- "%{cusolver_lib}": _lib_name("cusolver", cpu_value),
- "%{cudnn_lib}": _lib_name("cudnn", cpu_value),
- "%{cufft_lib}": _lib_name("cufft", cpu_value),
- "%{curand_lib}": _lib_name("curand", cpu_value),
- "%{cupti_lib}": _lib_name("cupti", cpu_value),
- "%{cuda_include_genrules}": "",
- "%{cuda_headers}": "",
- },
- )
+ # Create dummy files for the CUDA toolkit since they are still required by
+ # tensorflow/core/platform/default/build_config:cuda.
+ repository_ctx.file("cuda/cuda/include/cuda.h", "")
+ repository_ctx.file("cuda/cuda/include/cublas.h", "")
+ repository_ctx.file("cuda/cuda/include/cudnn.h", "")
+ repository_ctx.file("cuda/cuda/extras/CUPTI/include/cupti.h", "")
+ repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cuda", cpu_value))
+ repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cudart", cpu_value))
+ repository_ctx.file(
+ "cuda/cuda/lib/%s" % _lib_name("cudart_static", cpu_value))
+ repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cublas", cpu_value))
+ repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cusolver", cpu_value))
+ repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cudnn", cpu_value))
+ repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("curand", cpu_value))
+ repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cufft", cpu_value))
+ repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cupti", cpu_value))
+
+ # Set up cuda_config.h, which is used by
+ # tensorflow/stream_executor/dso_loader.cc.
+ _tpl(
+ repository_ctx,
+ "cuda:cuda_config.h",
+ {
+ "%{cuda_version}":
+ _DEFAULT_CUDA_VERSION,
+ "%{cudnn_version}":
+ _DEFAULT_CUDNN_VERSION,
+ "%{cuda_compute_capabilities}":
+ ",".join([
+ "CudaVersion(\"%s\")" % c
+ for c in _DEFAULT_CUDA_COMPUTE_CAPABILITIES
+ ]),
+ "%{cuda_toolkit_path}":
+ _DEFAULT_CUDA_TOOLKIT_PATH,
+ },
+ "cuda/cuda/cuda_config.h",
+ )
- # Create dummy files for the CUDA toolkit since they are still required by
- # tensorflow/core/platform/default/build_config:cuda.
- repository_ctx.file("cuda/cuda/include/cuda.h", "")
- repository_ctx.file("cuda/cuda/include/cublas.h", "")
- repository_ctx.file("cuda/cuda/include/cudnn.h", "")
- repository_ctx.file("cuda/cuda/extras/CUPTI/include/cupti.h", "")
- repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cuda", cpu_value))
- repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cudart", cpu_value))
- repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cudart_static", cpu_value))
- repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cublas", cpu_value))
- repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cusolver", cpu_value))
- repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cudnn", cpu_value))
- repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("curand", cpu_value))
- repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cufft", cpu_value))
- repository_ctx.file("cuda/cuda/lib/%s" % _lib_name("cupti", cpu_value))
-
- # Set up cuda_config.h, which is used by
- # tensorflow/stream_executor/dso_loader.cc.
- _tpl(
- repository_ctx,
- "cuda:cuda_config.h",
- {
- "%{cuda_version}": _DEFAULT_CUDA_VERSION,
- "%{cudnn_version}": _DEFAULT_CUDNN_VERSION,
- "%{cuda_compute_capabilities}": ",".join([
- "CudaVersion(\"%s\")" % c
- for c in _DEFAULT_CUDA_COMPUTE_CAPABILITIES
- ]),
- "%{cuda_toolkit_path}": _DEFAULT_CUDA_TOOLKIT_PATH,
- },
- "cuda/cuda/cuda_config.h",
- )
+ # If cuda_configure is not configured to build with GPU support, and the user
+ # attempts to build with --config=cuda, add a dummy build rule to intercept
+ # this and fail with an actionable error message.
+ repository_ctx.file(
+ "crosstool/error_gpu_disabled.bzl",
+ _DUMMY_CROSSTOOL_BZL_FILE,
+ )
+ repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE)
- # If cuda_configure is not configured to build with GPU support, and the user
- # attempts to build with --config=cuda, add a dummy build rule to intercept
- # this and fail with an actionable error message.
- repository_ctx.file(
- "crosstool/error_gpu_disabled.bzl",
- _DUMMY_CROSSTOOL_BZL_FILE,
- )
- repository_ctx.file("crosstool/BUILD", _DUMMY_CROSSTOOL_BUILD_FILE)
def _execute(
repository_ctx,
@@ -1042,35 +1119,35 @@ def _execute(
error_msg = None,
error_details = None,
empty_stdout_fine = False):
- """Executes an arbitrary shell command.
+ """Executes an arbitrary shell command.
Args:
repository_ctx: the repository_ctx object
cmdline: list of strings, the command to execute
error_msg: string, a summary of the error if the command fails
error_details: string, details about the error or steps to fix it
- empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise
- it's an error
- Return:
- the result of repository_ctx.execute(cmdline)
- """
- result = repository_ctx.execute(cmdline)
- if result.stderr or not (empty_stdout_fine or result.stdout):
- auto_configure_fail(
- "\n".join([
- error_msg.strip() if error_msg else "Repository command failed",
- result.stderr.strip(),
- error_details if error_details else "",
- ]),
- )
- return result
+ empty_stdout_fine: bool, if True, an empty stdout result is fine,
+ otherwise it's an error
+ Return: the result of repository_ctx.execute(cmdline)
+ """
+ result = repository_ctx.execute(cmdline)
+ if result.stderr or not (empty_stdout_fine or result.stdout):
+ auto_configure_fail(
+ "\n".join([
+ error_msg.strip() if error_msg else "Repository command failed",
+ result.stderr.strip(),
+ error_details if error_details else "",
+ ]),)
+ return result
+
def _norm_path(path):
- """Returns a path with '/' and remove the trailing slash."""
- path = path.replace("\\", "/")
- if path[-1] == "/":
- path = path[:-1]
- return path
+ """Returns a path with '/' and remove the trailing slash."""
+ path = path.replace("\\", "/")
+ if path[-1] == "/":
+ path = path[:-1]
+ return path
+
def symlink_genrule_for_dir(
repository_ctx,
@@ -1079,167 +1156,174 @@ def symlink_genrule_for_dir(
genrule_name,
src_files = [],
dest_files = []):
- """Returns a genrule to symlink(or copy if on Windows) a set of files.
+ """Returns a genrule to symlink(or copy if on Windows) a set of files.
If src_dir is passed, files will be read from the given directory; otherwise
we assume files are in src_files and dest_files
"""
- if src_dir != None:
- src_dir = _norm_path(src_dir)
- dest_dir = _norm_path(dest_dir)
- files = "\n".join(sorted(_read_dir(repository_ctx, src_dir).splitlines()))
-
- # Create a list with the src_dir stripped to use for outputs.
- dest_files = files.replace(src_dir, "").splitlines()
- src_files = files.splitlines()
- command = []
- if not _is_windows(repository_ctx):
- # We clear folders that might have been generated previously to avoid
- # undesired inclusions
- command.append('if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi')
- command.append('if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi')
- command.append('if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi')
- command.append('if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi')
- outs = []
- for i in range(len(dest_files)):
- if dest_files[i] != "":
- # If we have only one file to link we do not want to use the dest_dir, as
- # $(@D) will include the full path to the file.
- dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i]
-
- # Copy the headers to create a sandboxable setup.
- cmd = "cp -f"
- command.append(cmd + ' "%s" "%s"' % (src_files[i], dest))
- outs.append(' "' + dest_dir + dest_files[i] + '",')
- genrule = _genrule(
- src_dir,
- genrule_name,
- " && ".join(command),
- "\n".join(outs),
- )
- return genrule
+ if src_dir != None:
+ src_dir = _norm_path(src_dir)
+ dest_dir = _norm_path(dest_dir)
+ files = "\n".join(sorted(_read_dir(repository_ctx, src_dir).splitlines()))
+
+ # Create a list with the src_dir stripped to use for outputs.
+ dest_files = files.replace(src_dir, "").splitlines()
+ src_files = files.splitlines()
+ command = []
+ if not _is_windows(repository_ctx):
+ # We clear folders that might have been generated previously to avoid
+ # undesired inclusions
+ command.append('if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi')
+ command.append('if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi')
+ command.append('if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi')
+ command.append('if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi')
+ outs = []
+ for i in range(len(dest_files)):
+ if dest_files[i] != "":
+ # If we have only one file to link we do not want to use the dest_dir, as
+ # $(@D) will include the full path to the file.
+ dest = "$(@D)/" + dest_dir + dest_files[i] if len(
+ dest_files) != 1 else "$(@D)/" + dest_files[i]
+
+ # Copy the headers to create a sandboxable setup.
+ cmd = "cp -f"
+ command.append(cmd + ' "%s" "%s"' % (src_files[i], dest))
+ outs.append(' "' + dest_dir + dest_files[i] + '",')
+ genrule = _genrule(
+ src_dir,
+ genrule_name,
+ " && ".join(command),
+ "\n".join(outs),
+ )
+ return genrule
+
def _genrule(src_dir, genrule_name, command, outs):
- """Returns a string with a genrule.
+ """Returns a string with a genrule.
Genrule executes the given command and produces the given outputs.
"""
- return (
- "genrule(\n" +
- ' name = "' +
- genrule_name + '",\n' +
- " outs = [\n" +
- outs +
- "\n ],\n" +
- ' cmd = """\n' +
- command +
- '\n """,\n' +
- ")\n"
- )
+ return (
+ "genrule(\n" + ' name = "' + genrule_name + '",\n' + " outs = [\n" +
+ outs + "\n ],\n" + ' cmd = """\n' + command + '\n """,\n' + ")\n")
+
def _read_dir(repository_ctx, src_dir):
- """Returns a string with all files in a directory.
+ """Returns a string with all files in a directory.
Finds all files inside a directory, traversing subfolders and following
symlinks. The returned string contains the full path of all files
separated by line breaks.
"""
- if _is_windows(repository_ctx):
- src_dir = src_dir.replace("/", "\\")
- find_result = _execute(
- repository_ctx,
- ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"],
- empty_stdout_fine = True,
- )
+ if _is_windows(repository_ctx):
+ src_dir = src_dir.replace("/", "\\")
+ find_result = _execute(
+ repository_ctx,
+ ["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"],
+ empty_stdout_fine=True,
+ )
+
+ # src_files will be used in genrule.outs where the paths must
+ # use forward slashes.
+ result = find_result.stdout.replace("\\", "/")
+ else:
+ find_result = _execute(
+ repository_ctx,
+ ["find", src_dir, "-follow", "-type", "f"],
+ empty_stdout_fine=True,
+ )
+ result = find_result.stdout
+ return result
- # src_files will be used in genrule.outs where the paths must
- # use forward slashes.
- result = find_result.stdout.replace("\\", "/")
- else:
- find_result = _execute(
- repository_ctx,
- ["find", src_dir, "-follow", "-type", "f"],
- empty_stdout_fine = True,
- )
- result = find_result.stdout
- return result
def _flag_enabled(repository_ctx, flag_name):
- if flag_name in repository_ctx.os.environ:
- value = repository_ctx.os.environ[flag_name].strip()
- return value == "1"
- return False
+ if flag_name in repository_ctx.os.environ:
+ value = repository_ctx.os.environ[flag_name].strip()
+ return value == "1"
+ return False
+
def _use_cuda_clang(repository_ctx):
- return _flag_enabled(repository_ctx, "TF_CUDA_CLANG")
+ return _flag_enabled(repository_ctx, "TF_CUDA_CLANG")
+
def _compute_cuda_extra_copts(repository_ctx, compute_capabilities):
- if _use_cuda_clang(repository_ctx):
- capability_flags = ["--cuda-gpu-arch=sm_" +
- cap.replace(".", "") for cap in compute_capabilities]
- else:
- # Capabilities are handled in the "crosstool_wrapper_driver_is_not_gcc" for nvcc
- capability_flags = []
- return str(capability_flags)
+ if _use_cuda_clang(repository_ctx):
+ capability_flags = [
+ "--cuda-gpu-arch=sm_" + cap.replace(".", "")
+ for cap in compute_capabilities
+ ]
+ else:
+ # Capabilities are handled in the "crosstool_wrapper_driver_is_not_gcc" for nvcc
+ # TODO(csigg): Make this consistent with cuda clang and pass to crosstool.
+ capability_flags = []
+ return str(capability_flags)
+
def _create_local_cuda_repository(repository_ctx):
- """Creates the repository containing files set up to build with CUDA."""
- cuda_config = _get_cuda_config(repository_ctx)
+ """Creates the repository containing files set up to build with CUDA."""
+ cuda_config = _get_cuda_config(repository_ctx)
- cuda_include_path = _find_cuda_include_path(repository_ctx, cuda_config)
- cudnn_header_dir = _find_cudnn_header_dir(
- repository_ctx,
- cuda_config.cudnn_install_basedir,
- )
- cupti_header_dir = _find_cupti_header_dir(repository_ctx, cuda_config)
- nvvm_libdevice_dir = _find_nvvm_libdevice_dir(repository_ctx, cuda_config)
-
- # Set up symbolic links for the cuda toolkit by creating genrules to do
- # symlinking. We create one genrule for each directory we want to track under
- # cuda_toolkit_path
- cuda_toolkit_path = cuda_config.cuda_toolkit_path
- genrules = [symlink_genrule_for_dir(
- repository_ctx,
- cuda_include_path,
- "cuda/include",
- "cuda-include",
- )]
- genrules.append(symlink_genrule_for_dir(
- repository_ctx,
- nvvm_libdevice_dir,
- "cuda/nvvm/libdevice",
- "cuda-nvvm",
- ))
- genrules.append(symlink_genrule_for_dir(
- repository_ctx,
- cupti_header_dir,
- "cuda/extras/CUPTI/include",
- "cuda-extras",
- ))
-
- cuda_libs = _find_libs(repository_ctx, cuda_config)
- cuda_lib_src = []
- cuda_lib_dest = []
- for lib in cuda_libs.values():
- cuda_lib_src.append(lib.path)
- cuda_lib_dest.append("cuda/lib/" + lib.file_name)
- genrules.append(symlink_genrule_for_dir(
- repository_ctx,
- None,
- "",
- "cuda-lib",
- cuda_lib_src,
- cuda_lib_dest,
- ))
-
- # Set up the symbolic links for cudnn if cndnn was not installed to
- # CUDA_TOOLKIT_PATH.
- included_files = _read_dir(repository_ctx, cuda_include_path).replace(
- cuda_include_path,
- "",
- ).splitlines()
- if "/cudnn.h" not in included_files:
- genrules.append(symlink_genrule_for_dir(
+ cuda_include_path = _find_cuda_include_path(repository_ctx, cuda_config)
+ cudnn_header_dir = _find_cudnn_header_dir(
+ repository_ctx,
+ cuda_config.cudnn_install_basedir,
+ )
+ cupti_header_dir = _find_cupti_header_dir(repository_ctx, cuda_config)
+ nvvm_libdevice_dir = _find_nvvm_libdevice_dir(repository_ctx, cuda_config)
+
+ # Set up symbolic links for the cuda toolkit by creating genrules to do
+ # symlinking. We create one genrule for each directory we want to track under
+ # cuda_toolkit_path
+ cuda_toolkit_path = cuda_config.cuda_toolkit_path
+ genrules = [
+ symlink_genrule_for_dir(
+ repository_ctx,
+ cuda_include_path,
+ "cuda/include",
+ "cuda-include",
+ )
+ ]
+ genrules.append(
+ symlink_genrule_for_dir(
+ repository_ctx,
+ nvvm_libdevice_dir,
+ "cuda/nvvm/libdevice",
+ "cuda-nvvm",
+ ))
+ genrules.append(
+ symlink_genrule_for_dir(
+ repository_ctx,
+ cupti_header_dir,
+ "cuda/extras/CUPTI/include",
+ "cuda-extras",
+ ))
+
+ cuda_libs = _find_libs(repository_ctx, cuda_config)
+ cuda_lib_src = []
+ cuda_lib_dest = []
+ for lib in cuda_libs.values():
+ cuda_lib_src.append(lib.path)
+ cuda_lib_dest.append("cuda/lib/" + lib.file_name)
+ genrules.append(
+ symlink_genrule_for_dir(
+ repository_ctx,
+ None,
+ "",
+ "cuda-lib",
+ cuda_lib_src,
+ cuda_lib_dest,
+ ))
+
+ # Set up the symbolic links for cudnn if cndnn was not installed to
+ # CUDA_TOOLKIT_PATH.
+ included_files = _read_dir(repository_ctx, cuda_include_path).replace(
+ cuda_include_path,
+ "",
+ ).splitlines()
+ if "/cudnn.h" not in included_files:
+ genrules.append(
+ symlink_genrule_for_dir(
repository_ctx,
None,
"cuda/include/",
@@ -1247,204 +1331,229 @@ def _create_local_cuda_repository(repository_ctx):
[cudnn_header_dir + "/cudnn.h"],
["cudnn.h"],
))
- else:
- genrules.append(
- "filegroup(\n" +
- ' name = "cudnn-include",\n' +
- " srcs = [],\n" +
- ")\n",
- )
-
- # Set up BUILD file for cuda/
- _tpl(
- repository_ctx,
- "cuda:build_defs.bzl",
- {
- "%{cuda_is_configured}": "True",
- "%{cuda_extra_copts}": _compute_cuda_extra_copts(
- repository_ctx,
- cuda_config.compute_capabilities,
- ),
- },
- )
- _tpl(
- repository_ctx,
- "cuda:BUILD.windows" if _is_windows(repository_ctx) else "cuda:BUILD",
- {
- "%{cuda_driver_lib}": cuda_libs["cuda"].file_name,
- "%{cudart_static_lib}": cuda_libs["cudart_static"].file_name,
- "%{cudart_static_linkopt}": _cudart_static_linkopt(
- cuda_config.cpu_value,
- ),
- "%{cudart_lib}": cuda_libs["cudart"].file_name,
- "%{cublas_lib}": cuda_libs["cublas"].file_name,
- "%{cusolver_lib}": cuda_libs["cusolver"].file_name,
- "%{cudnn_lib}": cuda_libs["cudnn"].file_name,
- "%{cufft_lib}": cuda_libs["cufft"].file_name,
- "%{curand_lib}": cuda_libs["curand"].file_name,
- "%{cupti_lib}": cuda_libs["cupti"].file_name,
- "%{cuda_include_genrules}": "\n".join(genrules),
- "%{cuda_headers}": ('":cuda-include",\n' +
- ' ":cudnn-include",'),
- },
- "cuda/BUILD",
- )
-
- is_cuda_clang = _use_cuda_clang(repository_ctx)
+ else:
+ genrules.append(
+ "filegroup(\n" + ' name = "cudnn-include",\n' + " srcs = [],\n" +
+ ")\n",)
+
+ # Set up BUILD file for cuda/
+ _tpl(
+ repository_ctx,
+ "cuda:build_defs.bzl",
+ {
+ "%{cuda_is_configured}":
+ "True",
+ "%{cuda_extra_copts}":
+ _compute_cuda_extra_copts(
+ repository_ctx,
+ cuda_config.compute_capabilities,
+ ),
+ },
+ )
+ _tpl(
+ repository_ctx,
+ "cuda:BUILD.windows" if _is_windows(repository_ctx) else "cuda:BUILD",
+ {
+ "%{cuda_driver_lib}":
+ cuda_libs["cuda"].file_name,
+ "%{cudart_static_lib}":
+ cuda_libs["cudart_static"].file_name,
+ "%{cudart_static_linkopt}":
+ _cudart_static_linkopt(cuda_config.cpu_value,),
+ "%{cudart_lib}":
+ cuda_libs["cudart"].file_name,
+ "%{cublas_lib}":
+ cuda_libs["cublas"].file_name,
+ "%{cusolver_lib}":
+ cuda_libs["cusolver"].file_name,
+ "%{cudnn_lib}":
+ cuda_libs["cudnn"].file_name,
+ "%{cufft_lib}":
+ cuda_libs["cufft"].file_name,
+ "%{curand_lib}":
+ cuda_libs["curand"].file_name,
+ "%{cupti_lib}":
+ cuda_libs["cupti"].file_name,
+ "%{cuda_include_genrules}":
+ "\n".join(genrules),
+ "%{cuda_headers}": ('":cuda-include",\n' + ' ":cudnn-include",'
+ ),
+ },
+ "cuda/BUILD",
+ )
- should_download_clang = is_cuda_clang and _flag_enabled(
- repository_ctx,
- _TF_DOWNLOAD_CLANG,
- )
- if should_download_clang:
- download_clang(repository_ctx, "crosstool/extra_tools")
-
- # Set up crosstool/
- cc = find_cc(repository_ctx)
- cc_fullpath = cc if not should_download_clang else "crosstool/" + cc
-
- host_compiler_includes = _host_compiler_includes(repository_ctx, cc_fullpath)
- cuda_defines = {}
- # Bazel sets '-B/usr/bin' flag to workaround build errors on RHEL (see
- # https://github.com/bazelbuild/bazel/issues/760).
- # However, this stops our custom clang toolchain from picking the provided
- # LLD linker, so we're only adding '-B/usr/bin' when using non-downloaded
- # toolchain.
- # TODO: when bazel stops adding '-B/usr/bin' by default, remove this
- # flag from the CROSSTOOL completely (see
- # https://github.com/bazelbuild/bazel/issues/5634)
- if should_download_clang:
- cuda_defines["%{linker_bin_path_flag}"] = ""
- else:
- cuda_defines["%{linker_bin_path_flag}"] = 'flag: "-B/usr/bin"'
+ is_cuda_clang = _use_cuda_clang(repository_ctx)
- if is_cuda_clang:
- cuda_defines["%{host_compiler_path}"] = str(cc)
- cuda_defines["%{host_compiler_warnings}"] = """
+ should_download_clang = is_cuda_clang and _flag_enabled(
+ repository_ctx,
+ _TF_DOWNLOAD_CLANG,
+ )
+ if should_download_clang:
+ download_clang(repository_ctx, "crosstool/extra_tools")
+
+ # Set up crosstool/
+ cc = find_cc(repository_ctx)
+ cc_fullpath = cc if not should_download_clang else "crosstool/" + cc
+
+ host_compiler_includes = _host_compiler_includes(repository_ctx, cc_fullpath)
+ cuda_defines = {}
+ # Bazel sets '-B/usr/bin' flag to workaround build errors on RHEL (see
+ # https://github.com/bazelbuild/bazel/issues/760).
+ # However, this stops our custom clang toolchain from picking the provided
+ # LLD linker, so we're only adding '-B/usr/bin' when using non-downloaded
+ # toolchain.
+ # TODO: when bazel stops adding '-B/usr/bin' by default, remove this
+ # flag from the CROSSTOOL completely (see
+ # https://github.com/bazelbuild/bazel/issues/5634)
+ if should_download_clang:
+ cuda_defines["%{linker_bin_path_flag}"] = ""
+ else:
+ cuda_defines["%{linker_bin_path_flag}"] = 'flag: "-B/usr/bin"'
+
+ if is_cuda_clang:
+ cuda_defines["%{host_compiler_path}"] = str(cc)
+ cuda_defines["%{host_compiler_warnings}"] = """
# Some parts of the codebase set -Werror and hit this warning, so
# switch it off for now.
flag: "-Wno-invalid-partial-specialization"
"""
- cuda_defines["%{host_compiler_includes}"] = host_compiler_includes
- _tpl(repository_ctx, "crosstool:BUILD", {"%{linker_files}": ":empty", "%{win_linker_files}": ":empty"})
- repository_ctx.file("crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", "")
- repository_ctx.file("crosstool/windows/msvc_wrapper_for_nvcc.py", "")
- repository_ctx.file("crosstool/windows/msvc_wrapper_for_nvcc.bat", "")
- else:
- cuda_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc"
- cuda_defines["%{host_compiler_warnings}"] = ""
-
- # nvcc has the system include paths built in and will automatically
- # search them; we cannot work around that, so we add the relevant cuda
- # system paths to the allowed compiler specific include paths.
- cuda_defines["%{host_compiler_includes}"] = (
- host_compiler_includes + "\n" +
- _cuda_include_path(repository_ctx, cuda_config) +
- "\n cxx_builtin_include_directory: \"%s\"" % cupti_header_dir +
- "\n cxx_builtin_include_directory: \"%s\"" % cudnn_header_dir)
- nvcc_path = str(repository_ctx.path("%s/bin/nvcc%s" %
- (
- cuda_config.cuda_toolkit_path,
- ".exe" if _is_windows(repository_ctx) else "",
- )))
- _tpl(
- repository_ctx,
- "crosstool:BUILD",
- {
- "%{linker_files}": ":crosstool_wrapper_driver_is_not_gcc",
- "%{win_linker_files}": ":windows_msvc_wrapper_files",
- },
- )
- wrapper_defines = {
- "%{cpu_compiler}": str(cc),
- "%{cuda_version}": cuda_config.cuda_version,
- "%{nvcc_path}": nvcc_path,
- "%{gcc_host_compiler_path}": str(cc),
- "%{cuda_compute_capabilities}": ", ".join(
- ["\"%s\"" % c for c in cuda_config.compute_capabilities],
- ),
- "%{nvcc_tmp_dir}": _get_nvcc_tmp_dir_for_windows(repository_ctx),
- }
- _tpl(
- repository_ctx,
- "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc",
- wrapper_defines,
- )
- _tpl(
- repository_ctx,
- "crosstool:windows/msvc_wrapper_for_nvcc.py",
- wrapper_defines,
- )
- _tpl(
- repository_ctx,
- "crosstool:windows/msvc_wrapper_for_nvcc.bat",
- {
- "%{python_binary}": _get_python_bin(repository_ctx),
- },
- )
-
+ cuda_defines["%{host_compiler_includes}"] = host_compiler_includes
+ _tpl(repository_ctx, "crosstool:BUILD", {
+ "%{linker_files}": ":empty",
+ "%{win_linker_files}": ":empty"
+ })
+ repository_ctx.file(
+ "crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc", "")
+ repository_ctx.file("crosstool/windows/msvc_wrapper_for_nvcc.py", "")
+ repository_ctx.file("crosstool/windows/msvc_wrapper_for_nvcc.bat", "")
+ else:
+ cuda_defines[
+ "%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc"
+ cuda_defines["%{host_compiler_warnings}"] = ""
+
+ # nvcc has the system include paths built in and will automatically
+ # search them; we cannot work around that, so we add the relevant cuda
+ # system paths to the allowed compiler specific include paths.
+ cuda_defines["%{host_compiler_includes}"] = (
+ host_compiler_includes + "\n" + _cuda_include_path(
+ repository_ctx, cuda_config) +
+ "\n cxx_builtin_include_directory: \"%s\"" % cupti_header_dir +
+ "\n cxx_builtin_include_directory: \"%s\"" % cudnn_header_dir)
+ nvcc_path = str(
+ repository_ctx.path("%s/bin/nvcc%s" % (
+ cuda_config.cuda_toolkit_path,
+ ".exe" if _is_windows(repository_ctx) else "",
+ )))
_tpl(
repository_ctx,
- "crosstool:CROSSTOOL",
- cuda_defines + _get_win_cuda_defines(repository_ctx),
- out = "crosstool/CROSSTOOL",
+ "crosstool:BUILD",
+ {
+ "%{linker_files}": ":crosstool_wrapper_driver_is_not_gcc",
+ "%{win_linker_files}": ":windows_msvc_wrapper_files",
+ },
)
-
- # Set up cuda_config.h, which is used by
- # tensorflow/stream_executor/dso_loader.cc.
+ wrapper_defines = {
+ "%{cpu_compiler}":
+ str(cc),
+ "%{cuda_version}":
+ cuda_config.cuda_version,
+ "%{nvcc_path}":
+ nvcc_path,
+ "%{gcc_host_compiler_path}":
+ str(cc),
+ "%{cuda_compute_capabilities}":
+ ", ".join(
+ ["\"%s\"" % c for c in cuda_config.compute_capabilities],),
+ "%{nvcc_tmp_dir}":
+ _get_nvcc_tmp_dir_for_windows(repository_ctx),
+ }
_tpl(
repository_ctx,
- "cuda:cuda_config.h",
- {
- "%{cuda_version}": cuda_config.cuda_version,
- "%{cudnn_version}": cuda_config.cudnn_version,
- "%{cuda_compute_capabilities}": ",".join(
- [
- "CudaVersion(\"%s\")" % c
- for c in cuda_config.compute_capabilities
- ],
- ),
- "%{cuda_toolkit_path}": cuda_config.cuda_toolkit_path,
- },
- "cuda/cuda/cuda_config.h",
+ "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc",
+ wrapper_defines,
)
-
-def _create_remote_cuda_repository(repository_ctx, remote_config_repo):
- """Creates pointers to a remotely configured repo set up to build with CUDA."""
_tpl(
repository_ctx,
- "cuda:build_defs.bzl",
- {
- "%{cuda_is_configured}": "True",
- "%{cuda_extra_copts}": _compute_cuda_extra_copts(
- repository_ctx,
- _compute_capabilities(repository_ctx),
- ),
- },
+ "crosstool:windows/msvc_wrapper_for_nvcc.py",
+ wrapper_defines,
)
_tpl(
repository_ctx,
- "cuda:remote.BUILD",
+ "crosstool:windows/msvc_wrapper_for_nvcc.bat",
{
- "%{remote_cuda_repo}": remote_config_repo,
+ "%{python_binary}": _get_python_bin(repository_ctx),
},
- "cuda/BUILD",
)
- _tpl(repository_ctx, "crosstool:remote.BUILD", {
- "%{remote_cuda_repo}": remote_config_repo,
- }, "crosstool/BUILD")
+
+ _tpl(
+ repository_ctx,
+ "crosstool:CROSSTOOL",
+ cuda_defines + _get_win_cuda_defines(repository_ctx),
+ out="crosstool/CROSSTOOL",
+ )
+
+ # Set up cuda_config.h, which is used by
+ # tensorflow/stream_executor/dso_loader.cc.
+ _tpl(
+ repository_ctx,
+ "cuda:cuda_config.h",
+ {
+ "%{cuda_version}":
+ cuda_config.cuda_version,
+ "%{cudnn_version}":
+ cuda_config.cudnn_version,
+ "%{cuda_compute_capabilities}":
+ ",".join([
+ "CudaVersion(\"%s\")" % c
+ for c in cuda_config.compute_capabilities
+ ],),
+ "%{cuda_toolkit_path}":
+ cuda_config.cuda_toolkit_path,
+ },
+ "cuda/cuda/cuda_config.h",
+ )
+
+
+def _create_remote_cuda_repository(repository_ctx, remote_config_repo):
+ """Creates pointers to a remotely configured repo set up to build with CUDA."""
+ _tpl(
+ repository_ctx,
+ "cuda:build_defs.bzl",
+ {
+ "%{cuda_is_configured}":
+ "True",
+ "%{cuda_extra_copts}":
+ _compute_cuda_extra_copts(
+ repository_ctx,
+ compute_capabilities(repository_ctx),
+ ),
+ },
+ )
+ _tpl(
+ repository_ctx,
+ "cuda:remote.BUILD",
+ {
+ "%{remote_cuda_repo}": remote_config_repo,
+ },
+ "cuda/BUILD",
+ )
+ _tpl(repository_ctx, "crosstool:remote.BUILD", {
+ "%{remote_cuda_repo}": remote_config_repo,
+ }, "crosstool/BUILD")
+
def _cuda_autoconf_impl(repository_ctx):
- """Implementation of the cuda_autoconf repository rule."""
- if not _enable_cuda(repository_ctx):
- _create_dummy_repository(repository_ctx)
- elif _TF_CUDA_CONFIG_REPO in repository_ctx.os.environ:
- _create_remote_cuda_repository(
- repository_ctx,
- repository_ctx.os.environ[_TF_CUDA_CONFIG_REPO],
- )
- else:
- _create_local_cuda_repository(repository_ctx)
+ """Implementation of the cuda_autoconf repository rule."""
+ if not _enable_cuda(repository_ctx):
+ _create_dummy_repository(repository_ctx)
+ elif _TF_CUDA_CONFIG_REPO in repository_ctx.os.environ:
+ _create_remote_cuda_repository(
+ repository_ctx,
+ repository_ctx.os.environ[_TF_CUDA_CONFIG_REPO],
+ )
+ else:
+ _create_local_cuda_repository(repository_ctx)
+
cuda_configure = repository_rule(
implementation = _cuda_autoconf_impl,
diff --git a/third_party/jpeg/BUILD b/third_party/jpeg/BUILD
index 5b01f6e3e4..e3aec1fce9 100644
--- a/third_party/jpeg/BUILD
+++ b/third_party/jpeg/BUILD
@@ -1 +1 @@
-licenses(["notice"])
+# Needed to make this a package.
diff --git a/third_party/jpeg/jpeg.BUILD b/third_party/jpeg/BUILD.bazel
index 1b9b9bf2f5..5243e995a3 100644
--- a/third_party/jpeg/jpeg.BUILD
+++ b/third_party/jpeg/BUILD.bazel
@@ -162,9 +162,9 @@ cc_library(
hdrs = [
"simd/powerpc/jccolext-altivec.c",
"simd/powerpc/jcgryext-altivec.c",
+ "simd/powerpc/jcsample.h",
"simd/powerpc/jdcolext-altivec.c",
"simd/powerpc/jdmrgext-altivec.c",
- "simd/powerpc/jcsample.h",
"simd/powerpc/jsimd_altivec.h",
],
copts = libjpegturbo_copts,
@@ -186,7 +186,6 @@ cc_library(
"jsimd.h",
"jsimddct.h",
"simd/jsimd.h",
- "simd/x86_64/jsimd.c",
"simd/x86_64/jccolor-avx2.o",
"simd/x86_64/jccolor-sse2.o",
"simd/x86_64/jcgray-avx2.o",
@@ -213,6 +212,7 @@ cc_library(
"simd/x86_64/jquantf-sse2.o",
"simd/x86_64/jquanti-avx2.o",
"simd/x86_64/jquanti-sse2.o",
+ "simd/x86_64/jsimd.c",
"simd/x86_64/jsimdcpu.o",
],
copts = libjpegturbo_copts,
@@ -322,9 +322,9 @@ cc_library(
"jpeglib.h",
"jsimd.h",
"jsimddct.h",
- "simd/jsimd.h",
"simd/arm/jsimd.c",
"simd/arm/jsimd_neon.S",
+ "simd/jsimd.h",
],
copts = libjpegturbo_copts,
nocopts = libjpegturbo_nocopts,
@@ -343,9 +343,9 @@ cc_library(
"jpeglib.h",
"jsimd.h",
"jsimddct.h",
- "simd/jsimd.h",
"simd/arm64/jsimd.c",
"simd/arm64/jsimd_neon.S",
+ "simd/jsimd.h",
],
copts = libjpegturbo_copts,
nocopts = libjpegturbo_nocopts,
@@ -366,7 +366,6 @@ cc_library(
"jsimd.h",
"jsimddct.h",
"simd/jsimd.h",
- "simd/x86_64/jsimd.c",
"simd/x86_64/jccolor-avx2.obj",
"simd/x86_64/jccolor-sse2.obj",
"simd/x86_64/jcgray-avx2.obj",
@@ -393,6 +392,7 @@ cc_library(
"simd/x86_64/jquantf-sse2.obj",
"simd/x86_64/jquanti-avx2.obj",
"simd/x86_64/jquanti-sse2.obj",
+ "simd/x86_64/jsimd.c",
"simd/x86_64/jsimdcpu.obj",
],
copts = libjpegturbo_copts,
@@ -603,6 +603,7 @@ JCONFIGINT_WIN_SUBSTITUTIONS = {
}
JCONFIGINT_NOWIN_SUBSTITUTIONS.update(JCONFIGINT_COMMON_SUBSTITUTIONS)
+
JCONFIGINT_WIN_SUBSTITUTIONS.update(JCONFIGINT_COMMON_SUBSTITUTIONS)
template_rule(
diff --git a/third_party/systemlibs/jpeg.BUILD b/third_party/jpeg/BUILD.system
index f4f52da9bd..f4f52da9bd 100644
--- a/third_party/systemlibs/jpeg.BUILD
+++ b/third_party/jpeg/BUILD.system
diff --git a/third_party/jpeg/jpeg_helpers.BUILD.bazel b/third_party/jpeg/jpeg_helpers.BUILD.bazel
new file mode 100644
index 0000000000..5b01f6e3e4
--- /dev/null
+++ b/third_party/jpeg/jpeg_helpers.BUILD.bazel
@@ -0,0 +1 @@
+licenses(["notice"])
diff --git a/third_party/jpeg/workspace.bzl b/third_party/jpeg/workspace.bzl
new file mode 100644
index 0000000000..2bb7dacd32
--- /dev/null
+++ b/third_party/jpeg/workspace.bzl
@@ -0,0 +1,16 @@
+"""loads the jpeg library, used by TF."""
+
+load("//third_party:repo.bzl", "third_party_http_archive")
+
+def repo():
+ third_party_http_archive(
+ name = "jpeg",
+ urls = [
+ "https://mirror.bazel.build/github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.0.tar.gz",
+ "https://github.com/libjpeg-turbo/libjpeg-turbo/archive/2.0.0.tar.gz",
+ ],
+ sha256 = "f892fff427ab3adffc289363eac26d197ce3ccacefe5f5822377348a8166069b",
+ strip_prefix = "libjpeg-turbo-2.0.0",
+ build_file = "//third_party/jpeg:BUILD.bazel",
+ system_build_file = "//third_party/jpeg:BUILD.system",
+ )
diff --git a/third_party/nasm/BUILD b/third_party/nasm/BUILD
new file mode 100644
index 0000000000..e3aec1fce9
--- /dev/null
+++ b/third_party/nasm/BUILD
@@ -0,0 +1 @@
+# Needed to make this a package.
diff --git a/third_party/nasm.BUILD b/third_party/nasm/BUILD.bazel
index d746a65e7e..c68d713946 100644
--- a/third_party/nasm.BUILD
+++ b/third_party/nasm/BUILD.bazel
@@ -137,12 +137,6 @@ cc_binary(
":windows": ["config/msvc.h"],
"//conditions:default": [],
}),
- includes = [
- "asm",
- "include",
- "output",
- "x86",
- ],
copts = select({
":windows": [],
"//conditions:default": [
@@ -157,6 +151,12 @@ cc_binary(
"HAVE_SYS_TYPES_H",
],
}),
+ includes = [
+ "asm",
+ "include",
+ "output",
+ "x86",
+ ],
visibility = ["@jpeg//:__pkg__"],
)
diff --git a/third_party/systemlibs/nasm.BUILD b/third_party/nasm/BUILD.system
index 10ef8d8832..10ef8d8832 100644
--- a/third_party/systemlibs/nasm.BUILD
+++ b/third_party/nasm/BUILD.system
diff --git a/third_party/nasm/workspace.bzl b/third_party/nasm/workspace.bzl
new file mode 100644
index 0000000000..6d50f6fcad
--- /dev/null
+++ b/third_party/nasm/workspace.bzl
@@ -0,0 +1,17 @@
+"""loads the nasm library, used by TF."""
+
+load("//third_party:repo.bzl", "third_party_http_archive")
+
+def repo():
+ third_party_http_archive(
+ name = "nasm",
+ urls = [
+ "https://mirror.bazel.build/www.nasm.us/pub/nasm/releasebuilds/2.13.03/nasm-2.13.03.tar.bz2",
+ "http://pkgs.fedoraproject.org/repo/pkgs/nasm/nasm-2.13.03.tar.bz2/sha512/d7a6b4cee8dfd603d8d4c976e5287b5cc542fa0b466ff989b743276a6e28114e64289bf02a7819eca63142a5278aa6eed57773007e5f589e15768e6456a8919d/nasm-2.13.03.tar.bz2",
+ "http://www.nasm.us/pub/nasm/releasebuilds/2.13.03/nasm-2.13.03.tar.bz2",
+ ],
+ sha256 = "63ec86477ad3f0f6292325fd89e1d93aea2e2fd490070863f17d48f7cd387011",
+ strip_prefix = "nasm-2.13.03",
+ build_file = "//third_party/nasm:BUILD.bazel",
+ system_build_file = "//third_party/nasm:BUILD.system",
+ )
diff --git a/third_party/nccl/LICENSE b/third_party/nccl/LICENSE
index 146d9b765c..b958518186 100644
--- a/third_party/nccl/LICENSE
+++ b/third_party/nccl/LICENSE
@@ -1,203 +1,30 @@
-Copyright 2018 The TensorFlow Authors. All rights reserved.
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
-
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
- 1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- APPENDIX: How to apply the Apache License to your work.
-
- To apply the Apache License to your work, attach the following
- boilerplate notice, with the fields enclosed by brackets "[]"
- replaced with your own identifying information. (Don't include
- the brackets!) The text should be enclosed in the appropriate
- comment syntax for the file format. We also recommend that a
- file or class name and description of purpose be included on the
- same "printed page" as the copyright notice for easier
- identification within third-party archives.
-
- Copyright 2018, The TensorFlow Authors.
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
+ Copyright (c) 2015-2018, NVIDIA CORPORATION. All rights reserved.
+
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+ * Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+ * Neither the name of NVIDIA CORPORATION, Lawrence Berkeley National
+ Laboratory, the U.S. Department of Energy, nor the names of their
+ contributors may be used to endorse or promote products derived
+ from this software without specific prior written permission.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+ EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+ PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+ CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+ OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+ The U.S. Department of Energy funded the development of this software
+ under subcontract 7078610 with Lawrence Berkeley National Laboratory.
diff --git a/third_party/nccl/archive.BUILD b/third_party/nccl/archive.BUILD
new file mode 100644
index 0000000000..f57f04c75e
--- /dev/null
+++ b/third_party/nccl/archive.BUILD
@@ -0,0 +1,179 @@
+# NVIDIA NCCL 2
+# A package of optimized primitives for collective multi-GPU communication.
+
+licenses(["restricted"])
+
+exports_files(["LICENSE.txt"])
+
+load(
+ "@local_config_nccl//:build_defs.bzl",
+ "device_link",
+ "gen_nccl_h",
+ "nccl_library",
+ "rdc_copts",
+)
+load(
+ "@local_config_cuda//cuda:build_defs.bzl",
+ "cuda_default_copts",
+)
+
+# Generate the nccl.h header file.
+gen_nccl_h(
+ name = "nccl_h",
+ output = "src/nccl.h",
+ template = "src/nccl.h.in",
+)
+
+nccl_library(
+ name = "src_hdrs",
+ hdrs = [
+ "src/nccl.h",
+ # src/include/common_coll.h #includes "collectives/collectives.h".
+ # All other #includes of collectives.h are patched in process_srcs.
+ "src/collectives/collectives.h",
+ ],
+ strip_include_prefix = "src",
+)
+
+nccl_library(
+ name = "include_hdrs",
+ hdrs = glob(["src/include/*.h"]),
+ strip_include_prefix = "src/include",
+)
+
+filegroup(
+ name = "device_hdrs",
+ srcs = glob(["src/collectives/device/*.h"]),
+)
+
+filegroup(
+ name = "device_srcs",
+ srcs = [
+ "src/collectives/device/all_gather.cu",
+ "src/collectives/device/all_reduce.cu",
+ "src/collectives/device/broadcast.cu",
+ "src/collectives/device/reduce.cu",
+ "src/collectives/device/reduce_scatter.cu",
+ ],
+)
+
+nccl_library(
+ name = "sum",
+ srcs = [
+ ":device_hdrs",
+ ":device_srcs",
+ ],
+ copts = ["-DNCCL_OP=0"] + rdc_copts(),
+ prefix = "sum_",
+ deps = [
+ ":src_hdrs",
+ ":include_hdrs",
+ "@local_config_cuda//cuda:cuda_headers",
+ ],
+ linkstatic = True,
+)
+
+nccl_library(
+ name = "prod",
+ srcs = [
+ ":device_hdrs",
+ ":device_srcs",
+ ],
+ copts = ["-DNCCL_OP=1"] + rdc_copts(),
+ prefix = "_prod",
+ deps = [
+ ":src_hdrs",
+ ":include_hdrs",
+ "@local_config_cuda//cuda:cuda_headers",
+ ],
+ linkstatic = True,
+)
+
+nccl_library(
+ name = "min",
+ srcs = [
+ ":device_hdrs",
+ ":device_srcs",
+ ],
+ copts = ["-DNCCL_OP=2"] + rdc_copts(),
+ prefix = "min_",
+ deps = [
+ ":src_hdrs",
+ ":include_hdrs",
+ "@local_config_cuda//cuda:cuda_headers",
+ ],
+ linkstatic = True,
+)
+
+nccl_library(
+ name = "max",
+ srcs = [
+ ":device_hdrs",
+ ":device_srcs",
+ ],
+ copts = ["-DNCCL_OP=3"] + rdc_copts(),
+ prefix = "max_",
+ deps = [
+ ":src_hdrs",
+ ":include_hdrs",
+ "@local_config_cuda//cuda:cuda_headers",
+ ],
+ linkstatic = True,
+)
+
+nccl_library(
+ name = "functions",
+ srcs = [
+ ":device_hdrs",
+ "src/collectives/device/functions.cu",
+ ],
+ copts = rdc_copts(),
+ deps = [
+ ":src_hdrs",
+ ":include_hdrs",
+ "@local_config_cuda//cuda:cuda_headers",
+ ],
+ linkstatic = True,
+)
+
+device_link(
+ name = "device_code",
+ srcs = [
+ ":functions",
+ ":max",
+ ":min",
+ ":prod",
+ ":sum",
+ ],
+)
+
+# Primary NCCL target.
+nccl_library(
+ name = "nccl",
+ srcs = glob(
+ include = ["src/**/*.cu"],
+ # Exclude device-library code.
+ exclude = ["src/collectives/device/**"],
+ ) + [
+ # Required for header inclusion checking (see
+ # http://docs.bazel.build/versions/master/be/c-cpp.html#hdrs).
+ # Files in src/ which #include "nccl.h" load it from there rather than
+ # from the virtual includes directory.
+ "src/nccl.h",
+ ],
+ hdrs = ["src/nccl.h"],
+ include_prefix = "third_party/nccl",
+ strip_include_prefix = "src",
+ copts = cuda_default_copts(),
+ deps = [
+ ":device_code",
+ ":functions",
+ ":include_hdrs",
+ ":max",
+ ":min",
+ ":prod",
+ ":src_hdrs",
+ ":sum",
+ ],
+ visibility = ["//visibility:public"],
+)
diff --git a/third_party/nccl/build_defs.bzl.tpl b/third_party/nccl/build_defs.bzl.tpl
new file mode 100644
index 0000000000..ede1d3dad5
--- /dev/null
+++ b/third_party/nccl/build_defs.bzl.tpl
@@ -0,0 +1,210 @@
+"""Repository rule for NCCL."""
+
+load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts")
+
+def _gen_nccl_h_impl(ctx):
+ """Creates nccl.h from a template."""
+ ctx.actions.expand_template(
+ output = ctx.outputs.output,
+ template = ctx.file.template,
+ substitutions = {
+ "${nccl:Major}": "2",
+ "${nccl:Minor}": "3",
+ "${nccl:Patch}": "5",
+ "${nccl:Suffix}": "",
+ "${nccl:Version}": "2305",
+ },
+ )
+gen_nccl_h = rule(
+ implementation = _gen_nccl_h_impl,
+ attrs = {
+ "template": attr.label(allow_single_file = True),
+ "output": attr.output(),
+ },
+)
+"""Creates the NCCL header file."""
+
+
+def _process_srcs_impl(ctx):
+ """Appends .cc to .cu files, patches include directives."""
+ files = []
+ for src in ctx.files.srcs:
+ if not src.is_source:
+ # Process only once, specifically "src/nccl.h".
+ files.append(src)
+ continue
+ name = src.basename
+ if src.extension == "cu":
+ name = ctx.attr.prefix + name + ".cc"
+ file = ctx.actions.declare_file(name, sibling = src)
+ ctx.actions.expand_template(
+ output = file,
+ template = src,
+ substitutions = {
+ "\"collectives.h": "\"collectives/collectives.h",
+ "\"../collectives.h": "\"collectives/collectives.h",
+ "#if __CUDACC_VER_MAJOR__":
+ "#if defined __CUDACC_VER_MAJOR__ && __CUDACC_VER_MAJOR__",
+ # Substitutions are applied in order.
+ "std::nullptr_t": "nullptr_t",
+ "nullptr_t": "std::nullptr_t",
+ },
+ )
+ files.append(file)
+ return [DefaultInfo(files = depset(files))]
+_process_srcs = rule(
+ implementation = _process_srcs_impl,
+ attrs = {
+ "srcs": attr.label_list(allow_files = True),
+ "prefix": attr.string(default = ""),
+ },
+)
+"""Processes the NCCL srcs so they can be compiled with bazel and clang."""
+
+
+def nccl_library(name, srcs=None, hdrs=None, prefix=None, **kwargs):
+ """Processes the srcs and hdrs and creates a cc_library."""
+
+ _process_srcs(
+ name = name + "_srcs",
+ srcs = srcs,
+ prefix = prefix,
+ )
+ _process_srcs(
+ name = name + "_hdrs",
+ srcs = hdrs,
+ )
+
+ native.cc_library(
+ name = name,
+ srcs = [name + "_srcs"] if srcs else [],
+ hdrs = [name + "_hdrs"] if hdrs else [],
+ **kwargs
+ )
+
+
+def rdc_copts():
+ """Returns copts for compiling relocatable device code."""
+
+ # The global functions can not have a lower register count than the
+ # device functions. This is enforced by setting a fixed register count.
+ # https://github.com/NVIDIA/nccl/blob/f93fe9bfd94884cec2ba711897222e0df5569a53/makefiles/common.mk#L48
+ maxrregcount = "-maxrregcount=96"
+
+ return cuda_default_copts() + select({
+ "@local_config_cuda//cuda:using_nvcc": [
+ "-nvcc_options",
+ "relocatable-device-code=true",
+ "-nvcc_options",
+ "ptxas-options=" + maxrregcount,
+ ],
+ "@local_config_cuda//cuda:using_clang": [
+ "-fcuda-rdc",
+ "-Xcuda-ptxas",
+ maxrregcount,
+ ],
+ "//conditions:default": [],
+ }) + ["-fvisibility=hidden"]
+
+
+def _filter_impl(ctx):
+ suffix = ctx.attr.suffix
+ files = [src for src in ctx.files.srcs if src.path.endswith(suffix)]
+ return [DefaultInfo(files = depset(files))]
+_filter = rule(
+ implementation = _filter_impl,
+ attrs = {
+ "srcs": attr.label_list(allow_files = True),
+ "suffix": attr.string(),
+ },
+)
+"""Filters the srcs to the ones ending with suffix."""
+
+
+def _gen_link_src_impl(ctx):
+ ctx.actions.expand_template(
+ output = ctx.outputs.output,
+ template = ctx.file.template,
+ substitutions = {
+ "REGISTERLINKBINARYFILE": '"%s"' % ctx.file.register_hdr.short_path,
+ "FATBINFILE": '"%s"' % ctx.file.fatbin_hdr.short_path,
+ },
+ )
+_gen_link_src = rule(
+ implementation = _gen_link_src_impl,
+ attrs = {
+ "register_hdr": attr.label(allow_single_file = True),
+ "fatbin_hdr": attr.label(allow_single_file = True),
+ "template": attr.label(allow_single_file = True),
+ "output": attr.output(),
+ },
+)
+"""Patches the include directives for the link.stub file."""
+
+
+def device_link(name, srcs):
+ """Links seperately compiled relocatable device code into a cc_library."""
+
+ # From .a and .pic.a archives, just use the latter.
+ _filter(
+ name = name + "_pic_a",
+ srcs = srcs,
+ suffix = ".pic.a",
+ )
+
+ # Device-link to cubins for each architecture.
+ images = []
+ cubins = []
+ for arch in %{gpu_architectures}:
+ cubin = "%s_%s.cubin" % (name, arch)
+ register_hdr = "%s_%s.h" % (name, arch)
+ nvlink = "@local_config_nccl//:nvlink"
+ cmd = ("$(location %s) --cpu-arch=X86_64 " % nvlink +
+ "--arch=%s $(SRCS) " % arch +
+ "--register-link-binaries=$(location %s) " % register_hdr +
+ "--output-file=$(location %s)" % cubin)
+ native.genrule(
+ name = "%s_%s" % (name, arch),
+ outs = [register_hdr, cubin],
+ srcs = [name + "_pic_a"],
+ cmd = cmd,
+ tools = [nvlink],
+ )
+ images.append("--image=profile=%s,file=$(location %s)" % (arch, cubin))
+ cubins.append(cubin)
+
+ # Generate fatbin header from all cubins.
+ fatbin_hdr = name + ".fatbin.h"
+ fatbinary = "@local_config_nccl//:cuda/bin/fatbinary"
+ cmd = ("PATH=$$CUDA_TOOLKIT_PATH/bin:$$PATH " + # for bin2c
+ "$(location %s) -64 --cmdline=--compile-only --link " % fatbinary +
+ "--compress-all %s --create=%%{name}.fatbin " % " ".join(images) +
+ "--embedded-fatbin=$@")
+ native.genrule(
+ name = name + "_fatbin_h",
+ outs = [fatbin_hdr],
+ srcs = cubins,
+ cmd = cmd,
+ tools = [fatbinary],
+ )
+
+ # Generate the source file #including the headers generated above.
+ _gen_link_src(
+ name = name + "_cc",
+ # Include just the last one, they are equivalent.
+ register_hdr = register_hdr,
+ fatbin_hdr = fatbin_hdr,
+ template = "@local_config_nccl//:cuda/bin/crt/link.stub",
+ output = name + ".cc",
+ )
+
+ # Compile the source file into the cc_library.
+ native.cc_library(
+ name = name,
+ srcs = [name + "_cc"],
+ textual_hdrs = [register_hdr, fatbin_hdr],
+ deps = [
+ "@local_config_cuda//cuda:cuda_headers",
+ "@local_config_cuda//cuda:cudart_static",
+ ],
+ )
diff --git a/third_party/nccl/nccl_archive.BUILD b/third_party/nccl/nccl_archive.BUILD
deleted file mode 100644
index a05899e38d..0000000000
--- a/third_party/nccl/nccl_archive.BUILD
+++ /dev/null
@@ -1,68 +0,0 @@
-# NVIDIA nccl
-# A package of optimized primitives for collective multi-GPU communication.
-
-licenses(["notice"]) # BSD
-
-exports_files(["LICENSE.txt"])
-
-load("@local_config_cuda//cuda:build_defs.bzl", "cuda_default_copts", "if_cuda")
-
-SRCS = [
- "src/all_gather.cu",
- "src/all_reduce.cu",
- "src/broadcast.cu",
- "src/core.cu",
- "src/libwrap.cu",
- "src/reduce.cu",
- "src/reduce_scatter.cu",
-]
-
-# Copy .cu to .cu.cc so they can be in srcs of cc_library.
-[
- genrule(
- name = "gen_" + src,
- srcs = [src],
- outs = [src + ".cc"],
- cmd = "cp $(location " + src + ") $(location " + src + ".cc)",
- )
- for src in SRCS
-]
-
-SRCS_CU_CC = [src + ".cc" for src in SRCS]
-
-cc_library(
- name = "nccl",
- srcs = if_cuda(SRCS_CU_CC + glob(["src/*.h"])),
- hdrs = if_cuda(["src/nccl.h"]),
- copts = [
- "-DCUDA_MAJOR=0",
- "-DCUDA_MINOR=0",
- "-DNCCL_MAJOR=0",
- "-DNCCL_MINOR=0",
- "-DNCCL_PATCH=0",
- "-Iexternal/nccl_archive/src",
- "-O3",
- ] + cuda_default_copts(),
- include_prefix = "third_party/nccl",
- linkopts = select({
- "@org_tensorflow//tensorflow:android": [
- "-pie",
- ],
- "@org_tensorflow//tensorflow:darwin": [
- "-Wl,-framework",
- "-Wl,CoreFoundation",
- "-Wl,-framework",
- "-Wl,Security",
- ],
- "@org_tensorflow//tensorflow:ios": [],
- "@org_tensorflow//tensorflow:windows": [
- "-DEFAULTLIB:ws2_32.lib",
- ],
- "//conditions:default": [
- "-lrt",
- ],
- }),
- strip_include_prefix = "src",
- visibility = ["//visibility:public"],
- deps = ["@local_config_cuda//cuda:cuda_headers"],
-)
diff --git a/third_party/nccl/nccl_configure.bzl b/third_party/nccl/nccl_configure.bzl
index d78fe8f3aa..7f00df0962 100644
--- a/third_party/nccl/nccl_configure.bzl
+++ b/third_party/nccl/nccl_configure.bzl
@@ -11,12 +11,16 @@
load(
"//third_party/gpus:cuda_configure.bzl",
"auto_configure_fail",
+ "compute_capabilities",
+ "cuda_toolkit_path",
"find_cuda_define",
"matches_version",
)
-_NCCL_INSTALL_PATH = "NCCL_INSTALL_PATH"
+_CUDA_TOOLKIT_PATH = "CUDA_TOOLKIT_PATH"
_NCCL_HDR_PATH = "NCCL_HDR_PATH"
+_NCCL_INSTALL_PATH = "NCCL_INSTALL_PATH"
+_TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES"
_TF_NCCL_VERSION = "TF_NCCL_VERSION"
_TF_NCCL_CONFIG_REPO = "TF_NCCL_CONFIG_REPO"
@@ -37,6 +41,12 @@ cc_library(
"""
_NCCL_ARCHIVE_BUILD_CONTENT = """
+exports_files([
+ "cuda/bin/crt/link.stub",
+ "cuda/bin/fatbinary",
+ "nvlink",
+])
+
filegroup(
name = "LICENSE",
data = ["@nccl_archive//:LICENSE.txt"],
@@ -50,113 +60,125 @@ alias(
)
"""
-# Local build results in dynamic link and the license should not be included.
-_NCCL_REMOTE_BUILD_TEMPLATE = Label("//third_party/nccl:remote.BUILD.tpl")
-_NCCL_LOCAL_BUILD_TEMPLATE = Label("//third_party/nccl:system.BUILD.tpl")
+def _label(file):
+ return Label("//third_party/nccl:{}".format(file))
def _find_nccl_header(repository_ctx, nccl_install_path):
- """Finds the NCCL header on the system.
-
- Args:
- repository_ctx: The repository context.
- nccl_install_path: The NCCL library install directory.
+ """Finds the NCCL header on the system.
- Returns:
- The path to the NCCL header.
- """
- header_path = repository_ctx.path("%s/include/nccl.h" % nccl_install_path)
- if not header_path.exists:
- auto_configure_fail("Cannot find %s" % str(header_path))
- return header_path
+ Args:
+ repository_ctx: The repository context.
+ nccl_install_path: The NCCL library install directory.
+ Returns:
+ The path to the NCCL header.
+ """
+ header_path = repository_ctx.path("%s/include/nccl.h" % nccl_install_path)
+ if not header_path.exists:
+ auto_configure_fail("Cannot find %s" % str(header_path))
+ return header_path
def _check_nccl_version(repository_ctx, nccl_install_path, nccl_hdr_path, nccl_version):
- """Checks whether the header file matches the specified version of NCCL.
-
- Args:
- repository_ctx: The repository context.
- nccl_install_path: The NCCL library install directory.
- nccl_version: The expected NCCL version.
-
- Returns:
- A string containing the library version of NCCL.
- """
- header_path = repository_ctx.path("%s/nccl.h" % nccl_hdr_path)
- if not header_path.exists:
- header_path = _find_nccl_header(repository_ctx, nccl_install_path)
- header_dir = str(header_path.realpath.dirname)
- major_version = find_cuda_define(repository_ctx, header_dir, "nccl.h",
- _DEFINE_NCCL_MAJOR)
- minor_version = find_cuda_define(repository_ctx, header_dir, "nccl.h",
- _DEFINE_NCCL_MINOR)
- patch_version = find_cuda_define(repository_ctx, header_dir, "nccl.h",
- _DEFINE_NCCL_PATCH)
- header_version = "%s.%s.%s" % (major_version, minor_version, patch_version)
- if not matches_version(nccl_version, header_version):
- auto_configure_fail(
- ("NCCL library version detected from %s/nccl.h (%s) does not match " +
- "TF_NCCL_VERSION (%s). To fix this rerun configure again.") %
- (header_dir, header_version, nccl_version))
-
-
-def _find_nccl_lib(repository_ctx, nccl_install_path, nccl_version):
- """Finds the given NCCL library on the system.
-
- Args:
- repository_ctx: The repository context.
- nccl_install_path: The NCCL library installation directory.
- nccl_version: The version of NCCL library files as returned
- by _nccl_version.
-
- Returns:
- The path to the NCCL library.
- """
- lib_path = repository_ctx.path("%s/lib/libnccl.so.%s" % (nccl_install_path,
- nccl_version))
- if not lib_path.exists:
- auto_configure_fail("Cannot find NCCL library %s" % str(lib_path))
- return lib_path
-
+ """Checks whether the header file matches the specified version of NCCL.
+
+ Args:
+ repository_ctx: The repository context.
+ nccl_install_path: The NCCL library install directory.
+ nccl_hdr_path: The NCCL header path.
+ nccl_version: The expected NCCL version.
+
+ Returns:
+ A string containing the library version of NCCL.
+ """
+ header_path = repository_ctx.path("%s/nccl.h" % nccl_hdr_path)
+ if not header_path.exists:
+ header_path = _find_nccl_header(repository_ctx, nccl_install_path)
+ header_dir = str(header_path.realpath.dirname)
+ major_version = find_cuda_define(
+ repository_ctx,
+ header_dir,
+ "nccl.h",
+ _DEFINE_NCCL_MAJOR,
+ )
+ minor_version = find_cuda_define(
+ repository_ctx,
+ header_dir,
+ "nccl.h",
+ _DEFINE_NCCL_MINOR,
+ )
+ patch_version = find_cuda_define(
+ repository_ctx,
+ header_dir,
+ "nccl.h",
+ _DEFINE_NCCL_PATCH,
+ )
+ header_version = "%s.%s.%s" % (major_version, minor_version, patch_version)
+ if not matches_version(nccl_version, header_version):
+ auto_configure_fail(
+ ("NCCL library version detected from %s/nccl.h (%s) does not match " +
+ "TF_NCCL_VERSION (%s). To fix this rerun configure again.") %
+ (header_dir, header_version, nccl_version),
+ )
def _nccl_configure_impl(repository_ctx):
- """Implementation of the nccl_configure repository rule."""
- if _TF_NCCL_VERSION not in repository_ctx.os.environ:
- # Add a dummy build file to make bazel query happy.
- repository_ctx.file("BUILD", _NCCL_DUMMY_BUILD_CONTENT)
- return
-
- if _TF_NCCL_CONFIG_REPO in repository_ctx.os.environ:
- # Forward to the pre-configured remote repository.
- repository_ctx.template("BUILD", _NCCL_REMOTE_BUILD_TEMPLATE, {
- "%{target}": repository_ctx.os.environ[_TF_NCCL_CONFIG_REPO],
- })
- return
-
- nccl_version = repository_ctx.os.environ[_TF_NCCL_VERSION].strip()
- if matches_version("1", nccl_version):
- # Alias to GitHub target from @nccl_archive.
- if not matches_version(nccl_version, "1.3"):
- auto_configure_fail(
- "NCCL from GitHub must use version 1.3 (got %s)" % nccl_version)
- repository_ctx.file("BUILD", _NCCL_ARCHIVE_BUILD_CONTENT)
- else:
- # Create target for locally installed NCCL.
- nccl_install_path = repository_ctx.os.environ[_NCCL_INSTALL_PATH].strip()
- nccl_hdr_path = repository_ctx.os.environ[_NCCL_HDR_PATH].strip()
- _check_nccl_version(repository_ctx, nccl_install_path, nccl_hdr_path, nccl_version)
- repository_ctx.template("BUILD", _NCCL_LOCAL_BUILD_TEMPLATE, {
- "%{version}": nccl_version,
- "%{install_path}": nccl_install_path,
- "%{hdr_path}": nccl_hdr_path,
- })
-
+ """Implementation of the nccl_configure repository rule."""
+ if _TF_NCCL_VERSION not in repository_ctx.os.environ:
+ # Add a dummy build file to make bazel query happy.
+ repository_ctx.file("BUILD", _NCCL_DUMMY_BUILD_CONTENT)
+ return
+
+ if _TF_NCCL_CONFIG_REPO in repository_ctx.os.environ:
+ # Forward to the pre-configured remote repository.
+ repository_ctx.template("BUILD", _label("remote.BUILD.tpl"), {
+ "%{target}": repository_ctx.os.environ[_TF_NCCL_CONFIG_REPO],
+ })
+ return
+
+ nccl_version = repository_ctx.os.environ[_TF_NCCL_VERSION].strip()
+ if nccl_version == "":
+ # Alias to open source build from @nccl_archive.
+ repository_ctx.file("BUILD", _NCCL_ARCHIVE_BUILD_CONTENT)
+
+ # TODO(csigg): implement and reuse in cuda_configure.bzl.
+ gpu_architectures = [
+ "sm_" + capability.replace(".", "")
+ for capability in compute_capabilities(repository_ctx)
+ ]
+
+ # Round-about way to make the list unique.
+ gpu_architectures = dict(zip(gpu_architectures, gpu_architectures)).keys()
+ repository_ctx.template("build_defs.bzl", _label("build_defs.bzl.tpl"), {
+ "%{gpu_architectures}": str(gpu_architectures),
+ })
+
+ repository_ctx.symlink(cuda_toolkit_path(repository_ctx), "cuda")
+
+ # Temporary work-around for setups which symlink ptxas to a newer
+ # version. The versions of nvlink and ptxas need to agree, so we find
+ # nvlink next to the real location of ptxas. This is only temporary and
+ # will be removed again soon.
+ nvlink_dir = repository_ctx.path("cuda/bin/ptxas").realpath.dirname
+ repository_ctx.symlink(nvlink_dir.get_child("nvlink"), "nvlink")
+ else:
+ # Create target for locally installed NCCL.
+ nccl_install_path = repository_ctx.os.environ[_NCCL_INSTALL_PATH].strip()
+ nccl_hdr_path = repository_ctx.os.environ[_NCCL_HDR_PATH].strip()
+ _check_nccl_version(repository_ctx, nccl_install_path, nccl_hdr_path, nccl_version)
+ repository_ctx.template("BUILD", _label("system.BUILD.tpl"), {
+ "%{version}": nccl_version,
+ "%{install_path}": nccl_install_path,
+ "%{hdr_path}": nccl_hdr_path,
+ })
nccl_configure = repository_rule(
- implementation=_nccl_configure_impl,
- environ=[
- _NCCL_INSTALL_PATH,
+ implementation = _nccl_configure_impl,
+ environ = [
+ _CUDA_TOOLKIT_PATH,
_NCCL_HDR_PATH,
+ _NCCL_INSTALL_PATH,
_TF_NCCL_VERSION,
+ _TF_CUDA_COMPUTE_CAPABILITIES,
+ _TF_NCCL_CONFIG_REPO,
],
)
"""Detects and configures the NCCL configuration.